"""Per-user OAuth token management with encrypted storage.
Handles the full OAuth2 authorization-code flow for multiple providers
(GitHub, Google, Discord, Microsoft), stores tokens encrypted at rest
in Redis, and transparently refreshes expired access tokens.
Redis key pattern:
stargazer:oauth_tokens:{user_id}:{provider} (Fernet-encrypted JSON)
stargazer:oauth_link:{link_code} (one-time link code -> user_id)
"""
from __future__ import annotations
import jsonutil as json
import logging
import secrets
import time
from dataclasses import dataclass, field
from typing import Any, TYPE_CHECKING
import aiohttp
from cryptography.fernet import Fernet, InvalidToken
if TYPE_CHECKING:
pass
logger = logging.getLogger(__name__)
REDIS_TOKEN_PREFIX = "stargazer:oauth_tokens"
REDIS_LINK_PREFIX = "stargazer:oauth_link"
LINK_CODE_TTL = 300 # 5 minutes
REFRESH_BUFFER_SECONDS = 120 # refresh if expiring within 2 minutes
# ---------------------------------------------------------------------------
# Provider definitions
# ---------------------------------------------------------------------------
[docs]
@dataclass
class OAuthProvider:
"""Configuration for a single OAuth2 provider.
Holds the static endpoints (authorize/token/revoke URLs) and the
per-deployment credentials and scopes for one provider, so the rest of the
manager can build authorization URLs and exchange/refresh tokens without
re-deriving provider-specific quirks. Instances are created in
:meth:`OAuthManager.__init__` by merging a deployment's ``providers_config``
over the static :data:`PROVIDER_TEMPLATES`, and are read by
:meth:`OAuthManager.get_authorize_url`, :meth:`OAuthManager.exchange_code`,
:meth:`OAuthManager._refresh_token` and :meth:`OAuthManager._revoke`.
Args:
name (str): Provider name (e.g. ``"github"``, ``"google"``).
authorize_url (str): OAuth2 authorization endpoint.
token_url (str): OAuth2 token endpoint for code exchange and refresh.
client_id (str): OAuth client id for this deployment.
client_secret (str): OAuth client secret for this deployment.
scopes (list[str]): Default scopes requested at authorization time.
tokens_expire (bool): Whether issued access tokens expire (``False`` for
providers like GitHub whose tokens are long-lived).
revoke_url (str): Optional revocation endpoint; empty when unsupported.
extra_auth_params (dict[str, str]): Extra query params merged into the
authorize URL (e.g. Google's ``access_type``/``prompt``).
"""
name: str
authorize_url: str
token_url: str
client_id: str = ""
client_secret: str = ""
scopes: list[str] = field(default_factory=list)
tokens_expire: bool = True
revoke_url: str = ""
extra_auth_params: dict[str, str] = field(default_factory=dict)
PROVIDER_TEMPLATES: dict[str, dict[str, Any]] = {
"github": {
"authorize_url": "https://github.com/login/oauth/authorize",
"token_url": "https://github.com/login/oauth/access_token",
"tokens_expire": False,
"revoke_url": "",
"default_scopes": ["repo", "read:user", "gist", "notifications"],
},
"google": {
"authorize_url": "https://accounts.google.com/o/oauth2/v2/auth",
"token_url": "https://oauth2.googleapis.com/token",
"tokens_expire": True,
"revoke_url": "https://oauth2.googleapis.com/revoke",
"default_scopes": [
"https://www.googleapis.com/auth/drive.file",
"https://www.googleapis.com/auth/gmail.modify",
"https://www.googleapis.com/auth/calendar",
],
"extra_auth_params": {"access_type": "offline", "prompt": "consent"},
},
"discord": {
"authorize_url": "https://discord.com/api/oauth2/authorize",
"token_url": "https://discord.com/api/oauth2/token",
"tokens_expire": True,
"revoke_url": "https://discord.com/api/oauth2/token/revoke",
"default_scopes": ["identify", "guilds", "email", "connections"],
},
"microsoft": {
"authorize_url": "https://login.microsoftonline.com/common/oauth2/v2.0/authorize",
"token_url": "https://login.microsoftonline.com/common/oauth2/v2.0/token",
"tokens_expire": True,
"revoke_url": "",
"default_scopes": [
"User.Read",
"Files.ReadWrite",
"Mail.ReadWrite",
"Calendars.ReadWrite",
"offline_access",
],
},
}
[docs]
@dataclass
class TokenData:
"""Decrypted OAuth token bundle.
The in-memory, plaintext representation of a user's OAuth credentials for
one provider. It carries the access/refresh tokens, an absolute
``expires_at`` epoch and the granted scopes, and is the unit that
:class:`OAuthManager` serializes (via :meth:`to_dict`) and Fernet-encrypts
before writing to Redis, and reconstructs (via :meth:`from_dict`) after
decrypting. Produced by :meth:`OAuthManager.exchange_code` and
:meth:`OAuthManager._refresh_token`, and consumed throughout the manager's
load/store/refresh/revoke paths.
Args:
access_token (str): The bearer access token used to call provider APIs.
refresh_token (str): The refresh token, if the provider issued one.
expires_at (float): Absolute epoch seconds at which the access token
expires; ``0`` means it never expires.
scopes (list[str]): Scopes actually granted by the provider.
token_type (str): Token type, normally ``"Bearer"``.
provider (str): Provider name this bundle belongs to.
"""
access_token: str
refresh_token: str = ""
expires_at: float = 0.0
scopes: list[str] = field(default_factory=list)
token_type: str = "Bearer"
provider: str = ""
@property
def is_expired(self) -> bool:
"""Report whether the access token is at or past its refresh threshold.
Treats a non-positive ``expires_at`` as a non-expiring token (e.g.
GitHub, which issues tokens that never expire) and otherwise compares
the current wall-clock time against ``expires_at`` minus
``REFRESH_BUFFER_SECONDS`` so refresh happens slightly before true
expiry.
Consulted by :meth:`OAuthManager.get_token` to decide whether a
proactive refresh is required before handing the token to a tool. No
internal callers exist elsewhere in this module.
Returns:
bool: ``True`` if the token should be refreshed (within the buffer
window of expiry), ``False`` if it is still valid or never expires.
"""
if self.expires_at <= 0:
return False
return time.time() >= (self.expires_at - REFRESH_BUFFER_SECONDS)
[docs]
def to_dict(self) -> dict[str, Any]:
"""Serialize this token bundle to a plain JSON-compatible dict.
Produces the exact shape persisted to Redis. Called by
:meth:`OAuthManager._store_token`, which JSON-encodes and Fernet-encrypts
the result before writing it under the per-user token key. The inverse is
:meth:`from_dict`.
Returns:
dict[str, Any]: Mapping with ``access_token``, ``refresh_token``,
``expires_at``, ``scopes``, ``token_type`` and ``provider`` keys.
"""
return {
"access_token": self.access_token,
"refresh_token": self.refresh_token,
"expires_at": self.expires_at,
"scopes": self.scopes,
"token_type": self.token_type,
"provider": self.provider,
}
[docs]
@classmethod
def from_dict(cls, d: dict[str, Any]) -> TokenData:
"""Reconstruct a :class:`TokenData` from its serialized dict form.
Coerces ``expires_at`` to ``float`` and fills sensible defaults for any
missing keys, making it tolerant of partial or legacy payloads. Called by
:meth:`OAuthManager._load_token` after decrypting and JSON-decoding the
stored token. Inverse of :meth:`to_dict`.
Args:
d (dict[str, Any]): Decoded token mapping as produced by
:meth:`to_dict` (any subset of its keys).
Returns:
TokenData: A token bundle populated from ``d``.
"""
return cls(
access_token=d.get("access_token", ""),
refresh_token=d.get("refresh_token", ""),
expires_at=float(d.get("expires_at", 0)),
scopes=d.get("scopes", []),
token_type=d.get("token_type", "Bearer"),
provider=d.get("provider", ""),
)
[docs]
class OAuthNotConnected(Exception):
"""Raised when a tool needs an OAuth token the user hasn't provided.
Signals that a provider-backed tool cannot run because the user has not yet
authorized that provider. The exception carries the provider name and a
one-time connect URL so the surrounding tool layer can turn it into a
user-facing "click here to connect" prompt rather than a hard error. Raised
by :func:`require_oauth_token` and surfaced to the chat user by the OAuth
tools (``tools/microsoft_tools.py``, ``tools/google_oauth_tools.py``,
``tools/github_tools.py``, ``tools/discord_user_tools.py``).
"""
[docs]
def __init__(self, provider: str, connect_url: str) -> None:
"""Build the exception with a user-facing connect prompt.
Stores ``provider`` and ``connect_url`` as attributes and composes a
human-readable message instructing the user to click the link to connect
the provider. Raised by :func:`require_oauth_token` when a tool needs a
token the user has not yet granted; the message is surfaced back to the
chat user.
Args:
provider (str): Provider name the user must connect (e.g. ``"github"``).
connect_url (str): One-time OAuth authorization URL the user should
open to grant access.
"""
self.provider = provider
self.connect_url = connect_url
super().__init__(
f"You haven't connected {provider} yet. "
f"Click here to connect: {connect_url}"
)
# ---------------------------------------------------------------------------
# Manager
# ---------------------------------------------------------------------------
[docs]
class OAuthManager:
"""Core OAuth2 token lifecycle manager.
Owns the full per-user OAuth story for every configured provider: building
authorization URLs, exchanging authorization codes, transparently refreshing
expired access tokens, revoking on disconnect, and persisting every token
Fernet-encrypted in Redis under ``stargazer:oauth_tokens:{user_id}:{provider}``.
It also drives the chat-initiated connect flow via one-time link codes
(``stargazer:oauth_link:{code}``) and opaque state tokens
(``stargazer:oauth_state:{state}``). A single instance is created and held as
the module singleton by :func:`init_oauth_manager` / :func:`get_oauth_manager`,
and is reached by the web OAuth routes (``web/auth_routes.py``), the
``connect_service`` tool (``tools/connect_service.py``) and, via
:func:`require_oauth_token`, by every provider-backed tool.
"""
[docs]
def __init__(
self,
encryption_key: str = "",
base_url: str = "",
providers_config: dict[str, dict[str, Any]] | None = None,
) -> None:
"""Initialize the manager with encryption, base URL and provider config.
Builds a :class:`~cryptography.fernet.Fernet` cipher from
``encryption_key`` (used to encrypt tokens at rest); if the key is
invalid it logs an error and leaves ``self._fernet`` as ``None``, in
which case tokens are stored in plaintext. Merges per-deployment
``providers_config`` (client IDs/secrets/scopes) over the static
:data:`PROVIDER_TEMPLATES` to populate ``self.providers`` with one
:class:`OAuthProvider` per supported provider.
Constructed via :func:`init_oauth_manager`, which assigns the result to
the module-level singleton returned by :func:`get_oauth_manager`.
Args:
encryption_key (str): Fernet key (str or bytes) for at-rest token
encryption; empty disables encryption.
base_url (str): Public base URL of the web service, used to build
``/oauth/{provider}/callback`` redirect URIs; trailing slash is
stripped.
providers_config (dict[str, dict[str, Any]] | None): Optional
per-provider overrides keyed by provider name, each supplying
``client_id``, ``client_secret`` and/or ``scopes``.
"""
self._fernet: Fernet | None = None
if encryption_key:
try:
self._fernet = Fernet(
encryption_key.encode()
if isinstance(encryption_key, str)
else encryption_key
)
except Exception:
logger.error(
"Invalid OAuth encryption key -- tokens will NOT be encrypted"
)
self.base_url = base_url.rstrip("/") if base_url else ""
self.providers: dict[str, OAuthProvider] = {}
providers_config = providers_config or {}
for name, template in PROVIDER_TEMPLATES.items():
user_cfg = providers_config.get(name, {})
self.providers[name] = OAuthProvider(
name=name,
authorize_url=template["authorize_url"],
token_url=template["token_url"],
client_id=user_cfg.get("client_id", ""),
client_secret=user_cfg.get("client_secret", ""),
scopes=user_cfg.get("scopes", template.get("default_scopes", [])),
tokens_expire=template.get("tokens_expire", True),
revoke_url=template.get("revoke_url", ""),
extra_auth_params=template.get("extra_auth_params", {}),
)
# -- Encryption helpers -------------------------------------------------
def _encrypt(self, plaintext: str) -> str:
"""Fernet-encrypt a string, or return it unchanged if encryption is off.
When ``self._fernet`` is ``None`` (no/invalid key) the plaintext is
returned verbatim. Called by :meth:`_store_token` to protect the JSON
token payload before it is written to Redis.
Args:
plaintext (str): Value to encrypt (the serialized token JSON).
Returns:
str: The base64 Fernet ciphertext, or ``plaintext`` if encryption
is disabled.
"""
if self._fernet is None:
return plaintext
return self._fernet.encrypt(plaintext.encode()).decode()
def _decrypt(self, ciphertext: str) -> str:
"""Fernet-decrypt a stored value, returning empty on failure.
Mirrors :meth:`_encrypt`: returns the input unchanged when encryption is
disabled, and on an :class:`~cryptography.fernet.InvalidToken` (wrong key
or corrupted data) logs a warning and returns an empty string. Called by
:meth:`_load_token` to recover the JSON token payload read from Redis.
Args:
ciphertext (str): Stored Fernet ciphertext (or plaintext when
encryption is disabled).
Returns:
str: The decrypted plaintext, the input unchanged if encryption is
off, or ``""`` if decryption failed.
"""
if self._fernet is None:
return ciphertext
try:
return self._fernet.decrypt(ciphertext.encode()).decode()
except InvalidToken:
logger.warning("Failed to decrypt OAuth token -- returning empty")
return ""
# -- Redis helpers ------------------------------------------------------
@staticmethod
def _token_key(user_id: str, provider: str) -> str:
"""Build the Redis key under which a user's token for a provider lives.
Combines :data:`REDIS_TOKEN_PREFIX` with the user and provider to form
``stargazer:oauth_tokens:{user_id}:{provider}``. Used by
:meth:`_store_token`, :meth:`_load_token` and :meth:`_delete_token` for
all per-user token reads and writes.
Args:
user_id (str): Stargazer user identifier.
provider (str): Provider name (e.g. ``"discord"``).
Returns:
str: The fully-qualified Redis key for that user/provider token.
"""
return f"{REDIS_TOKEN_PREFIX}:{user_id}:{provider}"
async def _store_token(self, redis: Any, user_id: str, token: TokenData) -> None:
"""Encrypt and persist a token bundle to Redis for a user.
Serializes ``token`` via :meth:`TokenData.to_dict`, JSON-encodes it,
encrypts with :meth:`_encrypt`, and writes the result to the key from
:meth:`_token_key` (provider taken from ``token.provider``). Called by
the public :meth:`store_token`, by :meth:`_refresh_token` after a
successful refresh, and indirectly from the web callback route.
Args:
redis (Any): Async Redis client supporting ``set``.
user_id (str): Owner of the token.
token (TokenData): Token bundle to persist; its ``provider`` field
determines the key.
"""
payload = json.dumps(token.to_dict())
encrypted = self._encrypt(payload)
await redis.set(self._token_key(user_id, token.provider), encrypted)
async def _load_token(
self, redis: Any, user_id: str, provider: str
) -> TokenData | None:
"""Fetch, decrypt and deserialize a user's stored token, if any.
Reads the key from :meth:`_token_key`, decodes bytes to ``str`` if
needed, decrypts via :meth:`_decrypt`, and rebuilds the bundle with
:meth:`TokenData.from_dict`. Returns ``None`` when the key is absent,
decryption yields empty, or the JSON is malformed. Called by
:meth:`get_token`, :meth:`delete_token`, :meth:`list_user_connections`
and :meth:`has_token`.
Args:
redis (Any): Async Redis client supporting ``get``.
user_id (str): Owner of the token.
provider (str): Provider whose token to load.
Returns:
TokenData | None: The decrypted token, or ``None`` if missing or
undecryptable.
"""
raw = await redis.get(self._token_key(user_id, provider))
if raw is None:
return None
decrypted = self._decrypt(raw if isinstance(raw, str) else raw.decode())
if not decrypted:
return None
try:
return TokenData.from_dict(json.loads(decrypted))
except (json.JSONDecodeError, KeyError):
return None
async def _delete_token(self, redis: Any, user_id: str, provider: str) -> None:
"""Remove a user's stored token for a provider from Redis.
Deletes the key from :meth:`_token_key`. Called by the public
:meth:`delete_token` (after attempting remote revocation) and by
:meth:`get_token` when a refresh attempt fails, so that stale or broken
credentials are not retried.
Args:
redis (Any): Async Redis client supporting ``delete``.
user_id (str): Owner of the token.
provider (str): Provider whose token to delete.
"""
await redis.delete(self._token_key(user_id, provider))
# -- Link codes (chat-initiated flow) -----------------------------------
[docs]
async def create_link_code(self, redis: Any, user_id: str) -> str:
"""Mint a one-time link code mapping to a user for the chat connect flow.
Generates a URL-safe random token and stores it in Redis under
``stargazer:oauth_link:{code}`` -> ``user_id`` with a
:data:`LINK_CODE_TTL` expiry, so a browser hitting the OAuth flow can be
tied back to the chat user who requested it. Called by
:meth:`generate_connect_url`; the code is later consumed by
:meth:`resolve_link_code`.
Args:
redis (Any): Async Redis client supporting ``set`` with ``ex``.
user_id (str): User to associate with the generated code.
Returns:
str: The freshly minted, single-use link code.
"""
code = secrets.token_urlsafe(32)
await redis.set(
f"{REDIS_LINK_PREFIX}:{code}",
user_id,
ex=LINK_CODE_TTL,
)
return code
[docs]
async def resolve_link_code(self, redis: Any, code: str) -> str | None:
"""Consume a one-time link code and return its associated user id.
Looks up ``stargazer:oauth_link:{code}`` and, on a hit, deletes the key
so the code cannot be reused, then returns the user id (decoding bytes if
necessary). Called by the OAuth callback route in
``web/auth_routes.py`` to recover which chat user initiated a connect
when the session has no logged-in user.
Args:
redis (Any): Async Redis client supporting ``get`` and ``delete``.
code (str): The link code previously issued by
:meth:`create_link_code`.
Returns:
str | None: The associated user id, or ``None`` if the code is
unknown or expired.
"""
key = f"{REDIS_LINK_PREFIX}:{code}"
raw = await redis.get(key)
if raw is None:
return None
await redis.delete(key)
return raw if isinstance(raw, str) else raw.decode()
# -- Authorization URL --------------------------------------------------
[docs]
def get_authorize_url(
self,
provider: str,
state: str,
scopes: list[str] | None = None,
) -> str:
"""Build the provider's OAuth2 authorization URL for a redirect.
Assembles the standard authorization-code query string (client id,
redirect URI derived from ``self.base_url``, ``response_type=code``,
space-joined scopes, and the opaque ``state``) and merges in any
provider-specific ``extra_auth_params`` such as Google's
``access_type=offline``/``prompt=consent``. Called by the web authorize
route (``web/auth_routes.py``) and by :meth:`generate_connect_url` for
the chat-initiated flow.
Args:
provider (str): Provider to authorize against.
state (str): Opaque CSRF/round-trip state token echoed back to the
callback.
scopes (list[str] | None): Optional scope override; defaults to the
provider's configured scopes.
Returns:
str: The fully-formed authorization URL to redirect the user to.
Raises:
ValueError: If ``provider`` is not a known provider.
"""
p = self.providers.get(provider)
if p is None:
raise ValueError(f"Unknown provider: {provider}")
redirect_uri = f"{self.base_url}/oauth/{provider}/callback"
scope_str = " ".join(scopes or p.scopes)
params = {
"client_id": p.client_id,
"redirect_uri": redirect_uri,
"response_type": "code",
"scope": scope_str,
"state": state,
}
params.update(p.extra_auth_params)
from urllib.parse import urlencode
return f"{p.authorize_url}?{urlencode(params)}"
# -- Token exchange -----------------------------------------------------
[docs]
async def exchange_code(
self,
provider: str,
code: str,
) -> TokenData:
"""Exchange an authorization code for an access/refresh token bundle.
Performs the OAuth2 token-endpoint POST (``grant_type=authorization_code``)
for the provider, sending ``Accept: application/json`` for GitHub and a
form-encoded content type otherwise, then normalizes the response into a
:class:`TokenData` (computing an absolute ``expires_at`` from
``expires_in`` and splitting a space-delimited ``scope`` string into a
list). Called by the OAuth callback route in ``web/auth_routes.py``,
which subsequently persists the result via :meth:`store_token`.
Args:
provider (str): Provider the code was issued by.
code (str): The authorization code returned to the redirect URI.
Returns:
TokenData: The decrypted token bundle from the provider's response.
Raises:
RuntimeError: If the token endpoint returns a non-200 status.
"""
p = self.providers[provider]
redirect_uri = f"{self.base_url}/oauth/{provider}/callback"
data: dict[str, str] = {
"client_id": p.client_id,
"client_secret": p.client_secret,
"code": code,
"redirect_uri": redirect_uri,
"grant_type": "authorization_code",
}
headers: dict[str, str] = {}
if provider == "github":
headers["Accept"] = "application/json"
else:
headers["Content-Type"] = "application/x-www-form-urlencoded"
async with aiohttp.ClientSession() as session:
async with session.post(p.token_url, data=data, headers=headers) as resp:
if resp.status != 200:
body = await resp.text()
raise RuntimeError(f"Token exchange failed ({resp.status}): {body}")
token_data = await resp.json(content_type=None)
expires_in = token_data.get("expires_in", 0)
return TokenData(
access_token=token_data.get("access_token", ""),
refresh_token=token_data.get("refresh_token", ""),
expires_at=time.time() + int(expires_in) if expires_in else 0,
scopes=(
(token_data.get("scope", "") or "").split()
if isinstance(token_data.get("scope"), str)
else token_data.get("scope", [])
),
token_type=token_data.get("token_type", "Bearer"),
provider=provider,
)
# -- Token refresh ------------------------------------------------------
async def _refresh_token(
self,
redis: Any,
user_id: str,
token: TokenData,
) -> TokenData:
"""Refresh an expired access token using its refresh token.
POSTs ``grant_type=refresh_token`` to the provider's token endpoint,
builds a new :class:`TokenData` (carrying forward the previous
access/refresh token and scopes when the response omits them and
recomputing ``expires_at``), and persists it via :meth:`_store_token`.
Called only by :meth:`get_token` when a loaded token reports
:attr:`TokenData.is_expired` and a refresh token is present.
Args:
redis (Any): Async Redis client, forwarded to :meth:`_store_token`.
user_id (str): Owner of the token being refreshed.
token (TokenData): The current (expired) token bundle.
Returns:
TokenData: The newly issued, persisted token bundle.
Raises:
RuntimeError: If the token has no refresh token, or the refresh
endpoint returns a non-200 status.
"""
p = self.providers[token.provider]
if not token.refresh_token:
raise RuntimeError(f"No refresh token available for {token.provider}")
data: dict[str, str] = {
"client_id": p.client_id,
"client_secret": p.client_secret,
"refresh_token": token.refresh_token,
"grant_type": "refresh_token",
}
async with aiohttp.ClientSession() as session:
async with session.post(p.token_url, data=data) as resp:
if resp.status != 200:
body = await resp.text()
raise RuntimeError(f"Token refresh failed ({resp.status}): {body}")
new_data = await resp.json(content_type=None)
expires_in = new_data.get("expires_in", 0)
refreshed = TokenData(
access_token=new_data.get("access_token", token.access_token),
refresh_token=new_data.get("refresh_token", token.refresh_token),
expires_at=time.time() + int(expires_in) if expires_in else 0,
scopes=token.scopes,
token_type=new_data.get("token_type", token.token_type),
provider=token.provider,
)
await self._store_token(redis, user_id, refreshed)
return refreshed
# -- Public API for tools -----------------------------------------------
[docs]
async def get_token(
self,
user_id: str,
provider: str,
redis: Any,
) -> str | None:
"""Return a currently valid access token for a user/provider, refreshing if needed.
The primary read path tools use to obtain credentials. It loads the
stored token via :meth:`_load_token` and, when the token reports
:attr:`TokenData.is_expired` and a refresh token is present, transparently
renews it via :meth:`_refresh_token` (which re-persists the result). If a
refresh fails it logs the exception, drops the broken credentials via
:meth:`_delete_token`, and returns ``None`` so the caller can prompt the
user to reconnect. Returns ``None`` (rather than raising) whenever
``redis`` is unavailable, no token is stored, or the token has no access
token. Called by :func:`require_oauth_token`, which wraps a ``None``
result in :class:`OAuthNotConnected`.
Args:
user_id (str): Owner of the token.
provider (str): Provider whose token to return.
redis (Any): Async Redis client, or ``None``.
Returns:
str | None: A valid access token string, or ``None`` if the user is
not connected or the token could not be refreshed.
"""
if redis is None:
return None
token = await self._load_token(redis, user_id, provider)
if token is None:
return None
if token.is_expired and token.refresh_token:
try:
token = await self._refresh_token(redis, user_id, token)
except Exception:
logger.exception(
"Failed to refresh %s token for user %s", provider, user_id
)
await self._delete_token(redis, user_id, provider)
return None
return token.access_token if token.access_token else None
[docs]
async def store_token(
self,
redis: Any,
user_id: str,
token: TokenData,
) -> None:
"""Persist a token bundle for a user (public wrapper of storage).
Thin public entry point that delegates to :meth:`_store_token` (encrypt
+ write to Redis). Called by the OAuth callback route in
``web/auth_routes.py`` after :meth:`exchange_code` to save freshly
obtained credentials.
Args:
redis (Any): Async Redis client.
user_id (str): Owner of the token.
token (TokenData): Token bundle to store.
"""
await self._store_token(redis, user_id, token)
[docs]
async def delete_token(
self,
redis: Any,
user_id: str,
provider: str,
) -> None:
"""Revoke (best-effort) and delete a user's stored token.
Loads the token, and if the provider exposes a ``revoke_url`` attempts
remote revocation via :meth:`_revoke` (failures are logged and ignored),
then always removes the local copy via :meth:`_delete_token`. Called by
the OAuth disconnect route in ``web/auth_routes.py`` and by the
``connect_service`` disconnect tool (``tools/connect_service.py``).
Args:
redis (Any): Async Redis client.
user_id (str): Owner of the token to disconnect.
provider (str): Provider to disconnect.
"""
token = await self._load_token(redis, user_id, provider)
if token and self.providers.get(provider, OAuthProvider(name="")).revoke_url:
try:
await self._revoke(provider, token)
except Exception:
logger.warning(
"Failed to revoke %s token for user %s", provider, user_id
)
await self._delete_token(redis, user_id, provider)
async def _revoke(self, provider: str, token: TokenData) -> None:
"""Tell the provider to invalidate a token, where supported.
Sends a provider-specific revocation request to the provider's
``revoke_url``: Google receives the token as a query param, Discord as a
form body including client credentials; providers without a revoke URL
(or unsupported ones) are a no-op. Called only by :meth:`delete_token`,
wrapped in a try/except so revocation failures never block local
deletion.
Args:
provider (str): Provider whose API to call.
token (TokenData): The token bundle being revoked.
"""
p = self.providers[provider]
if not p.revoke_url:
return
async with aiohttp.ClientSession() as session:
if provider == "google":
await session.post(p.revoke_url, params={"token": token.access_token})
elif provider == "discord":
await session.post(
p.revoke_url,
data={
"token": token.access_token,
"client_id": p.client_id,
"client_secret": p.client_secret,
},
)
[docs]
async def list_user_connections(
self,
user_id: str,
redis: Any,
) -> list[dict[str, Any]]:
"""Return a summary of every provider the user currently has connected.
Iterates over all known providers, loading each token via
:meth:`_load_token`, and includes only those with a non-empty access
token. For each it reports the provider name, granted scopes, an
``expires_at`` (``None`` for non-expiring tokens) and whether a refresh
token is held, but never the secret token values themselves. This does
not trigger a refresh. Called by the web account/status route
(``web/auth_routes.py``) and the ``connect_service`` tool
(``tools/connect_service.py``) to render the user's connection list.
Args:
user_id (str): User whose connections to enumerate.
redis (Any): Async Redis client, or ``None``.
Returns:
list[dict[str, Any]]: One dict per connected provider with
``provider``, ``scopes``, ``expires_at`` and ``has_refresh_token``
keys; empty when ``redis`` is ``None`` or nothing is connected.
"""
if redis is None:
return []
result = []
for name in self.providers:
token = await self._load_token(redis, user_id, name)
if token and token.access_token:
result.append(
{
"provider": name,
"scopes": token.scopes,
"expires_at": (
token.expires_at if token.expires_at > 0 else None
),
"has_refresh_token": bool(token.refresh_token),
}
)
return result
[docs]
async def has_token(self, user_id: str, provider: str, redis: Any) -> bool:
"""Report whether a user currently has a stored token for a provider.
Loads the token via :meth:`_load_token` and checks it has a non-empty
access token; returns ``False`` when ``redis`` is ``None``. Unlike
:meth:`get_token`, this does not trigger a refresh. Called by the
``connect_service`` tool (``tools/connect_service.py``) to check and
report connection status.
Args:
user_id (str): User to check.
provider (str): Provider to check.
redis (Any): Async Redis client, or ``None``.
Returns:
bool: ``True`` if a token with a non-empty access token exists,
otherwise ``False``.
"""
if redis is None:
return False
token = await self._load_token(redis, user_id, provider)
return token is not None and bool(token.access_token)
# -- Chat-initiated connect URL -----------------------------------------
[docs]
async def generate_connect_url(
self,
user_id: str,
provider: str,
redis: Any,
scopes: list[str] | None = None,
) -> str:
"""Produce a ready-to-click OAuth connect URL for a chat user.
Mints a one-time link code via :meth:`create_link_code`, stores a
JSON state payload (link code + provider) under
``stargazer:oauth_state:{state}`` in Redis with a :data:`LINK_CODE_TTL`
expiry keyed by a random opaque state token, and returns the provider
authorization URL built by :meth:`get_authorize_url`. Called by
:func:`require_oauth_token` (raised inside :class:`OAuthNotConnected`)
and by the ``connect_service`` tool (``tools/connect_service.py``) so a
user can authorize from chat.
Args:
user_id (str): User initiating the connection from chat.
provider (str): Provider to connect.
redis (Any): Async Redis client supporting ``set`` with ``ex``.
scopes (list[str] | None): Optional scope override forwarded to
:meth:`get_authorize_url`.
Returns:
str: The provider authorization URL the user should open.
"""
link_code = await self.create_link_code(redis, user_id)
state = json.dumps({"link_code": link_code, "provider": provider})
state_encoded = secrets.token_urlsafe(16)
await redis.set(
f"stargazer:oauth_state:{state_encoded}",
state,
ex=LINK_CODE_TTL,
)
return self.get_authorize_url(provider, state_encoded, scopes=scopes)
# ---------------------------------------------------------------------------
# Singleton accessor
# ---------------------------------------------------------------------------
_instance: OAuthManager | None = None
[docs]
def get_oauth_manager() -> OAuthManager:
"""Return the process-wide :class:`OAuthManager` singleton.
Accessor for the single manager instance created at service startup by
:func:`init_oauth_manager`; it reads the module-level ``_instance`` and
raises if initialization has not happened yet, so callers never get a
half-configured manager. Called by the web OAuth routes
(``web/auth_routes.py``), the ``connect_service`` tool
(``tools/connect_service.py``) and :func:`require_oauth_token`.
Returns:
OAuthManager: The initialized global manager.
Raises:
RuntimeError: If :func:`init_oauth_manager` has not been called yet.
"""
if _instance is None:
raise RuntimeError(
"OAuthManager not initialized -- call init_oauth_manager() first"
)
return _instance
[docs]
def init_oauth_manager(
encryption_key: str = "",
base_url: str = "",
providers_config: dict[str, dict[str, Any]] | None = None,
) -> OAuthManager:
"""Construct the global :class:`OAuthManager` and store it as the singleton.
Instantiates an :class:`OAuthManager` with the given encryption key, public
base URL, and per-provider credentials, assigns it to the module-level
``_instance``, and returns it. This is the initialization entry point that
must run before :func:`get_oauth_manager` (typically at service startup); no
internal callers exist in this repository.
Args:
encryption_key (str): Fernet key for at-rest token encryption (empty
disables encryption).
base_url (str): Public base URL used to derive OAuth callback redirect
URIs.
providers_config (dict[str, dict[str, Any]] | None): Optional
per-provider credential/scope overrides.
Returns:
OAuthManager: The newly created and now-global manager instance.
"""
global _instance
_instance = OAuthManager(
encryption_key=encryption_key,
base_url=base_url,
providers_config=providers_config,
)
return _instance
[docs]
async def require_oauth_token(ctx: Any, provider: str) -> str:
"""Return a valid access token for the current user or demand a connection.
The single guard every provider-backed tool calls to obtain credentials. It
pulls the manager from :func:`get_oauth_manager` and asks
:meth:`OAuthManager.get_token` for a (possibly refreshed) token using the
user and Redis client carried on the tool ``ctx``. On success it returns the
bare access-token string; otherwise it mints a one-time connect URL via
:meth:`OAuthManager.generate_connect_url` and raises
:class:`OAuthNotConnected`, which the tool layer surfaces to the chat user as
a clickable connect prompt. Called by ``_get_token`` helpers in
``tools/microsoft_tools.py``, ``tools/google_oauth_tools.py``,
``tools/github_tools.py`` and ``tools/discord_user_tools.py``.
Args:
ctx (Any): Tool context exposing ``user_id`` and ``redis``.
provider (str): Provider whose access token is required.
Returns:
str: A valid access token for ``provider``.
Raises:
OAuthNotConnected: If the user has not connected the provider (or the
token could not be refreshed), carrying a connect URL.
"""
mgr = get_oauth_manager()
token = await mgr.get_token(ctx.user_id, provider, ctx.redis)
if token is not None:
return token
connect_url = await mgr.generate_connect_url(ctx.user_id, provider, ctx.redis)
raise OAuthNotConnected(provider, connect_url)