v0.0.17 Переработан websocket
All checks were successful
continuous-integration/drone/push Build is passing

This commit is contained in:
madipo2611 2025-08-11 00:22:16 +03:00
parent 62d6c44a1b
commit 6b9cdeba55
8 changed files with 242 additions and 91 deletions

View File

@ -109,11 +109,31 @@ func (r *messageResolver) Sender(ctx context.Context, obj *domain.Message) (*dom
// Receiver - возвращает получателя сообщения // Receiver - возвращает получателя сообщения
func (r *messageResolver) Receiver(ctx context.Context, obj *domain.Message) (*domain.User, error) { func (r *messageResolver) Receiver(ctx context.Context, obj *domain.Message) (*domain.User, error) {
chat, err := r.chatRepo.GetChatByID(ctx, obj.ChatID) // 1. Проверка на nil
if err != nil { if obj == nil {
return nil, fmt.Errorf("ошибка получения чата: %v", err) return nil, fmt.Errorf("message is nil")
} }
// 2. Если есть receiver_id - используем его напрямую
if obj.ReceiverID != 0 {
user, err := r.Services.User.GetByID(ctx, obj.ReceiverID)
if err != nil {
return nil, fmt.Errorf("failed to get receiver by ID: %v", err)
}
return user, nil
}
// 3. Альтернативный вариант через chat_id (если receiver_id не установлен)
if obj.ChatID == 0 {
return nil, fmt.Errorf("both receiver_id and chat_id are not set")
}
chat, err := r.chatRepo.GetChatByID(ctx, obj.ChatID)
if err != nil {
return nil, fmt.Errorf("failed to get chat: %v", err)
}
// Определяем ID получателя
receiverID := chat.User1ID receiverID := chat.User1ID
if obj.SenderID == chat.User1ID { if obj.SenderID == chat.User1ID {
receiverID = chat.User2ID receiverID = chat.User2ID
@ -121,8 +141,9 @@ func (r *messageResolver) Receiver(ctx context.Context, obj *domain.Message) (*d
user, err := r.Services.User.GetByID(ctx, receiverID) user, err := r.Services.User.GetByID(ctx, receiverID)
if err != nil { if err != nil {
return nil, fmt.Errorf("ошибка получения получателя: %v", err) return nil, fmt.Errorf("failed to get receiver user: %v", err)
} }
return user, nil return user, nil
} }
@ -153,30 +174,14 @@ func (r *mutationResolver) SendMessage(ctx context.Context, receiverID int, cont
// MarkAsRead - помечает сообщение как прочитанное // MarkAsRead - помечает сообщение как прочитанное
func (r *mutationResolver) MarkAsRead(ctx context.Context, messageID int) (bool, error) { func (r *mutationResolver) MarkAsRead(ctx context.Context, messageID int) (bool, error) {
userID, err := getUserIDFromContext(ctx) if r.Services == nil || r.Services.Chat == nil {
if err != nil { return false, errors.New("сервис чатов не инициализирован")
return false, errors.New("не авторизован")
} }
// Получаем сообщение напрямую из репозитория // Все операции через сервис
message, err := r.chatRepo.GetMessageByID(ctx, messageID) err := r.Services.Chat.MarkAsRead(ctx, messageID)
if err != nil { if err != nil {
return false, fmt.Errorf("ошибка получения сообщения: %v", err) return false, fmt.Errorf("ошибка отметки как прочитанного: %v", err)
}
// Проверяем доступ к сообщению
chat, err := r.chatRepo.GetChatByID(ctx, message.ChatID)
if err != nil {
return false, fmt.Errorf("ошибка получения чата: %v", err)
}
if userID != chat.User1ID && userID != chat.User2ID {
return false, errors.New("нет доступа к сообщению")
}
err = r.Services.Chat.MarkAsRead(ctx, messageID)
if err != nil {
return false, fmt.Errorf("ошибка обновления статуса: %v", err)
} }
return true, nil return true, nil

View File

@ -2,7 +2,9 @@ package handlers
import ( import (
"context" "context"
"log"
"net/http" "net/http"
"strings"
"tailly_back_v2/internal/domain" "tailly_back_v2/internal/domain"
"tailly_back_v2/internal/service" "tailly_back_v2/internal/service"
"tailly_back_v2/internal/ws" "tailly_back_v2/internal/ws"
@ -15,8 +17,9 @@ var upgrader = websocket.Upgrader{
ReadBufferSize: 1024, ReadBufferSize: 1024,
WriteBufferSize: 1024, WriteBufferSize: 1024,
CheckOrigin: func(r *http.Request) bool { CheckOrigin: func(r *http.Request) bool {
return true // В production заменить на проверку origin return true
}, },
Subprotocols: []string{"graphql-transport-ws"},
} }
type ChatHandler struct { type ChatHandler struct {
@ -34,22 +37,51 @@ func NewChatHandler(chatService service.ChatService, hub *ws.Hub, tokenAuth *aut
} }
func (h *ChatHandler) HandleWebSocket(w http.ResponseWriter, r *http.Request) { func (h *ChatHandler) HandleWebSocket(w http.ResponseWriter, r *http.Request) {
// Аутентификация пользователя requestedProtocol := r.Header.Get("Sec-WebSocket-Protocol")
token := r.URL.Query().Get("token") if requestedProtocol != "" && requestedProtocol != "graphql-transport-ws" {
http.Error(w, "Unsupported WebSocket protocol", http.StatusBadRequest)
return
}
log.Printf("Requested protocols: %v", r.Header["Sec-WebSocket-Protocol"])
// 1. Проверяем куки
var token string
cookie, err := r.Cookie("accessToken")
if err == nil {
token = cookie.Value
log.Printf("WebSocket: токен из куки: %s", token)
}
// 2. Если нет в куках, проверяем заголовок Authorization
if token == "" { if token == "" {
authHeader := r.Header.Get("Authorization")
if authHeader != "" {
token = strings.TrimPrefix(authHeader, "Bearer ")
log.Printf("WebSocket: токен из заголовка: %s", token)
}
}
// 3. Если токен не найден - возвращаем 401
if token == "" {
log.Println("WebSocket: токен не найден ни в куках, ни в заголовках")
http.Error(w, "Token is required", http.StatusUnauthorized) http.Error(w, "Token is required", http.StatusUnauthorized)
return return
} }
// 4. Валидация токена
userID, err := h.tokenAuth.ValidateAccessToken(token) userID, err := h.tokenAuth.ValidateAccessToken(token)
if err != nil { if err != nil {
log.Printf("WebSocket: ошибка валидации токена: %v", err)
http.Error(w, "Invalid token", http.StatusUnauthorized) http.Error(w, "Invalid token", http.StatusUnauthorized)
return return
} }
log.Printf("WebSocket: успешная авторизация, userID=%d", userID)
// 5. Обновление соединения
conn, err := upgrader.Upgrade(w, r, nil) conn, err := upgrader.Upgrade(w, r, nil)
if err != nil { if err != nil {
http.Error(w, "Could not open websocket connection", http.StatusBadRequest) log.Printf("WebSocket upgrade error: %v", err)
http.Error(w, "Could not upgrade to WebSocket", http.StatusBadRequest)
return return
} }
@ -60,49 +92,109 @@ func (h *ChatHandler) HandleWebSocket(w http.ResponseWriter, r *http.Request) {
h.hub.RegisterClient(client) h.hub.RegisterClient(client)
// Горутина для чтения сообщений // Добавляем контекст для управления жизненным циклом соединения
go h.readPump(conn, client, userID) ctx, cancel := context.WithCancel(context.Background())
// Горутина для записи сообщений defer cancel()
go h.writePump(conn, client)
// Запускаем горутины с обработкой контекста
go h.readPump(ctx, conn, client, userID)
go h.writePump(ctx, conn, client)
// Ждем завершения
<-ctx.Done()
} }
func (h *ChatHandler) readPump(conn *websocket.Conn, client *ws.Client, userID int) { func (h *ChatHandler) readPump(ctx context.Context, conn *websocket.Conn, client *ws.Client, userID int) {
defer func() { defer func() {
h.hub.UnregisterClient(client) h.hub.UnregisterClient(client)
conn.Close() conn.Close()
}() }()
for { for {
select {
case <-ctx.Done():
return
default:
var msg struct { var msg struct {
ChatID int `json:"chatId"` Type string `json:"type"`
Payload struct {
ReceiverID int `json:"receiverId"`
Content string `json:"content"` Content string `json:"content"`
} `json:"payload"`
} }
if err := conn.ReadJSON(&msg); err != nil { if err := conn.ReadJSON(&msg); err != nil {
break if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
log.Printf("WebSocket read error: %v", err)
}
return
} }
// Используем context.Background() вместо r.Context() if msg.Type != "message" {
message, err := h.chatService.SendMessage(
context.Background(),
userID,
msg.ChatID,
msg.Content,
)
if err != nil {
continue continue
} }
// Проверяем receiverId
if msg.Payload.ReceiverID == 0 {
log.Printf("Invalid receiverId: 0")
conn.WriteJSON(map[string]interface{}{
"type": "error",
"message": "Invalid receiver ID",
})
continue
}
// Логирование для отладки
log.Printf("Received message from %d to %d", userID, msg.Payload.ReceiverID)
// Создаем или находим чат
chat, err := h.chatService.GetOrCreateChat(ctx, userID, msg.Payload.ReceiverID)
if err != nil {
log.Printf("Chat error: %v", err)
conn.WriteJSON(map[string]interface{}{
"type": "error",
"message": "Chat error",
"details": err.Error(),
})
continue
}
// Отправляем сообщение
message, err := h.chatService.SendMessage(
ctx,
userID,
chat.ID,
msg.Payload.Content,
)
if err != nil {
log.Printf("Message send error: %v", err)
continue
}
// Рассылаем сообщение
h.hub.Broadcast(message) h.hub.Broadcast(message)
} }
} }
}
func (h *ChatHandler) writePump(conn *websocket.Conn, client *ws.Client) { func (h *ChatHandler) writePump(ctx context.Context, conn *websocket.Conn, client *ws.Client) {
defer conn.Close() defer conn.Close()
for message := range client.Send { for {
select {
case <-ctx.Done():
return
case message, ok := <-client.Send:
if !ok {
// Канал закрыт
conn.WriteMessage(websocket.CloseMessage, []byte{})
return
}
if err := conn.WriteJSON(message); err != nil { if err := conn.WriteJSON(message); err != nil {
break log.Printf("WebSocket write error: %v", err)
return
}
} }
} }
} }

View File

@ -2,6 +2,7 @@ package middleware
import ( import (
"context" "context"
"log"
"net/http" "net/http"
"strings" "strings"
"tailly_back_v2/pkg/auth" "tailly_back_v2/pkg/auth"
@ -10,14 +11,26 @@ import (
const ( const (
authorizationHeader = "Authorization" authorizationHeader = "Authorization"
bearerPrefix = "Bearer " bearerPrefix = "Bearer "
userIDKey = "userID" // Ключ для хранения userID в контексте userIDKey = "userID"
) )
// AuthMiddleware проверяет JWT токен и добавляет userID в контекст // AuthMiddleware проверяет JWT токен и добавляет userID в контекст
func AuthMiddleware(tokenAuth *auth.TokenAuth) func(http.Handler) http.Handler { func AuthMiddleware(tokenAuth *auth.TokenAuth) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Пропускаем OPTIONS запросы (для CORS)
log.Printf("Middleware: путь %s", r.URL.Path)
// Пропускаем WebSocket маршрут
if r.URL.Path == "/ws" {
log.Printf("Middleware: пропускаем /ws")
next.ServeHTTP(w, r)
return
}
if strings.Contains(r.Header.Get("Upgrade"), "websocket") {
next.ServeHTTP(w, r)
return
}
if r.Method == http.MethodOptions { if r.Method == http.MethodOptions {
next.ServeHTTP(w, r) next.ServeHTTP(w, r)
return return
@ -56,3 +69,15 @@ func AuthMiddleware(tokenAuth *auth.TokenAuth) func(http.Handler) http.Handler {
}) })
} }
} }
func WebSocketMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Header.Get("Upgrade") == "websocket" {
// Используем оригинальный ResponseWriter для WebSocket
next.ServeHTTP(w, r)
return
}
// Для обычных HTTP запросов используем наш кастомный writer
rw := &responseWriter{ResponseWriter: w, status: http.StatusOK}
next.ServeHTTP(rw, r)
})
}

View File

@ -5,7 +5,6 @@ import (
"strings" "strings"
) )
// CORS middleware настраивает политику кросс-доменных запросов
func CORS(allowedOrigins []string) func(http.Handler) http.Handler { func CORS(allowedOrigins []string) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@ -41,12 +40,10 @@ func isOriginAllowed(origin string, allowedOrigins []string) bool {
return false return false
} }
// Разрешаем все источники в development
if len(allowedOrigins) == 1 && allowedOrigins[0] == "*" { if len(allowedOrigins) == 1 && allowedOrigins[0] == "*" {
return true return true
} }
// Точноe сравнение с разрешенными доменами
for _, allowed := range allowedOrigins { for _, allowed := range allowedOrigins {
if strings.EqualFold(origin, allowed) { if strings.EqualFold(origin, allowed) {
return true return true

View File

@ -1,9 +1,12 @@
package middleware package middleware
import ( import (
"bufio"
"bytes" "bytes"
"errors"
"io" "io"
"log" "log"
"net"
"net/http" "net/http"
"time" "time"
) )
@ -14,6 +17,11 @@ func LoggingMiddleware(logger *log.Logger) func(http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now() start := time.Now()
if r.Header.Get("Upgrade") == "websocket" {
next.ServeHTTP(w, r)
return
}
// Логируем основные параметры запроса // Логируем основные параметры запроса
logData := map[string]interface{}{ logData := map[string]interface{}{
"method": r.Method, "method": r.Method,
@ -81,3 +89,11 @@ func (rw *responseWriter) Write(b []byte) (int, error) {
rw.size += size rw.size += size
return size, err return size, err
} }
// Добавляем поддержку Hijacker
func (rw *responseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
if hijacker, ok := rw.ResponseWriter.(http.Hijacker); ok {
return hijacker.Hijack()
}
return nil, nil, errors.New("response writer does not support hijacking")
}

View File

@ -25,14 +25,14 @@ type Server struct {
cfg *config.Config cfg *config.Config
services *service.Services services *service.Services
tokenAuth *auth.TokenAuth tokenAuth *auth.TokenAuth
db *sql.DB // Добавляем подключение к БД db *sql.DB
} }
func NewServer( func NewServer(
cfg *config.Config, cfg *config.Config,
services *service.Services, services *service.Services,
tokenAuth *auth.TokenAuth, tokenAuth *auth.TokenAuth,
db *sql.DB, // Добавляем параметр БД db *sql.DB,
) *Server { ) *Server {
s := &Server{ s := &Server{
router: chi.NewRouter(), router: chi.NewRouter(),
@ -50,8 +50,8 @@ func NewServer(
func (s *Server) configureRouter() { func (s *Server) configureRouter() {
allowedOrigins := []string{ allowedOrigins := []string{
"http://localhost:3000", // React dev server "http://localhost:3000",
"https://tailly.ru", // Продакшен домен "https://tailly.ru",
} }
// Инициализация WebSocket хаба // Инициализация WebSocket хаба
hub := ws.NewHub() hub := ws.NewHub()
@ -64,24 +64,20 @@ func (s *Server) configureRouter() {
hub, hub,
) )
// Обновляем сервис чата в services
s.services.Chat = chatService s.services.Chat = chatService
// Добавляем обработчик WebSocket
// Логирование
logger := log.New(os.Stdout, "HTTP: ", log.LstdFlags) logger := log.New(os.Stdout, "HTTP: ", log.LstdFlags)
s.router.Use(middleware.WebSocketMiddleware)
s.router.Use(middleware.LoggingMiddleware(logger)) s.router.Use(middleware.LoggingMiddleware(logger))
s.router.Use(middleware.MetricsMiddleware) s.router.Use(middleware.MetricsMiddleware)
s.router.Use(middleware.CORS(allowedOrigins)) s.router.Use(middleware.CORS(allowedOrigins))
s.router.Use(middleware.AuthMiddleware(s.tokenAuth)) s.router.Use(middleware.AuthMiddleware(s.tokenAuth))
// GraphQL handler
resolver := graph.NewResolver(s.services, s.db) // Теперь передаем оба аргумента resolver := graph.NewResolver(s.services, s.db)
srv := handler.NewDefaultServer(graph.NewExecutableSchema(graph.Config{ srv := handler.NewDefaultServer(graph.NewExecutableSchema(graph.Config{
Resolvers: resolver, Resolvers: resolver,
})) }))
// Routes
s.router.Handle("/", playground.Handler("GraphQL playground", "/query")) s.router.Handle("/", playground.Handler("GraphQL playground", "/query"))
s.router.Handle("/query", srv) s.router.Handle("/query", srv)
s.router.Handle("/uploads/*", http.StripPrefix("/uploads/", http.FileServer(http.Dir("./uploads")))) s.router.Handle("/uploads/*", http.StripPrefix("/uploads/", http.FileServer(http.Dir("./uploads"))))
@ -105,6 +101,5 @@ func (s *Server) Run() error {
} }
func (s *Server) Shutdown(ctx context.Context) error { func (s *Server) Shutdown(ctx context.Context) error {
// Здесь можно добавить логику graceful shutdown
return nil return nil
} }

View File

@ -134,6 +134,14 @@ func (r *chatRepository) CreateChat(ctx context.Context, user1ID, user2ID int) (
return nil, errors.New("cannot create chat with yourself") return nil, errors.New("cannot create chat with yourself")
} }
// Проверяем существование пользователей
if err := r.checkUserExists(ctx, user1ID); err != nil {
return nil, fmt.Errorf("user1 does not exist: %v", err)
}
if err := r.checkUserExists(ctx, user2ID); err != nil {
return nil, fmt.Errorf("user2 does not exist: %v", err)
}
// Упорядочиваем ID пользователей согласно CHECK constraint // Упорядочиваем ID пользователей согласно CHECK constraint
if user1ID > user2ID { if user1ID > user2ID {
user1ID, user2ID = user2ID, user1ID user1ID, user2ID = user2ID, user1ID
@ -156,6 +164,18 @@ func (r *chatRepository) CreateChat(ctx context.Context, user1ID, user2ID int) (
return chat, nil return chat, nil
} }
func (r *chatRepository) checkUserExists(ctx context.Context, userID int) error {
var exists bool
err := r.db.QueryRowContext(ctx, "SELECT EXISTS(SELECT 1 FROM users WHERE id = $1)", userID).Scan(&exists)
if err != nil {
return err
}
if !exists {
return errors.New("user not found")
}
return nil
}
func (r *chatRepository) GetChatByID(ctx context.Context, id int) (*domain.Chat, error) { func (r *chatRepository) GetChatByID(ctx context.Context, id int) (*domain.Chat, error) {
query := ` query := `
SELECT id, user1_id, user2_id, created_at SELECT id, user1_id, user2_id, created_at

View File

@ -1,4 +1,3 @@
// ws/hub.go
package ws package ws
import ( import (
@ -45,12 +44,16 @@ func (h *Hub) Run() {
select { select {
case client := <-h.register: case client := <-h.register:
h.mu.Lock() h.mu.Lock()
// Закрываем предыдущее соединение если есть
if existing, ok := h.clients[client.UserID]; ok {
close(existing.Send)
}
h.clients[client.UserID] = client h.clients[client.UserID] = client
h.mu.Unlock() h.mu.Unlock()
case client := <-h.unregister: case client := <-h.unregister:
h.mu.Lock() h.mu.Lock()
if c, ok := h.clients[client.UserID]; ok { if c, ok := h.clients[client.UserID]; ok && c == client {
close(c.Send) close(c.Send)
delete(h.clients, client.UserID) delete(h.clients, client.UserID)
} }
@ -58,23 +61,21 @@ func (h *Hub) Run() {
case message := <-h.broadcast: case message := <-h.broadcast:
h.mu.RLock() h.mu.RLock()
// Отправляем сообщение отправителю // Отправляем всем клиентам, кто участвует в этом чате
if sender, ok := h.clients[message.SenderID]; ok { for _, client := range h.clients {
if client.UserID == message.SenderID || client.UserID == message.ReceiverID {
select { select {
case sender.Send <- message: case client.Send <- message:
default: default:
close(sender.Send) // Если канал полон, закрываем соединение
delete(h.clients, sender.UserID) close(client.Send)
h.mu.RUnlock()
h.mu.Lock()
delete(h.clients, client.UserID)
h.mu.Unlock()
h.mu.RLock()
} }
} }
// Отправляем сообщение получателю
if receiver, ok := h.clients[message.ReceiverID]; ok {
select {
case receiver.Send <- message:
default:
close(receiver.Send)
delete(h.clients, receiver.UserID)
}
} }
h.mu.RUnlock() h.mu.RUnlock()
} }