505 lines
14 KiB
Go
505 lines
14 KiB
Go
package main
|
||
|
||
import (
|
||
"context"
|
||
"database/sql"
|
||
"encoding/json"
|
||
"fmt"
|
||
"github.com/jackc/pgx/v4/pgxpool"
|
||
amqp "github.com/rabbitmq/amqp091-go"
|
||
"google.golang.org/grpc"
|
||
protobuf "google.golang.org/protobuf/proto"
|
||
"google.golang.org/protobuf/types/known/timestamppb"
|
||
"log"
|
||
"net"
|
||
"sync"
|
||
"tailly_messages/proto"
|
||
"time"
|
||
)
|
||
|
||
type server struct {
|
||
proto.UnimplementedMessageServiceServer
|
||
db *pgxpool.Pool
|
||
rabbitConn *amqp.Connection
|
||
mu sync.Mutex
|
||
}
|
||
|
||
func NewServer(db *pgxpool.Pool, rabbitConn *amqp.Connection) *server {
|
||
return &server{db: db, rabbitConn: rabbitConn}
|
||
}
|
||
|
||
func (s *server) CreateChat(ctx context.Context, req *proto.CreateChatRequest) (*proto.ChatResponse, error) {
|
||
log.Printf("CreateChat request received: user1_id=%d, user2_id=%d", req.GetUser1Id(), req.GetUser2Id())
|
||
|
||
user1, user2 := req.GetUser1Id(), req.GetUser2Id()
|
||
if user1 > user2 {
|
||
user1, user2 = user2, user1
|
||
}
|
||
|
||
// Проверка существования пользователей
|
||
var user1Exists, user2Exists bool
|
||
err := s.db.QueryRow(ctx, "SELECT EXISTS(SELECT 1 FROM users WHERE id = $1)", user1).Scan(&user1Exists)
|
||
if err != nil {
|
||
log.Printf("Error checking user1 existence: %v", err)
|
||
return nil, fmt.Errorf("failed to check user existence")
|
||
}
|
||
err = s.db.QueryRow(ctx, "SELECT EXISTS(SELECT 1 FROM users WHERE id = $1)", user2).Scan(&user2Exists)
|
||
if err != nil {
|
||
log.Printf("Error checking user2 existence: %v", err)
|
||
return nil, fmt.Errorf("failed to check user existence")
|
||
}
|
||
|
||
if !user1Exists || !user2Exists {
|
||
errMsg := fmt.Sprintf("One or both users don't exist: user1=%d (%t), user2=%d (%t)",
|
||
user1, user1Exists, user2, user2Exists)
|
||
log.Println(errMsg)
|
||
return nil, fmt.Errorf(errMsg)
|
||
}
|
||
|
||
var chat proto.Chat
|
||
var createdAt, updatedAt time.Time
|
||
|
||
// Проверяем, не существует ли уже чат
|
||
var chatExists bool
|
||
err = s.db.QueryRow(ctx, "SELECT EXISTS(SELECT 1 FROM chats WHERE user1_id = $1 AND user2_id = $2)",
|
||
user1, user2).Scan(&chatExists)
|
||
if err != nil {
|
||
log.Printf("Error checking chat existence: %v", err)
|
||
return nil, fmt.Errorf("failed to check chat existence")
|
||
}
|
||
|
||
if chatExists {
|
||
log.Printf("Chat already exists between users %d and %d", user1, user2)
|
||
return s.GetChat(ctx, &proto.GetChatRequest{
|
||
User1Id: user1,
|
||
User2Id: user2,
|
||
})
|
||
}
|
||
|
||
// Создаем новый чат
|
||
err = s.db.QueryRow(ctx, `
|
||
INSERT INTO chats (user1_id, user2_id)
|
||
VALUES ($1, $2)
|
||
RETURNING id, user1_id, user2_id, created_at, updated_at
|
||
`, user1, user2).Scan(
|
||
&chat.Id, &chat.User1Id, &chat.User2Id, &createdAt, &updatedAt,
|
||
)
|
||
if err != nil {
|
||
log.Printf("Failed to create chat: %v", err)
|
||
return nil, fmt.Errorf("failed to create chat: %v", err)
|
||
}
|
||
|
||
log.Printf("Successfully created new chat: id=%d, user1_id=%d, user2_id=%d",
|
||
chat.Id, chat.User1Id, chat.User2Id)
|
||
|
||
chat.CreatedAt = timestamppb.New(createdAt)
|
||
chat.UpdatedAt = timestamppb.New(updatedAt)
|
||
|
||
return &proto.ChatResponse{Chat: &chat}, nil
|
||
}
|
||
|
||
func (s *server) SendMessage(ctx context.Context, req *proto.SendMessageRequest) (*proto.MessageResponse, error) {
|
||
// Получаем информацию о чате (как в оригинале)
|
||
var user1Id, user2Id int32
|
||
err := s.db.QueryRow(ctx, "SELECT user1_id, user2_id FROM chats WHERE id = $1", req.ChatId).Scan(&user1Id, &user2Id)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to get chat info: %v", err)
|
||
}
|
||
|
||
// Определяем получателя
|
||
var receiverId int32
|
||
if req.SenderId == user1Id {
|
||
receiverId = user2Id
|
||
} else if req.SenderId == user2Id {
|
||
receiverId = user1Id
|
||
} else {
|
||
return nil, fmt.Errorf("sender %d is not a participant of chat %d", req.SenderId, req.ChatId)
|
||
}
|
||
|
||
// Создаем сообщение в БД
|
||
var message proto.Message
|
||
var createdAt time.Time
|
||
|
||
err = s.db.QueryRow(ctx, `
|
||
INSERT INTO messages (chat_id, sender_id, receiver_id, content)
|
||
VALUES ($1, $2, $3, $4)
|
||
RETURNING id, chat_id, sender_id, receiver_id, content, status, created_at
|
||
`, req.ChatId, req.SenderId, receiverId, req.Content).Scan(
|
||
&message.Id, &message.ChatId, &message.SenderId, &message.ReceiverId, &message.Content, &message.Status, &createdAt,
|
||
)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
message.CreatedAt = timestamppb.New(createdAt)
|
||
|
||
// Обновляем время чата
|
||
_, err = s.db.Exec(ctx, `UPDATE chats SET updated_at = NOW() WHERE id = $1`, req.ChatId)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
var lastErr error
|
||
for i := 0; i < 3; i++ {
|
||
ch, err := s.rabbitConn.Channel()
|
||
if err != nil {
|
||
lastErr = fmt.Errorf("failed to open channel (attempt %d): %v", i+1, err)
|
||
time.Sleep(time.Second * time.Duration(i+1))
|
||
continue
|
||
}
|
||
|
||
queueName := fmt.Sprintf("user_%d_messages", receiverId)
|
||
msgBytes, _ := json.Marshal(message)
|
||
|
||
err = ch.PublishWithContext(ctx,
|
||
"", // exchange
|
||
queueName, // routing key
|
||
false, // mandatory
|
||
false, // immediate
|
||
amqp.Publishing{
|
||
ContentType: "application/json",
|
||
Body: msgBytes,
|
||
DeliveryMode: amqp.Persistent, // Сохраняем сообщения на диск
|
||
})
|
||
|
||
ch.Close()
|
||
if err == nil {
|
||
return &proto.MessageResponse{Message: &message}, nil
|
||
}
|
||
lastErr = err
|
||
}
|
||
return nil, fmt.Errorf("failed to publish message after 3 attempts: %v", lastErr)
|
||
}
|
||
|
||
func mustMarshal(msg protobuf.Message) []byte {
|
||
data, err := protobuf.Marshal(msg)
|
||
if err != nil {
|
||
log.Fatalf("failed to marshal message: %v", err)
|
||
}
|
||
return data
|
||
}
|
||
func protoMessageToMap(msg *proto.Message) map[string]interface{} {
|
||
return map[string]interface{}{
|
||
"id": msg.Id,
|
||
"chatId": msg.ChatId,
|
||
"senderId": msg.SenderId,
|
||
"receiverId": msg.ReceiverId,
|
||
"content": msg.Content,
|
||
"status": msg.Status,
|
||
"createdAt": msg.CreatedAt.AsTime().Format(time.RFC3339Nano),
|
||
}
|
||
}
|
||
func (s *server) GetChat(ctx context.Context, req *proto.GetChatRequest) (*proto.ChatResponse, error) {
|
||
var chat proto.Chat
|
||
var createdAt, updatedAt time.Time
|
||
var lastMessageID sql.NullInt32
|
||
var lastMessageContent sql.NullString
|
||
var lastMessageStatus sql.NullString
|
||
var lastMessageCreatedAt sql.NullTime
|
||
|
||
user1, user2 := req.User1Id, req.User2Id
|
||
if user1 > user2 {
|
||
user1, user2 = user2, user1
|
||
}
|
||
|
||
err := s.db.QueryRow(ctx, `
|
||
SELECT c.id, c.user1_id, c.user2_id, c.created_at, c.updated_at,
|
||
m.id, m.content, m.status, m.created_at
|
||
FROM chats c
|
||
LEFT JOIN messages m ON m.id = (
|
||
SELECT id FROM messages WHERE chat_id = c.id
|
||
ORDER BY created_at DESC LIMIT 1
|
||
)
|
||
WHERE c.user1_id = $1 AND c.user2_id = $2
|
||
`, user1, user2).Scan(
|
||
&chat.Id, &chat.User1Id, &chat.User2Id, &createdAt, &updatedAt,
|
||
&lastMessageID, &lastMessageContent, &lastMessageStatus, &lastMessageCreatedAt,
|
||
)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
chat.CreatedAt = timestamppb.New(createdAt)
|
||
chat.UpdatedAt = timestamppb.New(updatedAt)
|
||
|
||
if lastMessageID.Valid {
|
||
chat.LastMessage = &proto.Message{
|
||
Id: lastMessageID.Int32,
|
||
ChatId: chat.Id,
|
||
Content: lastMessageContent.String,
|
||
Status: lastMessageStatus.String,
|
||
CreatedAt: timestamppb.New(lastMessageCreatedAt.Time),
|
||
}
|
||
}
|
||
|
||
return &proto.ChatResponse{Chat: &chat}, nil
|
||
}
|
||
|
||
func (s *server) GetChatMessages(ctx context.Context, req *proto.GetChatMessagesRequest) (*proto.MessagesResponse, error) {
|
||
rows, err := s.db.Query(ctx, `
|
||
SELECT id, chat_id, sender_id, content, status, created_at
|
||
FROM messages
|
||
WHERE chat_id = $1
|
||
ORDER BY created_at DESC
|
||
LIMIT $2 OFFSET $3
|
||
`, req.ChatId, req.Limit, req.Offset)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
defer rows.Close()
|
||
|
||
var messages []*proto.Message
|
||
for rows.Next() {
|
||
var msg proto.Message
|
||
var createdAt time.Time
|
||
err := rows.Scan(
|
||
&msg.Id, &msg.ChatId, &msg.SenderId, &msg.Content, &msg.Status, &createdAt,
|
||
)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
msg.CreatedAt = timestamppb.New(createdAt)
|
||
messages = append(messages, &msg)
|
||
}
|
||
|
||
return &proto.MessagesResponse{Messages: messages}, nil
|
||
}
|
||
|
||
func (s *server) GetUserChats(ctx context.Context, req *proto.GetUserChatsRequest) (*proto.UserChatsResponse, error) {
|
||
rows, err := s.db.Query(ctx, `
|
||
SELECT c.id, c.user1_id, c.user2_id, c.created_at, c.updated_at,
|
||
m.id, m.content, m.status, m.created_at
|
||
FROM chats c
|
||
LEFT JOIN messages m ON m.id = (
|
||
SELECT id FROM messages WHERE chat_id = c.id
|
||
ORDER BY created_at DESC LIMIT 1
|
||
)
|
||
WHERE c.user1_id = $1 OR c.user2_id = $1
|
||
ORDER BY c.updated_at DESC
|
||
`, req.UserId)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
defer rows.Close()
|
||
|
||
var chats []*proto.Chat
|
||
for rows.Next() {
|
||
var chat proto.Chat
|
||
var createdAt, updatedAt time.Time
|
||
var lastMessageID sql.NullInt32
|
||
var lastMessageContent sql.NullString
|
||
var lastMessageStatus sql.NullString
|
||
var lastMessageCreatedAt sql.NullTime
|
||
|
||
err := rows.Scan(
|
||
&chat.Id, &chat.User1Id, &chat.User2Id, &createdAt, &updatedAt,
|
||
&lastMessageID, &lastMessageContent, &lastMessageStatus, &lastMessageCreatedAt,
|
||
)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
chat.CreatedAt = timestamppb.New(createdAt)
|
||
chat.UpdatedAt = timestamppb.New(updatedAt)
|
||
|
||
if lastMessageID.Valid {
|
||
chat.LastMessage = &proto.Message{
|
||
Id: lastMessageID.Int32,
|
||
ChatId: chat.Id,
|
||
Content: lastMessageContent.String,
|
||
Status: lastMessageStatus.String,
|
||
CreatedAt: timestamppb.New(lastMessageCreatedAt.Time),
|
||
}
|
||
}
|
||
|
||
chats = append(chats, &chat)
|
||
}
|
||
|
||
return &proto.UserChatsResponse{Chats: chats}, nil
|
||
}
|
||
|
||
func (s *server) UpdateMessageStatus(ctx context.Context, req *proto.UpdateMessageStatusRequest) (*proto.MessageResponse, error) {
|
||
var message proto.Message
|
||
var createdAt time.Time
|
||
|
||
err := s.db.QueryRow(ctx, `
|
||
UPDATE messages
|
||
SET status = $1
|
||
WHERE id = $2
|
||
RETURNING id, chat_id, sender_id, content, status, created_at
|
||
`, req.Status, req.MessageId).Scan(
|
||
&message.Id, &message.ChatId, &message.SenderId, &message.Content, &message.Status, &createdAt,
|
||
)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
message.CreatedAt = timestamppb.New(createdAt)
|
||
|
||
return &proto.MessageResponse{Message: &message}, nil
|
||
}
|
||
|
||
func (s *server) StreamMessages(req *proto.StreamMessagesRequest, stream proto.MessageService_StreamMessagesServer) error {
|
||
const maxRetries = 5
|
||
retryDelay := time.Second
|
||
|
||
for i := 0; i < maxRetries; i++ {
|
||
err := s.runStream(req, stream)
|
||
if err == nil {
|
||
return nil
|
||
}
|
||
|
||
log.Printf("Stream error (attempt %d/%d): %v", i+1, maxRetries, err)
|
||
time.Sleep(retryDelay)
|
||
retryDelay *= 2
|
||
}
|
||
|
||
return fmt.Errorf("max retries (%d) exceeded", maxRetries)
|
||
}
|
||
|
||
func (s *server) runStream(req *proto.StreamMessagesRequest, stream proto.MessageService_StreamMessagesServer) error {
|
||
ctx := stream.Context()
|
||
queueName := fmt.Sprintf("user_%d_messages", req.UserId)
|
||
|
||
ch, err := s.rabbitConn.Channel()
|
||
if err != nil {
|
||
return fmt.Errorf("failed to open channel: %v", err)
|
||
}
|
||
defer ch.Close()
|
||
|
||
// Объявляем очередь с persistence
|
||
_, err = ch.QueueDeclare(
|
||
queueName,
|
||
true, // durable
|
||
false, // autoDelete
|
||
false, // exclusive
|
||
false, // noWait
|
||
amqp.Table{
|
||
"x-message-ttl": int32(86400000), // 24 часа TTL
|
||
},
|
||
)
|
||
if err != nil {
|
||
return fmt.Errorf("failed to declare queue: %v", err)
|
||
}
|
||
|
||
// QoS для контроля скорости обработки
|
||
err = ch.Qos(
|
||
1, // prefetch count
|
||
0, // prefetch size
|
||
false, // global
|
||
)
|
||
if err != nil {
|
||
return fmt.Errorf("failed to set QoS: %v", err)
|
||
}
|
||
|
||
msgs, err := ch.Consume(
|
||
queueName,
|
||
"", // consumer
|
||
false, // auto-ack (false для ручного подтверждения)
|
||
false, // exclusive
|
||
false, // noLocal
|
||
false, // noWait
|
||
nil, // args
|
||
)
|
||
if err != nil {
|
||
return fmt.Errorf("failed to consume: %v", err)
|
||
}
|
||
|
||
log.Printf("Starting message stream for user %d", req.UserId)
|
||
defer log.Printf("Stopping message stream for user %d", req.UserId)
|
||
|
||
for {
|
||
select {
|
||
case <-ctx.Done():
|
||
return nil
|
||
case d, ok := <-msgs:
|
||
if !ok {
|
||
return fmt.Errorf("message channel closed")
|
||
}
|
||
|
||
var msg proto.Message
|
||
if err := json.Unmarshal(d.Body, &msg); err != nil {
|
||
log.Printf("Failed to unmarshal message: %v", err)
|
||
d.Nack(false, false) // Отбрасываем некорректное сообщение
|
||
continue
|
||
}
|
||
|
||
log.Printf("Sending message to user %d: %+v", req.UserId, msg)
|
||
if err := stream.Send(&proto.MessageResponse{Message: &msg}); err != nil {
|
||
d.Nack(false, true) // Возвращаем в очередь при ошибке отправки
|
||
return err
|
||
}
|
||
|
||
d.Ack(false) // Подтверждаем обработку
|
||
}
|
||
}
|
||
}
|
||
|
||
func main() {
|
||
// Инициализация подключения к БД
|
||
pool, err := pgxpool.Connect(context.Background(), "postgres://tailly_v2:i0Oq%2675LA%26M612ceuy@79.174.89.104:15452/tailly_v2")
|
||
if err != nil {
|
||
log.Fatalf("Unable to connect to database: %v", err)
|
||
}
|
||
defer pool.Close()
|
||
|
||
// Инициализация подключения к RabbitMQ
|
||
var rabbitConn *amqp.Connection
|
||
|
||
for i := 0; i < 5; i++ {
|
||
rabbitConn, err = amqp.DialConfig("amqp://tailly_rabbitmq:o2p2S80MPbl27LUU@89.104.69.222:5673/", amqp.Config{
|
||
Heartbeat: 10 * time.Second,
|
||
Locale: "en_US",
|
||
})
|
||
if err == nil {
|
||
break
|
||
}
|
||
log.Printf("Failed to connect to RabbitMQ (attempt %d): %v", i+1, err)
|
||
time.Sleep(time.Second * time.Duration(i+1))
|
||
}
|
||
if err != nil {
|
||
log.Fatalf("Failed to connect to RabbitMQ after 5 attempts: %v", err)
|
||
}
|
||
defer rabbitConn.Close()
|
||
|
||
// Обработка событий соединения
|
||
go func(conn **amqp.Connection, mu *sync.Mutex) {
|
||
for {
|
||
reason, ok := <-(*conn).NotifyClose(make(chan *amqp.Error))
|
||
if !ok {
|
||
log.Println("RabbitMQ connection closed")
|
||
break
|
||
}
|
||
log.Printf("RabbitMQ connection closed: %v", reason)
|
||
|
||
// Попытка переподключения
|
||
for i := 0; i < 5; i++ {
|
||
time.Sleep(time.Second * time.Duration(i+1))
|
||
newConn, err := amqp.Dial("amqp://tailly_rabbitmq:o2p2S80MPbl27LUU@89.104.69.222:5673/")
|
||
if err == nil {
|
||
mu.Lock()
|
||
*conn = newConn
|
||
mu.Unlock()
|
||
log.Println("Successfully reconnected to RabbitMQ")
|
||
break
|
||
}
|
||
log.Printf("Failed to reconnect to RabbitMQ (attempt %d): %v", i+1, err)
|
||
}
|
||
}
|
||
}(&rabbitConn, &sync.Mutex{})
|
||
|
||
// Создаем gRPC сервер
|
||
grpcServer := grpc.NewServer()
|
||
proto.RegisterMessageServiceServer(grpcServer, NewServer(pool, rabbitConn))
|
||
|
||
// Запускаем сервер
|
||
lis, err := net.Listen("tcp", ":50052")
|
||
if err != nil {
|
||
log.Fatalf("failed to listen: %v", err)
|
||
}
|
||
log.Println("Server started on port 50052")
|
||
if err := grpcServer.Serve(lis); err != nil {
|
||
log.Fatalf("failed to serve: %v", err)
|
||
}
|
||
}
|