diff --git a/crypto/server_crypto.go b/crypto/server_crypto.go index ecb477a..51f06db 100644 --- a/crypto/server_crypto.go +++ b/crypto/server_crypto.go @@ -7,6 +7,7 @@ import ( "crypto/rsa" "crypto/sha256" "crypto/x509" + "encoding/base64" "io" ) @@ -38,7 +39,7 @@ func NewCryptoService() (*CryptoService, error) { }, nil } -// Шифрование сообщения +// Шифрование сообщения - возвращает зашифрованный контент, ключ и nonce func (cs *CryptoService) EncryptMessage(content string) ([]byte, []byte, []byte, error) { // Генерация сессионного AES ключа aesKey := make([]byte, 32) @@ -62,7 +63,7 @@ func (cs *CryptoService) EncryptMessage(content string) ([]byte, []byte, []byte, return nil, nil, nil, err } - encryptedContent := gcm.Seal(nonce, nonce, []byte(content), nil) + encryptedContent := gcm.Seal(nil, nonce, []byte(content), nil) // Шифрование AES ключа RSA-OAEP encryptedKey, err := rsa.EncryptOAEP(sha256.New(), rand.Reader, &cs.privateKey.PublicKey, aesKey, nil) @@ -99,3 +100,13 @@ func (cs *CryptoService) DecryptMessage(encryptedContent, encryptedKey, nonce [] return string(plaintext), nil } + +// Утилитарная функция для base64 кодирования (для хранения в Vault) +func (cs *CryptoService) EncodeBase64(data []byte) string { + return base64.StdEncoding.EncodeToString(data) +} + +// Утилитарная функция для base64 декодирования +func (cs *CryptoService) DecodeBase64(data string) ([]byte, error) { + return base64.StdEncoding.DecodeString(data) +} diff --git a/crypto/vault.go b/crypto/vault.go index 5a9e116..ab97947 100644 --- a/crypto/vault.go +++ b/crypto/vault.go @@ -71,8 +71,8 @@ func (v *VaultManager) GetMasterPrivateKey() ([]byte, error) { return nil, fmt.Errorf("invalid data format in Vault secret") } - // Получаем приватный ключ в base64 - ОБРАТИТЕ ВНИМАНИЕ НА НОВОЕ ИМЯ ПОЛЯ! - keyInterface, ok := data["private_key_base64"] // ← Изменилось здесь! + // Получаем приватный ключ в base64 + keyInterface, ok := data["private_key_base64"] if !ok { return nil, fmt.Errorf("private_key_base64 not found in Vault data") } diff --git a/migrations/0002_initial_schema.down.sql b/migrations/0002_initial_schema.down.sql deleted file mode 100644 index 950ffe8..0000000 --- a/migrations/0002_initial_schema.down.sql +++ /dev/null @@ -1,3 +0,0 @@ -ALTER TABLE messages -DROP COLUMN encrypted_key, -DROP COLUMN nonce; \ No newline at end of file diff --git a/migrations/0002_initial_schema.up.sql b/migrations/0002_initial_schema.up.sql index beb9fde..dce04a1 100644 --- a/migrations/0002_initial_schema.up.sql +++ b/migrations/0002_initial_schema.up.sql @@ -1,18 +1,9 @@ +-- Добавляем колонки для хранения зашифрованных данных ALTER TABLE messages - ADD COLUMN encrypted_key BYTEA, -ADD COLUMN nonce BYTEA; + ADD COLUMN encrypted_content BYTEA NOT NULL, +ADD COLUMN encrypted_key BYTEA NOT NULL, +ADD COLUMN nonce BYTEA NOT NULL, +DROP COLUMN content; --- Обновляем существующие сообщения (если есть) -UPDATE messages SET - encrypted_key = ''::bytea, -nonce = ''::bytea -WHERE encrypted_key IS NULL OR nonce IS NULL; - --- Делаем поля обязательными -ALTER TABLE messages - ALTER COLUMN encrypted_key SET NOT NULL, -ALTER COLUMN nonce SET NOT NULL; -ALTER TABLE messages -ALTER COLUMN content TYPE BYTEA USING content::bytea, -ALTER COLUMN encrypted_key TYPE BYTEA USING encrypted_key::bytea, -ALTER COLUMN nonce TYPE BYTEA USING nonce::bytea; \ No newline at end of file +-- Обновляем индексы (если нужно) +CREATE INDEX idx_messages_encrypted ON messages(encrypted_content); \ No newline at end of file diff --git a/server.go b/server.go index 14905f6..f593252 100644 --- a/server.go +++ b/server.go @@ -3,7 +3,6 @@ package main import ( "context" "database/sql" - "encoding/base64" "encoding/json" "errors" "fmt" @@ -18,7 +17,7 @@ import ( "net" "os" "sync" - crypto2 "tailly_messages/crypto" + "tailly_messages/crypto" "tailly_messages/proto" "time" ) @@ -27,21 +26,15 @@ type server struct { proto.UnimplementedMessageServiceServer db *pgxpool.Pool rabbitConn *amqp.Connection - crypto *crypto2.CryptoService mu sync.Mutex logger *log.Logger + crypto *crypto.CryptoService } func NewServer(db *pgxpool.Pool, rabbitConn *amqp.Connection) *server { - cryptoService, err := crypto2.NewCryptoService() - if err != nil { - log.Fatalf("Failed to initialize crypto service: %v", err) - } - return &server{ db: db, rabbitConn: rabbitConn, - crypto: cryptoService, logger: log.New(os.Stdout, "MESSAGE_SERVICE: ", log.LstdFlags|log.Lmicroseconds|log.Lshortfile), } } @@ -86,7 +79,7 @@ func (s *server) CreateChat(ctx context.Context, req *proto.CreateChatRequest) ( errMsg := fmt.Sprintf("One or both users don't exist: user1=%d (%t), user2=%d (%t)", user1, user1Exists, user2, user2Exists) s.logger.Println(errMsg) - return nil, fmt.Errorf("%w", errors.New(errMsg)) + return nil, fmt.Errorf("%w", errors.New(errMsg)) // Оборачиваем ошибку } var chat proto.Chat @@ -139,44 +132,13 @@ func (s *server) SendMessage(ctx context.Context, req *proto.SendMessageRequest) s.logger.Printf("SendMessage execution time: %v", time.Since(start)) }(time.Now()) - // Шифрование сообщения - s.logger.Printf("Encrypting message for chat_id=%d", req.ChatId) + // Шифруем сообщение encryptedContent, encryptedKey, nonce, err := s.crypto.EncryptMessage(req.Content) if err != nil { s.logger.Printf("Failed to encrypt message: %v", err) return nil, fmt.Errorf("failed to encrypt message: %v", err) } - // Кодируем в base64 для хранения в TEXT поле - encodedContent := base64.StdEncoding.EncodeToString(encryptedContent) - encodedKey := base64.StdEncoding.EncodeToString(encryptedKey) - encodedNonce := base64.StdEncoding.EncodeToString(nonce) - - s.logger.Printf("Base64 encoded: content_len=%d, key_len=%d, nonce_len=%d", - len(encodedContent), len(encodedKey), len(encodedNonce)) - - // Тестовая расшифровка сразу после шифрования - testDecrypted, err := s.crypto.DecryptMessage(encryptedContent, encryptedKey, nonce) - if err != nil { - s.logger.Printf("IMMEDIATE DECRYPTION TEST FAILED: %v", err) - return nil, fmt.Errorf("encryption/decryption test failed: %v", err) - } - - if testDecrypted != req.Content { - s.logger.Printf("IMMEDIATE DECRYPTION CONTENT MISMATCH") - return nil, fmt.Errorf("encryption/decryption content mismatch") - } - - s.logger.Printf("Immediate decryption test passed") - - // Сохранение сессионного ключа в Vault - s.logger.Printf("Storing session key in Vault for chat_id=%d", req.ChatId) - err = s.crypto.Vault.StoreSessionKey(int(req.ChatId), 0, encryptedKey) - if err != nil { - s.logger.Printf("Failed to store session key: %v", err) - return nil, fmt.Errorf("failed to store session key: %v", err) - } - s.logger.Printf("Getting chat info for chat_id=%d", req.ChatId) var user1Id, user2Id int32 err = s.db.QueryRow(ctx, "SELECT user1_id, user2_id FROM chats WHERE id = $1", req.ChatId).Scan(&user1Id, &user2Id) @@ -201,27 +163,28 @@ func (s *server) SendMessage(ctx context.Context, req *proto.SendMessageRequest) var message proto.Message var createdAt time.Time - // Используем закодированные base64 строки err = s.db.QueryRow(ctx, ` - INSERT INTO messages (chat_id, sender_id, receiver_id, content, encrypted_key, nonce) + INSERT INTO messages (chat_id, sender_id, receiver_id, encrypted_content, encrypted_key, nonce) VALUES ($1, $2, $3, $4, $5, $6) RETURNING id, chat_id, sender_id, receiver_id, status, created_at - `, req.ChatId, req.SenderId, receiverId, encodedContent, encodedKey, encodedNonce).Scan( + `, req.ChatId, req.SenderId, receiverId, encryptedContent, encryptedKey, nonce).Scan( &message.Id, &message.ChatId, &message.SenderId, &message.ReceiverId, &message.Status, &createdAt, ) if err != nil { - s.logger.Printf("Failed to insert message: %v", err) + s.logger.Printf("Failed to insert encrypted message: %v", err) return nil, err } message.CreatedAt = timestamppb.New(createdAt) - // Обновление сессионного ключа в Vault с правильным messageID - s.logger.Printf("Updating session key in Vault with actual message_id=%d", message.Id) - err = s.crypto.Vault.StoreSessionKey(int(req.ChatId), int(message.Id), encryptedKey) + // Для RabbitMQ отправляем расшифрованное сообщение (или зашифрованное, в зависимости от требований) + // Здесь отправляем расшифрованное для совместимости с существующими клиентами + decryptedContent, err := s.crypto.DecryptMessage(encryptedContent, encryptedKey, nonce) if err != nil { - s.logger.Printf("Failed to update session key with message ID: %v", err) + s.logger.Printf("Failed to decrypt for RabbitMQ: %v", err) + return nil, err } + message.Content = decryptedContent s.logger.Printf("Updating chat updated_at for chat_id=%d", req.ChatId) _, err = s.db.Exec(ctx, `UPDATE chats SET updated_at = NOW() WHERE id = $1`, req.ChatId) @@ -230,24 +193,7 @@ func (s *server) SendMessage(ctx context.Context, req *proto.SendMessageRequest) return nil, err } - // Расшифровка для отправки через RabbitMQ - decryptedContent, err := s.crypto.DecryptMessage(encryptedContent, encryptedKey, nonce) - if err != nil { - s.logger.Printf("Failed to decrypt message for RabbitMQ: %v", err) - return nil, fmt.Errorf("failed to decrypt message: %v", err) - } - - rabbitMsg := proto.Message{ - Id: message.Id, - ChatId: message.ChatId, - SenderId: message.SenderId, - ReceiverId: message.ReceiverId, - Content: decryptedContent, - Status: message.Status, - CreatedAt: message.CreatedAt, - } - - s.logger.Printf("Publishing decrypted message to RabbitMQ for user_id=%d", receiverId) + s.logger.Printf("Publishing message to RabbitMQ for user_id=%d", receiverId) var lastErr error for i := 0; i < 3; i++ { ch, err := s.rabbitConn.Channel() @@ -259,7 +205,7 @@ func (s *server) SendMessage(ctx context.Context, req *proto.SendMessageRequest) } queueName := fmt.Sprintf("user_%d_messages", receiverId) - msgBytes, _ := json.Marshal(rabbitMsg) + msgBytes, _ := json.Marshal(message) s.logger.Printf("Publishing to queue %s: %s", queueName, string(msgBytes)) err = ch.PublishWithContext(ctx, @@ -298,7 +244,7 @@ func (s *server) GetChat(ctx context.Context, req *proto.GetChatRequest) (*proto var chat proto.Chat var createdAt, updatedAt time.Time var lastMessageID sql.NullInt32 - var lastMessageContentBase64, lastMessageEncryptedKeyBase64, lastMessageNonceBase64 sql.NullString + var lastMessageEncryptedContent, lastMessageEncryptedKey, lastMessageNonce []byte var lastMessageStatus sql.NullString var lastMessageCreatedAt sql.NullTime @@ -309,18 +255,18 @@ func (s *server) GetChat(ctx context.Context, req *proto.GetChatRequest) (*proto s.logger.Printf("Querying chat between users %d and %d", user1, user2) err := s.db.QueryRow(ctx, ` - SELECT c.id, c.user1_id, c.user2_id, c.created_at, c.updated_at, - m.id, m.content, m.encrypted_key, m.nonce, m.status, m.created_at - FROM chats c - LEFT JOIN messages m ON m.id = ( - SELECT id FROM messages WHERE chat_id = c.id - ORDER BY created_at DESC LIMIT 1 - ) - WHERE c.user1_id = $1 AND c.user2_id = $2 - `, user1, user2).Scan( + SELECT c.id, c.user1_id, c.user2_id, c.created_at, c.updated_at, + m.id, m.encrypted_content, m.encrypted_key, m.nonce, m.status, m.created_at + FROM chats c + LEFT JOIN messages m ON m.id = ( + SELECT id FROM messages WHERE chat_id = c.id + ORDER BY created_at DESC LIMIT 1 + ) + WHERE c.user1_id = $1 AND c.user2_id = $2 + `, user1, user2).Scan( &chat.Id, &chat.User1Id, &chat.User2Id, &createdAt, &updatedAt, - &lastMessageID, &lastMessageContentBase64, &lastMessageEncryptedKeyBase64, - &lastMessageNonceBase64, &lastMessageStatus, &lastMessageCreatedAt, + &lastMessageID, &lastMessageEncryptedContent, &lastMessageEncryptedKey, + &lastMessageNonce, &lastMessageStatus, &lastMessageCreatedAt, ) if err != nil { s.logger.Printf("Failed to get chat: %v", err) @@ -332,39 +278,15 @@ func (s *server) GetChat(ctx context.Context, req *proto.GetChatRequest) (*proto if lastMessageID.Valid { // Расшифровываем последнее сообщение - var decryptedContent string - - if lastMessageContentBase64.Valid && lastMessageEncryptedKeyBase64.Valid && lastMessageNonceBase64.Valid { - // Декодируем из base64 - encryptedContent, err := base64.StdEncoding.DecodeString(lastMessageContentBase64.String) - if err != nil { - s.logger.Printf("Failed to decode content for last message %d: %v", lastMessageID.Int32, err) - decryptedContent = "[decode error: content]" - } else { - encryptedKey, err := base64.StdEncoding.DecodeString(lastMessageEncryptedKeyBase64.String) - if err != nil { - s.logger.Printf("Failed to decode key for last message %d: %v", lastMessageID.Int32, err) - decryptedContent = "[decode error: key]" - } else { - nonce, err := base64.StdEncoding.DecodeString(lastMessageNonceBase64.String) - if err != nil { - s.logger.Printf("Failed to decode nonce for last message %d: %v", lastMessageID.Int32, err) - decryptedContent = "[decode error: nonce]" - } else { - // Расшифровка сообщения - decrypted, err := s.crypto.DecryptMessage(encryptedContent, encryptedKey, nonce) - if err != nil { - s.logger.Printf("Failed to decrypt last message %d: %v", lastMessageID.Int32, err) - decryptedContent = "[encrypted message]" - } else { - decryptedContent = decrypted - s.logger.Printf("Successfully decrypted last message for chat %d", chat.Id) - } - } - } - } - } else { - decryptedContent = "[no message content]" + decryptedContent, err := s.crypto.DecryptMessage( + lastMessageEncryptedContent, + lastMessageEncryptedKey, + lastMessageNonce, + ) + if err != nil { + s.logger.Printf("Failed to decrypt last message: %v", err) + // Можно вернуть ошибку или пустое сообщение + decryptedContent = "[не удалось расшифровать]" } chat.LastMessage = &proto.Message{ @@ -390,12 +312,12 @@ func (s *server) GetChatMessages(ctx context.Context, req *proto.GetChatMessages s.logger.Printf("Querying messages for chat_id=%d, limit=%d, offset=%d", req.ChatId, req.Limit, req.Offset) rows, err := s.db.Query(ctx, ` - SELECT id, chat_id, sender_id, receiver_id, content, encrypted_key, nonce, status, created_at - FROM messages - WHERE chat_id = $1 - ORDER BY created_at DESC - LIMIT $2 OFFSET $3 - `, req.ChatId, req.Limit, req.Offset) + SELECT id, chat_id, sender_id, receiver_id, encrypted_content, encrypted_key, nonce, status, created_at + FROM messages + WHERE chat_id = $1 + ORDER BY created_at DESC + LIMIT $2 OFFSET $3 + `, req.ChatId, req.Limit, req.Offset) if err != nil { s.logger.Printf("Failed to query messages: %v", err) return nil, err @@ -406,69 +328,25 @@ func (s *server) GetChatMessages(ctx context.Context, req *proto.GetChatMessages for rows.Next() { var msg proto.Message var createdAt time.Time - var encryptedKeyBase64, nonceBase64, encryptedContentBase64 string + var encryptedContent, encryptedKey, nonce []byte err := rows.Scan( &msg.Id, &msg.ChatId, &msg.SenderId, &msg.ReceiverId, - &encryptedContentBase64, &encryptedKeyBase64, &nonceBase64, &msg.Status, &createdAt, + &encryptedContent, &encryptedKey, &nonce, + &msg.Status, &createdAt, ) if err != nil { s.logger.Printf("Failed to scan message row: %v", err) return nil, err } - s.logger.Printf("Processing message %d: content_len=%d, key_len=%d, nonce_len=%d", - msg.Id, len(encryptedContentBase64), len(encryptedKeyBase64), len(nonceBase64)) - - // Декодируем из base64 - encryptedContent, err := base64.StdEncoding.DecodeString(encryptedContentBase64) - if err != nil { - s.logger.Printf("Failed to decode content for message %d: %v", msg.Id, err) - msg.Content = "[base64 decode error: content]" - continue - } - - encryptedKey, err := base64.StdEncoding.DecodeString(encryptedKeyBase64) - if err != nil { - s.logger.Printf("Failed to decode key for message %d: %v", msg.Id, err) - msg.Content = "[base64 decode error: key]" - continue - } - - nonce, err := base64.StdEncoding.DecodeString(nonceBase64) - if err != nil { - s.logger.Printf("Failed to decode nonce for message %d: %v", msg.Id, err) - msg.Content = "[base64 decode error: nonce]" - continue - } - - // Проверяем что данные не пустые после декодирования - if len(encryptedContent) == 0 { - s.logger.Printf("Empty content after base64 decode for message %d", msg.Id) - msg.Content = "[empty content after decode]" - continue - } - if len(encryptedKey) == 0 { - s.logger.Printf("Empty key after base64 decode for message %d", msg.Id) - msg.Content = "[empty key after decode]" - continue - } - if len(nonce) == 0 { - s.logger.Printf("Empty nonce after base64 decode for message %d", msg.Id) - msg.Content = "[empty nonce after decode]" - continue - } - - // Расшифровка сообщения + // Расшифровываем сообщение decryptedContent, err := s.crypto.DecryptMessage(encryptedContent, encryptedKey, nonce) if err != nil { s.logger.Printf("Failed to decrypt message %d: %v", msg.Id, err) - s.logger.Printf("Decryption failed details: content_len=%d, key_len=%d, nonce_len=%d", - len(encryptedContent), len(encryptedKey), len(nonce)) - msg.Content = "[decryption failed: " + err.Error() + "]" + msg.Content = "[не удалось расшифровать]" } else { msg.Content = decryptedContent - s.logger.Printf("Successfully decrypted message %d", msg.Id) } msg.CreatedAt = timestamppb.New(createdAt) @@ -480,7 +358,7 @@ func (s *server) GetChatMessages(ctx context.Context, req *proto.GetChatMessages return nil, err } - s.logger.Printf("Retrieved and decrypted %d messages for chat_id=%d", len(messages), req.ChatId) + s.logger.Printf("Retrieved %d messages for chat_id=%d", len(messages), req.ChatId) resp := &proto.MessagesResponse{Messages: messages} s.logResponse("GetChatMessages", resp, nil) return resp, nil @@ -494,16 +372,16 @@ func (s *server) GetUserChats(ctx context.Context, req *proto.GetUserChatsReques s.logger.Printf("Querying chats for user_id=%d", req.UserId) rows, err := s.db.Query(ctx, ` - SELECT c.id, c.user1_id, c.user2_id, c.created_at, c.updated_at, - m.id, m.content, m.encrypted_key, m.nonce, m.status, m.created_at - FROM chats c - LEFT JOIN messages m ON m.id = ( - SELECT id FROM messages WHERE chat_id = c.id - ORDER BY created_at DESC LIMIT 1 - ) - WHERE c.user1_id = $1 OR c.user2_id = $1 - ORDER BY c.updated_at DESC - `, req.UserId) + SELECT c.id, c.user1_id, c.user2_id, c.created_at, c.updated_at, + m.id, m.content, m.status, m.created_at + FROM chats c + LEFT JOIN messages m ON m.id = ( + SELECT id FROM messages WHERE chat_id = c.id + ORDER BY created_at DESC LIMIT 1 + ) + WHERE c.user1_id = $1 OR c.user2_id = $1 + ORDER BY c.updated_at DESC + `, req.UserId) if err != nil { s.logger.Printf("Failed to query user chats: %v", err) return nil, err @@ -515,14 +393,13 @@ func (s *server) GetUserChats(ctx context.Context, req *proto.GetUserChatsReques var chat proto.Chat var createdAt, updatedAt time.Time var lastMessageID sql.NullInt32 - var lastMessageContentBase64, lastMessageEncryptedKeyBase64, lastMessageNonceBase64 sql.NullString + var lastMessageContent sql.NullString var lastMessageStatus sql.NullString var lastMessageCreatedAt sql.NullTime err := rows.Scan( &chat.Id, &chat.User1Id, &chat.User2Id, &createdAt, &updatedAt, - &lastMessageID, &lastMessageContentBase64, &lastMessageEncryptedKeyBase64, - &lastMessageNonceBase64, &lastMessageStatus, &lastMessageCreatedAt, + &lastMessageID, &lastMessageContent, &lastMessageStatus, &lastMessageCreatedAt, ) if err != nil { s.logger.Printf("Failed to scan chat row: %v", err) @@ -533,46 +410,10 @@ func (s *server) GetUserChats(ctx context.Context, req *proto.GetUserChatsReques chat.UpdatedAt = timestamppb.New(updatedAt) if lastMessageID.Valid { - // Расшифровываем последнее сообщение - var decryptedContent string - - if lastMessageContentBase64.Valid && lastMessageEncryptedKeyBase64.Valid && lastMessageNonceBase64.Valid { - // Декодируем из base64 - encryptedContent, err := base64.StdEncoding.DecodeString(lastMessageContentBase64.String) - if err != nil { - s.logger.Printf("Failed to decode content for last message %d: %v", lastMessageID.Int32, err) - decryptedContent = "[decode error: content]" - } else { - encryptedKey, err := base64.StdEncoding.DecodeString(lastMessageEncryptedKeyBase64.String) - if err != nil { - s.logger.Printf("Failed to decode key for last message %d: %v", lastMessageID.Int32, err) - decryptedContent = "[decode error: key]" - } else { - nonce, err := base64.StdEncoding.DecodeString(lastMessageNonceBase64.String) - if err != nil { - s.logger.Printf("Failed to decode nonce for last message %d: %v", lastMessageID.Int32, err) - decryptedContent = "[decode error: nonce]" - } else { - // Расшифровка сообщения - decrypted, err := s.crypto.DecryptMessage(encryptedContent, encryptedKey, nonce) - if err != nil { - s.logger.Printf("Failed to decrypt last message %d: %v", lastMessageID.Int32, err) - decryptedContent = "[encrypted message]" - } else { - decryptedContent = decrypted - s.logger.Printf("Successfully decrypted last message for chat %d", chat.Id) - } - } - } - } - } else { - decryptedContent = "[no message content]" - } - chat.LastMessage = &proto.Message{ Id: lastMessageID.Int32, ChatId: chat.Id, - Content: decryptedContent, + Content: lastMessageContent.String, Status: lastMessageStatus.String, CreatedAt: timestamppb.New(lastMessageCreatedAt.Time), }