Source code for oauth_manager

"""Per-user OAuth token management with encrypted storage.

Handles the full OAuth2 authorization-code flow for multiple providers
(GitHub, Google, Discord, Microsoft), stores tokens encrypted at rest
in Redis, and transparently refreshes expired access tokens.

Redis key pattern:
    stargazer:oauth_tokens:{user_id}:{provider}   (Fernet-encrypted JSON)
    stargazer:oauth_link:{link_code}               (one-time link code -> user_id)
"""

from __future__ import annotations

import json
import logging
import secrets
import time
from dataclasses import dataclass, field
from typing import Any, TYPE_CHECKING

import aiohttp
from cryptography.fernet import Fernet, InvalidToken

if TYPE_CHECKING:
    pass

logger = logging.getLogger(__name__)

REDIS_TOKEN_PREFIX = "stargazer:oauth_tokens"
REDIS_LINK_PREFIX = "stargazer:oauth_link"
LINK_CODE_TTL = 300  # 5 minutes

REFRESH_BUFFER_SECONDS = 120  # refresh if expiring within 2 minutes


# ---------------------------------------------------------------------------
# Provider definitions
# ---------------------------------------------------------------------------

[docs] @dataclass class OAuthProvider: """Configuration for a single OAuth2 provider.""" name: str authorize_url: str token_url: str client_id: str = "" client_secret: str = "" scopes: list[str] = field(default_factory=list) tokens_expire: bool = True revoke_url: str = "" extra_auth_params: dict[str, str] = field(default_factory=dict)
PROVIDER_TEMPLATES: dict[str, dict[str, Any]] = { "github": { "authorize_url": "https://github.com/login/oauth/authorize", "token_url": "https://github.com/login/oauth/access_token", "tokens_expire": False, "revoke_url": "", "default_scopes": ["repo", "read:user", "gist", "notifications"], }, "google": { "authorize_url": "https://accounts.google.com/o/oauth2/v2/auth", "token_url": "https://oauth2.googleapis.com/token", "tokens_expire": True, "revoke_url": "https://oauth2.googleapis.com/revoke", "default_scopes": [ "https://www.googleapis.com/auth/drive.file", "https://www.googleapis.com/auth/gmail.modify", "https://www.googleapis.com/auth/calendar", ], "extra_auth_params": {"access_type": "offline", "prompt": "consent"}, }, "discord": { "authorize_url": "https://discord.com/api/oauth2/authorize", "token_url": "https://discord.com/api/oauth2/token", "tokens_expire": True, "revoke_url": "https://discord.com/api/oauth2/token/revoke", "default_scopes": ["identify", "guilds", "email", "connections"], }, "microsoft": { "authorize_url": "https://login.microsoftonline.com/common/oauth2/v2.0/authorize", "token_url": "https://login.microsoftonline.com/common/oauth2/v2.0/token", "tokens_expire": True, "revoke_url": "", "default_scopes": [ "User.Read", "Files.ReadWrite", "Mail.ReadWrite", "Calendars.ReadWrite", "offline_access", ], }, }
[docs] @dataclass class TokenData: """Decrypted OAuth token bundle.""" access_token: str refresh_token: str = "" expires_at: float = 0.0 scopes: list[str] = field(default_factory=list) token_type: str = "Bearer" provider: str = "" @property def is_expired(self) -> bool: if self.expires_at <= 0: return False return time.time() >= (self.expires_at - REFRESH_BUFFER_SECONDS)
[docs] def to_dict(self) -> dict[str, Any]: return { "access_token": self.access_token, "refresh_token": self.refresh_token, "expires_at": self.expires_at, "scopes": self.scopes, "token_type": self.token_type, "provider": self.provider, }
[docs] @classmethod def from_dict(cls, d: dict[str, Any]) -> TokenData: return cls( access_token=d.get("access_token", ""), refresh_token=d.get("refresh_token", ""), expires_at=float(d.get("expires_at", 0)), scopes=d.get("scopes", []), token_type=d.get("token_type", "Bearer"), provider=d.get("provider", ""), )
[docs] class OAuthNotConnected(Exception): """Raised when a tool needs an OAuth token the user hasn't provided.""" def __init__(self, provider: str, connect_url: str) -> None: self.provider = provider self.connect_url = connect_url super().__init__( f"You haven't connected {provider} yet. " f"Click here to connect: {connect_url}" )
# --------------------------------------------------------------------------- # Manager # ---------------------------------------------------------------------------
[docs] class OAuthManager: """Core OAuth2 token lifecycle manager.""" def __init__( self, encryption_key: str = "", base_url: str = "", providers_config: dict[str, dict[str, Any]] | None = None, ) -> None: self._fernet: Fernet | None = None if encryption_key: try: self._fernet = Fernet(encryption_key.encode() if isinstance(encryption_key, str) else encryption_key) except Exception: logger.error("Invalid OAuth encryption key -- tokens will NOT be encrypted") self.base_url = base_url.rstrip("/") if base_url else "" self.providers: dict[str, OAuthProvider] = {} providers_config = providers_config or {} for name, template in PROVIDER_TEMPLATES.items(): user_cfg = providers_config.get(name, {}) self.providers[name] = OAuthProvider( name=name, authorize_url=template["authorize_url"], token_url=template["token_url"], client_id=user_cfg.get("client_id", ""), client_secret=user_cfg.get("client_secret", ""), scopes=user_cfg.get("scopes", template.get("default_scopes", [])), tokens_expire=template.get("tokens_expire", True), revoke_url=template.get("revoke_url", ""), extra_auth_params=template.get("extra_auth_params", {}), )
[docs] def is_provider_configured(self, provider: str) -> bool: p = self.providers.get(provider) return p is not None and bool(p.client_id and p.client_secret)
[docs] def list_configured_providers(self) -> list[str]: return [n for n in self.providers if self.is_provider_configured(n)]
# -- Encryption helpers ------------------------------------------------- def _encrypt(self, plaintext: str) -> str: if self._fernet is None: return plaintext return self._fernet.encrypt(plaintext.encode()).decode() def _decrypt(self, ciphertext: str) -> str: if self._fernet is None: return ciphertext try: return self._fernet.decrypt(ciphertext.encode()).decode() except InvalidToken: logger.warning("Failed to decrypt OAuth token -- returning empty") return "" # -- Redis helpers ------------------------------------------------------ @staticmethod def _token_key(user_id: str, provider: str) -> str: return f"{REDIS_TOKEN_PREFIX}:{user_id}:{provider}" async def _store_token(self, redis: Any, user_id: str, token: TokenData) -> None: payload = json.dumps(token.to_dict()) encrypted = self._encrypt(payload) await redis.set(self._token_key(user_id, token.provider), encrypted) async def _load_token(self, redis: Any, user_id: str, provider: str) -> TokenData | None: raw = await redis.get(self._token_key(user_id, provider)) if raw is None: return None decrypted = self._decrypt(raw if isinstance(raw, str) else raw.decode()) if not decrypted: return None try: return TokenData.from_dict(json.loads(decrypted)) except (json.JSONDecodeError, KeyError): return None async def _delete_token(self, redis: Any, user_id: str, provider: str) -> None: await redis.delete(self._token_key(user_id, provider)) # -- Link codes (chat-initiated flow) ----------------------------------- # -- Authorization URL --------------------------------------------------
[docs] def get_authorize_url( self, provider: str, state: str, scopes: list[str] | None = None, ) -> str: p = self.providers.get(provider) if p is None: raise ValueError(f"Unknown provider: {provider}") redirect_uri = f"{self.base_url}/oauth/{provider}/callback" scope_str = " ".join(scopes or p.scopes) params = { "client_id": p.client_id, "redirect_uri": redirect_uri, "response_type": "code", "scope": scope_str, "state": state, } params.update(p.extra_auth_params) from urllib.parse import urlencode return f"{p.authorize_url}?{urlencode(params)}"
# -- Token exchange -----------------------------------------------------
[docs] async def exchange_code( self, provider: str, code: str, ) -> TokenData: p = self.providers[provider] redirect_uri = f"{self.base_url}/oauth/{provider}/callback" data: dict[str, str] = { "client_id": p.client_id, "client_secret": p.client_secret, "code": code, "redirect_uri": redirect_uri, "grant_type": "authorization_code", } headers: dict[str, str] = {} if provider == "github": headers["Accept"] = "application/json" else: headers["Content-Type"] = "application/x-www-form-urlencoded" async with aiohttp.ClientSession() as session: async with session.post(p.token_url, data=data, headers=headers) as resp: if resp.status != 200: body = await resp.text() raise RuntimeError(f"Token exchange failed ({resp.status}): {body}") token_data = await resp.json(content_type=None) expires_in = token_data.get("expires_in", 0) return TokenData( access_token=token_data.get("access_token", ""), refresh_token=token_data.get("refresh_token", ""), expires_at=time.time() + int(expires_in) if expires_in else 0, scopes=(token_data.get("scope", "") or "").split() if isinstance(token_data.get("scope"), str) else token_data.get("scope", []), token_type=token_data.get("token_type", "Bearer"), provider=provider, )
# -- Token refresh ------------------------------------------------------ async def _refresh_token( self, redis: Any, user_id: str, token: TokenData, ) -> TokenData: p = self.providers[token.provider] if not token.refresh_token: raise RuntimeError(f"No refresh token available for {token.provider}") data: dict[str, str] = { "client_id": p.client_id, "client_secret": p.client_secret, "refresh_token": token.refresh_token, "grant_type": "refresh_token", } async with aiohttp.ClientSession() as session: async with session.post(p.token_url, data=data) as resp: if resp.status != 200: body = await resp.text() raise RuntimeError(f"Token refresh failed ({resp.status}): {body}") new_data = await resp.json(content_type=None) expires_in = new_data.get("expires_in", 0) refreshed = TokenData( access_token=new_data.get("access_token", token.access_token), refresh_token=new_data.get("refresh_token", token.refresh_token), expires_at=time.time() + int(expires_in) if expires_in else 0, scopes=token.scopes, token_type=new_data.get("token_type", token.token_type), provider=token.provider, ) await self._store_token(redis, user_id, refreshed) return refreshed # -- Public API for tools -----------------------------------------------
[docs] async def get_token( self, user_id: str, provider: str, redis: Any, ) -> str | None: """Return a valid access token, refreshing if needed. Returns None if not connected.""" if redis is None: return None token = await self._load_token(redis, user_id, provider) if token is None: return None if token.is_expired and token.refresh_token: try: token = await self._refresh_token(redis, user_id, token) except Exception: logger.exception("Failed to refresh %s token for user %s", provider, user_id) await self._delete_token(redis, user_id, provider) return None return token.access_token if token.access_token else None
[docs] async def store_token( self, redis: Any, user_id: str, token: TokenData, ) -> None: await self._store_token(redis, user_id, token)
[docs] async def delete_token( self, redis: Any, user_id: str, provider: str, ) -> None: token = await self._load_token(redis, user_id, provider) if token and self.providers.get(provider, OAuthProvider(name="")).revoke_url: try: await self._revoke(provider, token) except Exception: logger.warning("Failed to revoke %s token for user %s", provider, user_id) await self._delete_token(redis, user_id, provider)
async def _revoke(self, provider: str, token: TokenData) -> None: p = self.providers[provider] if not p.revoke_url: return async with aiohttp.ClientSession() as session: if provider == "google": await session.post(p.revoke_url, params={"token": token.access_token}) elif provider == "discord": await session.post( p.revoke_url, data={"token": token.access_token, "client_id": p.client_id, "client_secret": p.client_secret}, )
[docs] async def list_user_connections( self, user_id: str, redis: Any, ) -> list[dict[str, Any]]: """Return a list of providers the user has connected.""" if redis is None: return [] result = [] for name in self.providers: token = await self._load_token(redis, user_id, name) if token and token.access_token: result.append({ "provider": name, "scopes": token.scopes, "expires_at": token.expires_at if token.expires_at > 0 else None, "has_refresh_token": bool(token.refresh_token), }) return result
[docs] async def has_token(self, user_id: str, provider: str, redis: Any) -> bool: if redis is None: return False token = await self._load_token(redis, user_id, provider) return token is not None and bool(token.access_token)
# -- Chat-initiated connect URL -----------------------------------------
[docs] async def generate_connect_url( self, user_id: str, provider: str, redis: Any, scopes: list[str] | None = None, ) -> str: link_code = await self.create_link_code(redis, user_id) state = json.dumps({"link_code": link_code, "provider": provider}) state_encoded = secrets.token_urlsafe(16) await redis.set( f"stargazer:oauth_state:{state_encoded}", state, ex=LINK_CODE_TTL, ) return self.get_authorize_url(provider, state_encoded, scopes=scopes)
# --------------------------------------------------------------------------- # Singleton accessor # --------------------------------------------------------------------------- _instance: OAuthManager | None = None
[docs] def get_oauth_manager() -> OAuthManager: """Return the global OAuthManager singleton (must be initialized first).""" if _instance is None: raise RuntimeError("OAuthManager not initialized -- call init_oauth_manager() first") return _instance
[docs] def init_oauth_manager( encryption_key: str = "", base_url: str = "", providers_config: dict[str, dict[str, Any]] | None = None, ) -> OAuthManager: global _instance _instance = OAuthManager( encryption_key=encryption_key, base_url=base_url, providers_config=providers_config, ) return _instance
[docs] async def require_oauth_token(ctx: Any, provider: str) -> str: """Get a valid token or raise OAuthNotConnected with a connect link.""" mgr = get_oauth_manager() token = await mgr.get_token(ctx.user_id, provider, ctx.redis) if token is not None: return token connect_url = await mgr.generate_connect_url(ctx.user_id, provider, ctx.redis) raise OAuthNotConnected(provider, connect_url)