"""Bot configuration loaded from config.yaml with environment variable overrides.
Supports a ``platforms`` list so multiple chat platforms (Matrix, Discord, ...)
can be configured independently alongside the shared LLM / web settings.
"""
from __future__ import annotations
import os
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
import yaml
[docs]
@dataclass
class Config:
# --- Shared LLM settings ---
"""Config.
"""
api_key: str = ""
gemini_api_key: str = ""
llm_base_url: str = "http://localhost:3000/openai"
model: str = "x-ai/grok-4.1-fast"
temperature: float = 0.7
max_tokens: int = 60000
system_prompt_file: str = "system_prompt.j2"
max_history: int = 100
tools_dir: str = "tools"
# --- Tool permissions (tool_name -> list of allowed user IDs) ---
tool_permissions: dict[str, list[str]] = field(default_factory=dict)
# --- API keys for external services (Brave Search, Vultr, etc.) ---
api_keys: dict[str, Any] = field(default_factory=dict)
# --- Redis message cache ---
redis_url: str = ""
redis_tls_cert: str = ""
redis_tls_key: str = ""
redis_tls_ca: str = ""
embedding_model: str = "google/gemini-embedding-001"
# --- Deferred embedding batching ---
embedding_batch_size: int = 50
embedding_flush_interval: float = 3600.0
# --- Knowledge graph ---
kg_extraction_model: str = "gemini-3-flash-preview"
kg_max_hops: int = 2
kg_seed_top_k: int = 8
kg_max_context_entities: int = 30
kg_entity_dedup_threshold: float = 0.90
kg_relationship_decay_factor: float = 0.95
kg_per_message_extraction: bool = False
kg_min_message_length: int = 100
kg_per_user_extraction_limit: int = 5
kg_extraction_channel_hints: dict[str, str] = field(default_factory=dict)
"""Optional `platform:channel_id` -> human label for KG extraction prompts."""
@property
def openrouter_api_key(self) -> str:
"""Backward-compat alias -- embeddings code still reads this."""
return self.api_key
@openrouter_api_key.setter
def openrouter_api_key(self, value: str) -> None:
"""Openrouter api key.
Args:
value (str): Value to set.
"""
self.api_key = value
@property
def API_KEYS(self) -> dict:
"""Backward-compat: tools expect config.API_KEYS['brave'], etc."""
return self.api_keys
# --- LLM quality filter ---
llm_filter_enabled: bool = False
# --- Proactive responses ---
proactive_enabled: bool = False
proactive_default_frequency: float = 0.05
proactive_triage_enabled: bool = True
proactive_triage_model: str = "gemini-3.1-flash-lite-preview"
# --- Message batching ---
batch_window: float = 5.0
max_batch_size: int = 10
# --- Threadweave ---
dna_vault_path: str = "data/dna_vault"
# --- API key encryption (per-user keys in SQLite) ---
api_key_encryption_db_path: str = "data/api_key_encryption_keys.db"
# --- Media cache ---
media_cache_dir: str = "media_cache"
media_cache_max_mb: int = 500
# --- Per-user LLM sandboxes (filesystem + Tor-enforced code execution) ---
user_sandboxes_dir: str = "data/user_sandboxes"
"""Root directory for ``{user_id}/workspace`` sandbox trees."""
tor_gateway_container: str = "stargazer-tor-gateway"
"""Docker container name for ``docker run --network container:...`` Tor sidecar."""
sandbox_curl_image: str = "curlimages/curl:8.11.1"
"""Image for HTTPS downloads into a user sandbox over the Tor netns."""
# --- Emoji resolution ---
resolve_emojis_as_images: bool = True
max_emojis_per_message: int = 5
# --- Web GUI ---
web_host: str = "127.0.0.1"
web_port: int = 8080
# --- Admin user IDs (bypass privilege escalation, access admin UI) ---
admin_user_ids: list[str] = field(default_factory=list)
# --- Webhook ---
webhook_secret: str = ""
# --- Discord OAuth2 (web UI authentication) ---
discord_oauth_client_id: str = ""
discord_oauth_client_secret: str = ""
discord_oauth_redirect_uri: str = ""
# --- OAuth2 token management (per-user service connections) ---
oauth_encryption_key: str = ""
oauth_base_url: str = ""
oauth_providers: dict[str, dict[str, Any]] = field(default_factory=dict)
# --- Per-platform configs ---
platforms: list[PlatformConfig] = field(default_factory=list)
# ---- Legacy top-level Matrix fields (kept for backward compat) ----
homeserver: str = "https://matrix.org"
user_id: str = ""
password: str = ""
store_path: str = "nio_store"
credentials_file: str = "credentials.json"
@classmethod
def _parse_kg_config(cls, data: dict) -> dict:
"""Extract knowledge_graph sub-section from YAML data."""
kg = data.get("knowledge_graph", {})
if not isinstance(kg, dict):
kg = {}
hints_raw = kg.get("extraction_channel_hints", {})
if not isinstance(hints_raw, dict):
hints_raw = {}
hints: dict[str, str] = {}
for k, v in hints_raw.items():
if k is None or v is None:
continue
hints[str(k)] = str(v)
return {
"kg_extraction_model": kg.get("extraction_model", cls.kg_extraction_model),
"kg_max_hops": int(kg.get("max_hops", cls.kg_max_hops)),
"kg_seed_top_k": int(kg.get("seed_top_k", cls.kg_seed_top_k)),
"kg_max_context_entities": int(kg.get("max_context_entities", cls.kg_max_context_entities)),
"kg_entity_dedup_threshold": float(kg.get("entity_dedup_threshold", cls.kg_entity_dedup_threshold)),
"kg_relationship_decay_factor": float(kg.get("relationship_decay_factor", cls.kg_relationship_decay_factor)),
"kg_per_message_extraction": bool(kg.get("per_message_extraction", cls.kg_per_message_extraction)),
"kg_min_message_length": int(kg.get("min_message_length_for_extraction", cls.kg_min_message_length)),
"kg_per_user_extraction_limit": int(kg.get("per_user_extraction_limit", cls.kg_per_user_extraction_limit)),
"kg_extraction_channel_hints": hints,
}
[docs]
@classmethod
def load(cls, path: str | Path = "config.yaml") -> "Config":
"""Load config from a YAML file, then apply environment variable overrides."""
data: dict = {}
config_path = Path(path)
if config_path.exists():
with open(config_path, "r", encoding="utf-8") as f:
data = yaml.safe_load(f) or {}
# --- Parse platform list (new-style) --------------------------
raw_platforms: list[dict[str, Any]] = data.get("platforms", [])
platform_configs: list[PlatformConfig] = []
for raw in raw_platforms:
ptype = raw.get("type", "")
enabled = raw.get("enabled", True)
settings = {
k: v for k, v in raw.items() if k not in ("type", "enabled")
}
platform_configs.append(PlatformConfig(
type=ptype,
enabled=enabled,
settings=settings,
))
# --- Legacy top-level Matrix fields ---------------------------
homeserver = data.get("homeserver", cls.homeserver)
user_id_val = data.get("user_id", cls.user_id)
password = data.get("password", cls.password)
store_path = data.get("store_path", cls.store_path)
credentials_file = data.get("credentials_file", cls.credentials_file)
# If no platforms list but legacy Matrix fields are present,
# synthesise a Matrix platform entry for backward compatibility.
if not platform_configs and user_id_val:
platform_configs.append(PlatformConfig(
type="matrix",
enabled=True,
settings={
"homeserver": homeserver,
"user_id": user_id_val,
"password": password,
"store_path": store_path,
"credentials_file": credentials_file,
},
))
# --- Tool permissions --------------------------------------
raw_perms: dict[str, list[str]] = data.get("tool_permissions", {})
tool_permissions = {
k: [str(uid) for uid in v] for k, v in raw_perms.items()
} if isinstance(raw_perms, dict) else {}
# --- API keys -----------------------------------------------
raw_api_keys = data.get("api_keys", {})
api_keys = dict(raw_api_keys) if isinstance(raw_api_keys, dict) else {}
# Inject API keys from env vars if not in YAML # 🕷️
for env_name, key_name in (
("XAI_API_KEY", "xai"),
("ELEVENLABS_API_KEY", "elevenlabs"),
("SUNO_COOKIE", "suno_cookie"),
):
if key_name not in api_keys:
env_val = os.environ.get(env_name, "")
if env_val:
api_keys[key_name] = env_val
# --- Admin user IDs ----------------------------------------
raw_admin_ids = data.get("admin_user_ids", [])
admin_user_ids = [str(uid) for uid in raw_admin_ids] if isinstance(raw_admin_ids, list) else []
# --- Discord OAuth2 ----------------------------------------
discord_oauth = data.get("discord_oauth", {})
if not isinstance(discord_oauth, dict):
discord_oauth = {}
# --- OAuth2 token management --------------------------------
oauth_cfg = data.get("oauth", {})
if not isinstance(oauth_cfg, dict):
oauth_cfg = {}
oauth_providers_raw = oauth_cfg.get("providers", {})
if not isinstance(oauth_providers_raw, dict):
oauth_providers_raw = {}
# --- Proactive responses ----------------------------------
proactive_cfg = data.get("proactive", {})
if not isinstance(proactive_cfg, dict):
proactive_cfg = {}
resolved_api_key = (
data.get("api_key")
or data.get("openrouter_api_key")
or cls.api_key
)
cfg = cls(
api_key=resolved_api_key,
gemini_api_key=data.get("gemini_api_key", cls.gemini_api_key),
llm_base_url=data.get("llm_base_url", cls.llm_base_url),
model=data.get("model", cls.model),
temperature=float(data.get("temperature", cls.temperature)),
max_tokens=int(data.get("max_tokens", cls.max_tokens)),
system_prompt_file=data.get("system_prompt_file", cls.system_prompt_file),
max_history=int(data.get("max_history", cls.max_history)),
tools_dir=data.get("tools_dir", cls.tools_dir),
tool_permissions=tool_permissions,
api_keys=api_keys,
redis_url=data.get("redis_url", cls.redis_url),
redis_tls_cert=data.get("redis_tls_cert", cls.redis_tls_cert),
redis_tls_key=data.get("redis_tls_key", cls.redis_tls_key),
redis_tls_ca=data.get("redis_tls_ca", cls.redis_tls_ca),
embedding_model=data.get("embedding_model", cls.embedding_model),
embedding_batch_size=int(data.get("embedding_batch_size", cls.embedding_batch_size)),
embedding_flush_interval=float(data.get("embedding_flush_interval", cls.embedding_flush_interval)),
**cls._parse_kg_config(data),
llm_filter_enabled=bool(data.get("llm_filter_enabled", cls.llm_filter_enabled)),
proactive_enabled=bool(proactive_cfg.get("enabled", cls.proactive_enabled)),
proactive_default_frequency=float(proactive_cfg.get("default_frequency", cls.proactive_default_frequency)),
proactive_triage_enabled=bool(proactive_cfg.get("triage_enabled", cls.proactive_triage_enabled)),
proactive_triage_model=str(proactive_cfg.get("triage_model", cls.proactive_triage_model)),
batch_window=float(data.get("batch_window", cls.batch_window)),
max_batch_size=int(data.get("max_batch_size", cls.max_batch_size)),
dna_vault_path=data.get("dna_vault_path", cls.dna_vault_path),
api_key_encryption_db_path=data.get(
"api_key_encryption_db_path", cls.api_key_encryption_db_path
),
media_cache_dir=data.get("media_cache_dir", cls.media_cache_dir),
media_cache_max_mb=int(data.get("media_cache_max_mb", cls.media_cache_max_mb)),
user_sandboxes_dir=data.get("user_sandboxes_dir", cls.user_sandboxes_dir),
tor_gateway_container=data.get(
"tor_gateway_container", cls.tor_gateway_container
),
sandbox_curl_image=data.get("sandbox_curl_image", cls.sandbox_curl_image),
resolve_emojis_as_images=bool(data.get("resolve_emojis_as_images", cls.resolve_emojis_as_images)),
max_emojis_per_message=int(data.get("max_emojis_per_message", cls.max_emojis_per_message)),
web_host=data.get("web_host", cls.web_host),
web_port=int(data.get("web_port", cls.web_port)),
admin_user_ids=admin_user_ids,
webhook_secret=data.get("webhook_secret", cls.webhook_secret),
discord_oauth_client_id=discord_oauth.get("client_id", ""),
discord_oauth_client_secret=discord_oauth.get("client_secret", ""),
discord_oauth_redirect_uri=discord_oauth.get("redirect_uri", ""),
oauth_encryption_key=oauth_cfg.get("encryption_key", ""),
oauth_base_url=oauth_cfg.get("base_url", ""),
oauth_providers=oauth_providers_raw,
platforms=platform_configs,
homeserver=homeserver,
user_id=user_id_val,
password=password,
store_path=store_path,
credentials_file=credentials_file,
)
# --- Environment variable overrides ---------------------------
env_map = {
"API_KEY": "api_key",
"OPENROUTER_API_KEY": "api_key",
"GEMINI_API_KEY": "gemini_api_key",
"LLM_BASE_URL": "llm_base_url",
"OPENROUTER_MODEL": "model",
"OPENROUTER_TEMPERATURE": "temperature",
"OPENROUTER_MAX_TOKENS": "max_tokens",
"BOT_SYSTEM_PROMPT_FILE": "system_prompt_file",
"BOT_MAX_HISTORY": "max_history",
"BOT_TOOLS_DIR": "tools_dir",
"REDIS_URL": "redis_url",
"REDIS_TLS_CERT": "redis_tls_cert",
"REDIS_TLS_KEY": "redis_tls_key",
"REDIS_TLS_CA": "redis_tls_ca",
"EMBEDDING_MODEL": "embedding_model",
"EMBEDDING_BATCH_SIZE": "embedding_batch_size",
"EMBEDDING_FLUSH_INTERVAL": "embedding_flush_interval",
"BOT_MEDIA_CACHE_DIR": "media_cache_dir",
"BOT_MEDIA_CACHE_MAX_MB": "media_cache_max_mb",
"STARGAZER_USER_SANDBOXES_DIR": "user_sandboxes_dir",
"STARGAZER_TOR_GATEWAY_CONTAINER": "tor_gateway_container",
"STARGAZER_SANDBOX_CURL_IMAGE": "sandbox_curl_image",
"RESOLVE_EMOJIS_AS_IMAGES": "resolve_emojis_as_images",
"MAX_EMOJIS_PER_MESSAGE": "max_emojis_per_message",
"BOT_WEB_HOST": "web_host",
"BOT_WEB_PORT": "web_port",
# Webhook
"WEBHOOK_SECRET": "webhook_secret",
# Discord OAuth2
"DISCORD_OAUTH_CLIENT_ID": "discord_oauth_client_id",
"DISCORD_OAUTH_CLIENT_SECRET": "discord_oauth_client_secret",
"DISCORD_OAUTH_REDIRECT_URI": "discord_oauth_redirect_uri",
# OAuth2 token management
"OAUTH_ENCRYPTION_KEY": "oauth_encryption_key",
"OAUTH_BASE_URL": "oauth_base_url",
# LLM quality filter
"LLM_FILTER_ENABLED": "llm_filter_enabled",
# Proactive responses
"PROACTIVE_ENABLED": "proactive_enabled",
"PROACTIVE_DEFAULT_FREQUENCY": "proactive_default_frequency",
"PROACTIVE_TRIAGE_ENABLED": "proactive_triage_enabled",
"PROACTIVE_TRIAGE_MODEL": "proactive_triage_model",
# Legacy Matrix env vars
"MATRIX_HOMESERVER": "homeserver",
"MATRIX_USER_ID": "user_id",
"MATRIX_PASSWORD": "password",
"MATRIX_STORE_PATH": "store_path",
"MATRIX_CREDENTIALS_FILE": "credentials_file",
}
for env_var, attr in env_map.items():
val = os.environ.get(env_var)
if val is not None:
current = getattr(cfg, attr)
if isinstance(current, bool):
setattr(cfg, attr, val.lower() not in ("0", "false", "no", ""))
elif isinstance(current, float):
setattr(cfg, attr, float(val))
elif isinstance(current, int):
setattr(cfg, attr, int(val))
else:
setattr(cfg, attr, val)
# Discord token from env var
discord_token = os.environ.get("DISCORD_TOKEN")
if discord_token:
# Check if a discord platform already exists
has_discord = any(p.type == "discord" for p in cfg.platforms)
if not has_discord:
cfg.platforms.append(PlatformConfig(
type="discord",
enabled=True,
settings={"token": discord_token},
))
else:
# Update existing discord platform token
for p in cfg.platforms:
if p.type == "discord":
p.settings["token"] = discord_token
# OAuth provider credentials from env vars
for provider in ("github", "google", "discord", "microsoft"):
env_prefix = f"OAUTH_{provider.upper()}"
cid = os.environ.get(f"{env_prefix}_CLIENT_ID", "")
csecret = os.environ.get(f"{env_prefix}_CLIENT_SECRET", "")
if cid or csecret:
if provider not in cfg.oauth_providers:
cfg.oauth_providers[provider] = {}
if cid:
cfg.oauth_providers[provider]["client_id"] = cid
if csecret:
cfg.oauth_providers[provider]["client_secret"] = csecret
return cfg
[docs]
def redis_ssl_kwargs(self) -> dict:
"""Return SSL keyword arguments for redis.asyncio clients.
Returns an empty dict when mTLS is not configured, so callers can
unconditionally unpack the result into ``Redis()`` / ``from_url()``.
All certificate verification is explicitly disabled.
"""
if not (self.redis_tls_cert and self.redis_tls_key):
return {}
import ssl as _ssl
kwargs: dict[str, Any] = {
"ssl_cert_reqs": _ssl.CERT_NONE,
"ssl_check_hostname": False,
"ssl_certfile": self.redis_tls_cert,
"ssl_keyfile": self.redis_tls_key,
}
if self.redis_tls_ca:
kwargs["ssl_ca_certs"] = self.redis_tls_ca
return kwargs