"""
API Key Encryption Module
Per-user AES-256-GCM encryption for API keys. Encryption keys are stored in a
dedicated SQLite database, protected by a master KEK from environment.
"""
from __future__ import annotations
import asyncio
import base64
import hashlib
import logging
import os
import sqlite3
from datetime import datetime, timezone
from pathlib import Path
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
logger = logging.getLogger(__name__)
ENCRYPTED_PREFIX = "v2:"
"""Prefix for encrypted values in Redis. Values without this are legacy plaintext."""
POOL_KEY_SALT = b"stargazer:api_key_pool"
def _ensure_db_dir(path: str | Path) -> None:
"""Ensure the directory containing the SQLite file exists."""
Path(path).parent.mkdir(parents=True, exist_ok=True)
def _init_schema(conn: sqlite3.Connection) -> None:
"""Create the encryption_keys table if it does not exist."""
conn.execute("""
CREATE TABLE IF NOT EXISTS encryption_keys (
user_id TEXT PRIMARY KEY,
encrypted_key BLOB NOT NULL,
created_at TEXT NOT NULL
)
""")
conn.commit()
def _get_or_create_user_key_sync(
user_id: str,
sqlite_path: str | Path,
master_key: bytes,
) -> bytes:
"""Synchronous: load or generate per-user 32-byte key; persist encrypted in SQLite."""
_ensure_db_dir(sqlite_path)
conn = sqlite3.connect(str(sqlite_path))
try:
_init_schema(conn)
row = conn.execute(
"SELECT encrypted_key FROM encryption_keys WHERE user_id = ?",
(user_id,),
).fetchone()
if row:
encrypted_blob = row[0]
aesgcm = AESGCM(master_key)
combined = bytes(encrypted_blob)
if len(combined) < 28:
raise ValueError("Stored encrypted key too short")
nonce = combined[:12]
ciphertext = combined[12:]
key_bytes = aesgcm.decrypt(nonce, ciphertext, None)
return key_bytes
# Generate new key
key = AESGCM.generate_key(bit_length=256)
aesgcm = AESGCM(master_key)
nonce = os.urandom(12)
ciphertext = aesgcm.encrypt(nonce, key, None)
combined = nonce + ciphertext
created_at = datetime.now(timezone.utc).isoformat()
conn.execute(
"INSERT INTO encryption_keys (user_id, encrypted_key, created_at) VALUES (?, ?, ?)",
(user_id, combined, created_at),
)
conn.commit()
return key
finally:
conn.close()
[docs]
async def get_or_create_user_key(
user_id: str,
sqlite_path: str | Path,
master_key: bytes,
) -> bytes:
"""Load or generate per-user 32-byte key; persist encrypted in SQLite."""
return await asyncio.to_thread(
_get_or_create_user_key_sync,
user_id,
sqlite_path,
master_key,
)
[docs]
def encrypt(plaintext: str, key: bytes) -> str:
"""AES-256-GCM encrypt with random nonce; return base64 string with v2 prefix."""
aesgcm = AESGCM(key)
nonce = os.urandom(12)
plaintext_bytes = plaintext.encode("utf-8")
ciphertext = aesgcm.encrypt(nonce, plaintext_bytes, None)
combined = nonce + ciphertext
b64 = base64.urlsafe_b64encode(combined).decode("utf-8")
return ENCRYPTED_PREFIX + b64
[docs]
def decrypt(ciphertext: str, key: bytes) -> str:
"""Decrypt base64-encoded ciphertext (with optional v2 prefix)."""
if ciphertext.startswith(ENCRYPTED_PREFIX):
ciphertext = ciphertext[len(ENCRYPTED_PREFIX):]
combined = base64.urlsafe_b64decode(ciphertext)
if len(combined) < 28:
raise ValueError("Encrypted data too short")
nonce = combined[:12]
ct = combined[12:]
aesgcm = AESGCM(key)
plaintext_bytes = aesgcm.decrypt(nonce, ct, None)
return plaintext_bytes.decode("utf-8")
[docs]
def get_pool_key(master_key: bytes) -> bytes:
"""Derive pool encryption key via PBKDF2-HMAC-SHA256."""
return hashlib.pbkdf2_hmac(
"sha256",
master_key,
POOL_KEY_SALT,
iterations=100000,
dklen=32,
)
[docs]
def resolve_master_key() -> bytes | None:
"""Load master KEK from API_KEY_MASTER_KEY env var (base64, 32 bytes)."""
key_b64 = os.environ.get("API_KEY_MASTER_KEY", "").strip()
if not key_b64:
return None
try:
key = base64.urlsafe_b64decode(key_b64)
if len(key) == 32:
return key
logger.warning("API_KEY_MASTER_KEY is not 32 bytes, ignoring")
return None
except Exception as e:
logger.warning("Failed to decode API_KEY_MASTER_KEY: %s", e)
return None
[docs]
def is_encrypted(value: str) -> bool:
"""Return True if value has the encrypted prefix."""
return value.startswith(ENCRYPTED_PREFIX)
[docs]
def api_key_hash(api_key: str) -> str:
"""SHA-256 hex digest for pool lookup (deterministic, avoids storing plaintext as key)."""
return hashlib.sha256(api_key.encode()).hexdigest()