Source code for wallet_manager

"""
Ethereum Wallet Manager Module

Handles HD wallet creation, derivation, and encrypted storage in Redis.
Supports any EVM-compatible network through configurable RPC endpoints.
Uses BIP39 mnemonics and BIP44 derivation paths (m/44'/60'/0'/0/x).
"""

import os
import json
import base64
import hashlib
import logging
from typing import Optional, Dict, Any, Tuple, List
from datetime import datetime

from cryptography.hazmat.primitives.ciphers.aead import AESGCM
from cryptography.exceptions import InvalidTag
from wallet_key_utils import ensure_master_key

logger = logging.getLogger(__name__)

WALLET_KEY_PREFIX = "eth_wallet:"
WALLET_INDEX_PREFIX = "eth_wallet_index:"
WALLET_MASTER_KEY_ENV = "ETH_WALLET_MASTER_KEY"
WALLET_MASTER_REDIS_KEY = "wallet_master_key:eth"


[docs] class WalletManager: """ Manages Ethereum HD wallets with encrypted storage in Redis. Features: - BIP39 mnemonic generation and import - BIP44 derivation (m/44'/60'/0'/0/x) - AES-256-GCM encryption for private keys at rest - Per-user wallet isolation - Address caching for derived addresses Unlike the v2 singleton, this class accepts an async Redis client via its async methods so it can work with the v3 ToolContext.redis. """
[docs] def __init__(self): """Initialize the instance. """ self._master_key: Optional[bytes] = None
async def _ensure_master_key(self, redis_client) -> None: """Lazily load or generate the master key, persisting to Redis.""" self._master_key = await ensure_master_key( self._master_key, redis_client, WALLET_MASTER_REDIS_KEY, WALLET_MASTER_KEY_ENV, ) def _derive_wallet_key(self, user_id: str, wallet_name: str) -> bytes: """Internal helper: derive wallet key. Args: user_id (str): Unique identifier for the user. wallet_name (str): The wallet name value. Returns: bytes: The result. """ salt = f"{user_id}:{wallet_name}".encode('utf-8') derived = hashlib.pbkdf2_hmac( 'sha256', self._master_key, salt, iterations=100000, dklen=32 ) return derived def _encrypt(self, key: bytes, plaintext: str) -> str: """Internal helper: encrypt. Args: key (bytes): Dictionary or cache key. plaintext (str): The plaintext value. Returns: str: Result string. """ aesgcm = AESGCM(key) nonce = os.urandom(12) plaintext_bytes = plaintext.encode('utf-8') ciphertext = aesgcm.encrypt(nonce, plaintext_bytes, None) combined = nonce + ciphertext return base64.urlsafe_b64encode(combined).decode('utf-8') def _decrypt(self, key: bytes, encrypted_data: str) -> str: """Internal helper: decrypt. Args: key (bytes): Dictionary or cache key. encrypted_data (str): The encrypted data value. Returns: str: Result string. """ combined = base64.urlsafe_b64decode(encrypted_data.encode('utf-8')) if len(combined) < 28: raise ValueError("Encrypted data too short") nonce = combined[:12] ciphertext = combined[12:] aesgcm = AESGCM(key) plaintext_bytes = aesgcm.decrypt(nonce, ciphertext, None) return plaintext_bytes.decode('utf-8')
[docs] @staticmethod def generate_mnemonic(strength: int = 128) -> str: """Generate mnemonic. Args: strength (int): The strength value. Returns: str: Result string. """ from mnemonic import Mnemonic mnemo = Mnemonic("english") return mnemo.generate(strength=strength)
[docs] @staticmethod def validate_mnemonic(mnemonic: str) -> bool: """Validate the mnemonic. Args: mnemonic (str): The mnemonic value. Returns: bool: True on success, False otherwise. """ from mnemonic import Mnemonic mnemo = Mnemonic("english") return mnemo.check(mnemonic)
[docs] @staticmethod def derive_address_from_mnemonic(mnemonic: str, index: int = 0) -> Tuple[str, str]: """Derive address from mnemonic. Args: mnemonic (str): The mnemonic value. index (int): The index value. Returns: Tuple[str, str]: The result. """ from eth_account import Account Account.enable_unaudited_hdwallet_features() path = f"m/44'/60'/0'/0/{index}" account = Account.from_mnemonic(mnemonic, account_path=path) return account.address, account.key.hex()
[docs] @staticmethod def derive_address_from_private_key(private_key: str) -> str: """Derive address from private key. Args: private_key (str): The private key value. Returns: str: Result string. """ from eth_account import Account if not private_key.startswith('0x'): private_key = '0x' + private_key account = Account.from_key(private_key) return account.address
[docs] @staticmethod def is_valid_private_key(private_key: str) -> bool: """Check whether is valid private key. Args: private_key (str): The private key value. Returns: bool: True on success, False otherwise. """ try: from eth_account import Account if not private_key.startswith('0x'): private_key = '0x' + private_key Account.from_key(private_key) return True except Exception: return False
[docs] async def create_wallet( self, user_id: str, wallet_name: str, redis_client, mnemonic: Optional[str] = None, strength: int = 128 ) -> Dict[str, Any]: """Create a new wallet. Args: user_id (str): Unique identifier for the user. wallet_name (str): The wallet name value. redis_client: Redis connection client. mnemonic (Optional[str]): The mnemonic value. strength (int): The strength value. Returns: Dict[str, Any]: The result. """ wallet_name = wallet_name.strip().lower() if not wallet_name or len(wallet_name) > 32: raise ValueError("Wallet name must be 1-32 characters") if not wallet_name.replace('_', '').replace('-', '').isalnum(): raise ValueError("Wallet name must be alphanumeric (with - or _)") await self._ensure_master_key(redis_client) if await self.wallet_exists(user_id, wallet_name, redis_client): raise ValueError(f"Wallet '{wallet_name}' already exists") if mnemonic: if not self.validate_mnemonic(mnemonic): raise ValueError("Invalid mnemonic phrase") else: mnemonic = self.generate_mnemonic(strength=strength) address, _ = self.derive_address_from_mnemonic(mnemonic, index=0) encryption_key = self._derive_wallet_key(user_id, wallet_name) encrypted_seed = self._encrypt(encryption_key, mnemonic) wallet_data = { "type": "hd", "encrypted_seed": encrypted_seed, "addresses": {"0": address}, "created_at": datetime.utcnow().isoformat() } wallet_key = f"{WALLET_KEY_PREFIX}{user_id}:{wallet_name}" await redis_client.set(wallet_key, json.dumps(wallet_data)) index_key = f"{WALLET_INDEX_PREFIX}{user_id}" await redis_client.sadd(index_key, wallet_name) logger.info(f"Created HD wallet '{wallet_name}' for user {user_id}") return { "wallet_name": wallet_name, "type": "hd", "address": address, "created_at": wallet_data["created_at"] }
[docs] async def import_private_key( self, user_id: str, wallet_name: str, private_key: str, redis_client, ) -> Dict[str, Any]: """Import private key. Args: user_id (str): Unique identifier for the user. wallet_name (str): The wallet name value. private_key (str): The private key value. redis_client: Redis connection client. Returns: Dict[str, Any]: The result. """ wallet_name = wallet_name.strip().lower() if not wallet_name or len(wallet_name) > 32: raise ValueError("Wallet name must be 1-32 characters") if not wallet_name.replace('_', '').replace('-', '').isalnum(): raise ValueError("Wallet name must be alphanumeric (with - or _)") await self._ensure_master_key(redis_client) if await self.wallet_exists(user_id, wallet_name, redis_client): raise ValueError(f"Wallet '{wallet_name}' already exists") if not private_key.startswith('0x'): private_key = '0x' + private_key if not self.is_valid_private_key(private_key): raise ValueError("Invalid private key") address = self.derive_address_from_private_key(private_key) encryption_key = self._derive_wallet_key(user_id, wallet_name) encrypted_seed = self._encrypt(encryption_key, private_key) wallet_data = { "type": "simple", "encrypted_seed": encrypted_seed, "addresses": {"0": address}, "created_at": datetime.utcnow().isoformat() } wallet_key = f"{WALLET_KEY_PREFIX}{user_id}:{wallet_name}" await redis_client.set(wallet_key, json.dumps(wallet_data)) index_key = f"{WALLET_INDEX_PREFIX}{user_id}" await redis_client.sadd(index_key, wallet_name) logger.info(f"Imported simple wallet '{wallet_name}' for user {user_id}") return { "wallet_name": wallet_name, "type": "simple", "address": address, "created_at": wallet_data["created_at"] }
[docs] async def wallet_exists(self, user_id: str, wallet_name: str, redis_client) -> bool: """Wallet exists. Args: user_id (str): Unique identifier for the user. wallet_name (str): The wallet name value. redis_client: Redis connection client. Returns: bool: True on success, False otherwise. """ wallet_key = f"{WALLET_KEY_PREFIX}{user_id}:{wallet_name.strip().lower()}" return await redis_client.exists(wallet_key) > 0
[docs] async def get_wallet(self, user_id: str, wallet_name: str, redis_client) -> Optional[Dict[str, Any]]: """Retrieve the wallet. Args: user_id (str): Unique identifier for the user. wallet_name (str): The wallet name value. redis_client: Redis connection client. Returns: Optional[Dict[str, Any]]: The result. """ wallet_key = f"{WALLET_KEY_PREFIX}{user_id}:{wallet_name.strip().lower()}" data = await redis_client.get(wallet_key) if not data: return None wallet_data = json.loads(data) return { "wallet_name": wallet_name, "type": wallet_data.get("type", "hd"), "addresses": wallet_data.get("addresses", {}), "created_at": wallet_data.get("created_at") }
[docs] async def get_decrypted_seed(self, user_id: str, wallet_name: str, redis_client) -> Optional[str]: """Retrieve the decrypted seed. Args: user_id (str): Unique identifier for the user. wallet_name (str): The wallet name value. redis_client: Redis connection client. Returns: Optional[str]: The result. """ wallet_name = wallet_name.strip().lower() await self._ensure_master_key(redis_client) wallet_key = f"{WALLET_KEY_PREFIX}{user_id}:{wallet_name}" data = await redis_client.get(wallet_key) if not data: return None wallet_data = json.loads(data) encrypted_seed = wallet_data.get("encrypted_seed") if not encrypted_seed: return None try: encryption_key = self._derive_wallet_key(user_id, wallet_name) return self._decrypt(encryption_key, encrypted_seed) except (InvalidTag, ValueError) as e: logger.error(f"Failed to decrypt wallet seed: {e}") return None
[docs] async def derive_address( self, user_id: str, wallet_name: str, index: int, redis_client, ) -> Optional[str]: """Derive address. Args: user_id (str): Unique identifier for the user. wallet_name (str): The wallet name value. index (int): The index value. redis_client: Redis connection client. Returns: Optional[str]: The result. """ wallet_name = wallet_name.strip().lower() await self._ensure_master_key(redis_client) wallet_key = f"{WALLET_KEY_PREFIX}{user_id}:{wallet_name}" data = await redis_client.get(wallet_key) if not data: return None wallet_data = json.loads(data) addresses = wallet_data.get("addresses", {}) str_index = str(index) if str_index in addresses: return addresses[str_index] if wallet_data.get("type") != "hd": if index == 0: return addresses.get("0") raise ValueError("Cannot derive additional addresses from a simple (non-HD) wallet") seed = await self.get_decrypted_seed(user_id, wallet_name, redis_client) if not seed: return None address, _ = self.derive_address_from_mnemonic(seed, index=index) addresses[str_index] = address wallet_data["addresses"] = addresses await redis_client.set(wallet_key, json.dumps(wallet_data)) return address
[docs] async def get_private_key( self, user_id: str, wallet_name: str, redis_client, index: int = 0, ) -> Optional[str]: """Retrieve the private key. Args: user_id (str): Unique identifier for the user. wallet_name (str): The wallet name value. redis_client: Redis connection client. index (int): The index value. Returns: Optional[str]: The result. """ wallet_name = wallet_name.strip().lower() await self._ensure_master_key(redis_client) seed = await self.get_decrypted_seed(user_id, wallet_name, redis_client) if not seed: return None wallet = await self.get_wallet(user_id, wallet_name, redis_client) if not wallet: return None if wallet["type"] == "simple": if index != 0: raise ValueError("Simple wallets only have one address (index 0)") return seed else: _, private_key = self.derive_address_from_mnemonic(seed, index=index) return private_key
[docs] async def list_wallets(self, user_id: str, redis_client) -> List[Dict[str, Any]]: """List wallets. Args: user_id (str): Unique identifier for the user. redis_client: Redis connection client. Returns: List[Dict[str, Any]]: The result. """ index_key = f"{WALLET_INDEX_PREFIX}{user_id}" wallet_names = await redis_client.smembers(index_key) wallets = [] for name in wallet_names: if isinstance(name, bytes): name = name.decode('utf-8') wallet = await self.get_wallet(user_id, name, redis_client) if wallet: wallets.append(wallet) return sorted(wallets, key=lambda w: w.get("created_at", ""))
[docs] async def delete_wallet(self, user_id: str, wallet_name: str, redis_client) -> bool: """Delete the specified wallet. Args: user_id (str): Unique identifier for the user. wallet_name (str): The wallet name value. redis_client: Redis connection client. Returns: bool: True on success, False otherwise. """ wallet_name = wallet_name.strip().lower() wallet_key = f"{WALLET_KEY_PREFIX}{user_id}:{wallet_name}" deleted = await redis_client.delete(wallet_key) if deleted: index_key = f"{WALLET_INDEX_PREFIX}{user_id}" await redis_client.srem(index_key, wallet_name) logger.info(f"Deleted wallet '{wallet_name}' for user {user_id}") return True return False
wallet_manager = WalletManager()