v0.0.17 Переработан websocket
All checks were successful
continuous-integration/drone/push Build is passing
All checks were successful
continuous-integration/drone/push Build is passing
This commit is contained in:
parent
62d6c44a1b
commit
6b9cdeba55
@ -109,11 +109,31 @@ func (r *messageResolver) Sender(ctx context.Context, obj *domain.Message) (*dom
|
|||||||
|
|
||||||
// Receiver - возвращает получателя сообщения
|
// Receiver - возвращает получателя сообщения
|
||||||
func (r *messageResolver) Receiver(ctx context.Context, obj *domain.Message) (*domain.User, error) {
|
func (r *messageResolver) Receiver(ctx context.Context, obj *domain.Message) (*domain.User, error) {
|
||||||
chat, err := r.chatRepo.GetChatByID(ctx, obj.ChatID)
|
// 1. Проверка на nil
|
||||||
if err != nil {
|
if obj == nil {
|
||||||
return nil, fmt.Errorf("ошибка получения чата: %v", err)
|
return nil, fmt.Errorf("message is nil")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 2. Если есть receiver_id - используем его напрямую
|
||||||
|
if obj.ReceiverID != 0 {
|
||||||
|
user, err := r.Services.User.GetByID(ctx, obj.ReceiverID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get receiver by ID: %v", err)
|
||||||
|
}
|
||||||
|
return user, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. Альтернативный вариант через chat_id (если receiver_id не установлен)
|
||||||
|
if obj.ChatID == 0 {
|
||||||
|
return nil, fmt.Errorf("both receiver_id and chat_id are not set")
|
||||||
|
}
|
||||||
|
|
||||||
|
chat, err := r.chatRepo.GetChatByID(ctx, obj.ChatID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get chat: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Определяем ID получателя
|
||||||
receiverID := chat.User1ID
|
receiverID := chat.User1ID
|
||||||
if obj.SenderID == chat.User1ID {
|
if obj.SenderID == chat.User1ID {
|
||||||
receiverID = chat.User2ID
|
receiverID = chat.User2ID
|
||||||
@ -121,8 +141,9 @@ func (r *messageResolver) Receiver(ctx context.Context, obj *domain.Message) (*d
|
|||||||
|
|
||||||
user, err := r.Services.User.GetByID(ctx, receiverID)
|
user, err := r.Services.User.GetByID(ctx, receiverID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("ошибка получения получателя: %v", err)
|
return nil, fmt.Errorf("failed to get receiver user: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return user, nil
|
return user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -153,30 +174,14 @@ func (r *mutationResolver) SendMessage(ctx context.Context, receiverID int, cont
|
|||||||
|
|
||||||
// MarkAsRead - помечает сообщение как прочитанное
|
// MarkAsRead - помечает сообщение как прочитанное
|
||||||
func (r *mutationResolver) MarkAsRead(ctx context.Context, messageID int) (bool, error) {
|
func (r *mutationResolver) MarkAsRead(ctx context.Context, messageID int) (bool, error) {
|
||||||
userID, err := getUserIDFromContext(ctx)
|
if r.Services == nil || r.Services.Chat == nil {
|
||||||
if err != nil {
|
return false, errors.New("сервис чатов не инициализирован")
|
||||||
return false, errors.New("не авторизован")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Получаем сообщение напрямую из репозитория
|
// Все операции через сервис
|
||||||
message, err := r.chatRepo.GetMessageByID(ctx, messageID)
|
err := r.Services.Chat.MarkAsRead(ctx, messageID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, fmt.Errorf("ошибка получения сообщения: %v", err)
|
return false, fmt.Errorf("ошибка отметки как прочитанного: %v", err)
|
||||||
}
|
|
||||||
|
|
||||||
// Проверяем доступ к сообщению
|
|
||||||
chat, err := r.chatRepo.GetChatByID(ctx, message.ChatID)
|
|
||||||
if err != nil {
|
|
||||||
return false, fmt.Errorf("ошибка получения чата: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if userID != chat.User1ID && userID != chat.User2ID {
|
|
||||||
return false, errors.New("нет доступа к сообщению")
|
|
||||||
}
|
|
||||||
|
|
||||||
err = r.Services.Chat.MarkAsRead(ctx, messageID)
|
|
||||||
if err != nil {
|
|
||||||
return false, fmt.Errorf("ошибка обновления статуса: %v", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return true, nil
|
return true, nil
|
||||||
|
|||||||
@ -2,7 +2,9 @@ package handlers
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strings"
|
||||||
"tailly_back_v2/internal/domain"
|
"tailly_back_v2/internal/domain"
|
||||||
"tailly_back_v2/internal/service"
|
"tailly_back_v2/internal/service"
|
||||||
"tailly_back_v2/internal/ws"
|
"tailly_back_v2/internal/ws"
|
||||||
@ -15,8 +17,9 @@ var upgrader = websocket.Upgrader{
|
|||||||
ReadBufferSize: 1024,
|
ReadBufferSize: 1024,
|
||||||
WriteBufferSize: 1024,
|
WriteBufferSize: 1024,
|
||||||
CheckOrigin: func(r *http.Request) bool {
|
CheckOrigin: func(r *http.Request) bool {
|
||||||
return true // В production заменить на проверку origin
|
return true
|
||||||
},
|
},
|
||||||
|
Subprotocols: []string{"graphql-transport-ws"},
|
||||||
}
|
}
|
||||||
|
|
||||||
type ChatHandler struct {
|
type ChatHandler struct {
|
||||||
@ -34,22 +37,51 @@ func NewChatHandler(chatService service.ChatService, hub *ws.Hub, tokenAuth *aut
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *ChatHandler) HandleWebSocket(w http.ResponseWriter, r *http.Request) {
|
func (h *ChatHandler) HandleWebSocket(w http.ResponseWriter, r *http.Request) {
|
||||||
// Аутентификация пользователя
|
requestedProtocol := r.Header.Get("Sec-WebSocket-Protocol")
|
||||||
token := r.URL.Query().Get("token")
|
if requestedProtocol != "" && requestedProtocol != "graphql-transport-ws" {
|
||||||
|
http.Error(w, "Unsupported WebSocket protocol", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Printf("Requested protocols: %v", r.Header["Sec-WebSocket-Protocol"])
|
||||||
|
// 1. Проверяем куки
|
||||||
|
var token string
|
||||||
|
cookie, err := r.Cookie("accessToken")
|
||||||
|
if err == nil {
|
||||||
|
token = cookie.Value
|
||||||
|
log.Printf("WebSocket: токен из куки: %s", token)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Если нет в куках, проверяем заголовок Authorization
|
||||||
if token == "" {
|
if token == "" {
|
||||||
|
authHeader := r.Header.Get("Authorization")
|
||||||
|
if authHeader != "" {
|
||||||
|
token = strings.TrimPrefix(authHeader, "Bearer ")
|
||||||
|
log.Printf("WebSocket: токен из заголовка: %s", token)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. Если токен не найден - возвращаем 401
|
||||||
|
if token == "" {
|
||||||
|
log.Println("WebSocket: токен не найден ни в куках, ни в заголовках")
|
||||||
http.Error(w, "Token is required", http.StatusUnauthorized)
|
http.Error(w, "Token is required", http.StatusUnauthorized)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 4. Валидация токена
|
||||||
userID, err := h.tokenAuth.ValidateAccessToken(token)
|
userID, err := h.tokenAuth.ValidateAccessToken(token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
log.Printf("WebSocket: ошибка валидации токена: %v", err)
|
||||||
http.Error(w, "Invalid token", http.StatusUnauthorized)
|
http.Error(w, "Invalid token", http.StatusUnauthorized)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
log.Printf("WebSocket: успешная авторизация, userID=%d", userID)
|
||||||
|
|
||||||
|
// 5. Обновление соединения
|
||||||
conn, err := upgrader.Upgrade(w, r, nil)
|
conn, err := upgrader.Upgrade(w, r, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, "Could not open websocket connection", http.StatusBadRequest)
|
log.Printf("WebSocket upgrade error: %v", err)
|
||||||
|
http.Error(w, "Could not upgrade to WebSocket", http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -60,49 +92,109 @@ func (h *ChatHandler) HandleWebSocket(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
h.hub.RegisterClient(client)
|
h.hub.RegisterClient(client)
|
||||||
|
|
||||||
// Горутина для чтения сообщений
|
// Добавляем контекст для управления жизненным циклом соединения
|
||||||
go h.readPump(conn, client, userID)
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
// Горутина для записи сообщений
|
defer cancel()
|
||||||
go h.writePump(conn, client)
|
|
||||||
|
// Запускаем горутины с обработкой контекста
|
||||||
|
go h.readPump(ctx, conn, client, userID)
|
||||||
|
go h.writePump(ctx, conn, client)
|
||||||
|
|
||||||
|
// Ждем завершения
|
||||||
|
<-ctx.Done()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *ChatHandler) readPump(conn *websocket.Conn, client *ws.Client, userID int) {
|
func (h *ChatHandler) readPump(ctx context.Context, conn *websocket.Conn, client *ws.Client, userID int) {
|
||||||
defer func() {
|
defer func() {
|
||||||
h.hub.UnregisterClient(client)
|
h.hub.UnregisterClient(client)
|
||||||
conn.Close()
|
conn.Close()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
default:
|
||||||
var msg struct {
|
var msg struct {
|
||||||
ChatID int `json:"chatId"`
|
Type string `json:"type"`
|
||||||
|
Payload struct {
|
||||||
|
ReceiverID int `json:"receiverId"`
|
||||||
Content string `json:"content"`
|
Content string `json:"content"`
|
||||||
|
} `json:"payload"`
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := conn.ReadJSON(&msg); err != nil {
|
if err := conn.ReadJSON(&msg); err != nil {
|
||||||
break
|
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
|
||||||
|
log.Printf("WebSocket read error: %v", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Используем context.Background() вместо r.Context()
|
if msg.Type != "message" {
|
||||||
message, err := h.chatService.SendMessage(
|
|
||||||
context.Background(),
|
|
||||||
userID,
|
|
||||||
msg.ChatID,
|
|
||||||
msg.Content,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Проверяем receiverId
|
||||||
|
if msg.Payload.ReceiverID == 0 {
|
||||||
|
log.Printf("Invalid receiverId: 0")
|
||||||
|
conn.WriteJSON(map[string]interface{}{
|
||||||
|
"type": "error",
|
||||||
|
"message": "Invalid receiver ID",
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Логирование для отладки
|
||||||
|
log.Printf("Received message from %d to %d", userID, msg.Payload.ReceiverID)
|
||||||
|
|
||||||
|
// Создаем или находим чат
|
||||||
|
chat, err := h.chatService.GetOrCreateChat(ctx, userID, msg.Payload.ReceiverID)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Chat error: %v", err)
|
||||||
|
conn.WriteJSON(map[string]interface{}{
|
||||||
|
"type": "error",
|
||||||
|
"message": "Chat error",
|
||||||
|
"details": err.Error(),
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Отправляем сообщение
|
||||||
|
message, err := h.chatService.SendMessage(
|
||||||
|
ctx,
|
||||||
|
userID,
|
||||||
|
chat.ID,
|
||||||
|
msg.Payload.Content,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Message send error: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Рассылаем сообщение
|
||||||
h.hub.Broadcast(message)
|
h.hub.Broadcast(message)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (h *ChatHandler) writePump(conn *websocket.Conn, client *ws.Client) {
|
func (h *ChatHandler) writePump(ctx context.Context, conn *websocket.Conn, client *ws.Client) {
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
|
|
||||||
for message := range client.Send {
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case message, ok := <-client.Send:
|
||||||
|
if !ok {
|
||||||
|
// Канал закрыт
|
||||||
|
conn.WriteMessage(websocket.CloseMessage, []byte{})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
if err := conn.WriteJSON(message); err != nil {
|
if err := conn.WriteJSON(message); err != nil {
|
||||||
break
|
log.Printf("WebSocket write error: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -2,6 +2,7 @@ package middleware
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
"tailly_back_v2/pkg/auth"
|
"tailly_back_v2/pkg/auth"
|
||||||
@ -10,14 +11,26 @@ import (
|
|||||||
const (
|
const (
|
||||||
authorizationHeader = "Authorization"
|
authorizationHeader = "Authorization"
|
||||||
bearerPrefix = "Bearer "
|
bearerPrefix = "Bearer "
|
||||||
userIDKey = "userID" // Ключ для хранения userID в контексте
|
userIDKey = "userID"
|
||||||
)
|
)
|
||||||
|
|
||||||
// AuthMiddleware проверяет JWT токен и добавляет userID в контекст
|
// AuthMiddleware проверяет JWT токен и добавляет userID в контекст
|
||||||
func AuthMiddleware(tokenAuth *auth.TokenAuth) func(http.Handler) http.Handler {
|
func AuthMiddleware(tokenAuth *auth.TokenAuth) func(http.Handler) http.Handler {
|
||||||
return func(next http.Handler) http.Handler {
|
return func(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
// Пропускаем OPTIONS запросы (для CORS)
|
|
||||||
|
log.Printf("Middleware: путь %s", r.URL.Path)
|
||||||
|
|
||||||
|
// Пропускаем WebSocket маршрут
|
||||||
|
if r.URL.Path == "/ws" {
|
||||||
|
log.Printf("Middleware: пропускаем /ws")
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if strings.Contains(r.Header.Get("Upgrade"), "websocket") {
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
if r.Method == http.MethodOptions {
|
if r.Method == http.MethodOptions {
|
||||||
next.ServeHTTP(w, r)
|
next.ServeHTTP(w, r)
|
||||||
return
|
return
|
||||||
@ -56,3 +69,15 @@ func AuthMiddleware(tokenAuth *auth.TokenAuth) func(http.Handler) http.Handler {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
func WebSocketMiddleware(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Header.Get("Upgrade") == "websocket" {
|
||||||
|
// Используем оригинальный ResponseWriter для WebSocket
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Для обычных HTTP запросов используем наш кастомный writer
|
||||||
|
rw := &responseWriter{ResponseWriter: w, status: http.StatusOK}
|
||||||
|
next.ServeHTTP(rw, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
@ -5,7 +5,6 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
// CORS middleware настраивает политику кросс-доменных запросов
|
|
||||||
func CORS(allowedOrigins []string) func(http.Handler) http.Handler {
|
func CORS(allowedOrigins []string) func(http.Handler) http.Handler {
|
||||||
return func(next http.Handler) http.Handler {
|
return func(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
@ -41,12 +40,10 @@ func isOriginAllowed(origin string, allowedOrigins []string) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// Разрешаем все источники в development
|
|
||||||
if len(allowedOrigins) == 1 && allowedOrigins[0] == "*" {
|
if len(allowedOrigins) == 1 && allowedOrigins[0] == "*" {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// Точноe сравнение с разрешенными доменами
|
|
||||||
for _, allowed := range allowedOrigins {
|
for _, allowed := range allowedOrigins {
|
||||||
if strings.EqualFold(origin, allowed) {
|
if strings.EqualFold(origin, allowed) {
|
||||||
return true
|
return true
|
||||||
|
|||||||
@ -1,9 +1,12 @@
|
|||||||
package middleware
|
package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bufio"
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
@ -14,6 +17,11 @@ func LoggingMiddleware(logger *log.Logger) func(http.Handler) http.Handler {
|
|||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
|
|
||||||
|
if r.Header.Get("Upgrade") == "websocket" {
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// Логируем основные параметры запроса
|
// Логируем основные параметры запроса
|
||||||
logData := map[string]interface{}{
|
logData := map[string]interface{}{
|
||||||
"method": r.Method,
|
"method": r.Method,
|
||||||
@ -81,3 +89,11 @@ func (rw *responseWriter) Write(b []byte) (int, error) {
|
|||||||
rw.size += size
|
rw.size += size
|
||||||
return size, err
|
return size, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Добавляем поддержку Hijacker
|
||||||
|
func (rw *responseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||||
|
if hijacker, ok := rw.ResponseWriter.(http.Hijacker); ok {
|
||||||
|
return hijacker.Hijack()
|
||||||
|
}
|
||||||
|
return nil, nil, errors.New("response writer does not support hijacking")
|
||||||
|
}
|
||||||
|
|||||||
@ -25,14 +25,14 @@ type Server struct {
|
|||||||
cfg *config.Config
|
cfg *config.Config
|
||||||
services *service.Services
|
services *service.Services
|
||||||
tokenAuth *auth.TokenAuth
|
tokenAuth *auth.TokenAuth
|
||||||
db *sql.DB // Добавляем подключение к БД
|
db *sql.DB
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewServer(
|
func NewServer(
|
||||||
cfg *config.Config,
|
cfg *config.Config,
|
||||||
services *service.Services,
|
services *service.Services,
|
||||||
tokenAuth *auth.TokenAuth,
|
tokenAuth *auth.TokenAuth,
|
||||||
db *sql.DB, // Добавляем параметр БД
|
db *sql.DB,
|
||||||
) *Server {
|
) *Server {
|
||||||
s := &Server{
|
s := &Server{
|
||||||
router: chi.NewRouter(),
|
router: chi.NewRouter(),
|
||||||
@ -50,8 +50,8 @@ func NewServer(
|
|||||||
|
|
||||||
func (s *Server) configureRouter() {
|
func (s *Server) configureRouter() {
|
||||||
allowedOrigins := []string{
|
allowedOrigins := []string{
|
||||||
"http://localhost:3000", // React dev server
|
"http://localhost:3000",
|
||||||
"https://tailly.ru", // Продакшен домен
|
"https://tailly.ru",
|
||||||
}
|
}
|
||||||
// Инициализация WebSocket хаба
|
// Инициализация WebSocket хаба
|
||||||
hub := ws.NewHub()
|
hub := ws.NewHub()
|
||||||
@ -64,24 +64,20 @@ func (s *Server) configureRouter() {
|
|||||||
hub,
|
hub,
|
||||||
)
|
)
|
||||||
|
|
||||||
// Обновляем сервис чата в services
|
|
||||||
s.services.Chat = chatService
|
s.services.Chat = chatService
|
||||||
|
|
||||||
// Добавляем обработчик WebSocket
|
|
||||||
|
|
||||||
// Логирование
|
|
||||||
logger := log.New(os.Stdout, "HTTP: ", log.LstdFlags)
|
logger := log.New(os.Stdout, "HTTP: ", log.LstdFlags)
|
||||||
|
s.router.Use(middleware.WebSocketMiddleware)
|
||||||
s.router.Use(middleware.LoggingMiddleware(logger))
|
s.router.Use(middleware.LoggingMiddleware(logger))
|
||||||
s.router.Use(middleware.MetricsMiddleware)
|
s.router.Use(middleware.MetricsMiddleware)
|
||||||
s.router.Use(middleware.CORS(allowedOrigins))
|
s.router.Use(middleware.CORS(allowedOrigins))
|
||||||
s.router.Use(middleware.AuthMiddleware(s.tokenAuth))
|
s.router.Use(middleware.AuthMiddleware(s.tokenAuth))
|
||||||
// GraphQL handler
|
|
||||||
resolver := graph.NewResolver(s.services, s.db) // Теперь передаем оба аргумента
|
resolver := graph.NewResolver(s.services, s.db)
|
||||||
srv := handler.NewDefaultServer(graph.NewExecutableSchema(graph.Config{
|
srv := handler.NewDefaultServer(graph.NewExecutableSchema(graph.Config{
|
||||||
Resolvers: resolver,
|
Resolvers: resolver,
|
||||||
}))
|
}))
|
||||||
|
|
||||||
// Routes
|
|
||||||
s.router.Handle("/", playground.Handler("GraphQL playground", "/query"))
|
s.router.Handle("/", playground.Handler("GraphQL playground", "/query"))
|
||||||
s.router.Handle("/query", srv)
|
s.router.Handle("/query", srv)
|
||||||
s.router.Handle("/uploads/*", http.StripPrefix("/uploads/", http.FileServer(http.Dir("./uploads"))))
|
s.router.Handle("/uploads/*", http.StripPrefix("/uploads/", http.FileServer(http.Dir("./uploads"))))
|
||||||
@ -105,6 +101,5 @@ func (s *Server) Run() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) Shutdown(ctx context.Context) error {
|
func (s *Server) Shutdown(ctx context.Context) error {
|
||||||
// Здесь можно добавить логику graceful shutdown
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@ -134,6 +134,14 @@ func (r *chatRepository) CreateChat(ctx context.Context, user1ID, user2ID int) (
|
|||||||
return nil, errors.New("cannot create chat with yourself")
|
return nil, errors.New("cannot create chat with yourself")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Проверяем существование пользователей
|
||||||
|
if err := r.checkUserExists(ctx, user1ID); err != nil {
|
||||||
|
return nil, fmt.Errorf("user1 does not exist: %v", err)
|
||||||
|
}
|
||||||
|
if err := r.checkUserExists(ctx, user2ID); err != nil {
|
||||||
|
return nil, fmt.Errorf("user2 does not exist: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
// Упорядочиваем ID пользователей согласно CHECK constraint
|
// Упорядочиваем ID пользователей согласно CHECK constraint
|
||||||
if user1ID > user2ID {
|
if user1ID > user2ID {
|
||||||
user1ID, user2ID = user2ID, user1ID
|
user1ID, user2ID = user2ID, user1ID
|
||||||
@ -156,6 +164,18 @@ func (r *chatRepository) CreateChat(ctx context.Context, user1ID, user2ID int) (
|
|||||||
return chat, nil
|
return chat, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *chatRepository) checkUserExists(ctx context.Context, userID int) error {
|
||||||
|
var exists bool
|
||||||
|
err := r.db.QueryRowContext(ctx, "SELECT EXISTS(SELECT 1 FROM users WHERE id = $1)", userID).Scan(&exists)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if !exists {
|
||||||
|
return errors.New("user not found")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (r *chatRepository) GetChatByID(ctx context.Context, id int) (*domain.Chat, error) {
|
func (r *chatRepository) GetChatByID(ctx context.Context, id int) (*domain.Chat, error) {
|
||||||
query := `
|
query := `
|
||||||
SELECT id, user1_id, user2_id, created_at
|
SELECT id, user1_id, user2_id, created_at
|
||||||
|
|||||||
@ -1,4 +1,3 @@
|
|||||||
// ws/hub.go
|
|
||||||
package ws
|
package ws
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@ -45,12 +44,16 @@ func (h *Hub) Run() {
|
|||||||
select {
|
select {
|
||||||
case client := <-h.register:
|
case client := <-h.register:
|
||||||
h.mu.Lock()
|
h.mu.Lock()
|
||||||
|
// Закрываем предыдущее соединение если есть
|
||||||
|
if existing, ok := h.clients[client.UserID]; ok {
|
||||||
|
close(existing.Send)
|
||||||
|
}
|
||||||
h.clients[client.UserID] = client
|
h.clients[client.UserID] = client
|
||||||
h.mu.Unlock()
|
h.mu.Unlock()
|
||||||
|
|
||||||
case client := <-h.unregister:
|
case client := <-h.unregister:
|
||||||
h.mu.Lock()
|
h.mu.Lock()
|
||||||
if c, ok := h.clients[client.UserID]; ok {
|
if c, ok := h.clients[client.UserID]; ok && c == client {
|
||||||
close(c.Send)
|
close(c.Send)
|
||||||
delete(h.clients, client.UserID)
|
delete(h.clients, client.UserID)
|
||||||
}
|
}
|
||||||
@ -58,23 +61,21 @@ func (h *Hub) Run() {
|
|||||||
|
|
||||||
case message := <-h.broadcast:
|
case message := <-h.broadcast:
|
||||||
h.mu.RLock()
|
h.mu.RLock()
|
||||||
// Отправляем сообщение отправителю
|
// Отправляем всем клиентам, кто участвует в этом чате
|
||||||
if sender, ok := h.clients[message.SenderID]; ok {
|
for _, client := range h.clients {
|
||||||
|
if client.UserID == message.SenderID || client.UserID == message.ReceiverID {
|
||||||
select {
|
select {
|
||||||
case sender.Send <- message:
|
case client.Send <- message:
|
||||||
default:
|
default:
|
||||||
close(sender.Send)
|
// Если канал полон, закрываем соединение
|
||||||
delete(h.clients, sender.UserID)
|
close(client.Send)
|
||||||
|
h.mu.RUnlock()
|
||||||
|
h.mu.Lock()
|
||||||
|
delete(h.clients, client.UserID)
|
||||||
|
h.mu.Unlock()
|
||||||
|
h.mu.RLock()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Отправляем сообщение получателю
|
|
||||||
if receiver, ok := h.clients[message.ReceiverID]; ok {
|
|
||||||
select {
|
|
||||||
case receiver.Send <- message:
|
|
||||||
default:
|
|
||||||
close(receiver.Send)
|
|
||||||
delete(h.clients, receiver.UserID)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
h.mu.RUnlock()
|
h.mu.RUnlock()
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user