Source code for api_key_encryption

"""
API Key Encryption Module

Per-user AES-256-GCM encryption for API keys. Encryption keys are stored in a
dedicated SQLite database, protected by a master KEK from environment.
"""

from __future__ import annotations

import asyncio
import base64
import hashlib
import logging
import os
import sqlite3
from datetime import datetime, timezone
from pathlib import Path

from cryptography.hazmat.primitives.ciphers.aead import AESGCM

logger = logging.getLogger(__name__)

ENCRYPTED_PREFIX = "v2:"
"""Prefix for encrypted values in Redis. Values without this are legacy plaintext."""

POOL_KEY_SALT = b"stargazer:api_key_pool"


def _ensure_db_dir(path: str | Path) -> None:
    """Ensure the directory containing the SQLite file exists."""
    Path(path).parent.mkdir(parents=True, exist_ok=True)


def _init_schema(conn: sqlite3.Connection) -> None:
    """Create the encryption_keys table if it does not exist."""
    conn.execute("""
        CREATE TABLE IF NOT EXISTS encryption_keys (
            user_id TEXT PRIMARY KEY,
            encrypted_key BLOB NOT NULL,
            created_at TEXT NOT NULL
        )
    """)
    conn.commit()


def _get_or_create_user_key_sync(
    user_id: str,
    sqlite_path: str | Path,
    master_key: bytes,
) -> bytes:
    """Synchronous: load or generate per-user 32-byte key; persist encrypted in SQLite."""
    _ensure_db_dir(sqlite_path)
    conn = sqlite3.connect(str(sqlite_path))
    try:
        _init_schema(conn)
        row = conn.execute(
            "SELECT encrypted_key FROM encryption_keys WHERE user_id = ?",
            (user_id,),
        ).fetchone()
        if row:
            encrypted_blob = row[0]
            aesgcm = AESGCM(master_key)
            combined = bytes(encrypted_blob)
            if len(combined) < 28:
                raise ValueError("Stored encrypted key too short")
            nonce = combined[:12]
            ciphertext = combined[12:]
            key_bytes = aesgcm.decrypt(nonce, ciphertext, None)
            return key_bytes
        # Generate new key
        key = AESGCM.generate_key(bit_length=256)
        aesgcm = AESGCM(master_key)
        nonce = os.urandom(12)
        ciphertext = aesgcm.encrypt(nonce, key, None)
        combined = nonce + ciphertext
        created_at = datetime.now(timezone.utc).isoformat()
        conn.execute(
            "INSERT INTO encryption_keys (user_id, encrypted_key, created_at) VALUES (?, ?, ?)",
            (user_id, combined, created_at),
        )
        conn.commit()
        return key
    finally:
        conn.close()


[docs] async def get_or_create_user_key( user_id: str, sqlite_path: str | Path, master_key: bytes, ) -> bytes: """Load or generate per-user 32-byte key; persist encrypted in SQLite.""" return await asyncio.to_thread( _get_or_create_user_key_sync, user_id, sqlite_path, master_key, )
[docs] def encrypt(plaintext: str, key: bytes) -> str: """AES-256-GCM encrypt with random nonce; return base64 string with v2 prefix.""" aesgcm = AESGCM(key) nonce = os.urandom(12) plaintext_bytes = plaintext.encode("utf-8") ciphertext = aesgcm.encrypt(nonce, plaintext_bytes, None) combined = nonce + ciphertext b64 = base64.urlsafe_b64encode(combined).decode("utf-8") return ENCRYPTED_PREFIX + b64
[docs] def decrypt(ciphertext: str, key: bytes) -> str: """Decrypt base64-encoded ciphertext (with optional v2 prefix).""" if ciphertext.startswith(ENCRYPTED_PREFIX): ciphertext = ciphertext[len(ENCRYPTED_PREFIX):] combined = base64.urlsafe_b64decode(ciphertext) if len(combined) < 28: raise ValueError("Encrypted data too short") nonce = combined[:12] ct = combined[12:] aesgcm = AESGCM(key) plaintext_bytes = aesgcm.decrypt(nonce, ct, None) return plaintext_bytes.decode("utf-8")
[docs] def get_pool_key(master_key: bytes) -> bytes: """Derive pool encryption key via PBKDF2-HMAC-SHA256.""" return hashlib.pbkdf2_hmac( "sha256", master_key, POOL_KEY_SALT, iterations=100000, dklen=32, )
[docs] def resolve_master_key() -> bytes | None: """Load master KEK from API_KEY_MASTER_KEY env var (base64, 32 bytes).""" key_b64 = os.environ.get("API_KEY_MASTER_KEY", "").strip() if not key_b64: return None try: key = base64.urlsafe_b64decode(key_b64) if len(key) == 32: return key logger.warning("API_KEY_MASTER_KEY is not 32 bytes, ignoring") return None except Exception as e: logger.warning("Failed to decode API_KEY_MASTER_KEY: %s", e) return None
[docs] def is_encrypted(value: str) -> bool: """Return True if value has the encrypted prefix.""" return value.startswith(ENCRYPTED_PREFIX)
[docs] def api_key_hash(api_key: str) -> str: """SHA-256 hex digest for pool lookup (deterministic, avoids storing plaintext as key).""" return hashlib.sha256(api_key.encode()).hexdigest()