diff --git a/internal/http/handlers/chat.go b/internal/http/handlers/chat.go index 8857526..309fa10 100644 --- a/internal/http/handlers/chat.go +++ b/internal/http/handlers/chat.go @@ -4,7 +4,6 @@ import ( "context" "log" "net/http" - "strings" "tailly_back_v2/internal/domain" "tailly_back_v2/internal/service" "tailly_back_v2/internal/ws" @@ -40,54 +39,8 @@ func NewChatHandler(chatService service.ChatService, hub *ws.Hub, tokenAuth *aut func (h *ChatHandler) HandleWebSocket(w http.ResponseWriter, r *http.Request) { log.Printf("Incoming WebSocket headers: %+v", r.Header) log.Printf("Cookies: %+v", r.Cookies()) - requestedProtocol := r.Header.Get("Sec-WebSocket-Protocol") - if requestedProtocol != "" && requestedProtocol != "graphql-transport-ws" { - http.Error(w, "Unsupported WebSocket protocol", http.StatusBadRequest) - return - } - log.Printf("Requested protocols: %v", r.Header["Sec-WebSocket-Protocol"]) - // 1. Проверяем куки - var token string - cookie, err := r.Cookie("accessToken") - if err == nil { - token = cookie.Value - log.Printf("WebSocket: токен из куки: %s", token) - } - // Из заголовка Authorization - if authHeader := r.Header.Get("Authorization"); authHeader != "" { - token = strings.TrimPrefix(authHeader, "Bearer ") - } - - // Из параметра URL - if token == "" { - token = r.URL.Query().Get("token") - } - - // Из куков - if token == "" { - if cookie, err := r.Cookie("accessToken"); err == nil { - token = cookie.Value - } - } - - if token == "" { - log.Println("WebSocket: токен не найден") - http.Error(w, "Token is required", http.StatusUnauthorized) - return - } - - // Валидация токена - userID, err := h.tokenAuth.ValidateAccessToken(token) - if err != nil { - log.Printf("WebSocket: ошибка валидации токена: %v", err) - http.Error(w, "Invalid token", http.StatusUnauthorized) - return - } - - log.Printf("WebSocket: успешная авторизация, userID=%d", userID) - - // 5. Обновление соединения + // Убираем все проверки токена здесь, так как будем проверять его в первом сообщении conn, err := upgrader.Upgrade(w, r, nil) if err != nil { log.Printf("WebSocket upgrade error: %v", err) @@ -96,32 +49,64 @@ func (h *ChatHandler) HandleWebSocket(w http.ResponseWriter, r *http.Request) { } client := &ws.Client{ - UserID: userID, + UserID: 0, // Пока не авторизован Send: make(chan *domain.Message, 256), } h.hub.RegisterClient(client) - // Добавляем контекст для управления жизненным циклом соединения ctx, cancel := context.WithCancel(context.Background()) defer cancel() - // Запускаем горутины с обработкой контекста - go h.readPump(ctx, conn, client, userID) + go h.readPump(ctx, conn, client) go h.writePump(ctx, conn, client) - // Ждем завершения <-ctx.Done() } -func (h *ChatHandler) readPump(ctx context.Context, conn *websocket.Conn, client *ws.Client, userID int) { +func (h *ChatHandler) readPump(ctx context.Context, conn *websocket.Conn, client *ws.Client) { ticker := time.NewTicker(30 * time.Second) defer ticker.Stop() + defer conn.Close() + // 1. Ожидаем первое сообщение с токеном + var authMsg struct { + Type string `json:"type"` + Token string `json:"token"` + } + + if err := conn.ReadJSON(&authMsg); err != nil { + log.Printf("Failed to read auth message: %v", err) + return + } + + if authMsg.Type != "auth" || authMsg.Token == "" { + log.Println("First message must be auth with token") + conn.WriteJSON(map[string]string{"type": "error", "message": "Auth required"}) + return + } + + // 2. Валидируем токен + userID, err := h.tokenAuth.ValidateAccessToken(authMsg.Token) + if err != nil { + log.Printf("Token validation error: %v", err) + conn.WriteJSON(map[string]string{"type": "error", "message": "Invalid token"}) + return + } + + client.UserID = userID + log.Printf("WebSocket authenticated, userID=%d", userID) + + // 3. Отправляем подтверждение авторизации + if err := conn.WriteJSON(map[string]string{"type": "auth_success"}); err != nil { + log.Printf("Failed to send auth confirmation: %v", err) + return + } + + // 4. Основной цикл обработки сообщений for { select { case <-ticker.C: - // Отправляем ping if err := conn.WriteJSON(map[string]string{"type": "ping"}); err != nil { log.Printf("Ping error: %v", err) return @@ -144,59 +129,52 @@ func (h *ChatHandler) readPump(ctx context.Context, conn *websocket.Conn, client return } - if msg.Type == "pong" { + switch msg.Type { + case "pong": continue - } - - // Обработка ping/pong - if msg.Type == "ping" { + case "ping": conn.WriteJSON(map[string]string{"type": "pong"}) continue + case "message": + if client.UserID == 0 { + conn.WriteJSON(map[string]string{"type": "error", "message": "Not authenticated"}) + continue + } + + if msg.Payload.ReceiverID == 0 { + conn.WriteJSON(map[string]interface{}{ + "type": "error", + "message": "Invalid receiver ID", + }) + continue + } + + chat, err := h.chatService.GetOrCreateChat(ctx, client.UserID, msg.Payload.ReceiverID) + if err != nil { + log.Printf("Chat error: %v", err) + conn.WriteJSON(map[string]interface{}{ + "type": "error", + "message": "Chat error", + "details": err.Error(), + }) + continue + } + + message, err := h.chatService.SendMessage( + ctx, + client.UserID, + chat.ID, + msg.Payload.Content, + ) + if err != nil { + log.Printf("Message send error: %v", err) + continue + } + + h.hub.Broadcast(message) + default: + conn.WriteJSON(map[string]string{"type": "error", "message": "Unknown message type"}) } - - if msg.Type != "message" { - continue - } - - // Проверяем receiverId - if msg.Payload.ReceiverID == 0 { - log.Printf("Invalid receiverId: 0") - conn.WriteJSON(map[string]interface{}{ - "type": "error", - "message": "Invalid receiver ID", - }) - continue - } - - // Логирование для отладки - log.Printf("Received message from %d to %d", userID, msg.Payload.ReceiverID) - - // Создаем или находим чат - chat, err := h.chatService.GetOrCreateChat(ctx, userID, msg.Payload.ReceiverID) - if err != nil { - log.Printf("Chat error: %v", err) - conn.WriteJSON(map[string]interface{}{ - "type": "error", - "message": "Chat error", - "details": err.Error(), - }) - continue - } - - // Отправляем сообщение - message, err := h.chatService.SendMessage( - ctx, - userID, - chat.ID, - msg.Payload.Content, - ) - if err != nil { - log.Printf("Message send error: %v", err) - continue - } - - // Рассылаем сообщение - h.hub.Broadcast(message) } } }