Add wallet key encryption service

This commit is contained in:
Alexa Amundson
2025-11-16 01:47:22 -06:00
parent b2f933762a
commit 1aa9329491
7 changed files with 149 additions and 9 deletions

View File

@@ -8,6 +8,7 @@ from sqlalchemy import select, desc
from app.models.blockchain import Block, Transaction, Wallet
from app.models.user import User
from app.config import settings
from app.services.crypto import wallet_crypto, WalletKeyDecryptionError
import secrets
@@ -140,9 +141,16 @@ class BlockchainService:
from_address: str,
to_address: str,
amount: float,
private_key: str
encrypted_private_key: str
) -> Transaction:
"""Create a new transaction"""
try:
private_key = wallet_crypto.decrypt(encrypted_private_key)
except WalletKeyDecryptionError as exc:
raise WalletKeyDecryptionError(
"Unable to decrypt wallet key for transaction"
) from exc
# Generate transaction hash
tx_data = f"{from_address}{to_address}{amount}{datetime.utcnow()}"
transaction_hash = hashlib.sha256(tx_data.encode()).hexdigest()

View File

@@ -0,0 +1,106 @@
"""Utilities for encrypting wallet secrets."""
from __future__ import annotations
import base64
import hashlib
import string
from dataclasses import dataclass
from typing import Tuple
from cryptography.fernet import Fernet, InvalidToken
from sqlalchemy import select
from app.config import settings
from app.database import AsyncSessionLocal
from app.models.blockchain import Wallet
from app.models.user import User
class WalletKeyEncryptionError(RuntimeError):
"""Raised when a wallet key cannot be encrypted."""
class WalletKeyDecryptionError(RuntimeError):
"""Raised when a wallet key cannot be decrypted."""
@dataclass
class WalletCryptoService:
"""Simple Fernet wrapper used for wallet key protection."""
master_key: str
def __post_init__(self) -> None:
if not self.master_key or len(self.master_key) < 32:
raise ValueError("WALLET_MASTER_KEY must be at least 32 characters long")
derived_key = self._derive_key(self.master_key)
self._fernet = Fernet(derived_key)
@staticmethod
def _derive_key(master_key: str) -> bytes:
"""Derive a url-safe 32-byte key for Fernet from the provided secret."""
digest = hashlib.sha256(master_key.encode("utf-8")).digest()
return base64.urlsafe_b64encode(digest)
def encrypt(self, plaintext: str) -> str:
"""Encrypt a wallet secret."""
if plaintext is None:
raise WalletKeyEncryptionError("Cannot encrypt an empty wallet key")
try:
token = self._fernet.encrypt(plaintext.encode("utf-8"))
return token.decode("utf-8")
except Exception as exc: # pragma: no cover - defensive
raise WalletKeyEncryptionError("Unable to encrypt wallet key") from exc
def decrypt(self, token: str) -> str:
"""Decrypt a wallet secret."""
if token is None:
raise WalletKeyDecryptionError("Wallet key is missing")
try:
plaintext = self._fernet.decrypt(token.encode("utf-8"))
return plaintext.decode("utf-8")
except (InvalidToken, ValueError, TypeError) as exc:
raise WalletKeyDecryptionError("Unable to decrypt wallet key") from exc
wallet_crypto = WalletCryptoService(settings.WALLET_MASTER_KEY)
def _looks_like_plaintext_key(value: str | None) -> bool:
if not value:
return False
normalized = value.strip()
return len(normalized) == 64 and all(ch in string.hexdigits for ch in normalized)
async def rotate_plaintext_wallet_keys() -> Tuple[int, int]:
"""Encrypt plaintext wallet keys that may still exist in the database."""
updated_users = 0
updated_wallets = 0
async with AsyncSessionLocal() as session:
user_result = await session.execute(select(User).where(User.wallet_private_key.is_not(None)))
for user in user_result.scalars():
if _looks_like_plaintext_key(user.wallet_private_key):
user.wallet_private_key = wallet_crypto.encrypt(user.wallet_private_key)
updated_users += 1
wallet_result = await session.execute(select(Wallet).where(Wallet.private_key.is_not(None)))
for wallet in wallet_result.scalars():
if _looks_like_plaintext_key(wallet.private_key):
wallet.private_key = wallet_crypto.encrypt(wallet.private_key)
updated_wallets += 1
if updated_users or updated_wallets:
await session.commit()
return updated_users, updated_wallets
__all__ = [
"WalletCryptoService",
"WalletKeyEncryptionError",
"WalletKeyDecryptionError",
"wallet_crypto",
"rotate_plaintext_wallet_keys",
]