tailly_back_v2/internal/repository/chat_repository.go
madipo2611 6f5298d420 v0.0.3
2025-05-03 02:37:08 +03:00

219 lines
5.4 KiB
Go

package repository
import (
"context"
"database/sql"
"errors"
"tailly_back_v2/internal/domain"
"time"
)
var (
ErrMessageNotFound = errors.New("message not found")
ErrChatNotFound = errors.New("chat not found")
)
type ChatRepository interface {
// Основные методы сообщений
SaveMessage(ctx context.Context, message *domain.Message) error
GetMessageByID(ctx context.Context, id int) (*domain.Message, error)
GetMessagesByChat(ctx context.Context, chatID int, limit, offset int) ([]*domain.Message, error)
UpdateMessageStatus(ctx context.Context, id int, status string) error
DeleteMessage(ctx context.Context, id int) error
// Методы чатов
CreateChat(ctx context.Context, user1ID, user2ID int) (*domain.Chat, error)
GetChatByID(ctx context.Context, id int) (*domain.Chat, error)
GetUserChats(ctx context.Context, userID int) ([]*domain.Chat, error)
GetChatByParticipants(ctx context.Context, user1ID, user2ID int) (*domain.Chat, error)
GetUnreadCount(ctx context.Context, chatID, userID int) (int, error)
}
type chatRepository struct {
db *sql.DB
}
func NewChatRepository(db *sql.DB) ChatRepository {
return &chatRepository{db: db}
}
func (r *chatRepository) SaveMessage(ctx context.Context, message *domain.Message) error {
query := `
INSERT INTO messages (chat_id, sender_id, content, status, created_at)
VALUES ($1, $2, $3, $4, $5)
RETURNING id
`
err := r.db.QueryRowContext(ctx, query,
message.ChatID,
message.SenderID,
message.Content,
message.Status,
message.CreatedAt,
).Scan(&message.ID)
return err
}
func (r *chatRepository) GetMessageByID(ctx context.Context, id int) (*domain.Message, error) {
query := `
SELECT id, chat_id, sender_id, content, status, created_at
FROM messages
WHERE id = $1
`
message := &domain.Message{}
err := r.db.QueryRowContext(ctx, query, id).Scan(
&message.ID,
&message.ChatID,
&message.SenderID,
&message.Content,
&message.Status,
&message.CreatedAt,
)
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrMessageNotFound
}
return message, err
}
func (r *chatRepository) GetMessagesByChat(ctx context.Context, chatID int, limit, offset int) ([]*domain.Message, error) {
query := `
SELECT id, chat_id, sender_id, content, status, created_at
FROM messages
WHERE chat_id = $1
ORDER BY created_at DESC
LIMIT $2 OFFSET $3
`
rows, err := r.db.QueryContext(ctx, query, chatID, limit, offset)
if err != nil {
return nil, err
}
defer rows.Close()
var messages []*domain.Message
for rows.Next() {
var message domain.Message
if err := rows.Scan(
&message.ID,
&message.ChatID,
&message.SenderID,
&message.Content,
&message.Status,
&message.CreatedAt,
); err != nil {
return nil, err
}
messages = append(messages, &message)
}
return messages, nil
}
func (r *chatRepository) UpdateMessageStatus(ctx context.Context, id int, status string) error {
query := `
UPDATE messages
SET status = $1
WHERE id = $2
`
_, err := r.db.ExecContext(ctx, query, status, id)
return err
}
func (r *chatRepository) DeleteMessage(ctx context.Context, id int) error {
query := `DELETE FROM messages WHERE id = $1`
_, err := r.db.ExecContext(ctx, query, id)
return err
}
func (r *chatRepository) CreateChat(ctx context.Context, user1ID, user2ID int) (*domain.Chat, error) {
query := `
INSERT INTO chats (user1_id, user2_id, created_at)
VALUES ($1, $2, $3)
RETURNING id
`
chat := &domain.Chat{
User1ID: user1ID,
User2ID: user2ID,
CreatedAt: time.Now(),
}
err := r.db.QueryRowContext(ctx, query, user1ID, user2ID, chat.CreatedAt).Scan(&chat.ID)
return chat, err
}
func (r *chatRepository) GetChatByID(ctx context.Context, id int) (*domain.Chat, error) {
query := `
SELECT id, user1_id, user2_id, created_at
FROM chats
WHERE id = $1
`
chat := &domain.Chat{}
err := r.db.QueryRowContext(ctx, query, id).Scan(
&chat.ID,
&chat.User1ID,
&chat.User2ID,
&chat.CreatedAt,
)
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrChatNotFound
}
return chat, err
}
func (r *chatRepository) GetUserChats(ctx context.Context, userID int) ([]*domain.Chat, error) {
query := `
SELECT id, user1_id, user2_id, created_at
FROM chats
WHERE user1_id = $1 OR user2_id = $1
ORDER BY created_at DESC
`
rows, err := r.db.QueryContext(ctx, query, userID)
if err != nil {
return nil, err
}
defer rows.Close()
var chats []*domain.Chat
for rows.Next() {
var chat domain.Chat
if err := rows.Scan(
&chat.ID,
&chat.User1ID,
&chat.User2ID,
&chat.CreatedAt,
); err != nil {
return nil, err
}
chats = append(chats, &chat)
}
return chats, nil
}
func (r *chatRepository) GetChatByParticipants(ctx context.Context, user1ID, user2ID int) (*domain.Chat, error) {
query := `
SELECT id, user1_id, user2_id, created_at
FROM chats
WHERE (user1_id = $1 AND user2_id = $2)
OR (user1_id = $2 AND user2_id = $1)
LIMIT 1
`
chat := &domain.Chat{}
err := r.db.QueryRowContext(ctx, query, user1ID, user2ID).Scan(
&chat.ID,
&chat.User1ID,
&chat.User2ID,
&chat.CreatedAt,
)
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrChatNotFound
}
return chat, err
}
func (r *chatRepository) GetUnreadCount(ctx context.Context, chatID, userID int) (int, error) {
var count int
err := r.db.QueryRowContext(ctx, `
SELECT COUNT(*)
FROM messages
WHERE chat_id = $1
AND sender_id != $2
AND status = 'sent'
`, chatID, userID).Scan(&count)
return count, err
}