743 lines
23 KiB
Go
743 lines
23 KiB
Go
package main
|
||
|
||
import (
|
||
"context"
|
||
"database/sql"
|
||
"encoding/json"
|
||
"errors"
|
||
"fmt"
|
||
"github.com/jackc/pgx/v4/pgxpool"
|
||
amqp "github.com/rabbitmq/amqp091-go"
|
||
"google.golang.org/grpc"
|
||
"google.golang.org/grpc/codes"
|
||
"google.golang.org/grpc/keepalive"
|
||
"google.golang.org/grpc/status"
|
||
"google.golang.org/protobuf/types/known/timestamppb"
|
||
"log"
|
||
"net"
|
||
"os"
|
||
"sync"
|
||
"tailly_messages/crypto"
|
||
"tailly_messages/proto"
|
||
"time"
|
||
)
|
||
|
||
type server struct {
|
||
proto.UnimplementedMessageServiceServer
|
||
db *pgxpool.Pool
|
||
rabbitConn *amqp.Connection
|
||
mu sync.Mutex
|
||
logger *log.Logger
|
||
crypto *crypto.CryptoService
|
||
}
|
||
|
||
func NewServer(db *pgxpool.Pool, rabbitConn *amqp.Connection) *server {
|
||
cryptoService, err := crypto.NewCryptoService()
|
||
if err != nil {
|
||
log.Printf("ERROR: Failed to create crypto service: %v", err)
|
||
}
|
||
return &server{
|
||
db: db,
|
||
rabbitConn: rabbitConn,
|
||
crypto: cryptoService,
|
||
logger: log.New(os.Stdout, "MESSAGE_SERVICE: ", log.LstdFlags|log.Lmicroseconds|log.Lshortfile),
|
||
}
|
||
}
|
||
|
||
func (s *server) logRequest(method string, req interface{}) {
|
||
s.logger.Printf("REQUEST: %s - %+v", method, req)
|
||
}
|
||
|
||
func (s *server) logResponse(method string, resp interface{}, err error) {
|
||
if err != nil {
|
||
s.logger.Printf("RESPONSE ERROR: %s - %v", method, err)
|
||
} else {
|
||
s.logger.Printf("RESPONSE: %s - %+v", method, resp)
|
||
}
|
||
}
|
||
|
||
func (s *server) CreateChat(ctx context.Context, req *proto.CreateChatRequest) (*proto.ChatResponse, error) {
|
||
s.logRequest("CreateChat", req)
|
||
defer func(start time.Time) {
|
||
s.logger.Printf("CreateChat execution time: %v", time.Since(start))
|
||
}(time.Now())
|
||
|
||
user1, user2 := req.GetUser1Id(), req.GetUser2Id()
|
||
if user1 > user2 {
|
||
user1, user2 = user2, user1
|
||
}
|
||
|
||
s.logger.Printf("Checking user existence: user1=%d, user2=%d", user1, user2)
|
||
var user1Exists, user2Exists bool
|
||
err := s.db.QueryRow(ctx, "SELECT EXISTS(SELECT 1 FROM users WHERE id = $1)", user1).Scan(&user1Exists)
|
||
if err != nil {
|
||
s.logger.Printf("Error checking user1 existence: %v", err)
|
||
return nil, fmt.Errorf("failed to check user existence")
|
||
}
|
||
err = s.db.QueryRow(ctx, "SELECT EXISTS(SELECT 1 FROM users WHERE id = $1)", user2).Scan(&user2Exists)
|
||
if err != nil {
|
||
s.logger.Printf("Error checking user2 existence: %v", err)
|
||
return nil, fmt.Errorf("failed to check user existence")
|
||
}
|
||
|
||
if !user1Exists || !user2Exists {
|
||
errMsg := fmt.Sprintf("One or both users don't exist: user1=%d (%t), user2=%d (%t)",
|
||
user1, user1Exists, user2, user2Exists)
|
||
s.logger.Println(errMsg)
|
||
return nil, fmt.Errorf("%w", errors.New(errMsg))
|
||
}
|
||
|
||
var chat proto.Chat
|
||
var createdAt, updatedAt time.Time
|
||
|
||
s.logger.Printf("Checking chat existence between users %d and %d", user1, user2)
|
||
var chatExists bool
|
||
err = s.db.QueryRow(ctx, "SELECT EXISTS(SELECT 1 FROM chats WHERE user1_id = $1 AND user2_id = $2)",
|
||
user1, user2).Scan(&chatExists)
|
||
if err != nil {
|
||
s.logger.Printf("Error checking chat existence: %v", err)
|
||
return nil, fmt.Errorf("failed to check chat existence")
|
||
}
|
||
|
||
if chatExists {
|
||
s.logger.Printf("Chat already exists between users %d and %d, returning existing chat", user1, user2)
|
||
return s.GetChat(ctx, &proto.GetChatRequest{
|
||
User1Id: user1,
|
||
User2Id: user2,
|
||
})
|
||
}
|
||
|
||
s.logger.Printf("Creating new chat between users %d and %d", user1, user2)
|
||
err = s.db.QueryRow(ctx, `
|
||
INSERT INTO chats (user1_id, user2_id)
|
||
VALUES ($1, $2)
|
||
RETURNING id, user1_id, user2_id, created_at, updated_at
|
||
`, user1, user2).Scan(
|
||
&chat.Id, &chat.User1Id, &chat.User2Id, &createdAt, &updatedAt,
|
||
)
|
||
if err != nil {
|
||
s.logger.Printf("Failed to create chat: %v", err)
|
||
return nil, fmt.Errorf("failed to create chat: %v", err)
|
||
}
|
||
|
||
s.logger.Printf("Successfully created new chat: id=%d, user1_id=%d, user2_id=%d",
|
||
chat.Id, chat.User1Id, chat.User2Id)
|
||
|
||
chat.CreatedAt = timestamppb.New(createdAt)
|
||
chat.UpdatedAt = timestamppb.New(updatedAt)
|
||
|
||
resp := &proto.ChatResponse{Chat: &chat}
|
||
s.logResponse("CreateChat", resp, nil)
|
||
return resp, nil
|
||
}
|
||
|
||
func (s *server) SendMessage(ctx context.Context, req *proto.SendMessageRequest) (*proto.MessageResponse, error) {
|
||
s.logRequest("SendMessage", req)
|
||
defer func(start time.Time) {
|
||
s.logger.Printf("SendMessage execution time: %v", time.Since(start))
|
||
}(time.Now())
|
||
|
||
// Шифруем сообщение
|
||
encryptedContent, encryptedKey, nonce, err := s.crypto.EncryptMessage(req.Content)
|
||
if err != nil {
|
||
s.logger.Printf("Failed to encrypt message: %v", err)
|
||
return nil, fmt.Errorf("failed to encrypt message: %v", err)
|
||
}
|
||
|
||
s.logger.Printf("Getting chat info for chat_id=%d", req.ChatId)
|
||
var user1Id, user2Id int32
|
||
err = s.db.QueryRow(ctx, "SELECT user1_id, user2_id FROM chats WHERE id = $1", req.ChatId).Scan(&user1Id, &user2Id)
|
||
if err != nil {
|
||
s.logger.Printf("Failed to get chat info: %v", err)
|
||
return nil, fmt.Errorf("failed to get chat info: %v", err)
|
||
}
|
||
|
||
var receiverId int32
|
||
if req.SenderId == user1Id {
|
||
receiverId = user2Id
|
||
} else if req.SenderId == user2Id {
|
||
receiverId = user1Id
|
||
} else {
|
||
errMsg := fmt.Sprintf("sender %d is not a participant of chat %d", req.SenderId, req.ChatId)
|
||
s.logger.Println(errMsg)
|
||
return nil, errors.New(errMsg)
|
||
}
|
||
|
||
s.logger.Printf("Inserting encrypted message into database: chat_id=%d, sender_id=%d, receiver_id=%d",
|
||
req.ChatId, req.SenderId, receiverId)
|
||
var message proto.Message
|
||
var createdAt time.Time
|
||
|
||
err = s.db.QueryRow(ctx, `
|
||
INSERT INTO messages (chat_id, sender_id, receiver_id, encrypted_content, encrypted_key, nonce)
|
||
VALUES ($1, $2, $3, $4, $5, $6)
|
||
RETURNING id, chat_id, sender_id, receiver_id, status, created_at
|
||
`, req.ChatId, req.SenderId, receiverId, encryptedContent, encryptedKey, nonce).Scan(
|
||
&message.Id, &message.ChatId, &message.SenderId, &message.ReceiverId, &message.Status, &createdAt,
|
||
)
|
||
if err != nil {
|
||
s.logger.Printf("Failed to insert encrypted message: %v", err)
|
||
return nil, err
|
||
}
|
||
|
||
message.CreatedAt = timestamppb.New(createdAt)
|
||
|
||
|
||
decryptedContent, err := s.crypto.DecryptMessage(encryptedContent, encryptedKey, nonce)
|
||
if err != nil {
|
||
s.logger.Printf("Failed to decrypt for RabbitMQ: %v", err)
|
||
return nil, err
|
||
}
|
||
message.Content = decryptedContent
|
||
|
||
s.logger.Printf("Updating chat updated_at for chat_id=%d", req.ChatId)
|
||
_, err = s.db.Exec(ctx, `UPDATE chats SET updated_at = NOW() WHERE id = $1`, req.ChatId)
|
||
if err != nil {
|
||
s.logger.Printf("Failed to update chat timestamp: %v", err)
|
||
return nil, err
|
||
}
|
||
|
||
s.logger.Printf("Publishing message to RabbitMQ for user_id=%d", receiverId)
|
||
var lastErr error
|
||
for i := 0; i < 3; i++ {
|
||
ch, err := s.rabbitConn.Channel()
|
||
if err != nil {
|
||
lastErr = fmt.Errorf("failed to open channel (attempt %d): %v", i+1, err)
|
||
s.logger.Printf("RabbitMQ channel error: %v", lastErr)
|
||
time.Sleep(time.Second * time.Duration(i+1))
|
||
continue
|
||
}
|
||
|
||
queueName := fmt.Sprintf("user_%d_messages", receiverId)
|
||
msgBytes, _ := json.Marshal(message)
|
||
|
||
s.logger.Printf("Publishing to queue %s: %s", queueName, string(msgBytes))
|
||
err = ch.PublishWithContext(ctx,
|
||
"", // exchange
|
||
queueName, // routing key
|
||
false, // mandatory
|
||
false, // immediate
|
||
amqp.Publishing{
|
||
ContentType: "application/json",
|
||
Body: msgBytes,
|
||
DeliveryMode: amqp.Persistent,
|
||
})
|
||
|
||
ch.Close()
|
||
if err == nil {
|
||
s.logger.Printf("Successfully published message to queue %s", queueName)
|
||
resp := &proto.MessageResponse{Message: &message}
|
||
s.logResponse("SendMessage", resp, nil)
|
||
return resp, nil
|
||
}
|
||
lastErr = err
|
||
s.logger.Printf("Failed to publish message (attempt %d): %v", i+1, err)
|
||
}
|
||
|
||
errMsg := fmt.Errorf("failed to publish message after 3 attempts: %v", lastErr)
|
||
s.logResponse("SendMessage", nil, errMsg)
|
||
return nil, errMsg
|
||
}
|
||
|
||
func (s *server) GetChat(ctx context.Context, req *proto.GetChatRequest) (*proto.ChatResponse, error) {
|
||
s.logRequest("GetChat", req)
|
||
defer func(start time.Time) {
|
||
s.logger.Printf("GetChat execution time: %v", time.Since(start))
|
||
}(time.Now())
|
||
|
||
var chat proto.Chat
|
||
var createdAt, updatedAt time.Time
|
||
var lastMessageID sql.NullInt32
|
||
var lastMessageEncryptedContent, lastMessageEncryptedKey, lastMessageNonce []byte
|
||
var lastMessageStatus sql.NullString
|
||
var lastMessageCreatedAt sql.NullTime
|
||
|
||
user1, user2 := req.User1Id, req.User2Id
|
||
if user1 > user2 {
|
||
user1, user2 = user2, user1
|
||
}
|
||
|
||
s.logger.Printf("Querying chat between users %d and %d", user1, user2)
|
||
err := s.db.QueryRow(ctx, `
|
||
SELECT c.id, c.user1_id, c.user2_id, c.created_at, c.updated_at,
|
||
m.id, m.encrypted_content, m.encrypted_key, m.nonce, m.status, m.created_at
|
||
FROM chats c
|
||
LEFT JOIN messages m ON m.id = (
|
||
SELECT id FROM messages WHERE chat_id = c.id
|
||
ORDER BY created_at DESC LIMIT 1
|
||
)
|
||
WHERE c.user1_id = $1 AND c.user2_id = $2
|
||
`, user1, user2).Scan(
|
||
&chat.Id, &chat.User1Id, &chat.User2Id, &createdAt, &updatedAt,
|
||
&lastMessageID, &lastMessageEncryptedContent, &lastMessageEncryptedKey,
|
||
&lastMessageNonce, &lastMessageStatus, &lastMessageCreatedAt,
|
||
)
|
||
if err != nil {
|
||
s.logger.Printf("Failed to get chat: %v", err)
|
||
return nil, err
|
||
}
|
||
|
||
chat.CreatedAt = timestamppb.New(createdAt)
|
||
chat.UpdatedAt = timestamppb.New(updatedAt)
|
||
|
||
if lastMessageID.Valid {
|
||
// Расшифровываем последнее сообщение
|
||
decryptedContent, err := s.crypto.DecryptMessage(
|
||
lastMessageEncryptedContent,
|
||
lastMessageEncryptedKey,
|
||
lastMessageNonce,
|
||
)
|
||
if err != nil {
|
||
s.logger.Printf("Failed to decrypt last message: %v", err)
|
||
decryptedContent = "[не удалось расшифровать]"
|
||
}
|
||
|
||
chat.LastMessage = &proto.Message{
|
||
Id: lastMessageID.Int32,
|
||
ChatId: chat.Id,
|
||
Content: decryptedContent,
|
||
Status: lastMessageStatus.String,
|
||
CreatedAt: timestamppb.New(lastMessageCreatedAt.Time),
|
||
}
|
||
s.logger.Printf("Found last message for chat %d: message_id=%d", chat.Id, lastMessageID.Int32)
|
||
}
|
||
|
||
resp := &proto.ChatResponse{Chat: &chat}
|
||
s.logResponse("GetChat", resp, nil)
|
||
return resp, nil
|
||
}
|
||
|
||
func (s *server) GetChatMessages(ctx context.Context, req *proto.GetChatMessagesRequest) (*proto.MessagesResponse, error) {
|
||
s.logRequest("GetChatMessages", req)
|
||
defer func(start time.Time) {
|
||
s.logger.Printf("GetChatMessages execution time: %v", time.Since(start))
|
||
}(time.Now())
|
||
|
||
s.logger.Printf("Querying messages for chat_id=%d, limit=%d, offset=%d", req.ChatId, req.Limit, req.Offset)
|
||
rows, err := s.db.Query(ctx, `
|
||
SELECT id, chat_id, sender_id, receiver_id, encrypted_content, encrypted_key, nonce, status, created_at
|
||
FROM messages
|
||
WHERE chat_id = $1
|
||
ORDER BY created_at DESC
|
||
LIMIT $2 OFFSET $3
|
||
`, req.ChatId, req.Limit, req.Offset)
|
||
if err != nil {
|
||
s.logger.Printf("Failed to query messages: %v", err)
|
||
return nil, err
|
||
}
|
||
defer rows.Close()
|
||
|
||
var messages []*proto.Message
|
||
for rows.Next() {
|
||
var msg proto.Message
|
||
var createdAt time.Time
|
||
var encryptedContent, encryptedKey, nonce []byte
|
||
|
||
err := rows.Scan(
|
||
&msg.Id, &msg.ChatId, &msg.SenderId, &msg.ReceiverId,
|
||
&encryptedContent, &encryptedKey, &nonce,
|
||
&msg.Status, &createdAt,
|
||
)
|
||
if err != nil {
|
||
s.logger.Printf("Failed to scan message row: %v", err)
|
||
return nil, err
|
||
}
|
||
|
||
// Расшифровываем сообщение
|
||
decryptedContent, err := s.crypto.DecryptMessage(encryptedContent, encryptedKey, nonce)
|
||
if err != nil {
|
||
s.logger.Printf("Failed to decrypt message %d: %v", msg.Id, err)
|
||
msg.Content = "[не удалось расшифровать]"
|
||
} else {
|
||
msg.Content = decryptedContent
|
||
}
|
||
|
||
msg.CreatedAt = timestamppb.New(createdAt)
|
||
messages = append(messages, &msg)
|
||
}
|
||
|
||
if err := rows.Err(); err != nil {
|
||
s.logger.Printf("Rows error: %v", err)
|
||
return nil, err
|
||
}
|
||
|
||
s.logger.Printf("Retrieved %d messages for chat_id=%d", len(messages), req.ChatId)
|
||
resp := &proto.MessagesResponse{Messages: messages}
|
||
s.logResponse("GetChatMessages", resp, nil)
|
||
return resp, nil
|
||
}
|
||
|
||
func (s *server) GetUserChats(ctx context.Context, req *proto.GetUserChatsRequest) (*proto.UserChatsResponse, error) {
|
||
s.logRequest("GetUserChats", req)
|
||
defer func(start time.Time) {
|
||
s.logger.Printf("GetUserChats execution time: %v", time.Since(start))
|
||
}(time.Now())
|
||
|
||
s.logger.Printf("Querying chats for user_id=%d", req.UserId)
|
||
rows, err := s.db.Query(ctx, `
|
||
SELECT c.id, c.user1_id, c.user2_id, c.created_at, c.updated_at,
|
||
m.id, m.encrypted_content, m.encrypted_key, m.nonce, m.status, m.created_at
|
||
FROM chats c
|
||
LEFT JOIN messages m ON m.id = (
|
||
SELECT id FROM messages WHERE chat_id = c.id
|
||
ORDER BY created_at DESC LIMIT 1
|
||
)
|
||
WHERE c.user1_id = $1 OR c.user2_id = $1
|
||
ORDER BY c.updated_at DESC
|
||
`, req.UserId)
|
||
if err != nil {
|
||
s.logger.Printf("Failed to query user chats: %v", err)
|
||
return nil, err
|
||
}
|
||
defer rows.Close()
|
||
|
||
var chats []*proto.Chat
|
||
for rows.Next() {
|
||
var chat proto.Chat
|
||
var createdAt, updatedAt time.Time
|
||
var lastMessageID sql.NullInt32
|
||
var lastMessageEncryptedContent, lastMessageEncryptedKey, lastMessageNonce []byte
|
||
var lastMessageStatus sql.NullString
|
||
var lastMessageCreatedAt sql.NullTime
|
||
|
||
err := rows.Scan(
|
||
&chat.Id, &chat.User1Id, &chat.User2Id, &createdAt, &updatedAt,
|
||
&lastMessageID, &lastMessageEncryptedContent, &lastMessageEncryptedKey,
|
||
&lastMessageNonce, &lastMessageStatus, &lastMessageCreatedAt,
|
||
)
|
||
if err != nil {
|
||
s.logger.Printf("Failed to scan chat row: %v", err)
|
||
return nil, err
|
||
}
|
||
|
||
chat.CreatedAt = timestamppb.New(createdAt)
|
||
chat.UpdatedAt = timestamppb.New(updatedAt)
|
||
|
||
if lastMessageID.Valid {
|
||
// Расшифровываем последнее сообщение
|
||
decryptedContent, err := s.crypto.DecryptMessage(
|
||
lastMessageEncryptedContent,
|
||
lastMessageEncryptedKey,
|
||
lastMessageNonce,
|
||
)
|
||
if err != nil {
|
||
s.logger.Printf("Failed to decrypt last message: %v", err)
|
||
decryptedContent = "[не удалось расшифровать]"
|
||
}
|
||
|
||
chat.LastMessage = &proto.Message{
|
||
Id: lastMessageID.Int32,
|
||
ChatId: chat.Id,
|
||
Content: decryptedContent,
|
||
Status: lastMessageStatus.String,
|
||
CreatedAt: timestamppb.New(lastMessageCreatedAt.Time),
|
||
}
|
||
}
|
||
|
||
chats = append(chats, &chat)
|
||
}
|
||
|
||
if err := rows.Err(); err != nil {
|
||
s.logger.Printf("Rows error: %v", err)
|
||
return nil, err
|
||
}
|
||
|
||
s.logger.Printf("Retrieved %d chats for user_id=%d", len(chats), req.UserId)
|
||
resp := &proto.UserChatsResponse{Chats: chats}
|
||
s.logResponse("GetUserChats", resp, nil)
|
||
return resp, nil
|
||
}
|
||
|
||
func (s *server) UpdateMessageStatus(ctx context.Context, req *proto.UpdateMessageStatusRequest) (*proto.MessageResponse, error) {
|
||
// Если MessageId = 0, обновляем все сообщения для пользователя
|
||
if req.MessageId == 0 {
|
||
_, err := s.db.Exec(ctx, `
|
||
UPDATE messages
|
||
SET status = $1
|
||
WHERE receiver_id = $2 AND status = 'SENT'
|
||
`, req.Status, req.UserId)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to bulk update status: %w", err)
|
||
}
|
||
return &proto.MessageResponse{}, nil
|
||
}
|
||
|
||
// Обычное обновление одного сообщения
|
||
var msg proto.Message
|
||
err := s.db.QueryRow(ctx, `
|
||
UPDATE messages
|
||
SET status = $1
|
||
WHERE id = $2
|
||
RETURNING id, chat_id, sender_id, receiver_id, content, status, created_at
|
||
`, req.Status, req.MessageId).Scan(
|
||
&msg.Id, &msg.ChatId, &msg.SenderId, &msg.ReceiverId,
|
||
&msg.Content, &msg.Status, &msg.CreatedAt,
|
||
)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to update status: %w", err)
|
||
}
|
||
return &proto.MessageResponse{Message: &msg}, nil
|
||
}
|
||
|
||
func (s *server) StreamMessages(req *proto.StreamMessagesRequest, stream proto.MessageService_StreamMessagesServer) error {
|
||
s.logRequest("StreamMessages", req)
|
||
defer func(start time.Time) {
|
||
s.logger.Printf("StreamMessages execution time: %v", time.Since(start))
|
||
}(time.Now())
|
||
|
||
if req.UserId == 0 {
|
||
err := status.Error(codes.InvalidArgument, "userID cannot be 0")
|
||
s.logResponse("StreamMessages", nil, err)
|
||
return err
|
||
}
|
||
|
||
// Создаем контекст с увеличенным таймаутом
|
||
ctx, cancel := context.WithTimeout(stream.Context(), 24*time.Hour)
|
||
defer cancel()
|
||
|
||
retryDelay := time.Second
|
||
const maxRetries = 5
|
||
|
||
for i := 0; i < maxRetries; i++ {
|
||
select {
|
||
case <-ctx.Done():
|
||
return nil
|
||
default:
|
||
err := s.runStream(ctx, req, stream)
|
||
if err == nil {
|
||
return nil
|
||
}
|
||
|
||
s.logger.Printf("Stream error (attempt %d/%d): %v", i+1, maxRetries, err)
|
||
|
||
time.Sleep(retryDelay)
|
||
retryDelay *= 2
|
||
}
|
||
}
|
||
|
||
err := fmt.Errorf("max retries (%d) exceeded", maxRetries)
|
||
s.logResponse("StreamMessages", nil, err)
|
||
return err
|
||
}
|
||
|
||
func (s *server) runStream(ctx context.Context, req *proto.StreamMessagesRequest, stream proto.MessageService_StreamMessagesServer) error {
|
||
queueName := fmt.Sprintf("user_%d_messages", req.UserId)
|
||
|
||
s.logger.Printf("Opening RabbitMQ channel for queue %s", queueName)
|
||
ch, err := s.rabbitConn.Channel()
|
||
if err != nil {
|
||
return fmt.Errorf("failed to open channel: %v", err)
|
||
}
|
||
defer func() {
|
||
s.logger.Printf("Closing RabbitMQ channel for queue %s", queueName)
|
||
ch.Close()
|
||
}()
|
||
|
||
s.logger.Printf("Declaring queue %s with persistence", queueName)
|
||
_, err = ch.QueueDeclare(
|
||
queueName,
|
||
true, // durable
|
||
false, // autoDelete
|
||
false, // exclusive
|
||
false, // noWait
|
||
amqp.Table{
|
||
"x-message-ttl": int32(86400000),
|
||
"x-expires": int32(86400000),
|
||
"x-single-active-consumer": false,
|
||
},
|
||
)
|
||
if err != nil {
|
||
return fmt.Errorf("failed to declare queue: %v", err)
|
||
}
|
||
|
||
s.logger.Printf("Setting QoS for queue %s", queueName)
|
||
err = ch.Qos(
|
||
1, // prefetch count
|
||
0, // prefetch size
|
||
false, // global
|
||
)
|
||
if err != nil {
|
||
return fmt.Errorf("failed to set QoS: %v", err)
|
||
}
|
||
|
||
s.logger.Printf("Starting consumer for queue %s", queueName)
|
||
msgs, err := ch.Consume(
|
||
queueName,
|
||
"", // consumer
|
||
false, // auto-ack
|
||
false, // exclusive
|
||
false, // noLocal
|
||
false, // noWait
|
||
nil, // args
|
||
)
|
||
if err != nil {
|
||
return fmt.Errorf("failed to consume: %v", err)
|
||
}
|
||
|
||
s.logger.Printf("Starting message stream for user %d", req.UserId)
|
||
defer s.logger.Printf("Stopping message stream for user %d", req.UserId)
|
||
|
||
heartbeat := time.NewTicker(60 * time.Second)
|
||
defer heartbeat.Stop()
|
||
keepaliveTicker := time.NewTicker(15 * time.Second)
|
||
defer keepaliveTicker.Stop()
|
||
for {
|
||
select {
|
||
case <-keepaliveTicker.C:
|
||
if err := stream.Send(&proto.MessageResponse{}); err != nil {
|
||
return err
|
||
}
|
||
case <-ctx.Done():
|
||
s.logger.Printf("Context canceled for user %d: %v", req.UserId, ctx.Err())
|
||
return nil
|
||
case <-heartbeat.C:
|
||
s.logger.Printf("Sending heartbeat for user %d", req.UserId)
|
||
if err := stream.Send(&proto.MessageResponse{}); err != nil {
|
||
s.logger.Printf("Failed to send heartbeat: %v", err)
|
||
return err
|
||
}
|
||
case d, ok := <-msgs:
|
||
if !ok {
|
||
s.logger.Printf("Message channel closed for user %d", req.UserId)
|
||
return fmt.Errorf("message channel closed")
|
||
}
|
||
|
||
s.logger.Printf("Received message from RabbitMQ for user %d: %s", req.UserId, string(d.Body))
|
||
var msg proto.Message
|
||
if err := json.Unmarshal(d.Body, &msg); err != nil {
|
||
s.logger.Printf("Failed to unmarshal message: %v", err)
|
||
d.Nack(false, false)
|
||
continue
|
||
}
|
||
|
||
if msg.Id == 0 || msg.Content == "" {
|
||
d.Ack(false)
|
||
continue
|
||
}
|
||
|
||
s.logger.Printf("Sending message to stream for user %d: %+v", req.UserId, msg)
|
||
if err := stream.Send(&proto.MessageResponse{Message: &msg}); err != nil {
|
||
s.logger.Printf("Failed to send message to stream: %v", err)
|
||
d.Nack(false, true)
|
||
return err
|
||
}
|
||
|
||
s.logger.Printf("Acknowledging message for user %d", req.UserId)
|
||
d.Ack(false)
|
||
}
|
||
}
|
||
}
|
||
|
||
func main() {
|
||
logger := log.New(os.Stdout, "MAIN: ", log.LstdFlags|log.Lmicroseconds|log.Lshortfile)
|
||
logger.Println("Starting message service")
|
||
|
||
// Инициализация подключения к БД
|
||
dbURL := "postgres://tailly_v2:i0Oq%2675LA%26M612ceuy@79.174.89.104:15452/tailly_v2"
|
||
logger.Printf("Connecting to database at %s", dbURL)
|
||
pool, err := pgxpool.Connect(context.Background(), dbURL)
|
||
if err != nil {
|
||
logger.Fatalf("Unable to connect to database: %v", err)
|
||
}
|
||
defer func() {
|
||
logger.Println("Closing database connection")
|
||
pool.Close()
|
||
}()
|
||
|
||
// Инициализация подключения к RabbitMQ
|
||
rabbitURL := "amqp://tailly_rabbitmq:o2p2S80MPbl27LUU@89.104.69.222:5673/"
|
||
logger.Printf("Connecting to RabbitMQ at %s", rabbitURL)
|
||
var rabbitConn *amqp.Connection
|
||
|
||
for i := 0; i < 5; i++ {
|
||
rabbitConn, err = amqp.DialConfig(rabbitURL, amqp.Config{
|
||
Heartbeat: 10 * time.Second,
|
||
Locale: "en_US",
|
||
})
|
||
if err == nil {
|
||
break
|
||
}
|
||
logger.Printf("Failed to connect to RabbitMQ (attempt %d): %v", i+1, err)
|
||
time.Sleep(time.Second * time.Duration(i+1))
|
||
}
|
||
if err != nil {
|
||
logger.Fatalf("Failed to connect to RabbitMQ after 5 attempts: %v", err)
|
||
}
|
||
defer func() {
|
||
logger.Println("Closing RabbitMQ connection")
|
||
rabbitConn.Close()
|
||
}()
|
||
|
||
// Обработка событий соединения
|
||
connMutex := &sync.Mutex{}
|
||
go func() {
|
||
for {
|
||
reason, ok := <-rabbitConn.NotifyClose(make(chan *amqp.Error))
|
||
if !ok {
|
||
logger.Println("RabbitMQ connection closed normally")
|
||
break
|
||
}
|
||
logger.Printf("RabbitMQ connection closed: %v", reason)
|
||
|
||
// Попытка переподключения
|
||
for i := 0; i < 5; i++ {
|
||
time.Sleep(time.Second * time.Duration(i+1))
|
||
logger.Printf("Attempting to reconnect to RabbitMQ (attempt %d)", i+1)
|
||
newConn, err := amqp.DialConfig(rabbitURL, amqp.Config{
|
||
Heartbeat: 10 * time.Second,
|
||
Locale: "en_US",
|
||
})
|
||
if err == nil {
|
||
connMutex.Lock()
|
||
rabbitConn = newConn
|
||
connMutex.Unlock()
|
||
logger.Println("Successfully reconnected to RabbitMQ")
|
||
break
|
||
}
|
||
logger.Printf("Failed to reconnect to RabbitMQ (attempt %d): %v", i+1, err)
|
||
}
|
||
}
|
||
}()
|
||
|
||
// Создаем gRPC сервер
|
||
grpcServer := grpc.NewServer(
|
||
grpc.KeepaliveParams(keepalive.ServerParameters{
|
||
MaxConnectionAge: 24 * time.Hour,
|
||
MaxConnectionAgeGrace: 5 * time.Minute,
|
||
Time: 30 * time.Second,
|
||
Timeout: 10 * time.Second,
|
||
}),
|
||
grpc.UnaryInterceptor(func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
|
||
logger.Printf("Unary call: %s, request: %+v", info.FullMethod, req)
|
||
start := time.Now()
|
||
defer func() {
|
||
logger.Printf("Unary call %s completed in %v, response: %+v, error: %v",
|
||
info.FullMethod, time.Since(start), resp, err)
|
||
}()
|
||
return handler(ctx, req)
|
||
}),
|
||
grpc.StreamInterceptor(func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
|
||
logger.Printf("Stream call: %s", info.FullMethod)
|
||
start := time.Now()
|
||
defer func() {
|
||
logger.Printf("Stream call %s completed in %v", info.FullMethod, time.Since(start))
|
||
}()
|
||
return handler(srv, ss)
|
||
}),
|
||
)
|
||
proto.RegisterMessageServiceServer(grpcServer, NewServer(pool, rabbitConn))
|
||
|
||
// Запускаем сервер
|
||
port := ":50052"
|
||
logger.Printf("Starting gRPC server on port %s", port)
|
||
lis, err := net.Listen("tcp", port)
|
||
if err != nil {
|
||
logger.Fatalf("failed to listen: %v", err)
|
||
}
|
||
|
||
logger.Println("Server is ready to accept connections")
|
||
if err := grpcServer.Serve(lis); err != nil {
|
||
logger.Fatalf("failed to serve: %v", err)
|
||
}
|
||
}
|