Source code for config

"""Bot configuration loaded from config.yaml with environment variable overrides.

Supports a ``platforms`` list so multiple chat platforms (Matrix, Discord, ...)
can be configured independently alongside the shared LLM / web settings.
"""

from __future__ import annotations

import os
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any

import yaml


[docs] @dataclass class PlatformConfig: """Configuration for a single chat platform.""" type: str = "" """Platform identifier: ``"matrix"``, ``"discord"``, etc.""" enabled: bool = True """Whether this platform should be started.""" # Arbitrary platform-specific keys settings: dict[str, Any] = field(default_factory=dict) """All remaining keys from the YAML block are stored here."""
[docs] def get(self, key: str, default: Any = None) -> Any: """Get. Args: key (str): Dictionary or cache key. default (Any): The default value. Returns: Any: The result. """ return self.settings.get(key, default)
[docs] @dataclass class Config: # --- Shared LLM settings --- """Config. """ api_key: str = "" gemini_api_key: str = "" llm_base_url: str = "http://localhost:3000/openai" model: str = "x-ai/grok-4.1-fast" temperature: float = 0.7 max_tokens: int = 60000 system_prompt_file: str = "system_prompt.j2" max_history: int = 100 tools_dir: str = "tools" # --- Tool permissions (tool_name -> list of allowed user IDs) --- tool_permissions: dict[str, list[str]] = field(default_factory=dict) # --- API keys for external services (Brave Search, Vultr, etc.) --- api_keys: dict[str, Any] = field(default_factory=dict) # --- Redis message cache --- redis_url: str = "" redis_tls_cert: str = "" redis_tls_key: str = "" redis_tls_ca: str = "" embedding_model: str = "google/gemini-embedding-001" # --- Deferred embedding batching --- embedding_batch_size: int = 50 embedding_flush_interval: float = 3600.0 # --- Knowledge graph --- kg_extraction_model: str = "gemini-3-flash-preview" kg_max_hops: int = 2 kg_seed_top_k: int = 8 kg_max_context_entities: int = 30 kg_entity_dedup_threshold: float = 0.90 kg_relationship_decay_factor: float = 0.95 kg_per_message_extraction: bool = False kg_min_message_length: int = 100 kg_per_user_extraction_limit: int = 5 kg_extraction_channel_hints: dict[str, str] = field(default_factory=dict) """Optional `platform:channel_id` -> human label for KG extraction prompts.""" @property def openrouter_api_key(self) -> str: """Backward-compat alias -- embeddings code still reads this.""" return self.api_key @openrouter_api_key.setter def openrouter_api_key(self, value: str) -> None: """Openrouter api key. Args: value (str): Value to set. """ self.api_key = value @property def API_KEYS(self) -> dict: """Backward-compat: tools expect config.API_KEYS['brave'], etc.""" return self.api_keys # --- LLM quality filter --- llm_filter_enabled: bool = False # --- Proactive responses --- proactive_enabled: bool = False proactive_default_frequency: float = 0.05 proactive_triage_enabled: bool = True proactive_triage_model: str = "gemini-3.1-flash-lite-preview" # --- Message batching --- batch_window: float = 5.0 max_batch_size: int = 10 # --- Threadweave --- dna_vault_path: str = "data/dna_vault" # --- API key encryption (per-user keys in SQLite) --- api_key_encryption_db_path: str = "data/api_key_encryption_keys.db" # --- Media cache --- media_cache_dir: str = "media_cache" media_cache_max_mb: int = 500 # --- Per-user LLM sandboxes (filesystem + Tor-enforced code execution) --- user_sandboxes_dir: str = "data/user_sandboxes" """Root directory for ``{user_id}/workspace`` sandbox trees.""" tor_gateway_container: str = "stargazer-tor-gateway" """Docker container name for ``docker run --network container:...`` Tor sidecar.""" sandbox_curl_image: str = "curlimages/curl:8.11.1" """Image for HTTPS downloads into a user sandbox over the Tor netns.""" # --- Emoji resolution --- resolve_emojis_as_images: bool = True max_emojis_per_message: int = 5 # --- Web GUI --- web_host: str = "127.0.0.1" web_port: int = 8080 # --- Admin user IDs (bypass privilege escalation, access admin UI) --- admin_user_ids: list[str] = field(default_factory=list) # --- Webhook --- webhook_secret: str = "" # --- Discord OAuth2 (web UI authentication) --- discord_oauth_client_id: str = "" discord_oauth_client_secret: str = "" discord_oauth_redirect_uri: str = "" # --- OAuth2 token management (per-user service connections) --- oauth_encryption_key: str = "" oauth_base_url: str = "" oauth_providers: dict[str, dict[str, Any]] = field(default_factory=dict) # --- Per-platform configs --- platforms: list[PlatformConfig] = field(default_factory=list) # ---- Legacy top-level Matrix fields (kept for backward compat) ---- homeserver: str = "https://matrix.org" user_id: str = "" password: str = "" store_path: str = "nio_store" credentials_file: str = "credentials.json" @classmethod def _parse_kg_config(cls, data: dict) -> dict: """Extract knowledge_graph sub-section from YAML data.""" kg = data.get("knowledge_graph", {}) if not isinstance(kg, dict): kg = {} hints_raw = kg.get("extraction_channel_hints", {}) if not isinstance(hints_raw, dict): hints_raw = {} hints: dict[str, str] = {} for k, v in hints_raw.items(): if k is None or v is None: continue hints[str(k)] = str(v) return { "kg_extraction_model": kg.get("extraction_model", cls.kg_extraction_model), "kg_max_hops": int(kg.get("max_hops", cls.kg_max_hops)), "kg_seed_top_k": int(kg.get("seed_top_k", cls.kg_seed_top_k)), "kg_max_context_entities": int(kg.get("max_context_entities", cls.kg_max_context_entities)), "kg_entity_dedup_threshold": float(kg.get("entity_dedup_threshold", cls.kg_entity_dedup_threshold)), "kg_relationship_decay_factor": float(kg.get("relationship_decay_factor", cls.kg_relationship_decay_factor)), "kg_per_message_extraction": bool(kg.get("per_message_extraction", cls.kg_per_message_extraction)), "kg_min_message_length": int(kg.get("min_message_length_for_extraction", cls.kg_min_message_length)), "kg_per_user_extraction_limit": int(kg.get("per_user_extraction_limit", cls.kg_per_user_extraction_limit)), "kg_extraction_channel_hints": hints, }
[docs] @classmethod def load(cls, path: str | Path = "config.yaml") -> "Config": """Load config from a YAML file, then apply environment variable overrides.""" data: dict = {} config_path = Path(path) if config_path.exists(): with open(config_path, "r", encoding="utf-8") as f: data = yaml.safe_load(f) or {} # --- Parse platform list (new-style) -------------------------- raw_platforms: list[dict[str, Any]] = data.get("platforms", []) platform_configs: list[PlatformConfig] = [] for raw in raw_platforms: ptype = raw.get("type", "") enabled = raw.get("enabled", True) settings = { k: v for k, v in raw.items() if k not in ("type", "enabled") } platform_configs.append(PlatformConfig( type=ptype, enabled=enabled, settings=settings, )) # --- Legacy top-level Matrix fields --------------------------- homeserver = data.get("homeserver", cls.homeserver) user_id_val = data.get("user_id", cls.user_id) password = data.get("password", cls.password) store_path = data.get("store_path", cls.store_path) credentials_file = data.get("credentials_file", cls.credentials_file) # If no platforms list but legacy Matrix fields are present, # synthesise a Matrix platform entry for backward compatibility. if not platform_configs and user_id_val: platform_configs.append(PlatformConfig( type="matrix", enabled=True, settings={ "homeserver": homeserver, "user_id": user_id_val, "password": password, "store_path": store_path, "credentials_file": credentials_file, }, )) # --- Tool permissions -------------------------------------- raw_perms: dict[str, list[str]] = data.get("tool_permissions", {}) tool_permissions = { k: [str(uid) for uid in v] for k, v in raw_perms.items() } if isinstance(raw_perms, dict) else {} # --- API keys ----------------------------------------------- raw_api_keys = data.get("api_keys", {}) api_keys = dict(raw_api_keys) if isinstance(raw_api_keys, dict) else {} # Inject API keys from env vars if not in YAML # 🕷️ for env_name, key_name in ( ("XAI_API_KEY", "xai"), ("ELEVENLABS_API_KEY", "elevenlabs"), ("SUNO_COOKIE", "suno_cookie"), ): if key_name not in api_keys: env_val = os.environ.get(env_name, "") if env_val: api_keys[key_name] = env_val # --- Admin user IDs ---------------------------------------- raw_admin_ids = data.get("admin_user_ids", []) admin_user_ids = [str(uid) for uid in raw_admin_ids] if isinstance(raw_admin_ids, list) else [] # --- Discord OAuth2 ---------------------------------------- discord_oauth = data.get("discord_oauth", {}) if not isinstance(discord_oauth, dict): discord_oauth = {} # --- OAuth2 token management -------------------------------- oauth_cfg = data.get("oauth", {}) if not isinstance(oauth_cfg, dict): oauth_cfg = {} oauth_providers_raw = oauth_cfg.get("providers", {}) if not isinstance(oauth_providers_raw, dict): oauth_providers_raw = {} # --- Proactive responses ---------------------------------- proactive_cfg = data.get("proactive", {}) if not isinstance(proactive_cfg, dict): proactive_cfg = {} resolved_api_key = ( data.get("api_key") or data.get("openrouter_api_key") or cls.api_key ) cfg = cls( api_key=resolved_api_key, gemini_api_key=data.get("gemini_api_key", cls.gemini_api_key), llm_base_url=data.get("llm_base_url", cls.llm_base_url), model=data.get("model", cls.model), temperature=float(data.get("temperature", cls.temperature)), max_tokens=int(data.get("max_tokens", cls.max_tokens)), system_prompt_file=data.get("system_prompt_file", cls.system_prompt_file), max_history=int(data.get("max_history", cls.max_history)), tools_dir=data.get("tools_dir", cls.tools_dir), tool_permissions=tool_permissions, api_keys=api_keys, redis_url=data.get("redis_url", cls.redis_url), redis_tls_cert=data.get("redis_tls_cert", cls.redis_tls_cert), redis_tls_key=data.get("redis_tls_key", cls.redis_tls_key), redis_tls_ca=data.get("redis_tls_ca", cls.redis_tls_ca), embedding_model=data.get("embedding_model", cls.embedding_model), embedding_batch_size=int(data.get("embedding_batch_size", cls.embedding_batch_size)), embedding_flush_interval=float(data.get("embedding_flush_interval", cls.embedding_flush_interval)), **cls._parse_kg_config(data), llm_filter_enabled=bool(data.get("llm_filter_enabled", cls.llm_filter_enabled)), proactive_enabled=bool(proactive_cfg.get("enabled", cls.proactive_enabled)), proactive_default_frequency=float(proactive_cfg.get("default_frequency", cls.proactive_default_frequency)), proactive_triage_enabled=bool(proactive_cfg.get("triage_enabled", cls.proactive_triage_enabled)), proactive_triage_model=str(proactive_cfg.get("triage_model", cls.proactive_triage_model)), batch_window=float(data.get("batch_window", cls.batch_window)), max_batch_size=int(data.get("max_batch_size", cls.max_batch_size)), dna_vault_path=data.get("dna_vault_path", cls.dna_vault_path), api_key_encryption_db_path=data.get( "api_key_encryption_db_path", cls.api_key_encryption_db_path ), media_cache_dir=data.get("media_cache_dir", cls.media_cache_dir), media_cache_max_mb=int(data.get("media_cache_max_mb", cls.media_cache_max_mb)), user_sandboxes_dir=data.get("user_sandboxes_dir", cls.user_sandboxes_dir), tor_gateway_container=data.get( "tor_gateway_container", cls.tor_gateway_container ), sandbox_curl_image=data.get("sandbox_curl_image", cls.sandbox_curl_image), resolve_emojis_as_images=bool(data.get("resolve_emojis_as_images", cls.resolve_emojis_as_images)), max_emojis_per_message=int(data.get("max_emojis_per_message", cls.max_emojis_per_message)), web_host=data.get("web_host", cls.web_host), web_port=int(data.get("web_port", cls.web_port)), admin_user_ids=admin_user_ids, webhook_secret=data.get("webhook_secret", cls.webhook_secret), discord_oauth_client_id=discord_oauth.get("client_id", ""), discord_oauth_client_secret=discord_oauth.get("client_secret", ""), discord_oauth_redirect_uri=discord_oauth.get("redirect_uri", ""), oauth_encryption_key=oauth_cfg.get("encryption_key", ""), oauth_base_url=oauth_cfg.get("base_url", ""), oauth_providers=oauth_providers_raw, platforms=platform_configs, homeserver=homeserver, user_id=user_id_val, password=password, store_path=store_path, credentials_file=credentials_file, ) # --- Environment variable overrides --------------------------- env_map = { "API_KEY": "api_key", "OPENROUTER_API_KEY": "api_key", "GEMINI_API_KEY": "gemini_api_key", "LLM_BASE_URL": "llm_base_url", "OPENROUTER_MODEL": "model", "OPENROUTER_TEMPERATURE": "temperature", "OPENROUTER_MAX_TOKENS": "max_tokens", "BOT_SYSTEM_PROMPT_FILE": "system_prompt_file", "BOT_MAX_HISTORY": "max_history", "BOT_TOOLS_DIR": "tools_dir", "REDIS_URL": "redis_url", "REDIS_TLS_CERT": "redis_tls_cert", "REDIS_TLS_KEY": "redis_tls_key", "REDIS_TLS_CA": "redis_tls_ca", "EMBEDDING_MODEL": "embedding_model", "EMBEDDING_BATCH_SIZE": "embedding_batch_size", "EMBEDDING_FLUSH_INTERVAL": "embedding_flush_interval", "BOT_MEDIA_CACHE_DIR": "media_cache_dir", "BOT_MEDIA_CACHE_MAX_MB": "media_cache_max_mb", "STARGAZER_USER_SANDBOXES_DIR": "user_sandboxes_dir", "STARGAZER_TOR_GATEWAY_CONTAINER": "tor_gateway_container", "STARGAZER_SANDBOX_CURL_IMAGE": "sandbox_curl_image", "RESOLVE_EMOJIS_AS_IMAGES": "resolve_emojis_as_images", "MAX_EMOJIS_PER_MESSAGE": "max_emojis_per_message", "BOT_WEB_HOST": "web_host", "BOT_WEB_PORT": "web_port", # Webhook "WEBHOOK_SECRET": "webhook_secret", # Discord OAuth2 "DISCORD_OAUTH_CLIENT_ID": "discord_oauth_client_id", "DISCORD_OAUTH_CLIENT_SECRET": "discord_oauth_client_secret", "DISCORD_OAUTH_REDIRECT_URI": "discord_oauth_redirect_uri", # OAuth2 token management "OAUTH_ENCRYPTION_KEY": "oauth_encryption_key", "OAUTH_BASE_URL": "oauth_base_url", # LLM quality filter "LLM_FILTER_ENABLED": "llm_filter_enabled", # Proactive responses "PROACTIVE_ENABLED": "proactive_enabled", "PROACTIVE_DEFAULT_FREQUENCY": "proactive_default_frequency", "PROACTIVE_TRIAGE_ENABLED": "proactive_triage_enabled", "PROACTIVE_TRIAGE_MODEL": "proactive_triage_model", # Legacy Matrix env vars "MATRIX_HOMESERVER": "homeserver", "MATRIX_USER_ID": "user_id", "MATRIX_PASSWORD": "password", "MATRIX_STORE_PATH": "store_path", "MATRIX_CREDENTIALS_FILE": "credentials_file", } for env_var, attr in env_map.items(): val = os.environ.get(env_var) if val is not None: current = getattr(cfg, attr) if isinstance(current, bool): setattr(cfg, attr, val.lower() not in ("0", "false", "no", "")) elif isinstance(current, float): setattr(cfg, attr, float(val)) elif isinstance(current, int): setattr(cfg, attr, int(val)) else: setattr(cfg, attr, val) # Discord token from env var discord_token = os.environ.get("DISCORD_TOKEN") if discord_token: # Check if a discord platform already exists has_discord = any(p.type == "discord" for p in cfg.platforms) if not has_discord: cfg.platforms.append(PlatformConfig( type="discord", enabled=True, settings={"token": discord_token}, )) else: # Update existing discord platform token for p in cfg.platforms: if p.type == "discord": p.settings["token"] = discord_token # OAuth provider credentials from env vars for provider in ("github", "google", "discord", "microsoft"): env_prefix = f"OAUTH_{provider.upper()}" cid = os.environ.get(f"{env_prefix}_CLIENT_ID", "") csecret = os.environ.get(f"{env_prefix}_CLIENT_SECRET", "") if cid or csecret: if provider not in cfg.oauth_providers: cfg.oauth_providers[provider] = {} if cid: cfg.oauth_providers[provider]["client_id"] = cid if csecret: cfg.oauth_providers[provider]["client_secret"] = csecret return cfg
[docs] def redis_ssl_kwargs(self) -> dict: """Return SSL keyword arguments for redis.asyncio clients. Returns an empty dict when mTLS is not configured, so callers can unconditionally unpack the result into ``Redis()`` / ``from_url()``. All certificate verification is explicitly disabled. """ if not (self.redis_tls_cert and self.redis_tls_key): return {} import ssl as _ssl kwargs: dict[str, Any] = { "ssl_cert_reqs": _ssl.CERT_NONE, "ssl_check_hostname": False, "ssl_certfile": self.redis_tls_cert, "ssl_keyfile": self.redis_tls_key, } if self.redis_tls_ca: kwargs["ssl_ca_certs"] = self.redis_tls_ca return kwargs