"""Mutable process-wide state for the FastAPI web layer (set by main.py)."""
from __future__ import annotations
import time
from typing import Any
import jsonutil as json
_bot_runner: Any = None
_cfg: Any = None
_sessions: dict[str, dict[str, Any]] = {}
_oauth_states: dict[str, dict[str, Any]] = {}
_limbic_redis: Any = None
_sprite_redis: Any = None # default DB (tools write sprite state here)
active_websockets: dict[str, list[Any]] = {}
[docs]
def set_bot_runner(runner: Any) -> None:
"""Set the bot runner."""
global _bot_runner
_bot_runner = runner
[docs]
def set_config(cfg: Any) -> None:
"""Set bot configuration."""
global _cfg
_cfg = cfg
# --- Redis-Backed Session & OAuth state functions ---
async def create_session(redis: Any, token: str, session_data: dict[str, Any]) -> None:
if redis is not None:
key = f"sg:session:{token}"
mapping = {
"user": json.dumps(session_data.get("user", {})),
"expires_at": str(session_data.get("expires_at", 0))
}
await redis.hset(key, mapping=mapping)
ttl = int(session_data.get("expires_at", 0) - time.time())
if ttl > 0:
await redis.expire(key, ttl)
else:
_sessions[token] = session_data
async def get_session(redis: Any, token: str) -> dict[str, Any] | None:
if redis is not None:
key = f"sg:session:{token}"
raw = await redis.hgetall(key)
if not raw:
return None
user_json = raw.get("user") or raw.get(b"user")
expires_at_raw = raw.get("expires_at") or raw.get(b"expires_at")
if user_json is None or expires_at_raw is None:
return None
user_str = user_json.decode() if isinstance(user_json, bytes) else str(user_json)
expires_at_str = expires_at_raw.decode() if isinstance(expires_at_raw, bytes) else str(expires_at_raw)
try:
return {
"user": json.loads(user_str),
"expires_at": float(expires_at_str)
}
except Exception:
return None
else:
session = _sessions.get(token)
if session:
if time.time() > session.get("expires_at", 0):
_sessions.pop(token, None)
return None
return session
return None
async def delete_session(redis: Any, token: str) -> None:
if redis is not None:
key = f"sg:session:{token}"
await redis.delete(key)
await redis.publish("sg:session:revocation", token)
else:
_sessions.pop(token, None)
async def store_oauth_state(redis: Any, state: str, data: dict[str, Any], expires_in: int = 900) -> None:
if redis is not None:
key = f"sg:oauth:{state}"
db_data = dict(data)
if not db_data:
db_data["_dummy"] = "1"
mapping = {k: json.dumps(v) for k, v in db_data.items()}
await redis.hset(key, mapping=mapping)
await redis.expire(key, expires_in)
else:
_oauth_states[state] = {
"created_at": time.time(),
"expires_at": time.time() + expires_in,
**data
}
async def get_oauth_state(redis: Any, state: str) -> dict[str, Any] | None:
if redis is not None:
key = f"sg:oauth:{state}"
raw = await redis.hgetall(key)
if not raw:
return None
out = {}
for k, v in raw.items():
ks = k.decode() if isinstance(k, bytes) else str(k)
vs = v.decode() if isinstance(v, bytes) else str(v)
try:
out[ks] = json.loads(vs)
except Exception:
out[ks] = vs
return out
else:
state_data = _oauth_states.get(state)
if state_data:
if time.time() > state_data.get("expires_at", 0):
_oauth_states.pop(state, None)
return None
return state_data
return None
async def delete_oauth_state(redis: Any, state: str) -> dict[str, Any] | None:
if redis is not None:
data = await get_oauth_state(redis, state)
if data:
key = f"sg:oauth:{state}"
await redis.delete(key)
return data
else:
state_data = _oauth_states.pop(state, None)
if state_data and time.time() > state_data.get("expires_at", 0):
return None
return state_data