From 3b7cc31449803fd94402845ebc403380cceb9295 Mon Sep 17 00:00:00 2001 From: madipo2611 Date: Thu, 14 Aug 2025 12:44:03 +0300 Subject: [PATCH] =?UTF-8?q?v0.0.18.2=20=D0=94=D0=BE=D0=B1=D0=B0=D0=B2?= =?UTF-8?q?=D0=BB=D0=B5=D0=BD=20WSAuthMiddleware?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/http/middleware/auth.go | 29 ++++----------------- internal/http/middleware/cors.go | 17 ++++++------ internal/http/middleware/logging.go | 31 ++++++++++------------ internal/http/middleware/ws_auth.go | 40 +++++++++++++++++++++++++++++ internal/http/server.go | 28 ++++++++++++++++++-- 5 files changed, 94 insertions(+), 51 deletions(-) create mode 100644 internal/http/middleware/ws_auth.go diff --git a/internal/http/middleware/auth.go b/internal/http/middleware/auth.go index 024f1b2..cc8de30 100644 --- a/internal/http/middleware/auth.go +++ b/internal/http/middleware/auth.go @@ -2,7 +2,7 @@ package middleware import ( "context" - "log" + "github.com/gorilla/websocket" "net/http" "strings" "tailly_back_v2/pkg/auth" @@ -11,26 +11,19 @@ import ( const ( authorizationHeader = "Authorization" bearerPrefix = "Bearer " - userIDKey = "userID" + userIDKey = "userID" // Ключ для хранения userID в контексте ) // AuthMiddleware проверяет JWT токен и добавляет userID в контекст func AuthMiddleware(tokenAuth *auth.TokenAuth) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - - log.Printf("Middleware: путь %s", r.URL.Path) - - // Пропускаем WebSocket маршрут - if r.URL.Path == "/ws" { - log.Printf("Middleware: пропускаем /ws") - next.ServeHTTP(w, r) - return - } - if strings.Contains(r.Header.Get("Upgrade"), "websocket") { + if websocket.IsWebSocketUpgrade(r) { next.ServeHTTP(w, r) return } + + // Пропускаем OPTIONS запросы (для CORS) if r.Method == http.MethodOptions { next.ServeHTTP(w, r) return @@ -69,15 +62,3 @@ func AuthMiddleware(tokenAuth *auth.TokenAuth) func(http.Handler) http.Handler { }) } } -func WebSocketMiddleware(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Header.Get("Upgrade") == "websocket" { - // Используем оригинальный ResponseWriter для WebSocket - next.ServeHTTP(w, r) - return - } - // Для обычных HTTP запросов используем наш кастомный writer - rw := &responseWriter{ResponseWriter: w, status: http.StatusOK} - next.ServeHTTP(rw, r) - }) -} diff --git a/internal/http/middleware/cors.go b/internal/http/middleware/cors.go index 09ba12a..1019451 100644 --- a/internal/http/middleware/cors.go +++ b/internal/http/middleware/cors.go @@ -1,31 +1,30 @@ package middleware import ( + "github.com/gorilla/websocket" "net/http" "strings" ) +// CORS middleware настраивает политику кросс-доменных запросов func CORS(allowedOrigins []string) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Особые правила для WebSocket - if r.Header.Get("Upgrade") == "websocket" { - w.Header().Set("Access-Control-Allow-Origin", "*") - w.Header().Set("Access-Control-Allow-Credentials", "true") - w.Header().Set("Access-Control-Allow-Headers", "*") + if websocket.IsWebSocketUpgrade(r) { next.ServeHTTP(w, r) return } - // Стандартная CORS логика для других запросов origin := r.Header.Get("Origin") - if isOriginAllowed(origin, allowedOrigins) { + + if IsOriginAllowed(origin, allowedOrigins) { w.Header().Set("Access-Control-Allow-Origin", origin) w.Header().Set("Access-Control-Allow-Credentials", "true") w.Header().Set("Access-Control-Allow-Headers", "Accept, Content-Type, Content-Length, Accept-Encoding, Authorization, bypass-auth") w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") + } if r.Method == "OPTIONS" { @@ -39,15 +38,17 @@ func CORS(allowedOrigins []string) func(http.Handler) http.Handler { } // isOriginAllowed проверяет разрешен ли домен для CORS -func isOriginAllowed(origin string, allowedOrigins []string) bool { +func IsOriginAllowed(origin string, allowedOrigins []string) bool { if origin == "" { return false } + // Разрешаем все источники в development if len(allowedOrigins) == 1 && allowedOrigins[0] == "*" { return true } + // Точноe сравнение с разрешенными доменами for _, allowed := range allowedOrigins { if strings.EqualFold(origin, allowed) { return true diff --git a/internal/http/middleware/logging.go b/internal/http/middleware/logging.go index 9ea0906..965ffa3 100644 --- a/internal/http/middleware/logging.go +++ b/internal/http/middleware/logging.go @@ -4,6 +4,7 @@ import ( "bufio" "bytes" "errors" + "github.com/gorilla/websocket" "io" "log" "net" @@ -15,13 +16,14 @@ import ( func LoggingMiddleware(logger *log.Logger) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - start := time.Now() - - if r.Header.Get("Upgrade") == "websocket" { + // Полностью пропускаем WebSocket запросы + if websocket.IsWebSocketUpgrade(r) { next.ServeHTTP(w, r) return } + start := time.Now() + // Логируем основные параметры запроса logData := map[string]interface{}{ "method": r.Method, @@ -38,26 +40,18 @@ func LoggingMiddleware(logger *log.Logger) func(http.Handler) http.Handler { if len(bodyBytes) > 0 { logData["body_size"] = len(bodyBytes) - // Для JSON-запросов логируем тело if r.Header.Get("Content-Type") == "application/json" { logData["body"] = string(bodyBytes) } } } - // Перехват ответа + // Создаем responseWriter только для НЕ WebSocket запросов rw := &responseWriter{ResponseWriter: w, status: http.StatusOK} - - // Обработка запроса next.ServeHTTP(rw, r) - // Дополняем данные для логирования + // Логирование только для НЕ WebSocket запросов duration := time.Since(start) - logData["status"] = rw.status - logData["duration"] = duration.String() - logData["response_size"] = rw.size - - // Форматированный вывод лога logger.Printf( "%s %s %d %s | IP: %s | Duration: %s | Body: %d bytes", r.Method, @@ -72,7 +66,6 @@ func LoggingMiddleware(logger *log.Logger) func(http.Handler) http.Handler { } } -// Кастомный responseWriter для перехвата статуса и размера ответа type responseWriter struct { http.ResponseWriter status int @@ -80,17 +73,21 @@ type responseWriter struct { } func (rw *responseWriter) WriteHeader(code int) { - rw.status = code - rw.ResponseWriter.WriteHeader(code) + if rw.status == 0 { // Защита от двойного вызова WriteHeader + rw.status = code + rw.ResponseWriter.WriteHeader(code) + } } func (rw *responseWriter) Write(b []byte) (int, error) { + if rw.status == 0 { + rw.WriteHeader(http.StatusOK) + } size, err := rw.ResponseWriter.Write(b) rw.size += size return size, err } -// Добавляем поддержку Hijacker func (rw *responseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { if hijacker, ok := rw.ResponseWriter.(http.Hijacker); ok { return hijacker.Hijack() diff --git a/internal/http/middleware/ws_auth.go b/internal/http/middleware/ws_auth.go new file mode 100644 index 0000000..4926a82 --- /dev/null +++ b/internal/http/middleware/ws_auth.go @@ -0,0 +1,40 @@ +package middleware + +import ( + "context" + "net/http" + "tailly_back_v2/pkg/auth" +) + +// WSAuthMiddleware проверяет JWT токен для WebSocket соединений +func WSAuthMiddleware(tokenAuth *auth.TokenAuth) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Проверяем, что это WebSocket запрос + if r.Header.Get("Upgrade") == "websocket" { + // Извлекаем токен из query параметров или заголовков + token := extractTokenFromRequest(r) + + if token != "" { + userID, err := tokenAuth.ValidateAccessToken(token) + if err == nil { + // Добавляем userID в контекст + ctx := context.WithValue(r.Context(), userIDKey, userID) + r = r.WithContext(ctx) + } + } + } + + next.ServeHTTP(w, r) + }) + } +} + +func extractTokenFromRequest(r *http.Request) string { + // Только проверка кук (как в вашем коде) + cookie, err := r.Cookie("accessToken") + if err == nil { + return cookie.Value + } + return "" +} diff --git a/internal/http/server.go b/internal/http/server.go index b81e763..943548e 100644 --- a/internal/http/server.go +++ b/internal/http/server.go @@ -4,8 +4,10 @@ import ( "context" "database/sql" "github.com/99designs/gqlgen/graphql/handler" + "github.com/99designs/gqlgen/graphql/handler/transport" "github.com/99designs/gqlgen/graphql/playground" "github.com/go-chi/chi/v5" + "github.com/gorilla/websocket" "github.com/prometheus/client_golang/prometheus/promhttp" "log" "net/http" @@ -49,13 +51,20 @@ func (s *Server) configureRouter() { allowedOrigins := []string{ "http://localhost:3000", "https://tailly.ru", + "http://tailly.ru", + "ws://tailly.ru", + "wss://tailly.ru", + "ws://localhost:3000", + "http://localhost:3006", } logger := log.New(os.Stdout, "HTTP: ", log.LstdFlags) - s.router.Use(middleware.WebSocketMiddleware) + s.router.Use(middleware.LoggingMiddleware(logger)) s.router.Use(middleware.MetricsMiddleware) s.router.Use(middleware.CORS(allowedOrigins)) + + s.router.Use(middleware.WSAuthMiddleware(s.tokenAuth)) s.router.Use(middleware.AuthMiddleware(s.tokenAuth)) resolver := graph.NewResolver(s.services, s.db, s.services.Messages) @@ -63,10 +72,25 @@ func (s *Server) configureRouter() { Resolvers: resolver, })) + wsTransport := transport.Websocket{ + Upgrader: websocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { + origin := r.Header.Get("Origin") + for _, allowed := range allowedOrigins { + if origin == allowed { + return true + } + } + return false + }, + }, + } + + srv.AddTransport(&wsTransport) + s.router.Handle("/", playground.Handler("GraphQL playground", "/query")) s.router.Handle("/query", srv) s.router.Handle("/uploads/*", http.StripPrefix("/uploads/", http.FileServer(http.Dir("./uploads")))) - } func (s *Server) configureMetrics() {