package main import ( "context" "database/sql" "encoding/json" "errors" "fmt" "github.com/jackc/pgx/v4/pgxpool" amqp "github.com/rabbitmq/amqp091-go" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/keepalive" "google.golang.org/grpc/status" "google.golang.org/protobuf/types/known/timestamppb" "log" "net" "os" "sync" "tailly_messages/crypto" "tailly_messages/proto" "time" ) type server struct { proto.UnimplementedMessageServiceServer db *pgxpool.Pool rabbitConn *amqp.Connection mu sync.Mutex logger *log.Logger crypto *crypto.CryptoService } func NewServer(db *pgxpool.Pool, rabbitConn *amqp.Connection) *server { cryptoService, err := crypto.NewCryptoService() if err != nil { log.Printf("ERROR: Failed to create crypto service: %v", err) } return &server{ db: db, rabbitConn: rabbitConn, crypto: cryptoService, logger: log.New(os.Stdout, "MESSAGE_SERVICE: ", log.LstdFlags|log.Lmicroseconds|log.Lshortfile), } } func (s *server) logRequest(method string, req interface{}) { s.logger.Printf("REQUEST: %s - %+v", method, req) } func (s *server) logResponse(method string, resp interface{}, err error) { if err != nil { s.logger.Printf("RESPONSE ERROR: %s - %v", method, err) } else { s.logger.Printf("RESPONSE: %s - %+v", method, resp) } } func (s *server) CreateChat(ctx context.Context, req *proto.CreateChatRequest) (*proto.ChatResponse, error) { s.logRequest("CreateChat", req) defer func(start time.Time) { s.logger.Printf("CreateChat execution time: %v", time.Since(start)) }(time.Now()) user1, user2 := req.GetUser1Id(), req.GetUser2Id() if user1 > user2 { user1, user2 = user2, user1 } s.logger.Printf("Checking user existence: user1=%d, user2=%d", user1, user2) var user1Exists, user2Exists bool err := s.db.QueryRow(ctx, "SELECT EXISTS(SELECT 1 FROM users WHERE id = $1)", user1).Scan(&user1Exists) if err != nil { s.logger.Printf("Error checking user1 existence: %v", err) return nil, fmt.Errorf("failed to check user existence") } err = s.db.QueryRow(ctx, "SELECT EXISTS(SELECT 1 FROM users WHERE id = $1)", user2).Scan(&user2Exists) if err != nil { s.logger.Printf("Error checking user2 existence: %v", err) return nil, fmt.Errorf("failed to check user existence") } if !user1Exists || !user2Exists { errMsg := fmt.Sprintf("One or both users don't exist: user1=%d (%t), user2=%d (%t)", user1, user1Exists, user2, user2Exists) s.logger.Println(errMsg) return nil, fmt.Errorf("%w", errors.New(errMsg)) } var chat proto.Chat var createdAt, updatedAt time.Time s.logger.Printf("Checking chat existence between users %d and %d", user1, user2) var chatExists bool err = s.db.QueryRow(ctx, "SELECT EXISTS(SELECT 1 FROM chats WHERE user1_id = $1 AND user2_id = $2)", user1, user2).Scan(&chatExists) if err != nil { s.logger.Printf("Error checking chat existence: %v", err) return nil, fmt.Errorf("failed to check chat existence") } if chatExists { s.logger.Printf("Chat already exists between users %d and %d, returning existing chat", user1, user2) return s.GetChat(ctx, &proto.GetChatRequest{ User1Id: user1, User2Id: user2, }) } s.logger.Printf("Creating new chat between users %d and %d", user1, user2) err = s.db.QueryRow(ctx, ` INSERT INTO chats (user1_id, user2_id) VALUES ($1, $2) RETURNING id, user1_id, user2_id, created_at, updated_at `, user1, user2).Scan( &chat.Id, &chat.User1Id, &chat.User2Id, &createdAt, &updatedAt, ) if err != nil { s.logger.Printf("Failed to create chat: %v", err) return nil, fmt.Errorf("failed to create chat: %v", err) } s.logger.Printf("Successfully created new chat: id=%d, user1_id=%d, user2_id=%d", chat.Id, chat.User1Id, chat.User2Id) chat.CreatedAt = timestamppb.New(createdAt) chat.UpdatedAt = timestamppb.New(updatedAt) resp := &proto.ChatResponse{Chat: &chat} s.logResponse("CreateChat", resp, nil) return resp, nil } func (s *server) SendMessage(ctx context.Context, req *proto.SendMessageRequest) (*proto.MessageResponse, error) { s.logRequest("SendMessage", req) defer func(start time.Time) { s.logger.Printf("SendMessage execution time: %v", time.Since(start)) }(time.Now()) // Шифруем сообщение encryptedContent, encryptedKey, nonce, err := s.crypto.EncryptMessage(req.Content) if err != nil { s.logger.Printf("Failed to encrypt message: %v", err) return nil, fmt.Errorf("failed to encrypt message: %v", err) } s.logger.Printf("Getting chat info for chat_id=%d", req.ChatId) var user1Id, user2Id int32 err = s.db.QueryRow(ctx, "SELECT user1_id, user2_id FROM chats WHERE id = $1", req.ChatId).Scan(&user1Id, &user2Id) if err != nil { s.logger.Printf("Failed to get chat info: %v", err) return nil, fmt.Errorf("failed to get chat info: %v", err) } var receiverId int32 if req.SenderId == user1Id { receiverId = user2Id } else if req.SenderId == user2Id { receiverId = user1Id } else { errMsg := fmt.Sprintf("sender %d is not a participant of chat %d", req.SenderId, req.ChatId) s.logger.Println(errMsg) return nil, errors.New(errMsg) } s.logger.Printf("Inserting encrypted message into database: chat_id=%d, sender_id=%d, receiver_id=%d", req.ChatId, req.SenderId, receiverId) var message proto.Message var createdAt time.Time err = s.db.QueryRow(ctx, ` INSERT INTO messages (chat_id, sender_id, receiver_id, encrypted_content, encrypted_key, nonce) VALUES ($1, $2, $3, $4, $5, $6) RETURNING id, chat_id, sender_id, receiver_id, status, created_at `, req.ChatId, req.SenderId, receiverId, encryptedContent, encryptedKey, nonce).Scan( &message.Id, &message.ChatId, &message.SenderId, &message.ReceiverId, &message.Status, &createdAt, ) if err != nil { s.logger.Printf("Failed to insert encrypted message: %v", err) return nil, err } message.CreatedAt = timestamppb.New(createdAt) decryptedContent, err := s.crypto.DecryptMessage(encryptedContent, encryptedKey, nonce) if err != nil { s.logger.Printf("Failed to decrypt for RabbitMQ: %v", err) return nil, err } message.Content = decryptedContent s.logger.Printf("Updating chat updated_at for chat_id=%d", req.ChatId) _, err = s.db.Exec(ctx, `UPDATE chats SET updated_at = NOW() WHERE id = $1`, req.ChatId) if err != nil { s.logger.Printf("Failed to update chat timestamp: %v", err) return nil, err } s.logger.Printf("Publishing message to RabbitMQ for user_id=%d", receiverId) var lastErr error for i := 0; i < 3; i++ { ch, err := s.rabbitConn.Channel() if err != nil { lastErr = fmt.Errorf("failed to open channel (attempt %d): %v", i+1, err) s.logger.Printf("RabbitMQ channel error: %v", lastErr) time.Sleep(time.Second * time.Duration(i+1)) continue } queueName := fmt.Sprintf("user_%d_messages", receiverId) msgBytes, _ := json.Marshal(message) s.logger.Printf("Publishing to queue %s: %s", queueName, string(msgBytes)) err = ch.PublishWithContext(ctx, "", // exchange queueName, // routing key false, // mandatory false, // immediate amqp.Publishing{ ContentType: "application/json", Body: msgBytes, DeliveryMode: amqp.Persistent, }) ch.Close() if err == nil { s.logger.Printf("Successfully published message to queue %s", queueName) resp := &proto.MessageResponse{Message: &message} s.logResponse("SendMessage", resp, nil) return resp, nil } lastErr = err s.logger.Printf("Failed to publish message (attempt %d): %v", i+1, err) } errMsg := fmt.Errorf("failed to publish message after 3 attempts: %v", lastErr) s.logResponse("SendMessage", nil, errMsg) return nil, errMsg } func (s *server) GetChat(ctx context.Context, req *proto.GetChatRequest) (*proto.ChatResponse, error) { s.logRequest("GetChat", req) defer func(start time.Time) { s.logger.Printf("GetChat execution time: %v", time.Since(start)) }(time.Now()) var chat proto.Chat var createdAt, updatedAt time.Time var lastMessageID sql.NullInt32 var lastMessageEncryptedContent, lastMessageEncryptedKey, lastMessageNonce []byte var lastMessageStatus sql.NullString var lastMessageCreatedAt sql.NullTime user1, user2 := req.User1Id, req.User2Id if user1 > user2 { user1, user2 = user2, user1 } s.logger.Printf("Querying chat between users %d and %d", user1, user2) err := s.db.QueryRow(ctx, ` SELECT c.id, c.user1_id, c.user2_id, c.created_at, c.updated_at, m.id, m.encrypted_content, m.encrypted_key, m.nonce, m.status, m.created_at FROM chats c LEFT JOIN messages m ON m.id = ( SELECT id FROM messages WHERE chat_id = c.id ORDER BY created_at DESC LIMIT 1 ) WHERE c.user1_id = $1 AND c.user2_id = $2 `, user1, user2).Scan( &chat.Id, &chat.User1Id, &chat.User2Id, &createdAt, &updatedAt, &lastMessageID, &lastMessageEncryptedContent, &lastMessageEncryptedKey, &lastMessageNonce, &lastMessageStatus, &lastMessageCreatedAt, ) if err != nil { s.logger.Printf("Failed to get chat: %v", err) return nil, err } chat.CreatedAt = timestamppb.New(createdAt) chat.UpdatedAt = timestamppb.New(updatedAt) if lastMessageID.Valid { // Расшифровываем последнее сообщение decryptedContent, err := s.crypto.DecryptMessage( lastMessageEncryptedContent, lastMessageEncryptedKey, lastMessageNonce, ) if err != nil { s.logger.Printf("Failed to decrypt last message: %v", err) decryptedContent = "[не удалось расшифровать]" } chat.LastMessage = &proto.Message{ Id: lastMessageID.Int32, ChatId: chat.Id, Content: decryptedContent, Status: lastMessageStatus.String, CreatedAt: timestamppb.New(lastMessageCreatedAt.Time), } s.logger.Printf("Found last message for chat %d: message_id=%d", chat.Id, lastMessageID.Int32) } resp := &proto.ChatResponse{Chat: &chat} s.logResponse("GetChat", resp, nil) return resp, nil } func (s *server) GetChatMessages(ctx context.Context, req *proto.GetChatMessagesRequest) (*proto.MessagesResponse, error) { s.logRequest("GetChatMessages", req) defer func(start time.Time) { s.logger.Printf("GetChatMessages execution time: %v", time.Since(start)) }(time.Now()) s.logger.Printf("Querying messages for chat_id=%d, limit=%d, offset=%d", req.ChatId, req.Limit, req.Offset) rows, err := s.db.Query(ctx, ` SELECT id, chat_id, sender_id, receiver_id, encrypted_content, encrypted_key, nonce, status, created_at FROM messages WHERE chat_id = $1 ORDER BY created_at DESC LIMIT $2 OFFSET $3 `, req.ChatId, req.Limit, req.Offset) if err != nil { s.logger.Printf("Failed to query messages: %v", err) return nil, err } defer rows.Close() var messages []*proto.Message for rows.Next() { var msg proto.Message var createdAt time.Time var encryptedContent, encryptedKey, nonce []byte err := rows.Scan( &msg.Id, &msg.ChatId, &msg.SenderId, &msg.ReceiverId, &encryptedContent, &encryptedKey, &nonce, &msg.Status, &createdAt, ) if err != nil { s.logger.Printf("Failed to scan message row: %v", err) return nil, err } // Расшифровываем сообщение decryptedContent, err := s.crypto.DecryptMessage(encryptedContent, encryptedKey, nonce) if err != nil { s.logger.Printf("Failed to decrypt message %d: %v", msg.Id, err) msg.Content = "[не удалось расшифровать]" } else { msg.Content = decryptedContent } msg.CreatedAt = timestamppb.New(createdAt) messages = append(messages, &msg) } if err := rows.Err(); err != nil { s.logger.Printf("Rows error: %v", err) return nil, err } s.logger.Printf("Retrieved %d messages for chat_id=%d", len(messages), req.ChatId) resp := &proto.MessagesResponse{Messages: messages} s.logResponse("GetChatMessages", resp, nil) return resp, nil } func (s *server) GetUserChats(ctx context.Context, req *proto.GetUserChatsRequest) (*proto.UserChatsResponse, error) { s.logRequest("GetUserChats", req) defer func(start time.Time) { s.logger.Printf("GetUserChats execution time: %v", time.Since(start)) }(time.Now()) s.logger.Printf("Querying chats for user_id=%d", req.UserId) rows, err := s.db.Query(ctx, ` SELECT c.id, c.user1_id, c.user2_id, c.created_at, c.updated_at, m.id, m.encrypted_content, m.encrypted_key, m.nonce, m.status, m.created_at FROM chats c LEFT JOIN messages m ON m.id = ( SELECT id FROM messages WHERE chat_id = c.id ORDER BY created_at DESC LIMIT 1 ) WHERE c.user1_id = $1 OR c.user2_id = $1 ORDER BY c.updated_at DESC `, req.UserId) if err != nil { s.logger.Printf("Failed to query user chats: %v", err) return nil, err } defer rows.Close() var chats []*proto.Chat for rows.Next() { var chat proto.Chat var createdAt, updatedAt time.Time var lastMessageID sql.NullInt32 var lastMessageEncryptedContent, lastMessageEncryptedKey, lastMessageNonce []byte var lastMessageStatus sql.NullString var lastMessageCreatedAt sql.NullTime err := rows.Scan( &chat.Id, &chat.User1Id, &chat.User2Id, &createdAt, &updatedAt, &lastMessageID, &lastMessageEncryptedContent, &lastMessageEncryptedKey, &lastMessageNonce, &lastMessageStatus, &lastMessageCreatedAt, ) if err != nil { s.logger.Printf("Failed to scan chat row: %v", err) return nil, err } chat.CreatedAt = timestamppb.New(createdAt) chat.UpdatedAt = timestamppb.New(updatedAt) if lastMessageID.Valid { // Расшифровываем последнее сообщение decryptedContent, err := s.crypto.DecryptMessage( lastMessageEncryptedContent, lastMessageEncryptedKey, lastMessageNonce, ) if err != nil { s.logger.Printf("Failed to decrypt last message: %v", err) decryptedContent = "[не удалось расшифровать]" } chat.LastMessage = &proto.Message{ Id: lastMessageID.Int32, ChatId: chat.Id, Content: decryptedContent, Status: lastMessageStatus.String, CreatedAt: timestamppb.New(lastMessageCreatedAt.Time), } } chats = append(chats, &chat) } if err := rows.Err(); err != nil { s.logger.Printf("Rows error: %v", err) return nil, err } s.logger.Printf("Retrieved %d chats for user_id=%d", len(chats), req.UserId) resp := &proto.UserChatsResponse{Chats: chats} s.logResponse("GetUserChats", resp, nil) return resp, nil } func (s *server) UpdateMessageStatus(ctx context.Context, req *proto.UpdateMessageStatusRequest) (*proto.MessageResponse, error) { // Если MessageId = 0, обновляем все сообщения для пользователя if req.MessageId == 0 { _, err := s.db.Exec(ctx, ` UPDATE messages SET status = $1 WHERE receiver_id = $2 AND status = 'SENT' `, req.Status, req.UserId) if err != nil { return nil, fmt.Errorf("failed to bulk update status: %w", err) } return &proto.MessageResponse{}, nil } // Обычное обновление одного сообщения var msg proto.Message err := s.db.QueryRow(ctx, ` UPDATE messages SET status = $1 WHERE id = $2 RETURNING id, chat_id, sender_id, receiver_id, content, status, created_at `, req.Status, req.MessageId).Scan( &msg.Id, &msg.ChatId, &msg.SenderId, &msg.ReceiverId, &msg.Content, &msg.Status, &msg.CreatedAt, ) if err != nil { return nil, fmt.Errorf("failed to update status: %w", err) } return &proto.MessageResponse{Message: &msg}, nil } func (s *server) StreamMessages(req *proto.StreamMessagesRequest, stream proto.MessageService_StreamMessagesServer) error { s.logRequest("StreamMessages", req) defer func(start time.Time) { s.logger.Printf("StreamMessages execution time: %v", time.Since(start)) }(time.Now()) if req.UserId == 0 { err := status.Error(codes.InvalidArgument, "userID cannot be 0") s.logResponse("StreamMessages", nil, err) return err } // Создаем контекст с увеличенным таймаутом ctx, cancel := context.WithTimeout(stream.Context(), 24*time.Hour) defer cancel() retryDelay := time.Second const maxRetries = 5 for i := 0; i < maxRetries; i++ { select { case <-ctx.Done(): return nil default: err := s.runStream(ctx, req, stream) if err == nil { return nil } s.logger.Printf("Stream error (attempt %d/%d): %v", i+1, maxRetries, err) time.Sleep(retryDelay) retryDelay *= 2 } } err := fmt.Errorf("max retries (%d) exceeded", maxRetries) s.logResponse("StreamMessages", nil, err) return err } func (s *server) runStream(ctx context.Context, req *proto.StreamMessagesRequest, stream proto.MessageService_StreamMessagesServer) error { queueName := fmt.Sprintf("user_%d_messages", req.UserId) s.logger.Printf("Opening RabbitMQ channel for queue %s", queueName) ch, err := s.rabbitConn.Channel() if err != nil { return fmt.Errorf("failed to open channel: %v", err) } defer func() { s.logger.Printf("Closing RabbitMQ channel for queue %s", queueName) ch.Close() }() s.logger.Printf("Declaring queue %s with persistence", queueName) _, err = ch.QueueDeclare( queueName, true, // durable false, // autoDelete false, // exclusive false, // noWait amqp.Table{ "x-message-ttl": int32(86400000), "x-expires": int32(86400000), "x-single-active-consumer": false, }, ) if err != nil { return fmt.Errorf("failed to declare queue: %v", err) } s.logger.Printf("Setting QoS for queue %s", queueName) err = ch.Qos( 1, // prefetch count 0, // prefetch size false, // global ) if err != nil { return fmt.Errorf("failed to set QoS: %v", err) } s.logger.Printf("Starting consumer for queue %s", queueName) msgs, err := ch.Consume( queueName, "", // consumer false, // auto-ack false, // exclusive false, // noLocal false, // noWait nil, // args ) if err != nil { return fmt.Errorf("failed to consume: %v", err) } s.logger.Printf("Starting message stream for user %d", req.UserId) defer s.logger.Printf("Stopping message stream for user %d", req.UserId) heartbeat := time.NewTicker(60 * time.Second) defer heartbeat.Stop() keepaliveTicker := time.NewTicker(15 * time.Second) defer keepaliveTicker.Stop() for { select { case <-keepaliveTicker.C: if err := stream.Send(&proto.MessageResponse{}); err != nil { return err } case <-ctx.Done(): s.logger.Printf("Context canceled for user %d: %v", req.UserId, ctx.Err()) return nil case <-heartbeat.C: s.logger.Printf("Sending heartbeat for user %d", req.UserId) if err := stream.Send(&proto.MessageResponse{}); err != nil { s.logger.Printf("Failed to send heartbeat: %v", err) return err } case d, ok := <-msgs: if !ok { s.logger.Printf("Message channel closed for user %d", req.UserId) return fmt.Errorf("message channel closed") } s.logger.Printf("Received message from RabbitMQ for user %d: %s", req.UserId, string(d.Body)) var msg proto.Message if err := json.Unmarshal(d.Body, &msg); err != nil { s.logger.Printf("Failed to unmarshal message: %v", err) d.Nack(false, false) continue } if msg.Id == 0 || msg.Content == "" { d.Ack(false) continue } s.logger.Printf("Sending message to stream for user %d: %+v", req.UserId, msg) if err := stream.Send(&proto.MessageResponse{Message: &msg}); err != nil { s.logger.Printf("Failed to send message to stream: %v", err) d.Nack(false, true) return err } s.logger.Printf("Acknowledging message for user %d", req.UserId) d.Ack(false) } } } func main() { logger := log.New(os.Stdout, "MAIN: ", log.LstdFlags|log.Lmicroseconds|log.Lshortfile) logger.Println("Starting message service") // Инициализация подключения к БД dbURL := "postgres://tailly_v2:i0Oq%2675LA%26M612ceuy@79.174.89.104:15452/tailly_v2" logger.Printf("Connecting to database at %s", dbURL) pool, err := pgxpool.Connect(context.Background(), dbURL) if err != nil { logger.Fatalf("Unable to connect to database: %v", err) } defer func() { logger.Println("Closing database connection") pool.Close() }() // Инициализация подключения к RabbitMQ rabbitURL := "amqp://tailly_rabbitmq:o2p2S80MPbl27LUU@89.104.69.222:5673/" logger.Printf("Connecting to RabbitMQ at %s", rabbitURL) var rabbitConn *amqp.Connection for i := 0; i < 5; i++ { rabbitConn, err = amqp.DialConfig(rabbitURL, amqp.Config{ Heartbeat: 10 * time.Second, Locale: "en_US", }) if err == nil { break } logger.Printf("Failed to connect to RabbitMQ (attempt %d): %v", i+1, err) time.Sleep(time.Second * time.Duration(i+1)) } if err != nil { logger.Fatalf("Failed to connect to RabbitMQ after 5 attempts: %v", err) } defer func() { logger.Println("Closing RabbitMQ connection") rabbitConn.Close() }() // Обработка событий соединения connMutex := &sync.Mutex{} go func() { for { reason, ok := <-rabbitConn.NotifyClose(make(chan *amqp.Error)) if !ok { logger.Println("RabbitMQ connection closed normally") break } logger.Printf("RabbitMQ connection closed: %v", reason) // Попытка переподключения for i := 0; i < 5; i++ { time.Sleep(time.Second * time.Duration(i+1)) logger.Printf("Attempting to reconnect to RabbitMQ (attempt %d)", i+1) newConn, err := amqp.DialConfig(rabbitURL, amqp.Config{ Heartbeat: 10 * time.Second, Locale: "en_US", }) if err == nil { connMutex.Lock() rabbitConn = newConn connMutex.Unlock() logger.Println("Successfully reconnected to RabbitMQ") break } logger.Printf("Failed to reconnect to RabbitMQ (attempt %d): %v", i+1, err) } } }() // Создаем gRPC сервер grpcServer := grpc.NewServer( grpc.KeepaliveParams(keepalive.ServerParameters{ MaxConnectionAge: 24 * time.Hour, MaxConnectionAgeGrace: 5 * time.Minute, Time: 30 * time.Second, Timeout: 10 * time.Second, }), grpc.UnaryInterceptor(func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) { logger.Printf("Unary call: %s, request: %+v", info.FullMethod, req) start := time.Now() defer func() { logger.Printf("Unary call %s completed in %v, response: %+v, error: %v", info.FullMethod, time.Since(start), resp, err) }() return handler(ctx, req) }), grpc.StreamInterceptor(func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { logger.Printf("Stream call: %s", info.FullMethod) start := time.Now() defer func() { logger.Printf("Stream call %s completed in %v", info.FullMethod, time.Since(start)) }() return handler(srv, ss) }), ) proto.RegisterMessageServiceServer(grpcServer, NewServer(pool, rabbitConn)) // Запускаем сервер port := ":50052" logger.Printf("Starting gRPC server on port %s", port) lis, err := net.Listen("tcp", port) if err != nil { logger.Fatalf("failed to listen: %v", err) } logger.Println("Server is ready to accept connections") if err := grpcServer.Serve(lis); err != nil { logger.Fatalf("failed to serve: %v", err) } }