mirror of
https://github.com/blackboxprogramming/BlackRoad-Operating-System.git
synced 2026-03-17 01:57:11 -05:00
107 lines
3.6 KiB
Python
107 lines
3.6 KiB
Python
"""Utilities for encrypting wallet secrets."""
|
|
from __future__ import annotations
|
|
|
|
import base64
|
|
import hashlib
|
|
import string
|
|
from dataclasses import dataclass
|
|
from typing import Tuple
|
|
|
|
from cryptography.fernet import Fernet, InvalidToken
|
|
from sqlalchemy import select
|
|
|
|
from app.config import settings
|
|
from app.database import AsyncSessionLocal
|
|
from app.models.blockchain import Wallet
|
|
from app.models.user import User
|
|
|
|
|
|
class WalletKeyEncryptionError(RuntimeError):
|
|
"""Raised when a wallet key cannot be encrypted."""
|
|
|
|
|
|
class WalletKeyDecryptionError(RuntimeError):
|
|
"""Raised when a wallet key cannot be decrypted."""
|
|
|
|
|
|
@dataclass
|
|
class WalletCryptoService:
|
|
"""Simple Fernet wrapper used for wallet key protection."""
|
|
|
|
master_key: str
|
|
|
|
def __post_init__(self) -> None:
|
|
if not self.master_key or len(self.master_key) < 32:
|
|
raise ValueError("WALLET_MASTER_KEY must be at least 32 characters long")
|
|
derived_key = self._derive_key(self.master_key)
|
|
self._fernet = Fernet(derived_key)
|
|
|
|
@staticmethod
|
|
def _derive_key(master_key: str) -> bytes:
|
|
"""Derive a url-safe 32-byte key for Fernet from the provided secret."""
|
|
digest = hashlib.sha256(master_key.encode("utf-8")).digest()
|
|
return base64.urlsafe_b64encode(digest)
|
|
|
|
def encrypt(self, plaintext: str) -> str:
|
|
"""Encrypt a wallet secret."""
|
|
if plaintext is None:
|
|
raise WalletKeyEncryptionError("Cannot encrypt an empty wallet key")
|
|
try:
|
|
token = self._fernet.encrypt(plaintext.encode("utf-8"))
|
|
return token.decode("utf-8")
|
|
except Exception as exc: # pragma: no cover - defensive
|
|
raise WalletKeyEncryptionError("Unable to encrypt wallet key") from exc
|
|
|
|
def decrypt(self, token: str) -> str:
|
|
"""Decrypt a wallet secret."""
|
|
if token is None:
|
|
raise WalletKeyDecryptionError("Wallet key is missing")
|
|
try:
|
|
plaintext = self._fernet.decrypt(token.encode("utf-8"))
|
|
return plaintext.decode("utf-8")
|
|
except (InvalidToken, ValueError, TypeError) as exc:
|
|
raise WalletKeyDecryptionError("Unable to decrypt wallet key") from exc
|
|
|
|
|
|
wallet_crypto = WalletCryptoService(settings.WALLET_MASTER_KEY)
|
|
|
|
|
|
def _looks_like_plaintext_key(value: str | None) -> bool:
|
|
if not value:
|
|
return False
|
|
normalized = value.strip()
|
|
return len(normalized) == 64 and all(ch in string.hexdigits for ch in normalized)
|
|
|
|
|
|
async def rotate_plaintext_wallet_keys() -> Tuple[int, int]:
|
|
"""Encrypt plaintext wallet keys that may still exist in the database."""
|
|
updated_users = 0
|
|
updated_wallets = 0
|
|
|
|
async with AsyncSessionLocal() as session:
|
|
user_result = await session.execute(select(User).where(User.wallet_private_key.is_not(None)))
|
|
for user in user_result.scalars():
|
|
if _looks_like_plaintext_key(user.wallet_private_key):
|
|
user.wallet_private_key = wallet_crypto.encrypt(user.wallet_private_key)
|
|
updated_users += 1
|
|
|
|
wallet_result = await session.execute(select(Wallet).where(Wallet.private_key.is_not(None)))
|
|
for wallet in wallet_result.scalars():
|
|
if _looks_like_plaintext_key(wallet.private_key):
|
|
wallet.private_key = wallet_crypto.encrypt(wallet.private_key)
|
|
updated_wallets += 1
|
|
|
|
if updated_users or updated_wallets:
|
|
await session.commit()
|
|
|
|
return updated_users, updated_wallets
|
|
|
|
|
|
__all__ = [
|
|
"WalletCryptoService",
|
|
"WalletKeyEncryptionError",
|
|
"WalletKeyDecryptionError",
|
|
"wallet_crypto",
|
|
"rotate_plaintext_wallet_keys",
|
|
]
|