diff --git a/internal/http/handlers/chat.go b/internal/http/handlers/chat.go index a5a5395..8857526 100644 --- a/internal/http/handlers/chat.go +++ b/internal/http/handlers/chat.go @@ -2,10 +2,9 @@ package handlers import ( "context" - "encoding/json" - "fmt" "log" "net/http" + "strings" "tailly_back_v2/internal/domain" "tailly_back_v2/internal/service" "tailly_back_v2/internal/ws" @@ -41,252 +40,165 @@ func NewChatHandler(chatService service.ChatService, hub *ws.Hub, tokenAuth *aut func (h *ChatHandler) HandleWebSocket(w http.ResponseWriter, r *http.Request) { log.Printf("Incoming WebSocket headers: %+v", r.Header) log.Printf("Cookies: %+v", r.Cookies()) + requestedProtocol := r.Header.Get("Sec-WebSocket-Protocol") + if requestedProtocol != "" && requestedProtocol != "graphql-transport-ws" { + http.Error(w, "Unsupported WebSocket protocol", http.StatusBadRequest) + return + } + log.Printf("Requested protocols: %v", r.Header["Sec-WebSocket-Protocol"]) + // 1. Проверяем куки + var token string + cookie, err := r.Cookie("accessToken") + if err == nil { + token = cookie.Value + log.Printf("WebSocket: токен из куки: %s", token) + } - conn, err := upgrader.Upgrade(w, r, nil) - if err != nil { - log.Printf("WebSocket upgrade error: %v", err) + // Из заголовка Authorization + if authHeader := r.Header.Get("Authorization"); authHeader != "" { + token = strings.TrimPrefix(authHeader, "Bearer ") + } + + // Из параметра URL + if token == "" { + token = r.URL.Query().Get("token") + } + + // Из куков + if token == "" { + if cookie, err := r.Cookie("accessToken"); err == nil { + token = cookie.Value + } + } + + if token == "" { + log.Println("WebSocket: токен не найден") + http.Error(w, "Token is required", http.StatusUnauthorized) + return + } + + // Валидация токена + userID, err := h.tokenAuth.ValidateAccessToken(token) + if err != nil { + log.Printf("WebSocket: ошибка валидации токена: %v", err) + http.Error(w, "Invalid token", http.StatusUnauthorized) + return + } + + log.Printf("WebSocket: успешная авторизация, userID=%d", userID) + + // 5. Обновление соединения + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + log.Printf("WebSocket upgrade error: %v", err) + http.Error(w, "Could not upgrade to WebSocket", http.StatusBadRequest) return } - // Создаем клиента без userID (он будет установлен при аутентификации) client := &ws.Client{ - Send: make(chan *domain.Message, 256), - LastSeen: time.Now(), - CloseChan: make(chan bool, 1), + UserID: userID, + Send: make(chan *domain.Message, 256), } h.hub.RegisterClient(client) + // Добавляем контекст для управления жизненным циклом соединения ctx, cancel := context.WithCancel(context.Background()) defer cancel() - // Горутина для чтения (теперь включает аутентификацию) - go h.readPump(ctx, conn, client) + // Запускаем горутины с обработкой контекста + go h.readPump(ctx, conn, client, userID) + go h.writePump(ctx, conn, client) - // Горутина для записи с обработкой закрытия - go func() { - defer conn.Close() - defer h.hub.UnregisterClient(client) - - for { - select { - case <-ctx.Done(): - return - case <-client.CloseChan: - return - case message, ok := <-client.Send: - if !ok { - return - } - - // Добавляем таймаут на запись - conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) - if err := conn.WriteJSON(message); err != nil { - log.Printf("WebSocket write error: %v", err) - return - } - conn.SetWriteDeadline(time.Time{}) // Сбрасываем таймаут - } - } - }() + // Ждем завершения + <-ctx.Done() } -func (h *ChatHandler) readPump(ctx context.Context, conn *websocket.Conn, client *ws.Client) { +func (h *ChatHandler) readPump(ctx context.Context, conn *websocket.Conn, client *ws.Client, userID int) { ticker := time.NewTicker(30 * time.Second) defer ticker.Stop() - defer conn.Close() - // 1. Аутентификация - if err := h.authenticateConnection(conn, client); err != nil { - log.Printf("Authentication failed: %v", err) - return - } - - // 2. Основной цикл обработки сообщений for { select { case <-ticker.C: - if err := h.sendPing(conn); err != nil { + // Отправляем ping + if err := conn.WriteJSON(map[string]string{"type": "ping"}); err != nil { + log.Printf("Ping error: %v", err) return } case <-ctx.Done(): return default: - if err := h.handleMessage(ctx, conn, client); err != nil { + var msg struct { + Type string `json:"type"` + Payload struct { + ReceiverID int `json:"receiverId"` + Content string `json:"content"` + } `json:"payload"` + } + + if err := conn.ReadJSON(&msg); err != nil { if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { log.Printf("WebSocket error: %v", err) } return } - } - } -} -func (h *ChatHandler) authenticateConnection(conn *websocket.Conn, client *ws.Client) error { - var authMsg struct { - Type string `json:"type"` - Token string `json:"token"` - } - - // Устанавливаем таймаут для аутентификации - conn.SetReadDeadline(time.Now().Add(10 * time.Second)) - defer conn.SetReadDeadline(time.Time{}) - - if err := conn.ReadJSON(&authMsg); err != nil { - conn.WriteJSON(map[string]string{"type": "error", "message": "Auth required"}) - return fmt.Errorf("failed to read auth message: %v", err) - } - - if authMsg.Type != "auth" || authMsg.Token == "" { - conn.WriteJSON(map[string]string{"type": "error", "message": "Invalid auth message"}) - return fmt.Errorf("first message must be auth with token") - } - - userID, err := h.tokenAuth.ValidateAccessToken(authMsg.Token) - if err != nil { - conn.WriteJSON(map[string]string{"type": "error", "message": "Invalid token"}) - return fmt.Errorf("token validation error: %v", err) - } - - client.UserID = userID - log.Printf("WebSocket authenticated, userID=%d", userID) - - // Отправляем подтверждение аутентификации - if err := conn.WriteJSON(map[string]interface{}{ - "type": "auth_success", - "user": map[string]interface{}{ - "id": userID, - }, - }); err != nil { - return fmt.Errorf("failed to send auth confirmation: %v", err) - } - - return nil -} - -func (h *ChatHandler) sendPing(conn *websocket.Conn) error { - if err := conn.WriteJSON(map[string]string{"type": "ping"}); err != nil { - log.Printf("Ping error: %v", err) - return err - } - return nil -} - -func (h *ChatHandler) handleMessage(ctx context.Context, conn *websocket.Conn, client *ws.Client) error { - var msg struct { - Type string `json:"type"` - Payload json.RawMessage `json:"payload"` // Используем RawMessage для гибкости - } - - if err := conn.ReadJSON(&msg); err != nil { - return err - } - - switch msg.Type { - case "pong": - return nil - case "ping": - return conn.WriteJSON(map[string]string{"type": "pong"}) - case "message": - return h.handleChatMessage(ctx, conn, client, msg.Payload) - default: - log.Printf("Unknown message type: %s", msg.Type) - return conn.WriteJSON(map[string]string{ - "type": "error", - "message": "Unknown message type: " + msg.Type, - }) - } -} - -func (h *ChatHandler) handleChatMessage(ctx context.Context, conn *websocket.Conn, client *ws.Client, payload json.RawMessage) error { - if client.UserID == 0 { - return conn.WriteJSON(map[string]string{ - "type": "error", - "message": "Not authenticated", - }) - } - - var messageData struct { - ReceiverID int `json:"receiverId"` - ChatID int `json:"chatId"` - Content string `json:"content"` - } - - if err := json.Unmarshal(payload, &messageData); err != nil { - return conn.WriteJSON(map[string]interface{}{ - "type": "error", - "message": "Invalid message format", - "details": err.Error(), - }) - } - - // Валидация данных сообщения - if messageData.ReceiverID == 0 && messageData.ChatID == 0 { - return conn.WriteJSON(map[string]interface{}{ - "type": "error", - "message": "Either receiverId or chatId must be provided", - }) - } - - if messageData.Content == "" { - return conn.WriteJSON(map[string]interface{}{ - "type": "error", - "message": "Message content cannot be empty", - }) - } - - // Определяем chatId если не указан - chatID := messageData.ChatID - if chatID == 0 { - chat, err := h.chatService.GetOrCreateChat(ctx, client.UserID, messageData.ReceiverID) - if err != nil { - log.Printf("Chat error: %v", err) - return conn.WriteJSON(map[string]interface{}{ - "type": "error", - "message": "Failed to get or create chat", - "details": err.Error(), - }) - } - chatID = chat.ID - } - - // Отправляем сообщение - message, err := h.chatService.SendMessage( - ctx, - client.UserID, - chatID, - messageData.Content, - ) - if err != nil { - log.Printf("Message send error: %v", err) - return conn.WriteJSON(map[string]interface{}{ - "type": "error", - "message": "Failed to send message", - "details": err.Error(), - }) - } - - // Убедимся, что в сообщении есть получатель - if message.ReceiverID == 0 { - if messageData.ReceiverID != 0 { - message.ReceiverID = messageData.ReceiverID - } else { - // Если receiver не указан, определяем его через чат - chat, err := h.chatService.GetChatByID(ctx, chatID) - if err == nil { - if chat.User1ID == client.UserID { - message.ReceiverID = chat.User2ID - } else { - message.ReceiverID = chat.User1ID - } + if msg.Type == "pong" { + continue } + + // Обработка ping/pong + if msg.Type == "ping" { + conn.WriteJSON(map[string]string{"type": "pong"}) + continue + } + + if msg.Type != "message" { + continue + } + + // Проверяем receiverId + if msg.Payload.ReceiverID == 0 { + log.Printf("Invalid receiverId: 0") + conn.WriteJSON(map[string]interface{}{ + "type": "error", + "message": "Invalid receiver ID", + }) + continue + } + + // Логирование для отладки + log.Printf("Received message from %d to %d", userID, msg.Payload.ReceiverID) + + // Создаем или находим чат + chat, err := h.chatService.GetOrCreateChat(ctx, userID, msg.Payload.ReceiverID) + if err != nil { + log.Printf("Chat error: %v", err) + conn.WriteJSON(map[string]interface{}{ + "type": "error", + "message": "Chat error", + "details": err.Error(), + }) + continue + } + + // Отправляем сообщение + message, err := h.chatService.SendMessage( + ctx, + userID, + chat.ID, + msg.Payload.Content, + ) + if err != nil { + log.Printf("Message send error: %v", err) + continue + } + + // Рассылаем сообщение + h.hub.Broadcast(message) } } - - // Рассылаем сообщение всем подписчикам - h.hub.Broadcast(message) - - return nil } func (h *ChatHandler) writePump(ctx context.Context, conn *websocket.Conn, client *ws.Client) {