v0.0.23 правки в /ws
All checks were successful
continuous-integration/drone/push Build is passing

This commit is contained in:
madipo2611 2025-08-19 13:50:14 +03:00
parent 3081416e95
commit aad2fb366f
2 changed files with 22 additions and 11 deletions

View File

@ -148,6 +148,9 @@ func (r *subscriptionResolver) MessageStream(ctx context.Context, userID int) (<
func (r *subscriptionResolver) runMessageStream(ctx context.Context, userID int, messageChan chan<- *domain.Message) error { func (r *subscriptionResolver) runMessageStream(ctx context.Context, userID int, messageChan chan<- *domain.Message) error {
log.Printf("Starting new stream for user %d", userID) log.Printf("Starting new stream for user %d", userID)
// Создаем отдельный контекст для gRPC стрима
grpcCtx := context.Background()
_, err := r.MessageClient.UpdateMessageStatus(ctx, &proto.UpdateMessageStatusRequest{ _, err := r.MessageClient.UpdateMessageStatus(ctx, &proto.UpdateMessageStatusRequest{
MessageId: 0, // 0 = все сообщения для пользователя MessageId: 0, // 0 = все сообщения для пользователя
Status: "DELIVERED", Status: "DELIVERED",
@ -157,10 +160,7 @@ func (r *subscriptionResolver) runMessageStream(ctx context.Context, userID int,
log.Printf("Failed to mark messages as delivered: %v", err) log.Printf("Failed to mark messages as delivered: %v", err)
} }
streamCtx, cancel := context.WithCancel(ctx) stream, err := r.MessageClient.StreamMessages(grpcCtx, &proto.StreamMessagesRequest{
defer cancel()
stream, err := r.MessageClient.StreamMessages(streamCtx, &proto.StreamMessagesRequest{
UserId: int32(userID), UserId: int32(userID),
}) })
if err != nil { if err != nil {

View File

@ -3,8 +3,8 @@ package middleware
import ( import (
"context" "context"
"net/http" "net/http"
"strings"
"tailly_back_v2/pkg/auth" "tailly_back_v2/pkg/auth"
"time"
) )
// WSAuthMiddleware проверяет JWT токен для WebSocket соединений // WSAuthMiddleware проверяет JWT токен для WebSocket соединений
@ -12,27 +12,38 @@ func WSAuthMiddleware(tokenAuth *auth.TokenAuth) func(http.Handler) http.Handler
return func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Header.Get("Upgrade") == "websocket" { if r.Header.Get("Upgrade") == "websocket" {
ctx, cancel := context.WithTimeout(r.Context(), 30*time.Minute)
defer cancel()
token := extractTokenFromRequest(r) token := extractTokenFromRequest(r)
if token != "" { if token != "" {
if userID, err := tokenAuth.ValidateAccessToken(token); err == nil { if userID, err := tokenAuth.ValidateAccessToken(token); err == nil {
ctx = context.WithValue(ctx, userIDKey, userID) // Создаем контекст без таймаута для WebSocket
} ctx := context.WithValue(r.Context(), userIDKey, userID)
}
r = r.WithContext(ctx) r = r.WithContext(ctx)
} }
}
}
next.ServeHTTP(w, r) next.ServeHTTP(w, r)
}) })
} }
} }
func extractTokenFromRequest(r *http.Request) string { func extractTokenFromRequest(r *http.Request) string {
// Только проверка кук (как в вашем коде) // Проверяем куки
cookie, err := r.Cookie("accessToken") cookie, err := r.Cookie("accessToken")
if err == nil { if err == nil {
return cookie.Value return cookie.Value
} }
// Проверяем заголовок Authorization
authHeader := r.Header.Get("Authorization")
if authHeader != "" && strings.HasPrefix(authHeader, "Bearer ") {
return strings.TrimPrefix(authHeader, "Bearer ")
}
// Проверяем query параметры
token := r.URL.Query().Get("token")
if token != "" {
return token
}
return "" return ""
} }