diff --git a/internal/http/graph/message_resolvers.go b/internal/http/graph/message_resolvers.go index 9208acd..9953c60 100644 --- a/internal/http/graph/message_resolvers.go +++ b/internal/http/graph/message_resolvers.go @@ -109,11 +109,31 @@ func (r *messageResolver) Sender(ctx context.Context, obj *domain.Message) (*dom // Receiver - возвращает получателя сообщения func (r *messageResolver) Receiver(ctx context.Context, obj *domain.Message) (*domain.User, error) { - chat, err := r.chatRepo.GetChatByID(ctx, obj.ChatID) - if err != nil { - return nil, fmt.Errorf("ошибка получения чата: %v", err) + // 1. Проверка на nil + if obj == nil { + return nil, fmt.Errorf("message is nil") } + // 2. Если есть receiver_id - используем его напрямую + if obj.ReceiverID != 0 { + user, err := r.Services.User.GetByID(ctx, obj.ReceiverID) + if err != nil { + return nil, fmt.Errorf("failed to get receiver by ID: %v", err) + } + return user, nil + } + + // 3. Альтернативный вариант через chat_id (если receiver_id не установлен) + if obj.ChatID == 0 { + return nil, fmt.Errorf("both receiver_id and chat_id are not set") + } + + chat, err := r.chatRepo.GetChatByID(ctx, obj.ChatID) + if err != nil { + return nil, fmt.Errorf("failed to get chat: %v", err) + } + + // Определяем ID получателя receiverID := chat.User1ID if obj.SenderID == chat.User1ID { receiverID = chat.User2ID @@ -121,8 +141,9 @@ func (r *messageResolver) Receiver(ctx context.Context, obj *domain.Message) (*d user, err := r.Services.User.GetByID(ctx, receiverID) if err != nil { - return nil, fmt.Errorf("ошибка получения получателя: %v", err) + return nil, fmt.Errorf("failed to get receiver user: %v", err) } + return user, nil } @@ -153,30 +174,14 @@ func (r *mutationResolver) SendMessage(ctx context.Context, receiverID int, cont // MarkAsRead - помечает сообщение как прочитанное func (r *mutationResolver) MarkAsRead(ctx context.Context, messageID int) (bool, error) { - userID, err := getUserIDFromContext(ctx) - if err != nil { - return false, errors.New("не авторизован") + if r.Services == nil || r.Services.Chat == nil { + return false, errors.New("сервис чатов не инициализирован") } - // Получаем сообщение напрямую из репозитория - message, err := r.chatRepo.GetMessageByID(ctx, messageID) + // Все операции через сервис + err := r.Services.Chat.MarkAsRead(ctx, messageID) if err != nil { - return false, fmt.Errorf("ошибка получения сообщения: %v", err) - } - - // Проверяем доступ к сообщению - chat, err := r.chatRepo.GetChatByID(ctx, message.ChatID) - if err != nil { - return false, fmt.Errorf("ошибка получения чата: %v", err) - } - - if userID != chat.User1ID && userID != chat.User2ID { - return false, errors.New("нет доступа к сообщению") - } - - err = r.Services.Chat.MarkAsRead(ctx, messageID) - if err != nil { - return false, fmt.Errorf("ошибка обновления статуса: %v", err) + return false, fmt.Errorf("ошибка отметки как прочитанного: %v", err) } return true, nil diff --git a/internal/http/handlers/chat.go b/internal/http/handlers/chat.go index cc5d722..f36ab09 100644 --- a/internal/http/handlers/chat.go +++ b/internal/http/handlers/chat.go @@ -2,7 +2,9 @@ package handlers import ( "context" + "log" "net/http" + "strings" "tailly_back_v2/internal/domain" "tailly_back_v2/internal/service" "tailly_back_v2/internal/ws" @@ -15,8 +17,9 @@ var upgrader = websocket.Upgrader{ ReadBufferSize: 1024, WriteBufferSize: 1024, CheckOrigin: func(r *http.Request) bool { - return true // В production заменить на проверку origin + return true }, + Subprotocols: []string{"graphql-transport-ws"}, } type ChatHandler struct { @@ -34,22 +37,51 @@ func NewChatHandler(chatService service.ChatService, hub *ws.Hub, tokenAuth *aut } func (h *ChatHandler) HandleWebSocket(w http.ResponseWriter, r *http.Request) { - // Аутентификация пользователя - token := r.URL.Query().Get("token") + requestedProtocol := r.Header.Get("Sec-WebSocket-Protocol") + if requestedProtocol != "" && requestedProtocol != "graphql-transport-ws" { + http.Error(w, "Unsupported WebSocket protocol", http.StatusBadRequest) + return + } + log.Printf("Requested protocols: %v", r.Header["Sec-WebSocket-Protocol"]) + // 1. Проверяем куки + var token string + cookie, err := r.Cookie("accessToken") + if err == nil { + token = cookie.Value + log.Printf("WebSocket: токен из куки: %s", token) + } + + // 2. Если нет в куках, проверяем заголовок Authorization if token == "" { + authHeader := r.Header.Get("Authorization") + if authHeader != "" { + token = strings.TrimPrefix(authHeader, "Bearer ") + log.Printf("WebSocket: токен из заголовка: %s", token) + } + } + + // 3. Если токен не найден - возвращаем 401 + if token == "" { + log.Println("WebSocket: токен не найден ни в куках, ни в заголовках") http.Error(w, "Token is required", http.StatusUnauthorized) return } + // 4. Валидация токена userID, err := h.tokenAuth.ValidateAccessToken(token) if err != nil { + log.Printf("WebSocket: ошибка валидации токена: %v", err) http.Error(w, "Invalid token", http.StatusUnauthorized) return } + log.Printf("WebSocket: успешная авторизация, userID=%d", userID) + + // 5. Обновление соединения conn, err := upgrader.Upgrade(w, r, nil) if err != nil { - http.Error(w, "Could not open websocket connection", http.StatusBadRequest) + log.Printf("WebSocket upgrade error: %v", err) + http.Error(w, "Could not upgrade to WebSocket", http.StatusBadRequest) return } @@ -60,49 +92,109 @@ func (h *ChatHandler) HandleWebSocket(w http.ResponseWriter, r *http.Request) { h.hub.RegisterClient(client) - // Горутина для чтения сообщений - go h.readPump(conn, client, userID) - // Горутина для записи сообщений - go h.writePump(conn, client) + // Добавляем контекст для управления жизненным циклом соединения + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Запускаем горутины с обработкой контекста + go h.readPump(ctx, conn, client, userID) + go h.writePump(ctx, conn, client) + + // Ждем завершения + <-ctx.Done() } -func (h *ChatHandler) readPump(conn *websocket.Conn, client *ws.Client, userID int) { +func (h *ChatHandler) readPump(ctx context.Context, conn *websocket.Conn, client *ws.Client, userID int) { defer func() { h.hub.UnregisterClient(client) conn.Close() }() for { - var msg struct { - ChatID int `json:"chatId"` - Content string `json:"content"` - } + select { + case <-ctx.Done(): + return + default: + var msg struct { + Type string `json:"type"` + Payload struct { + ReceiverID int `json:"receiverId"` + Content string `json:"content"` + } `json:"payload"` + } - if err := conn.ReadJSON(&msg); err != nil { - break - } + if err := conn.ReadJSON(&msg); err != nil { + if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { + log.Printf("WebSocket read error: %v", err) + } + return + } - // Используем context.Background() вместо r.Context() - message, err := h.chatService.SendMessage( - context.Background(), - userID, - msg.ChatID, - msg.Content, - ) - if err != nil { - continue - } + if msg.Type != "message" { + continue + } - h.hub.Broadcast(message) + // Проверяем receiverId + if msg.Payload.ReceiverID == 0 { + log.Printf("Invalid receiverId: 0") + conn.WriteJSON(map[string]interface{}{ + "type": "error", + "message": "Invalid receiver ID", + }) + continue + } + + // Логирование для отладки + log.Printf("Received message from %d to %d", userID, msg.Payload.ReceiverID) + + // Создаем или находим чат + chat, err := h.chatService.GetOrCreateChat(ctx, userID, msg.Payload.ReceiverID) + if err != nil { + log.Printf("Chat error: %v", err) + conn.WriteJSON(map[string]interface{}{ + "type": "error", + "message": "Chat error", + "details": err.Error(), + }) + continue + } + + // Отправляем сообщение + message, err := h.chatService.SendMessage( + ctx, + userID, + chat.ID, + msg.Payload.Content, + ) + if err != nil { + log.Printf("Message send error: %v", err) + continue + } + + // Рассылаем сообщение + h.hub.Broadcast(message) + } } } -func (h *ChatHandler) writePump(conn *websocket.Conn, client *ws.Client) { +func (h *ChatHandler) writePump(ctx context.Context, conn *websocket.Conn, client *ws.Client) { defer conn.Close() - for message := range client.Send { - if err := conn.WriteJSON(message); err != nil { - break + for { + select { + case <-ctx.Done(): + return + case message, ok := <-client.Send: + if !ok { + // Канал закрыт + conn.WriteMessage(websocket.CloseMessage, []byte{}) + return + } + + if err := conn.WriteJSON(message); err != nil { + log.Printf("WebSocket write error: %v", err) + return + } } } } diff --git a/internal/http/middleware/auth.go b/internal/http/middleware/auth.go index b8ca21d..024f1b2 100644 --- a/internal/http/middleware/auth.go +++ b/internal/http/middleware/auth.go @@ -2,6 +2,7 @@ package middleware import ( "context" + "log" "net/http" "strings" "tailly_back_v2/pkg/auth" @@ -10,14 +11,26 @@ import ( const ( authorizationHeader = "Authorization" bearerPrefix = "Bearer " - userIDKey = "userID" // Ключ для хранения userID в контексте + userIDKey = "userID" ) // AuthMiddleware проверяет JWT токен и добавляет userID в контекст func AuthMiddleware(tokenAuth *auth.TokenAuth) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Пропускаем OPTIONS запросы (для CORS) + + log.Printf("Middleware: путь %s", r.URL.Path) + + // Пропускаем WebSocket маршрут + if r.URL.Path == "/ws" { + log.Printf("Middleware: пропускаем /ws") + next.ServeHTTP(w, r) + return + } + if strings.Contains(r.Header.Get("Upgrade"), "websocket") { + next.ServeHTTP(w, r) + return + } if r.Method == http.MethodOptions { next.ServeHTTP(w, r) return @@ -56,3 +69,15 @@ func AuthMiddleware(tokenAuth *auth.TokenAuth) func(http.Handler) http.Handler { }) } } +func WebSocketMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Upgrade") == "websocket" { + // Используем оригинальный ResponseWriter для WebSocket + next.ServeHTTP(w, r) + return + } + // Для обычных HTTP запросов используем наш кастомный writer + rw := &responseWriter{ResponseWriter: w, status: http.StatusOK} + next.ServeHTTP(rw, r) + }) +} diff --git a/internal/http/middleware/cors.go b/internal/http/middleware/cors.go index 1e5497f..804ee76 100644 --- a/internal/http/middleware/cors.go +++ b/internal/http/middleware/cors.go @@ -5,7 +5,6 @@ import ( "strings" ) -// CORS middleware настраивает политику кросс-доменных запросов func CORS(allowedOrigins []string) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -41,12 +40,10 @@ func isOriginAllowed(origin string, allowedOrigins []string) bool { return false } - // Разрешаем все источники в development if len(allowedOrigins) == 1 && allowedOrigins[0] == "*" { return true } - // Точноe сравнение с разрешенными доменами for _, allowed := range allowedOrigins { if strings.EqualFold(origin, allowed) { return true diff --git a/internal/http/middleware/logging.go b/internal/http/middleware/logging.go index 54d9d0f..9ea0906 100644 --- a/internal/http/middleware/logging.go +++ b/internal/http/middleware/logging.go @@ -1,9 +1,12 @@ package middleware import ( + "bufio" "bytes" + "errors" "io" "log" + "net" "net/http" "time" ) @@ -14,6 +17,11 @@ func LoggingMiddleware(logger *log.Logger) func(http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { start := time.Now() + if r.Header.Get("Upgrade") == "websocket" { + next.ServeHTTP(w, r) + return + } + // Логируем основные параметры запроса logData := map[string]interface{}{ "method": r.Method, @@ -81,3 +89,11 @@ func (rw *responseWriter) Write(b []byte) (int, error) { rw.size += size return size, err } + +// Добавляем поддержку Hijacker +func (rw *responseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + if hijacker, ok := rw.ResponseWriter.(http.Hijacker); ok { + return hijacker.Hijack() + } + return nil, nil, errors.New("response writer does not support hijacking") +} diff --git a/internal/http/server.go b/internal/http/server.go index 7f1eca5..d71ccc1 100644 --- a/internal/http/server.go +++ b/internal/http/server.go @@ -25,14 +25,14 @@ type Server struct { cfg *config.Config services *service.Services tokenAuth *auth.TokenAuth - db *sql.DB // Добавляем подключение к БД + db *sql.DB } func NewServer( cfg *config.Config, services *service.Services, tokenAuth *auth.TokenAuth, - db *sql.DB, // Добавляем параметр БД + db *sql.DB, ) *Server { s := &Server{ router: chi.NewRouter(), @@ -50,8 +50,8 @@ func NewServer( func (s *Server) configureRouter() { allowedOrigins := []string{ - "http://localhost:3000", // React dev server - "https://tailly.ru", // Продакшен домен + "http://localhost:3000", + "https://tailly.ru", } // Инициализация WebSocket хаба hub := ws.NewHub() @@ -64,24 +64,20 @@ func (s *Server) configureRouter() { hub, ) - // Обновляем сервис чата в services s.services.Chat = chatService - // Добавляем обработчик WebSocket - - // Логирование logger := log.New(os.Stdout, "HTTP: ", log.LstdFlags) + s.router.Use(middleware.WebSocketMiddleware) s.router.Use(middleware.LoggingMiddleware(logger)) s.router.Use(middleware.MetricsMiddleware) s.router.Use(middleware.CORS(allowedOrigins)) s.router.Use(middleware.AuthMiddleware(s.tokenAuth)) - // GraphQL handler - resolver := graph.NewResolver(s.services, s.db) // Теперь передаем оба аргумента + + resolver := graph.NewResolver(s.services, s.db) srv := handler.NewDefaultServer(graph.NewExecutableSchema(graph.Config{ Resolvers: resolver, })) - // Routes s.router.Handle("/", playground.Handler("GraphQL playground", "/query")) s.router.Handle("/query", srv) s.router.Handle("/uploads/*", http.StripPrefix("/uploads/", http.FileServer(http.Dir("./uploads")))) @@ -105,6 +101,5 @@ func (s *Server) Run() error { } func (s *Server) Shutdown(ctx context.Context) error { - // Здесь можно добавить логику graceful shutdown return nil } diff --git a/internal/repository/chat_repository.go b/internal/repository/chat_repository.go index e1caffd..e5aa667 100644 --- a/internal/repository/chat_repository.go +++ b/internal/repository/chat_repository.go @@ -134,6 +134,14 @@ func (r *chatRepository) CreateChat(ctx context.Context, user1ID, user2ID int) ( return nil, errors.New("cannot create chat with yourself") } + // Проверяем существование пользователей + if err := r.checkUserExists(ctx, user1ID); err != nil { + return nil, fmt.Errorf("user1 does not exist: %v", err) + } + if err := r.checkUserExists(ctx, user2ID); err != nil { + return nil, fmt.Errorf("user2 does not exist: %v", err) + } + // Упорядочиваем ID пользователей согласно CHECK constraint if user1ID > user2ID { user1ID, user2ID = user2ID, user1ID @@ -156,6 +164,18 @@ func (r *chatRepository) CreateChat(ctx context.Context, user1ID, user2ID int) ( return chat, nil } +func (r *chatRepository) checkUserExists(ctx context.Context, userID int) error { + var exists bool + err := r.db.QueryRowContext(ctx, "SELECT EXISTS(SELECT 1 FROM users WHERE id = $1)", userID).Scan(&exists) + if err != nil { + return err + } + if !exists { + return errors.New("user not found") + } + return nil +} + func (r *chatRepository) GetChatByID(ctx context.Context, id int) (*domain.Chat, error) { query := ` SELECT id, user1_id, user2_id, created_at diff --git a/internal/ws/hub.go b/internal/ws/hub.go index 15c5046..2eb35ed 100644 --- a/internal/ws/hub.go +++ b/internal/ws/hub.go @@ -1,4 +1,3 @@ -// ws/hub.go package ws import ( @@ -45,12 +44,16 @@ func (h *Hub) Run() { select { case client := <-h.register: h.mu.Lock() + // Закрываем предыдущее соединение если есть + if existing, ok := h.clients[client.UserID]; ok { + close(existing.Send) + } h.clients[client.UserID] = client h.mu.Unlock() case client := <-h.unregister: h.mu.Lock() - if c, ok := h.clients[client.UserID]; ok { + if c, ok := h.clients[client.UserID]; ok && c == client { close(c.Send) delete(h.clients, client.UserID) } @@ -58,22 +61,20 @@ func (h *Hub) Run() { case message := <-h.broadcast: h.mu.RLock() - // Отправляем сообщение отправителю - if sender, ok := h.clients[message.SenderID]; ok { - select { - case sender.Send <- message: - default: - close(sender.Send) - delete(h.clients, sender.UserID) - } - } - // Отправляем сообщение получателю - if receiver, ok := h.clients[message.ReceiverID]; ok { - select { - case receiver.Send <- message: - default: - close(receiver.Send) - delete(h.clients, receiver.UserID) + // Отправляем всем клиентам, кто участвует в этом чате + for _, client := range h.clients { + if client.UserID == message.SenderID || client.UserID == message.ReceiverID { + select { + case client.Send <- message: + default: + // Если канал полон, закрываем соединение + close(client.Send) + h.mu.RUnlock() + h.mu.Lock() + delete(h.clients, client.UserID) + h.mu.Unlock() + h.mu.RLock() + } } } h.mu.RUnlock()