Source code for oauth_manager

"""Per-user OAuth token management with encrypted storage.

Handles the full OAuth2 authorization-code flow for multiple providers
(GitHub, Google, Discord, Microsoft), stores tokens encrypted at rest
in Redis, and transparently refreshes expired access tokens.

Redis key pattern:
    stargazer:oauth_tokens:{user_id}:{provider}   (Fernet-encrypted JSON)
    stargazer:oauth_link:{link_code}               (one-time link code -> user_id)
"""

from __future__ import annotations

import jsonutil as json
import logging
import secrets
import time
from dataclasses import dataclass, field
from typing import Any, TYPE_CHECKING

import aiohttp
from cryptography.fernet import Fernet, InvalidToken

if TYPE_CHECKING:
    pass

logger = logging.getLogger(__name__)

REDIS_TOKEN_PREFIX = "stargazer:oauth_tokens"
REDIS_LINK_PREFIX = "stargazer:oauth_link"
LINK_CODE_TTL = 300  # 5 minutes

REFRESH_BUFFER_SECONDS = 120  # refresh if expiring within 2 minutes


# ---------------------------------------------------------------------------
# Provider definitions
# ---------------------------------------------------------------------------


[docs] @dataclass class OAuthProvider: """Configuration for a single OAuth2 provider. Holds the static endpoints (authorize/token/revoke URLs) and the per-deployment credentials and scopes for one provider, so the rest of the manager can build authorization URLs and exchange/refresh tokens without re-deriving provider-specific quirks. Instances are created in :meth:`OAuthManager.__init__` by merging a deployment's ``providers_config`` over the static :data:`PROVIDER_TEMPLATES`, and are read by :meth:`OAuthManager.get_authorize_url`, :meth:`OAuthManager.exchange_code`, :meth:`OAuthManager._refresh_token` and :meth:`OAuthManager._revoke`. Args: name (str): Provider name (e.g. ``"github"``, ``"google"``). authorize_url (str): OAuth2 authorization endpoint. token_url (str): OAuth2 token endpoint for code exchange and refresh. client_id (str): OAuth client id for this deployment. client_secret (str): OAuth client secret for this deployment. scopes (list[str]): Default scopes requested at authorization time. tokens_expire (bool): Whether issued access tokens expire (``False`` for providers like GitHub whose tokens are long-lived). revoke_url (str): Optional revocation endpoint; empty when unsupported. extra_auth_params (dict[str, str]): Extra query params merged into the authorize URL (e.g. Google's ``access_type``/``prompt``). """ name: str authorize_url: str token_url: str client_id: str = "" client_secret: str = "" scopes: list[str] = field(default_factory=list) tokens_expire: bool = True revoke_url: str = "" extra_auth_params: dict[str, str] = field(default_factory=dict)
PROVIDER_TEMPLATES: dict[str, dict[str, Any]] = { "github": { "authorize_url": "https://github.com/login/oauth/authorize", "token_url": "https://github.com/login/oauth/access_token", "tokens_expire": False, "revoke_url": "", "default_scopes": ["repo", "read:user", "gist", "notifications"], }, "google": { "authorize_url": "https://accounts.google.com/o/oauth2/v2/auth", "token_url": "https://oauth2.googleapis.com/token", "tokens_expire": True, "revoke_url": "https://oauth2.googleapis.com/revoke", "default_scopes": [ "https://www.googleapis.com/auth/drive.file", "https://www.googleapis.com/auth/gmail.modify", "https://www.googleapis.com/auth/calendar", ], "extra_auth_params": {"access_type": "offline", "prompt": "consent"}, }, "discord": { "authorize_url": "https://discord.com/api/oauth2/authorize", "token_url": "https://discord.com/api/oauth2/token", "tokens_expire": True, "revoke_url": "https://discord.com/api/oauth2/token/revoke", "default_scopes": ["identify", "guilds", "email", "connections"], }, "microsoft": { "authorize_url": "https://login.microsoftonline.com/common/oauth2/v2.0/authorize", "token_url": "https://login.microsoftonline.com/common/oauth2/v2.0/token", "tokens_expire": True, "revoke_url": "", "default_scopes": [ "User.Read", "Files.ReadWrite", "Mail.ReadWrite", "Calendars.ReadWrite", "offline_access", ], }, }
[docs] @dataclass class TokenData: """Decrypted OAuth token bundle. The in-memory, plaintext representation of a user's OAuth credentials for one provider. It carries the access/refresh tokens, an absolute ``expires_at`` epoch and the granted scopes, and is the unit that :class:`OAuthManager` serializes (via :meth:`to_dict`) and Fernet-encrypts before writing to Redis, and reconstructs (via :meth:`from_dict`) after decrypting. Produced by :meth:`OAuthManager.exchange_code` and :meth:`OAuthManager._refresh_token`, and consumed throughout the manager's load/store/refresh/revoke paths. Args: access_token (str): The bearer access token used to call provider APIs. refresh_token (str): The refresh token, if the provider issued one. expires_at (float): Absolute epoch seconds at which the access token expires; ``0`` means it never expires. scopes (list[str]): Scopes actually granted by the provider. token_type (str): Token type, normally ``"Bearer"``. provider (str): Provider name this bundle belongs to. """ access_token: str refresh_token: str = "" expires_at: float = 0.0 scopes: list[str] = field(default_factory=list) token_type: str = "Bearer" provider: str = "" @property def is_expired(self) -> bool: """Report whether the access token is at or past its refresh threshold. Treats a non-positive ``expires_at`` as a non-expiring token (e.g. GitHub, which issues tokens that never expire) and otherwise compares the current wall-clock time against ``expires_at`` minus ``REFRESH_BUFFER_SECONDS`` so refresh happens slightly before true expiry. Consulted by :meth:`OAuthManager.get_token` to decide whether a proactive refresh is required before handing the token to a tool. No internal callers exist elsewhere in this module. Returns: bool: ``True`` if the token should be refreshed (within the buffer window of expiry), ``False`` if it is still valid or never expires. """ if self.expires_at <= 0: return False return time.time() >= (self.expires_at - REFRESH_BUFFER_SECONDS)
[docs] def to_dict(self) -> dict[str, Any]: """Serialize this token bundle to a plain JSON-compatible dict. Produces the exact shape persisted to Redis. Called by :meth:`OAuthManager._store_token`, which JSON-encodes and Fernet-encrypts the result before writing it under the per-user token key. The inverse is :meth:`from_dict`. Returns: dict[str, Any]: Mapping with ``access_token``, ``refresh_token``, ``expires_at``, ``scopes``, ``token_type`` and ``provider`` keys. """ return { "access_token": self.access_token, "refresh_token": self.refresh_token, "expires_at": self.expires_at, "scopes": self.scopes, "token_type": self.token_type, "provider": self.provider, }
[docs] @classmethod def from_dict(cls, d: dict[str, Any]) -> TokenData: """Reconstruct a :class:`TokenData` from its serialized dict form. Coerces ``expires_at`` to ``float`` and fills sensible defaults for any missing keys, making it tolerant of partial or legacy payloads. Called by :meth:`OAuthManager._load_token` after decrypting and JSON-decoding the stored token. Inverse of :meth:`to_dict`. Args: d (dict[str, Any]): Decoded token mapping as produced by :meth:`to_dict` (any subset of its keys). Returns: TokenData: A token bundle populated from ``d``. """ return cls( access_token=d.get("access_token", ""), refresh_token=d.get("refresh_token", ""), expires_at=float(d.get("expires_at", 0)), scopes=d.get("scopes", []), token_type=d.get("token_type", "Bearer"), provider=d.get("provider", ""), )
[docs] class OAuthNotConnected(Exception): """Raised when a tool needs an OAuth token the user hasn't provided. Signals that a provider-backed tool cannot run because the user has not yet authorized that provider. The exception carries the provider name and a one-time connect URL so the surrounding tool layer can turn it into a user-facing "click here to connect" prompt rather than a hard error. Raised by :func:`require_oauth_token` and surfaced to the chat user by the OAuth tools (``tools/microsoft_tools.py``, ``tools/google_oauth_tools.py``, ``tools/github_tools.py``, ``tools/discord_user_tools.py``). """
[docs] def __init__(self, provider: str, connect_url: str) -> None: """Build the exception with a user-facing connect prompt. Stores ``provider`` and ``connect_url`` as attributes and composes a human-readable message instructing the user to click the link to connect the provider. Raised by :func:`require_oauth_token` when a tool needs a token the user has not yet granted; the message is surfaced back to the chat user. Args: provider (str): Provider name the user must connect (e.g. ``"github"``). connect_url (str): One-time OAuth authorization URL the user should open to grant access. """ self.provider = provider self.connect_url = connect_url super().__init__( f"You haven't connected {provider} yet. " f"Click here to connect: {connect_url}" )
# --------------------------------------------------------------------------- # Manager # ---------------------------------------------------------------------------
[docs] class OAuthManager: """Core OAuth2 token lifecycle manager. Owns the full per-user OAuth story for every configured provider: building authorization URLs, exchanging authorization codes, transparently refreshing expired access tokens, revoking on disconnect, and persisting every token Fernet-encrypted in Redis under ``stargazer:oauth_tokens:{user_id}:{provider}``. It also drives the chat-initiated connect flow via one-time link codes (``stargazer:oauth_link:{code}``) and opaque state tokens (``stargazer:oauth_state:{state}``). A single instance is created and held as the module singleton by :func:`init_oauth_manager` / :func:`get_oauth_manager`, and is reached by the web OAuth routes (``web/auth_routes.py``), the ``connect_service`` tool (``tools/connect_service.py``) and, via :func:`require_oauth_token`, by every provider-backed tool. """
[docs] def __init__( self, encryption_key: str = "", base_url: str = "", providers_config: dict[str, dict[str, Any]] | None = None, ) -> None: """Initialize the manager with encryption, base URL and provider config. Builds a :class:`~cryptography.fernet.Fernet` cipher from ``encryption_key`` (used to encrypt tokens at rest); if the key is invalid it logs an error and leaves ``self._fernet`` as ``None``, in which case tokens are stored in plaintext. Merges per-deployment ``providers_config`` (client IDs/secrets/scopes) over the static :data:`PROVIDER_TEMPLATES` to populate ``self.providers`` with one :class:`OAuthProvider` per supported provider. Constructed via :func:`init_oauth_manager`, which assigns the result to the module-level singleton returned by :func:`get_oauth_manager`. Args: encryption_key (str): Fernet key (str or bytes) for at-rest token encryption; empty disables encryption. base_url (str): Public base URL of the web service, used to build ``/oauth/{provider}/callback`` redirect URIs; trailing slash is stripped. providers_config (dict[str, dict[str, Any]] | None): Optional per-provider overrides keyed by provider name, each supplying ``client_id``, ``client_secret`` and/or ``scopes``. """ self._fernet: Fernet | None = None if encryption_key: try: self._fernet = Fernet( encryption_key.encode() if isinstance(encryption_key, str) else encryption_key ) except Exception: logger.error( "Invalid OAuth encryption key -- tokens will NOT be encrypted" ) self.base_url = base_url.rstrip("/") if base_url else "" self.providers: dict[str, OAuthProvider] = {} providers_config = providers_config or {} for name, template in PROVIDER_TEMPLATES.items(): user_cfg = providers_config.get(name, {}) self.providers[name] = OAuthProvider( name=name, authorize_url=template["authorize_url"], token_url=template["token_url"], client_id=user_cfg.get("client_id", ""), client_secret=user_cfg.get("client_secret", ""), scopes=user_cfg.get("scopes", template.get("default_scopes", [])), tokens_expire=template.get("tokens_expire", True), revoke_url=template.get("revoke_url", ""), extra_auth_params=template.get("extra_auth_params", {}), )
[docs] def is_provider_configured(self, provider: str) -> bool: """Report whether a provider has usable OAuth credentials. A provider counts as configured only when it is known and has both a non-empty ``client_id`` and ``client_secret``. Called by the web authorize route and the ``connect_service`` tool (``tools/connect_service.py``, ``web/auth_routes.py``) to gate connect attempts, and internally by :meth:`list_configured_providers`. Args: provider (str): Provider name to check (e.g. ``"google"``). Returns: bool: ``True`` if the provider exists and has both client credentials set, otherwise ``False``. """ p = self.providers.get(provider) return p is not None and bool(p.client_id and p.client_secret)
[docs] def list_configured_providers(self) -> list[str]: """Return the names of all providers with complete credentials. Filters ``self.providers`` through :meth:`is_provider_configured`. Called by the web OAuth status/callback routes (``web/auth_routes.py``) and the ``connect_service`` tool to present the user the set of connectable services. Returns: list[str]: Provider names that have both a client ID and secret. """ return [n for n in self.providers if self.is_provider_configured(n)]
# -- Encryption helpers ------------------------------------------------- def _encrypt(self, plaintext: str) -> str: """Fernet-encrypt a string, or return it unchanged if encryption is off. When ``self._fernet`` is ``None`` (no/invalid key) the plaintext is returned verbatim. Called by :meth:`_store_token` to protect the JSON token payload before it is written to Redis. Args: plaintext (str): Value to encrypt (the serialized token JSON). Returns: str: The base64 Fernet ciphertext, or ``plaintext`` if encryption is disabled. """ if self._fernet is None: return plaintext return self._fernet.encrypt(plaintext.encode()).decode() def _decrypt(self, ciphertext: str) -> str: """Fernet-decrypt a stored value, returning empty on failure. Mirrors :meth:`_encrypt`: returns the input unchanged when encryption is disabled, and on an :class:`~cryptography.fernet.InvalidToken` (wrong key or corrupted data) logs a warning and returns an empty string. Called by :meth:`_load_token` to recover the JSON token payload read from Redis. Args: ciphertext (str): Stored Fernet ciphertext (or plaintext when encryption is disabled). Returns: str: The decrypted plaintext, the input unchanged if encryption is off, or ``""`` if decryption failed. """ if self._fernet is None: return ciphertext try: return self._fernet.decrypt(ciphertext.encode()).decode() except InvalidToken: logger.warning("Failed to decrypt OAuth token -- returning empty") return "" # -- Redis helpers ------------------------------------------------------ @staticmethod def _token_key(user_id: str, provider: str) -> str: """Build the Redis key under which a user's token for a provider lives. Combines :data:`REDIS_TOKEN_PREFIX` with the user and provider to form ``stargazer:oauth_tokens:{user_id}:{provider}``. Used by :meth:`_store_token`, :meth:`_load_token` and :meth:`_delete_token` for all per-user token reads and writes. Args: user_id (str): Stargazer user identifier. provider (str): Provider name (e.g. ``"discord"``). Returns: str: The fully-qualified Redis key for that user/provider token. """ return f"{REDIS_TOKEN_PREFIX}:{user_id}:{provider}" async def _store_token(self, redis: Any, user_id: str, token: TokenData) -> None: """Encrypt and persist a token bundle to Redis for a user. Serializes ``token`` via :meth:`TokenData.to_dict`, JSON-encodes it, encrypts with :meth:`_encrypt`, and writes the result to the key from :meth:`_token_key` (provider taken from ``token.provider``). Called by the public :meth:`store_token`, by :meth:`_refresh_token` after a successful refresh, and indirectly from the web callback route. Args: redis (Any): Async Redis client supporting ``set``. user_id (str): Owner of the token. token (TokenData): Token bundle to persist; its ``provider`` field determines the key. """ payload = json.dumps(token.to_dict()) encrypted = self._encrypt(payload) await redis.set(self._token_key(user_id, token.provider), encrypted) async def _load_token( self, redis: Any, user_id: str, provider: str ) -> TokenData | None: """Fetch, decrypt and deserialize a user's stored token, if any. Reads the key from :meth:`_token_key`, decodes bytes to ``str`` if needed, decrypts via :meth:`_decrypt`, and rebuilds the bundle with :meth:`TokenData.from_dict`. Returns ``None`` when the key is absent, decryption yields empty, or the JSON is malformed. Called by :meth:`get_token`, :meth:`delete_token`, :meth:`list_user_connections` and :meth:`has_token`. Args: redis (Any): Async Redis client supporting ``get``. user_id (str): Owner of the token. provider (str): Provider whose token to load. Returns: TokenData | None: The decrypted token, or ``None`` if missing or undecryptable. """ raw = await redis.get(self._token_key(user_id, provider)) if raw is None: return None decrypted = self._decrypt(raw if isinstance(raw, str) else raw.decode()) if not decrypted: return None try: return TokenData.from_dict(json.loads(decrypted)) except (json.JSONDecodeError, KeyError): return None async def _delete_token(self, redis: Any, user_id: str, provider: str) -> None: """Remove a user's stored token for a provider from Redis. Deletes the key from :meth:`_token_key`. Called by the public :meth:`delete_token` (after attempting remote revocation) and by :meth:`get_token` when a refresh attempt fails, so that stale or broken credentials are not retried. Args: redis (Any): Async Redis client supporting ``delete``. user_id (str): Owner of the token. provider (str): Provider whose token to delete. """ await redis.delete(self._token_key(user_id, provider)) # -- Link codes (chat-initiated flow) ----------------------------------- # -- Authorization URL --------------------------------------------------
[docs] def get_authorize_url( self, provider: str, state: str, scopes: list[str] | None = None, ) -> str: """Build the provider's OAuth2 authorization URL for a redirect. Assembles the standard authorization-code query string (client id, redirect URI derived from ``self.base_url``, ``response_type=code``, space-joined scopes, and the opaque ``state``) and merges in any provider-specific ``extra_auth_params`` such as Google's ``access_type=offline``/``prompt=consent``. Called by the web authorize route (``web/auth_routes.py``) and by :meth:`generate_connect_url` for the chat-initiated flow. Args: provider (str): Provider to authorize against. state (str): Opaque CSRF/round-trip state token echoed back to the callback. scopes (list[str] | None): Optional scope override; defaults to the provider's configured scopes. Returns: str: The fully-formed authorization URL to redirect the user to. Raises: ValueError: If ``provider`` is not a known provider. """ p = self.providers.get(provider) if p is None: raise ValueError(f"Unknown provider: {provider}") redirect_uri = f"{self.base_url}/oauth/{provider}/callback" scope_str = " ".join(scopes or p.scopes) params = { "client_id": p.client_id, "redirect_uri": redirect_uri, "response_type": "code", "scope": scope_str, "state": state, } params.update(p.extra_auth_params) from urllib.parse import urlencode return f"{p.authorize_url}?{urlencode(params)}"
# -- Token exchange -----------------------------------------------------
[docs] async def exchange_code( self, provider: str, code: str, ) -> TokenData: """Exchange an authorization code for an access/refresh token bundle. Performs the OAuth2 token-endpoint POST (``grant_type=authorization_code``) for the provider, sending ``Accept: application/json`` for GitHub and a form-encoded content type otherwise, then normalizes the response into a :class:`TokenData` (computing an absolute ``expires_at`` from ``expires_in`` and splitting a space-delimited ``scope`` string into a list). Called by the OAuth callback route in ``web/auth_routes.py``, which subsequently persists the result via :meth:`store_token`. Args: provider (str): Provider the code was issued by. code (str): The authorization code returned to the redirect URI. Returns: TokenData: The decrypted token bundle from the provider's response. Raises: RuntimeError: If the token endpoint returns a non-200 status. """ p = self.providers[provider] redirect_uri = f"{self.base_url}/oauth/{provider}/callback" data: dict[str, str] = { "client_id": p.client_id, "client_secret": p.client_secret, "code": code, "redirect_uri": redirect_uri, "grant_type": "authorization_code", } headers: dict[str, str] = {} if provider == "github": headers["Accept"] = "application/json" else: headers["Content-Type"] = "application/x-www-form-urlencoded" async with aiohttp.ClientSession() as session: async with session.post(p.token_url, data=data, headers=headers) as resp: if resp.status != 200: body = await resp.text() raise RuntimeError(f"Token exchange failed ({resp.status}): {body}") token_data = await resp.json(content_type=None) expires_in = token_data.get("expires_in", 0) return TokenData( access_token=token_data.get("access_token", ""), refresh_token=token_data.get("refresh_token", ""), expires_at=time.time() + int(expires_in) if expires_in else 0, scopes=( (token_data.get("scope", "") or "").split() if isinstance(token_data.get("scope"), str) else token_data.get("scope", []) ), token_type=token_data.get("token_type", "Bearer"), provider=provider, )
# -- Token refresh ------------------------------------------------------ async def _refresh_token( self, redis: Any, user_id: str, token: TokenData, ) -> TokenData: """Refresh an expired access token using its refresh token. POSTs ``grant_type=refresh_token`` to the provider's token endpoint, builds a new :class:`TokenData` (carrying forward the previous access/refresh token and scopes when the response omits them and recomputing ``expires_at``), and persists it via :meth:`_store_token`. Called only by :meth:`get_token` when a loaded token reports :attr:`TokenData.is_expired` and a refresh token is present. Args: redis (Any): Async Redis client, forwarded to :meth:`_store_token`. user_id (str): Owner of the token being refreshed. token (TokenData): The current (expired) token bundle. Returns: TokenData: The newly issued, persisted token bundle. Raises: RuntimeError: If the token has no refresh token, or the refresh endpoint returns a non-200 status. """ p = self.providers[token.provider] if not token.refresh_token: raise RuntimeError(f"No refresh token available for {token.provider}") data: dict[str, str] = { "client_id": p.client_id, "client_secret": p.client_secret, "refresh_token": token.refresh_token, "grant_type": "refresh_token", } async with aiohttp.ClientSession() as session: async with session.post(p.token_url, data=data) as resp: if resp.status != 200: body = await resp.text() raise RuntimeError(f"Token refresh failed ({resp.status}): {body}") new_data = await resp.json(content_type=None) expires_in = new_data.get("expires_in", 0) refreshed = TokenData( access_token=new_data.get("access_token", token.access_token), refresh_token=new_data.get("refresh_token", token.refresh_token), expires_at=time.time() + int(expires_in) if expires_in else 0, scopes=token.scopes, token_type=new_data.get("token_type", token.token_type), provider=token.provider, ) await self._store_token(redis, user_id, refreshed) return refreshed # -- Public API for tools -----------------------------------------------
[docs] async def get_token( self, user_id: str, provider: str, redis: Any, ) -> str | None: """Return a currently valid access token for a user/provider, refreshing if needed. The primary read path tools use to obtain credentials. It loads the stored token via :meth:`_load_token` and, when the token reports :attr:`TokenData.is_expired` and a refresh token is present, transparently renews it via :meth:`_refresh_token` (which re-persists the result). If a refresh fails it logs the exception, drops the broken credentials via :meth:`_delete_token`, and returns ``None`` so the caller can prompt the user to reconnect. Returns ``None`` (rather than raising) whenever ``redis`` is unavailable, no token is stored, or the token has no access token. Called by :func:`require_oauth_token`, which wraps a ``None`` result in :class:`OAuthNotConnected`. Args: user_id (str): Owner of the token. provider (str): Provider whose token to return. redis (Any): Async Redis client, or ``None``. Returns: str | None: A valid access token string, or ``None`` if the user is not connected or the token could not be refreshed. """ if redis is None: return None token = await self._load_token(redis, user_id, provider) if token is None: return None if token.is_expired and token.refresh_token: try: token = await self._refresh_token(redis, user_id, token) except Exception: logger.exception( "Failed to refresh %s token for user %s", provider, user_id ) await self._delete_token(redis, user_id, provider) return None return token.access_token if token.access_token else None
[docs] async def store_token( self, redis: Any, user_id: str, token: TokenData, ) -> None: """Persist a token bundle for a user (public wrapper of storage). Thin public entry point that delegates to :meth:`_store_token` (encrypt + write to Redis). Called by the OAuth callback route in ``web/auth_routes.py`` after :meth:`exchange_code` to save freshly obtained credentials. Args: redis (Any): Async Redis client. user_id (str): Owner of the token. token (TokenData): Token bundle to store. """ await self._store_token(redis, user_id, token)
[docs] async def delete_token( self, redis: Any, user_id: str, provider: str, ) -> None: """Revoke (best-effort) and delete a user's stored token. Loads the token, and if the provider exposes a ``revoke_url`` attempts remote revocation via :meth:`_revoke` (failures are logged and ignored), then always removes the local copy via :meth:`_delete_token`. Called by the OAuth disconnect route in ``web/auth_routes.py`` and by the ``connect_service`` disconnect tool (``tools/connect_service.py``). Args: redis (Any): Async Redis client. user_id (str): Owner of the token to disconnect. provider (str): Provider to disconnect. """ token = await self._load_token(redis, user_id, provider) if token and self.providers.get(provider, OAuthProvider(name="")).revoke_url: try: await self._revoke(provider, token) except Exception: logger.warning( "Failed to revoke %s token for user %s", provider, user_id ) await self._delete_token(redis, user_id, provider)
async def _revoke(self, provider: str, token: TokenData) -> None: """Tell the provider to invalidate a token, where supported. Sends a provider-specific revocation request to the provider's ``revoke_url``: Google receives the token as a query param, Discord as a form body including client credentials; providers without a revoke URL (or unsupported ones) are a no-op. Called only by :meth:`delete_token`, wrapped in a try/except so revocation failures never block local deletion. Args: provider (str): Provider whose API to call. token (TokenData): The token bundle being revoked. """ p = self.providers[provider] if not p.revoke_url: return async with aiohttp.ClientSession() as session: if provider == "google": await session.post(p.revoke_url, params={"token": token.access_token}) elif provider == "discord": await session.post( p.revoke_url, data={ "token": token.access_token, "client_id": p.client_id, "client_secret": p.client_secret, }, )
[docs] async def list_user_connections( self, user_id: str, redis: Any, ) -> list[dict[str, Any]]: """Return a summary of every provider the user currently has connected. Iterates over all known providers, loading each token via :meth:`_load_token`, and includes only those with a non-empty access token. For each it reports the provider name, granted scopes, an ``expires_at`` (``None`` for non-expiring tokens) and whether a refresh token is held, but never the secret token values themselves. This does not trigger a refresh. Called by the web account/status route (``web/auth_routes.py``) and the ``connect_service`` tool (``tools/connect_service.py``) to render the user's connection list. Args: user_id (str): User whose connections to enumerate. redis (Any): Async Redis client, or ``None``. Returns: list[dict[str, Any]]: One dict per connected provider with ``provider``, ``scopes``, ``expires_at`` and ``has_refresh_token`` keys; empty when ``redis`` is ``None`` or nothing is connected. """ if redis is None: return [] result = [] for name in self.providers: token = await self._load_token(redis, user_id, name) if token and token.access_token: result.append( { "provider": name, "scopes": token.scopes, "expires_at": ( token.expires_at if token.expires_at > 0 else None ), "has_refresh_token": bool(token.refresh_token), } ) return result
[docs] async def has_token(self, user_id: str, provider: str, redis: Any) -> bool: """Report whether a user currently has a stored token for a provider. Loads the token via :meth:`_load_token` and checks it has a non-empty access token; returns ``False`` when ``redis`` is ``None``. Unlike :meth:`get_token`, this does not trigger a refresh. Called by the ``connect_service`` tool (``tools/connect_service.py``) to check and report connection status. Args: user_id (str): User to check. provider (str): Provider to check. redis (Any): Async Redis client, or ``None``. Returns: bool: ``True`` if a token with a non-empty access token exists, otherwise ``False``. """ if redis is None: return False token = await self._load_token(redis, user_id, provider) return token is not None and bool(token.access_token)
# -- Chat-initiated connect URL -----------------------------------------
[docs] async def generate_connect_url( self, user_id: str, provider: str, redis: Any, scopes: list[str] | None = None, ) -> str: """Produce a ready-to-click OAuth connect URL for a chat user. Mints a one-time link code via :meth:`create_link_code`, stores a JSON state payload (link code + provider) under ``stargazer:oauth_state:{state}`` in Redis with a :data:`LINK_CODE_TTL` expiry keyed by a random opaque state token, and returns the provider authorization URL built by :meth:`get_authorize_url`. Called by :func:`require_oauth_token` (raised inside :class:`OAuthNotConnected`) and by the ``connect_service`` tool (``tools/connect_service.py``) so a user can authorize from chat. Args: user_id (str): User initiating the connection from chat. provider (str): Provider to connect. redis (Any): Async Redis client supporting ``set`` with ``ex``. scopes (list[str] | None): Optional scope override forwarded to :meth:`get_authorize_url`. Returns: str: The provider authorization URL the user should open. """ link_code = await self.create_link_code(redis, user_id) state = json.dumps({"link_code": link_code, "provider": provider}) state_encoded = secrets.token_urlsafe(16) await redis.set( f"stargazer:oauth_state:{state_encoded}", state, ex=LINK_CODE_TTL, ) return self.get_authorize_url(provider, state_encoded, scopes=scopes)
# --------------------------------------------------------------------------- # Singleton accessor # --------------------------------------------------------------------------- _instance: OAuthManager | None = None
[docs] def get_oauth_manager() -> OAuthManager: """Return the process-wide :class:`OAuthManager` singleton. Accessor for the single manager instance created at service startup by :func:`init_oauth_manager`; it reads the module-level ``_instance`` and raises if initialization has not happened yet, so callers never get a half-configured manager. Called by the web OAuth routes (``web/auth_routes.py``), the ``connect_service`` tool (``tools/connect_service.py``) and :func:`require_oauth_token`. Returns: OAuthManager: The initialized global manager. Raises: RuntimeError: If :func:`init_oauth_manager` has not been called yet. """ if _instance is None: raise RuntimeError( "OAuthManager not initialized -- call init_oauth_manager() first" ) return _instance
[docs] def init_oauth_manager( encryption_key: str = "", base_url: str = "", providers_config: dict[str, dict[str, Any]] | None = None, ) -> OAuthManager: """Construct the global :class:`OAuthManager` and store it as the singleton. Instantiates an :class:`OAuthManager` with the given encryption key, public base URL, and per-provider credentials, assigns it to the module-level ``_instance``, and returns it. This is the initialization entry point that must run before :func:`get_oauth_manager` (typically at service startup); no internal callers exist in this repository. Args: encryption_key (str): Fernet key for at-rest token encryption (empty disables encryption). base_url (str): Public base URL used to derive OAuth callback redirect URIs. providers_config (dict[str, dict[str, Any]] | None): Optional per-provider credential/scope overrides. Returns: OAuthManager: The newly created and now-global manager instance. """ global _instance _instance = OAuthManager( encryption_key=encryption_key, base_url=base_url, providers_config=providers_config, ) return _instance
[docs] async def require_oauth_token(ctx: Any, provider: str) -> str: """Return a valid access token for the current user or demand a connection. The single guard every provider-backed tool calls to obtain credentials. It pulls the manager from :func:`get_oauth_manager` and asks :meth:`OAuthManager.get_token` for a (possibly refreshed) token using the user and Redis client carried on the tool ``ctx``. On success it returns the bare access-token string; otherwise it mints a one-time connect URL via :meth:`OAuthManager.generate_connect_url` and raises :class:`OAuthNotConnected`, which the tool layer surfaces to the chat user as a clickable connect prompt. Called by ``_get_token`` helpers in ``tools/microsoft_tools.py``, ``tools/google_oauth_tools.py``, ``tools/github_tools.py`` and ``tools/discord_user_tools.py``. Args: ctx (Any): Tool context exposing ``user_id`` and ``redis``. provider (str): Provider whose access token is required. Returns: str: A valid access token for ``provider``. Raises: OAuthNotConnected: If the user has not connected the provider (or the token could not be refreshed), carrying a connect URL. """ mgr = get_oauth_manager() token = await mgr.get_token(ctx.user_id, provider, ctx.redis) if token is not None: return token connect_url = await mgr.generate_connect_url(ctx.user_id, provider, ctx.redis) raise OAuthNotConnected(provider, connect_url)