package repository import ( "context" "database/sql" "errors" "fmt" "tailly_back_v2/internal/domain" "time" ) var ( ErrMessageNotFound = errors.New("message not found") ErrChatNotFound = errors.New("chat not found") ) type ChatRepository interface { // Основные методы сообщений SaveMessage(ctx context.Context, message *domain.Message) error GetMessageByID(ctx context.Context, id int) (*domain.Message, error) GetMessagesByChat(ctx context.Context, chatID int, limit, offset int) ([]*domain.Message, error) UpdateMessageStatus(ctx context.Context, id int, status string) error DeleteMessage(ctx context.Context, id int) error // Методы чатов CreateChat(ctx context.Context, user1ID, user2ID int) (*domain.Chat, error) GetChatByID(ctx context.Context, id int) (*domain.Chat, error) GetUserChats(ctx context.Context, userID int) ([]*domain.Chat, error) GetChatByParticipants(ctx context.Context, user1ID, user2ID int) (*domain.Chat, error) GetUnreadCount(ctx context.Context, chatID, userID int) (int, error) } type chatRepository struct { db *sql.DB } func NewChatRepository(db *sql.DB) ChatRepository { return &chatRepository{db: db} } func (r *chatRepository) SaveMessage(ctx context.Context, message *domain.Message) error { if message.ReceiverID == 0 { return errors.New("receiver_id is required") } query := ` INSERT INTO messages (chat_id, sender_id, receiver_id, content, status, created_at) VALUES ($1, $2, $3, $4, $5, $6) RETURNING id ` err := r.db.QueryRowContext(ctx, query, message.ChatID, message.SenderID, message.ReceiverID, message.Content, message.Status, message.CreatedAt, ).Scan(&message.ID) return err } func (r *chatRepository) GetMessageByID(ctx context.Context, id int) (*domain.Message, error) { query := ` SELECT id, chat_id, sender_id, content, status, created_at FROM messages WHERE id = $1 ` message := &domain.Message{} err := r.db.QueryRowContext(ctx, query, id).Scan( &message.ID, &message.ChatID, &message.SenderID, &message.Content, &message.Status, &message.CreatedAt, ) if errors.Is(err, sql.ErrNoRows) { return nil, ErrMessageNotFound } return message, err } func (r *chatRepository) GetMessagesByChat(ctx context.Context, chatID int, limit, offset int) ([]*domain.Message, error) { query := ` SELECT id, chat_id, sender_id, content, status, created_at FROM messages WHERE chat_id = $1 ORDER BY created_at DESC LIMIT $2 OFFSET $3 ` rows, err := r.db.QueryContext(ctx, query, chatID, limit, offset) if err != nil { return nil, err } defer rows.Close() var messages []*domain.Message for rows.Next() { var message domain.Message if err := rows.Scan( &message.ID, &message.ChatID, &message.SenderID, &message.Content, &message.Status, &message.CreatedAt, ); err != nil { return nil, err } messages = append(messages, &message) } return messages, nil } func (r *chatRepository) UpdateMessageStatus(ctx context.Context, id int, status string) error { query := ` UPDATE messages SET status = $1 WHERE id = $2 ` _, err := r.db.ExecContext(ctx, query, status, id) return err } func (r *chatRepository) DeleteMessage(ctx context.Context, id int) error { query := `DELETE FROM messages WHERE id = $1` _, err := r.db.ExecContext(ctx, query, id) return err } func (r *chatRepository) CreateChat(ctx context.Context, user1ID, user2ID int) (*domain.Chat, error) { // Проверяем, что пользователи разные if user1ID == user2ID { return nil, errors.New("cannot create chat with yourself") } // Упорядочиваем ID пользователей согласно CHECK constraint if user1ID > user2ID { user1ID, user2ID = user2ID, user1ID } query := ` INSERT INTO chats (user1_id, user2_id, created_at) VALUES ($1, $2, $3) RETURNING id ` chat := &domain.Chat{ User1ID: user1ID, User2ID: user2ID, CreatedAt: time.Now(), } err := r.db.QueryRowContext(ctx, query, user1ID, user2ID, chat.CreatedAt).Scan(&chat.ID) if err != nil { return nil, fmt.Errorf("failed to create chat: %v", err) } return chat, nil } func (r *chatRepository) GetChatByID(ctx context.Context, id int) (*domain.Chat, error) { query := ` SELECT id, user1_id, user2_id, created_at FROM chats WHERE id = $1 ` chat := &domain.Chat{} err := r.db.QueryRowContext(ctx, query, id).Scan( &chat.ID, &chat.User1ID, &chat.User2ID, &chat.CreatedAt, ) if errors.Is(err, sql.ErrNoRows) { return nil, ErrChatNotFound } return chat, err } func (r *chatRepository) GetUserChats(ctx context.Context, userID int) ([]*domain.Chat, error) { query := ` SELECT id, user1_id, user2_id, created_at FROM chats WHERE user1_id = $1 OR user2_id = $1 ORDER BY created_at DESC ` rows, err := r.db.QueryContext(ctx, query, userID) if err != nil { return nil, err } defer rows.Close() var chats []*domain.Chat for rows.Next() { var chat domain.Chat if err := rows.Scan( &chat.ID, &chat.User1ID, &chat.User2ID, &chat.CreatedAt, ); err != nil { return nil, err } chats = append(chats, &chat) } return chats, nil } func (r *chatRepository) GetChatByParticipants(ctx context.Context, user1ID, user2ID int) (*domain.Chat, error) { // Упорядочиваем ID пользователей согласно CHECK constraint if user1ID > user2ID { user1ID, user2ID = user2ID, user1ID } query := ` SELECT id, user1_id, user2_id, created_at FROM chats WHERE user1_id = $1 AND user2_id = $2 LIMIT 1 ` chat := &domain.Chat{} err := r.db.QueryRowContext(ctx, query, user1ID, user2ID).Scan( &chat.ID, &chat.User1ID, &chat.User2ID, &chat.CreatedAt, ) if errors.Is(err, sql.ErrNoRows) { return nil, ErrChatNotFound } return chat, err } func (r *chatRepository) GetUnreadCount(ctx context.Context, chatID, userID int) (int, error) { var count int err := r.db.QueryRowContext(ctx, ` SELECT COUNT(*) FROM messages WHERE chat_id = $1 AND sender_id != $2 AND status = 'sent' `, chatID, userID).Scan(&count) return count, err }