v.0.0.4.6 Добавлено шифрование сообщения
All checks were successful
continuous-integration/drone/push Build is passing

This commit is contained in:
madipo2611 2025-08-21 00:50:27 +03:00
parent 261f14b451
commit f983a2f9d9
5 changed files with 83 additions and 243 deletions

View File

@ -7,6 +7,7 @@ import (
"crypto/rsa"
"crypto/sha256"
"crypto/x509"
"encoding/base64"
"io"
)
@ -38,7 +39,7 @@ func NewCryptoService() (*CryptoService, error) {
}, nil
}
// Шифрование сообщения
// Шифрование сообщения - возвращает зашифрованный контент, ключ и nonce
func (cs *CryptoService) EncryptMessage(content string) ([]byte, []byte, []byte, error) {
// Генерация сессионного AES ключа
aesKey := make([]byte, 32)
@ -62,7 +63,7 @@ func (cs *CryptoService) EncryptMessage(content string) ([]byte, []byte, []byte,
return nil, nil, nil, err
}
encryptedContent := gcm.Seal(nonce, nonce, []byte(content), nil)
encryptedContent := gcm.Seal(nil, nonce, []byte(content), nil)
// Шифрование AES ключа RSA-OAEP
encryptedKey, err := rsa.EncryptOAEP(sha256.New(), rand.Reader, &cs.privateKey.PublicKey, aesKey, nil)
@ -99,3 +100,13 @@ func (cs *CryptoService) DecryptMessage(encryptedContent, encryptedKey, nonce []
return string(plaintext), nil
}
// Утилитарная функция для base64 кодирования (для хранения в Vault)
func (cs *CryptoService) EncodeBase64(data []byte) string {
return base64.StdEncoding.EncodeToString(data)
}
// Утилитарная функция для base64 декодирования
func (cs *CryptoService) DecodeBase64(data string) ([]byte, error) {
return base64.StdEncoding.DecodeString(data)
}

View File

@ -71,8 +71,8 @@ func (v *VaultManager) GetMasterPrivateKey() ([]byte, error) {
return nil, fmt.Errorf("invalid data format in Vault secret")
}
// Получаем приватный ключ в base64 - ОБРАТИТЕ ВНИМАНИЕ НА НОВОЕ ИМЯ ПОЛЯ!
keyInterface, ok := data["private_key_base64"] // ← Изменилось здесь!
// Получаем приватный ключ в base64
keyInterface, ok := data["private_key_base64"]
if !ok {
return nil, fmt.Errorf("private_key_base64 not found in Vault data")
}

View File

@ -1,3 +0,0 @@
ALTER TABLE messages
DROP COLUMN encrypted_key,
DROP COLUMN nonce;

View File

@ -1,18 +1,9 @@
-- Добавляем колонки для хранения зашифрованных данных
ALTER TABLE messages
ADD COLUMN encrypted_key BYTEA,
ADD COLUMN nonce BYTEA;
ADD COLUMN encrypted_content BYTEA NOT NULL,
ADD COLUMN encrypted_key BYTEA NOT NULL,
ADD COLUMN nonce BYTEA NOT NULL,
DROP COLUMN content;
-- Обновляем существующие сообщения (если есть)
UPDATE messages SET
encrypted_key = ''::bytea,
nonce = ''::bytea
WHERE encrypted_key IS NULL OR nonce IS NULL;
-- Делаем поля обязательными
ALTER TABLE messages
ALTER COLUMN encrypted_key SET NOT NULL,
ALTER COLUMN nonce SET NOT NULL;
ALTER TABLE messages
ALTER COLUMN content TYPE BYTEA USING content::bytea,
ALTER COLUMN encrypted_key TYPE BYTEA USING encrypted_key::bytea,
ALTER COLUMN nonce TYPE BYTEA USING nonce::bytea;
-- Обновляем индексы (если нужно)
CREATE INDEX idx_messages_encrypted ON messages(encrypted_content);

235
server.go
View File

@ -3,7 +3,6 @@ package main
import (
"context"
"database/sql"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
@ -18,7 +17,7 @@ import (
"net"
"os"
"sync"
crypto2 "tailly_messages/crypto"
"tailly_messages/crypto"
"tailly_messages/proto"
"time"
)
@ -27,21 +26,15 @@ type server struct {
proto.UnimplementedMessageServiceServer
db *pgxpool.Pool
rabbitConn *amqp.Connection
crypto *crypto2.CryptoService
mu sync.Mutex
logger *log.Logger
crypto *crypto.CryptoService
}
func NewServer(db *pgxpool.Pool, rabbitConn *amqp.Connection) *server {
cryptoService, err := crypto2.NewCryptoService()
if err != nil {
log.Fatalf("Failed to initialize crypto service: %v", err)
}
return &server{
db: db,
rabbitConn: rabbitConn,
crypto: cryptoService,
logger: log.New(os.Stdout, "MESSAGE_SERVICE: ", log.LstdFlags|log.Lmicroseconds|log.Lshortfile),
}
}
@ -86,7 +79,7 @@ func (s *server) CreateChat(ctx context.Context, req *proto.CreateChatRequest) (
errMsg := fmt.Sprintf("One or both users don't exist: user1=%d (%t), user2=%d (%t)",
user1, user1Exists, user2, user2Exists)
s.logger.Println(errMsg)
return nil, fmt.Errorf("%w", errors.New(errMsg))
return nil, fmt.Errorf("%w", errors.New(errMsg)) // Оборачиваем ошибку
}
var chat proto.Chat
@ -139,44 +132,13 @@ func (s *server) SendMessage(ctx context.Context, req *proto.SendMessageRequest)
s.logger.Printf("SendMessage execution time: %v", time.Since(start))
}(time.Now())
// Шифрование сообщения
s.logger.Printf("Encrypting message for chat_id=%d", req.ChatId)
// Шифруем сообщение
encryptedContent, encryptedKey, nonce, err := s.crypto.EncryptMessage(req.Content)
if err != nil {
s.logger.Printf("Failed to encrypt message: %v", err)
return nil, fmt.Errorf("failed to encrypt message: %v", err)
}
// Кодируем в base64 для хранения в TEXT поле
encodedContent := base64.StdEncoding.EncodeToString(encryptedContent)
encodedKey := base64.StdEncoding.EncodeToString(encryptedKey)
encodedNonce := base64.StdEncoding.EncodeToString(nonce)
s.logger.Printf("Base64 encoded: content_len=%d, key_len=%d, nonce_len=%d",
len(encodedContent), len(encodedKey), len(encodedNonce))
// Тестовая расшифровка сразу после шифрования
testDecrypted, err := s.crypto.DecryptMessage(encryptedContent, encryptedKey, nonce)
if err != nil {
s.logger.Printf("IMMEDIATE DECRYPTION TEST FAILED: %v", err)
return nil, fmt.Errorf("encryption/decryption test failed: %v", err)
}
if testDecrypted != req.Content {
s.logger.Printf("IMMEDIATE DECRYPTION CONTENT MISMATCH")
return nil, fmt.Errorf("encryption/decryption content mismatch")
}
s.logger.Printf("Immediate decryption test passed")
// Сохранение сессионного ключа в Vault
s.logger.Printf("Storing session key in Vault for chat_id=%d", req.ChatId)
err = s.crypto.Vault.StoreSessionKey(int(req.ChatId), 0, encryptedKey)
if err != nil {
s.logger.Printf("Failed to store session key: %v", err)
return nil, fmt.Errorf("failed to store session key: %v", err)
}
s.logger.Printf("Getting chat info for chat_id=%d", req.ChatId)
var user1Id, user2Id int32
err = s.db.QueryRow(ctx, "SELECT user1_id, user2_id FROM chats WHERE id = $1", req.ChatId).Scan(&user1Id, &user2Id)
@ -201,27 +163,28 @@ func (s *server) SendMessage(ctx context.Context, req *proto.SendMessageRequest)
var message proto.Message
var createdAt time.Time
// Используем закодированные base64 строки
err = s.db.QueryRow(ctx, `
INSERT INTO messages (chat_id, sender_id, receiver_id, content, encrypted_key, nonce)
INSERT INTO messages (chat_id, sender_id, receiver_id, encrypted_content, encrypted_key, nonce)
VALUES ($1, $2, $3, $4, $5, $6)
RETURNING id, chat_id, sender_id, receiver_id, status, created_at
`, req.ChatId, req.SenderId, receiverId, encodedContent, encodedKey, encodedNonce).Scan(
`, req.ChatId, req.SenderId, receiverId, encryptedContent, encryptedKey, nonce).Scan(
&message.Id, &message.ChatId, &message.SenderId, &message.ReceiverId, &message.Status, &createdAt,
)
if err != nil {
s.logger.Printf("Failed to insert message: %v", err)
s.logger.Printf("Failed to insert encrypted message: %v", err)
return nil, err
}
message.CreatedAt = timestamppb.New(createdAt)
// Обновление сессионного ключа в Vault с правильным messageID
s.logger.Printf("Updating session key in Vault with actual message_id=%d", message.Id)
err = s.crypto.Vault.StoreSessionKey(int(req.ChatId), int(message.Id), encryptedKey)
// Для RabbitMQ отправляем расшифрованное сообщение (или зашифрованное, в зависимости от требований)
// Здесь отправляем расшифрованное для совместимости с существующими клиентами
decryptedContent, err := s.crypto.DecryptMessage(encryptedContent, encryptedKey, nonce)
if err != nil {
s.logger.Printf("Failed to update session key with message ID: %v", err)
s.logger.Printf("Failed to decrypt for RabbitMQ: %v", err)
return nil, err
}
message.Content = decryptedContent
s.logger.Printf("Updating chat updated_at for chat_id=%d", req.ChatId)
_, err = s.db.Exec(ctx, `UPDATE chats SET updated_at = NOW() WHERE id = $1`, req.ChatId)
@ -230,24 +193,7 @@ func (s *server) SendMessage(ctx context.Context, req *proto.SendMessageRequest)
return nil, err
}
// Расшифровка для отправки через RabbitMQ
decryptedContent, err := s.crypto.DecryptMessage(encryptedContent, encryptedKey, nonce)
if err != nil {
s.logger.Printf("Failed to decrypt message for RabbitMQ: %v", err)
return nil, fmt.Errorf("failed to decrypt message: %v", err)
}
rabbitMsg := proto.Message{
Id: message.Id,
ChatId: message.ChatId,
SenderId: message.SenderId,
ReceiverId: message.ReceiverId,
Content: decryptedContent,
Status: message.Status,
CreatedAt: message.CreatedAt,
}
s.logger.Printf("Publishing decrypted message to RabbitMQ for user_id=%d", receiverId)
s.logger.Printf("Publishing message to RabbitMQ for user_id=%d", receiverId)
var lastErr error
for i := 0; i < 3; i++ {
ch, err := s.rabbitConn.Channel()
@ -259,7 +205,7 @@ func (s *server) SendMessage(ctx context.Context, req *proto.SendMessageRequest)
}
queueName := fmt.Sprintf("user_%d_messages", receiverId)
msgBytes, _ := json.Marshal(rabbitMsg)
msgBytes, _ := json.Marshal(message)
s.logger.Printf("Publishing to queue %s: %s", queueName, string(msgBytes))
err = ch.PublishWithContext(ctx,
@ -298,7 +244,7 @@ func (s *server) GetChat(ctx context.Context, req *proto.GetChatRequest) (*proto
var chat proto.Chat
var createdAt, updatedAt time.Time
var lastMessageID sql.NullInt32
var lastMessageContentBase64, lastMessageEncryptedKeyBase64, lastMessageNonceBase64 sql.NullString
var lastMessageEncryptedContent, lastMessageEncryptedKey, lastMessageNonce []byte
var lastMessageStatus sql.NullString
var lastMessageCreatedAt sql.NullTime
@ -310,7 +256,7 @@ func (s *server) GetChat(ctx context.Context, req *proto.GetChatRequest) (*proto
s.logger.Printf("Querying chat between users %d and %d", user1, user2)
err := s.db.QueryRow(ctx, `
SELECT c.id, c.user1_id, c.user2_id, c.created_at, c.updated_at,
m.id, m.content, m.encrypted_key, m.nonce, m.status, m.created_at
m.id, m.encrypted_content, m.encrypted_key, m.nonce, m.status, m.created_at
FROM chats c
LEFT JOIN messages m ON m.id = (
SELECT id FROM messages WHERE chat_id = c.id
@ -319,8 +265,8 @@ func (s *server) GetChat(ctx context.Context, req *proto.GetChatRequest) (*proto
WHERE c.user1_id = $1 AND c.user2_id = $2
`, user1, user2).Scan(
&chat.Id, &chat.User1Id, &chat.User2Id, &createdAt, &updatedAt,
&lastMessageID, &lastMessageContentBase64, &lastMessageEncryptedKeyBase64,
&lastMessageNonceBase64, &lastMessageStatus, &lastMessageCreatedAt,
&lastMessageID, &lastMessageEncryptedContent, &lastMessageEncryptedKey,
&lastMessageNonce, &lastMessageStatus, &lastMessageCreatedAt,
)
if err != nil {
s.logger.Printf("Failed to get chat: %v", err)
@ -332,39 +278,15 @@ func (s *server) GetChat(ctx context.Context, req *proto.GetChatRequest) (*proto
if lastMessageID.Valid {
// Расшифровываем последнее сообщение
var decryptedContent string
if lastMessageContentBase64.Valid && lastMessageEncryptedKeyBase64.Valid && lastMessageNonceBase64.Valid {
// Декодируем из base64
encryptedContent, err := base64.StdEncoding.DecodeString(lastMessageContentBase64.String)
decryptedContent, err := s.crypto.DecryptMessage(
lastMessageEncryptedContent,
lastMessageEncryptedKey,
lastMessageNonce,
)
if err != nil {
s.logger.Printf("Failed to decode content for last message %d: %v", lastMessageID.Int32, err)
decryptedContent = "[decode error: content]"
} else {
encryptedKey, err := base64.StdEncoding.DecodeString(lastMessageEncryptedKeyBase64.String)
if err != nil {
s.logger.Printf("Failed to decode key for last message %d: %v", lastMessageID.Int32, err)
decryptedContent = "[decode error: key]"
} else {
nonce, err := base64.StdEncoding.DecodeString(lastMessageNonceBase64.String)
if err != nil {
s.logger.Printf("Failed to decode nonce for last message %d: %v", lastMessageID.Int32, err)
decryptedContent = "[decode error: nonce]"
} else {
// Расшифровка сообщения
decrypted, err := s.crypto.DecryptMessage(encryptedContent, encryptedKey, nonce)
if err != nil {
s.logger.Printf("Failed to decrypt last message %d: %v", lastMessageID.Int32, err)
decryptedContent = "[encrypted message]"
} else {
decryptedContent = decrypted
s.logger.Printf("Successfully decrypted last message for chat %d", chat.Id)
}
}
}
}
} else {
decryptedContent = "[no message content]"
s.logger.Printf("Failed to decrypt last message: %v", err)
// Можно вернуть ошибку или пустое сообщение
decryptedContent = "[не удалось расшифровать]"
}
chat.LastMessage = &proto.Message{
@ -390,7 +312,7 @@ func (s *server) GetChatMessages(ctx context.Context, req *proto.GetChatMessages
s.logger.Printf("Querying messages for chat_id=%d, limit=%d, offset=%d", req.ChatId, req.Limit, req.Offset)
rows, err := s.db.Query(ctx, `
SELECT id, chat_id, sender_id, receiver_id, content, encrypted_key, nonce, status, created_at
SELECT id, chat_id, sender_id, receiver_id, encrypted_content, encrypted_key, nonce, status, created_at
FROM messages
WHERE chat_id = $1
ORDER BY created_at DESC
@ -406,69 +328,25 @@ func (s *server) GetChatMessages(ctx context.Context, req *proto.GetChatMessages
for rows.Next() {
var msg proto.Message
var createdAt time.Time
var encryptedKeyBase64, nonceBase64, encryptedContentBase64 string
var encryptedContent, encryptedKey, nonce []byte
err := rows.Scan(
&msg.Id, &msg.ChatId, &msg.SenderId, &msg.ReceiverId,
&encryptedContentBase64, &encryptedKeyBase64, &nonceBase64, &msg.Status, &createdAt,
&encryptedContent, &encryptedKey, &nonce,
&msg.Status, &createdAt,
)
if err != nil {
s.logger.Printf("Failed to scan message row: %v", err)
return nil, err
}
s.logger.Printf("Processing message %d: content_len=%d, key_len=%d, nonce_len=%d",
msg.Id, len(encryptedContentBase64), len(encryptedKeyBase64), len(nonceBase64))
// Декодируем из base64
encryptedContent, err := base64.StdEncoding.DecodeString(encryptedContentBase64)
if err != nil {
s.logger.Printf("Failed to decode content for message %d: %v", msg.Id, err)
msg.Content = "[base64 decode error: content]"
continue
}
encryptedKey, err := base64.StdEncoding.DecodeString(encryptedKeyBase64)
if err != nil {
s.logger.Printf("Failed to decode key for message %d: %v", msg.Id, err)
msg.Content = "[base64 decode error: key]"
continue
}
nonce, err := base64.StdEncoding.DecodeString(nonceBase64)
if err != nil {
s.logger.Printf("Failed to decode nonce for message %d: %v", msg.Id, err)
msg.Content = "[base64 decode error: nonce]"
continue
}
// Проверяем что данные не пустые после декодирования
if len(encryptedContent) == 0 {
s.logger.Printf("Empty content after base64 decode for message %d", msg.Id)
msg.Content = "[empty content after decode]"
continue
}
if len(encryptedKey) == 0 {
s.logger.Printf("Empty key after base64 decode for message %d", msg.Id)
msg.Content = "[empty key after decode]"
continue
}
if len(nonce) == 0 {
s.logger.Printf("Empty nonce after base64 decode for message %d", msg.Id)
msg.Content = "[empty nonce after decode]"
continue
}
// Расшифровка сообщения
// Расшифровываем сообщение
decryptedContent, err := s.crypto.DecryptMessage(encryptedContent, encryptedKey, nonce)
if err != nil {
s.logger.Printf("Failed to decrypt message %d: %v", msg.Id, err)
s.logger.Printf("Decryption failed details: content_len=%d, key_len=%d, nonce_len=%d",
len(encryptedContent), len(encryptedKey), len(nonce))
msg.Content = "[decryption failed: " + err.Error() + "]"
msg.Content = "[не удалось расшифровать]"
} else {
msg.Content = decryptedContent
s.logger.Printf("Successfully decrypted message %d", msg.Id)
}
msg.CreatedAt = timestamppb.New(createdAt)
@ -480,7 +358,7 @@ func (s *server) GetChatMessages(ctx context.Context, req *proto.GetChatMessages
return nil, err
}
s.logger.Printf("Retrieved and decrypted %d messages for chat_id=%d", len(messages), req.ChatId)
s.logger.Printf("Retrieved %d messages for chat_id=%d", len(messages), req.ChatId)
resp := &proto.MessagesResponse{Messages: messages}
s.logResponse("GetChatMessages", resp, nil)
return resp, nil
@ -495,7 +373,7 @@ func (s *server) GetUserChats(ctx context.Context, req *proto.GetUserChatsReques
s.logger.Printf("Querying chats for user_id=%d", req.UserId)
rows, err := s.db.Query(ctx, `
SELECT c.id, c.user1_id, c.user2_id, c.created_at, c.updated_at,
m.id, m.content, m.encrypted_key, m.nonce, m.status, m.created_at
m.id, m.content, m.status, m.created_at
FROM chats c
LEFT JOIN messages m ON m.id = (
SELECT id FROM messages WHERE chat_id = c.id
@ -515,14 +393,13 @@ func (s *server) GetUserChats(ctx context.Context, req *proto.GetUserChatsReques
var chat proto.Chat
var createdAt, updatedAt time.Time
var lastMessageID sql.NullInt32
var lastMessageContentBase64, lastMessageEncryptedKeyBase64, lastMessageNonceBase64 sql.NullString
var lastMessageContent sql.NullString
var lastMessageStatus sql.NullString
var lastMessageCreatedAt sql.NullTime
err := rows.Scan(
&chat.Id, &chat.User1Id, &chat.User2Id, &createdAt, &updatedAt,
&lastMessageID, &lastMessageContentBase64, &lastMessageEncryptedKeyBase64,
&lastMessageNonceBase64, &lastMessageStatus, &lastMessageCreatedAt,
&lastMessageID, &lastMessageContent, &lastMessageStatus, &lastMessageCreatedAt,
)
if err != nil {
s.logger.Printf("Failed to scan chat row: %v", err)
@ -533,46 +410,10 @@ func (s *server) GetUserChats(ctx context.Context, req *proto.GetUserChatsReques
chat.UpdatedAt = timestamppb.New(updatedAt)
if lastMessageID.Valid {
// Расшифровываем последнее сообщение
var decryptedContent string
if lastMessageContentBase64.Valid && lastMessageEncryptedKeyBase64.Valid && lastMessageNonceBase64.Valid {
// Декодируем из base64
encryptedContent, err := base64.StdEncoding.DecodeString(lastMessageContentBase64.String)
if err != nil {
s.logger.Printf("Failed to decode content for last message %d: %v", lastMessageID.Int32, err)
decryptedContent = "[decode error: content]"
} else {
encryptedKey, err := base64.StdEncoding.DecodeString(lastMessageEncryptedKeyBase64.String)
if err != nil {
s.logger.Printf("Failed to decode key for last message %d: %v", lastMessageID.Int32, err)
decryptedContent = "[decode error: key]"
} else {
nonce, err := base64.StdEncoding.DecodeString(lastMessageNonceBase64.String)
if err != nil {
s.logger.Printf("Failed to decode nonce for last message %d: %v", lastMessageID.Int32, err)
decryptedContent = "[decode error: nonce]"
} else {
// Расшифровка сообщения
decrypted, err := s.crypto.DecryptMessage(encryptedContent, encryptedKey, nonce)
if err != nil {
s.logger.Printf("Failed to decrypt last message %d: %v", lastMessageID.Int32, err)
decryptedContent = "[encrypted message]"
} else {
decryptedContent = decrypted
s.logger.Printf("Successfully decrypted last message for chat %d", chat.Id)
}
}
}
}
} else {
decryptedContent = "[no message content]"
}
chat.LastMessage = &proto.Message{
Id: lastMessageID.Int32,
ChatId: chat.Id,
Content: decryptedContent,
Content: lastMessageContent.String,
Status: lastMessageStatus.String,
CreatedAt: timestamppb.New(lastMessageCreatedAt.Time),
}