"""Per-user OAuth token management with encrypted storage.
Handles the full OAuth2 authorization-code flow for multiple providers
(GitHub, Google, Discord, Microsoft), stores tokens encrypted at rest
in Redis, and transparently refreshes expired access tokens.
Redis key pattern:
stargazer:oauth_tokens:{user_id}:{provider} (Fernet-encrypted JSON)
stargazer:oauth_link:{link_code} (one-time link code -> user_id)
"""
from __future__ import annotations
import json
import logging
import secrets
import time
from dataclasses import dataclass, field
from typing import Any, TYPE_CHECKING
import aiohttp
from cryptography.fernet import Fernet, InvalidToken
if TYPE_CHECKING:
pass
logger = logging.getLogger(__name__)
REDIS_TOKEN_PREFIX = "stargazer:oauth_tokens"
REDIS_LINK_PREFIX = "stargazer:oauth_link"
LINK_CODE_TTL = 300 # 5 minutes
REFRESH_BUFFER_SECONDS = 120 # refresh if expiring within 2 minutes
# ---------------------------------------------------------------------------
# Provider definitions
# ---------------------------------------------------------------------------
[docs]
@dataclass
class OAuthProvider:
"""Configuration for a single OAuth2 provider."""
name: str
authorize_url: str
token_url: str
client_id: str = ""
client_secret: str = ""
scopes: list[str] = field(default_factory=list)
tokens_expire: bool = True
revoke_url: str = ""
extra_auth_params: dict[str, str] = field(default_factory=dict)
PROVIDER_TEMPLATES: dict[str, dict[str, Any]] = {
"github": {
"authorize_url": "https://github.com/login/oauth/authorize",
"token_url": "https://github.com/login/oauth/access_token",
"tokens_expire": False,
"revoke_url": "",
"default_scopes": ["repo", "read:user", "gist", "notifications"],
},
"google": {
"authorize_url": "https://accounts.google.com/o/oauth2/v2/auth",
"token_url": "https://oauth2.googleapis.com/token",
"tokens_expire": True,
"revoke_url": "https://oauth2.googleapis.com/revoke",
"default_scopes": [
"https://www.googleapis.com/auth/drive.file",
"https://www.googleapis.com/auth/gmail.modify",
"https://www.googleapis.com/auth/calendar",
],
"extra_auth_params": {"access_type": "offline", "prompt": "consent"},
},
"discord": {
"authorize_url": "https://discord.com/api/oauth2/authorize",
"token_url": "https://discord.com/api/oauth2/token",
"tokens_expire": True,
"revoke_url": "https://discord.com/api/oauth2/token/revoke",
"default_scopes": ["identify", "guilds", "email", "connections"],
},
"microsoft": {
"authorize_url": "https://login.microsoftonline.com/common/oauth2/v2.0/authorize",
"token_url": "https://login.microsoftonline.com/common/oauth2/v2.0/token",
"tokens_expire": True,
"revoke_url": "",
"default_scopes": [
"User.Read",
"Files.ReadWrite",
"Mail.ReadWrite",
"Calendars.ReadWrite",
"offline_access",
],
},
}
[docs]
@dataclass
class TokenData:
"""Decrypted OAuth token bundle."""
access_token: str
refresh_token: str = ""
expires_at: float = 0.0
scopes: list[str] = field(default_factory=list)
token_type: str = "Bearer"
provider: str = ""
@property
def is_expired(self) -> bool:
if self.expires_at <= 0:
return False
return time.time() >= (self.expires_at - REFRESH_BUFFER_SECONDS)
[docs]
def to_dict(self) -> dict[str, Any]:
return {
"access_token": self.access_token,
"refresh_token": self.refresh_token,
"expires_at": self.expires_at,
"scopes": self.scopes,
"token_type": self.token_type,
"provider": self.provider,
}
[docs]
@classmethod
def from_dict(cls, d: dict[str, Any]) -> TokenData:
return cls(
access_token=d.get("access_token", ""),
refresh_token=d.get("refresh_token", ""),
expires_at=float(d.get("expires_at", 0)),
scopes=d.get("scopes", []),
token_type=d.get("token_type", "Bearer"),
provider=d.get("provider", ""),
)
[docs]
class OAuthNotConnected(Exception):
"""Raised when a tool needs an OAuth token the user hasn't provided."""
def __init__(self, provider: str, connect_url: str) -> None:
self.provider = provider
self.connect_url = connect_url
super().__init__(
f"You haven't connected {provider} yet. "
f"Click here to connect: {connect_url}"
)
# ---------------------------------------------------------------------------
# Manager
# ---------------------------------------------------------------------------
[docs]
class OAuthManager:
"""Core OAuth2 token lifecycle manager."""
def __init__(
self,
encryption_key: str = "",
base_url: str = "",
providers_config: dict[str, dict[str, Any]] | None = None,
) -> None:
self._fernet: Fernet | None = None
if encryption_key:
try:
self._fernet = Fernet(encryption_key.encode() if isinstance(encryption_key, str) else encryption_key)
except Exception:
logger.error("Invalid OAuth encryption key -- tokens will NOT be encrypted")
self.base_url = base_url.rstrip("/") if base_url else ""
self.providers: dict[str, OAuthProvider] = {}
providers_config = providers_config or {}
for name, template in PROVIDER_TEMPLATES.items():
user_cfg = providers_config.get(name, {})
self.providers[name] = OAuthProvider(
name=name,
authorize_url=template["authorize_url"],
token_url=template["token_url"],
client_id=user_cfg.get("client_id", ""),
client_secret=user_cfg.get("client_secret", ""),
scopes=user_cfg.get("scopes", template.get("default_scopes", [])),
tokens_expire=template.get("tokens_expire", True),
revoke_url=template.get("revoke_url", ""),
extra_auth_params=template.get("extra_auth_params", {}),
)
# -- Encryption helpers -------------------------------------------------
def _encrypt(self, plaintext: str) -> str:
if self._fernet is None:
return plaintext
return self._fernet.encrypt(plaintext.encode()).decode()
def _decrypt(self, ciphertext: str) -> str:
if self._fernet is None:
return ciphertext
try:
return self._fernet.decrypt(ciphertext.encode()).decode()
except InvalidToken:
logger.warning("Failed to decrypt OAuth token -- returning empty")
return ""
# -- Redis helpers ------------------------------------------------------
@staticmethod
def _token_key(user_id: str, provider: str) -> str:
return f"{REDIS_TOKEN_PREFIX}:{user_id}:{provider}"
async def _store_token(self, redis: Any, user_id: str, token: TokenData) -> None:
payload = json.dumps(token.to_dict())
encrypted = self._encrypt(payload)
await redis.set(self._token_key(user_id, token.provider), encrypted)
async def _load_token(self, redis: Any, user_id: str, provider: str) -> TokenData | None:
raw = await redis.get(self._token_key(user_id, provider))
if raw is None:
return None
decrypted = self._decrypt(raw if isinstance(raw, str) else raw.decode())
if not decrypted:
return None
try:
return TokenData.from_dict(json.loads(decrypted))
except (json.JSONDecodeError, KeyError):
return None
async def _delete_token(self, redis: Any, user_id: str, provider: str) -> None:
await redis.delete(self._token_key(user_id, provider))
# -- Link codes (chat-initiated flow) -----------------------------------
[docs]
async def create_link_code(self, redis: Any, user_id: str) -> str:
code = secrets.token_urlsafe(32)
await redis.set(
f"{REDIS_LINK_PREFIX}:{code}",
user_id,
ex=LINK_CODE_TTL,
)
return code
[docs]
async def resolve_link_code(self, redis: Any, code: str) -> str | None:
key = f"{REDIS_LINK_PREFIX}:{code}"
raw = await redis.get(key)
if raw is None:
return None
await redis.delete(key)
return raw if isinstance(raw, str) else raw.decode()
# -- Authorization URL --------------------------------------------------
[docs]
def get_authorize_url(
self,
provider: str,
state: str,
scopes: list[str] | None = None,
) -> str:
p = self.providers.get(provider)
if p is None:
raise ValueError(f"Unknown provider: {provider}")
redirect_uri = f"{self.base_url}/oauth/{provider}/callback"
scope_str = " ".join(scopes or p.scopes)
params = {
"client_id": p.client_id,
"redirect_uri": redirect_uri,
"response_type": "code",
"scope": scope_str,
"state": state,
}
params.update(p.extra_auth_params)
from urllib.parse import urlencode
return f"{p.authorize_url}?{urlencode(params)}"
# -- Token exchange -----------------------------------------------------
[docs]
async def exchange_code(
self,
provider: str,
code: str,
) -> TokenData:
p = self.providers[provider]
redirect_uri = f"{self.base_url}/oauth/{provider}/callback"
data: dict[str, str] = {
"client_id": p.client_id,
"client_secret": p.client_secret,
"code": code,
"redirect_uri": redirect_uri,
"grant_type": "authorization_code",
}
headers: dict[str, str] = {}
if provider == "github":
headers["Accept"] = "application/json"
else:
headers["Content-Type"] = "application/x-www-form-urlencoded"
async with aiohttp.ClientSession() as session:
async with session.post(p.token_url, data=data, headers=headers) as resp:
if resp.status != 200:
body = await resp.text()
raise RuntimeError(f"Token exchange failed ({resp.status}): {body}")
token_data = await resp.json(content_type=None)
expires_in = token_data.get("expires_in", 0)
return TokenData(
access_token=token_data.get("access_token", ""),
refresh_token=token_data.get("refresh_token", ""),
expires_at=time.time() + int(expires_in) if expires_in else 0,
scopes=(token_data.get("scope", "") or "").split() if isinstance(token_data.get("scope"), str) else token_data.get("scope", []),
token_type=token_data.get("token_type", "Bearer"),
provider=provider,
)
# -- Token refresh ------------------------------------------------------
async def _refresh_token(
self,
redis: Any,
user_id: str,
token: TokenData,
) -> TokenData:
p = self.providers[token.provider]
if not token.refresh_token:
raise RuntimeError(f"No refresh token available for {token.provider}")
data: dict[str, str] = {
"client_id": p.client_id,
"client_secret": p.client_secret,
"refresh_token": token.refresh_token,
"grant_type": "refresh_token",
}
async with aiohttp.ClientSession() as session:
async with session.post(p.token_url, data=data) as resp:
if resp.status != 200:
body = await resp.text()
raise RuntimeError(f"Token refresh failed ({resp.status}): {body}")
new_data = await resp.json(content_type=None)
expires_in = new_data.get("expires_in", 0)
refreshed = TokenData(
access_token=new_data.get("access_token", token.access_token),
refresh_token=new_data.get("refresh_token", token.refresh_token),
expires_at=time.time() + int(expires_in) if expires_in else 0,
scopes=token.scopes,
token_type=new_data.get("token_type", token.token_type),
provider=token.provider,
)
await self._store_token(redis, user_id, refreshed)
return refreshed
# -- Public API for tools -----------------------------------------------
[docs]
async def get_token(
self,
user_id: str,
provider: str,
redis: Any,
) -> str | None:
"""Return a valid access token, refreshing if needed. Returns None if not connected."""
if redis is None:
return None
token = await self._load_token(redis, user_id, provider)
if token is None:
return None
if token.is_expired and token.refresh_token:
try:
token = await self._refresh_token(redis, user_id, token)
except Exception:
logger.exception("Failed to refresh %s token for user %s", provider, user_id)
await self._delete_token(redis, user_id, provider)
return None
return token.access_token if token.access_token else None
[docs]
async def store_token(
self,
redis: Any,
user_id: str,
token: TokenData,
) -> None:
await self._store_token(redis, user_id, token)
[docs]
async def delete_token(
self,
redis: Any,
user_id: str,
provider: str,
) -> None:
token = await self._load_token(redis, user_id, provider)
if token and self.providers.get(provider, OAuthProvider(name="")).revoke_url:
try:
await self._revoke(provider, token)
except Exception:
logger.warning("Failed to revoke %s token for user %s", provider, user_id)
await self._delete_token(redis, user_id, provider)
async def _revoke(self, provider: str, token: TokenData) -> None:
p = self.providers[provider]
if not p.revoke_url:
return
async with aiohttp.ClientSession() as session:
if provider == "google":
await session.post(p.revoke_url, params={"token": token.access_token})
elif provider == "discord":
await session.post(
p.revoke_url,
data={"token": token.access_token, "client_id": p.client_id, "client_secret": p.client_secret},
)
[docs]
async def list_user_connections(
self,
user_id: str,
redis: Any,
) -> list[dict[str, Any]]:
"""Return a list of providers the user has connected."""
if redis is None:
return []
result = []
for name in self.providers:
token = await self._load_token(redis, user_id, name)
if token and token.access_token:
result.append({
"provider": name,
"scopes": token.scopes,
"expires_at": token.expires_at if token.expires_at > 0 else None,
"has_refresh_token": bool(token.refresh_token),
})
return result
[docs]
async def has_token(self, user_id: str, provider: str, redis: Any) -> bool:
if redis is None:
return False
token = await self._load_token(redis, user_id, provider)
return token is not None and bool(token.access_token)
# -- Chat-initiated connect URL -----------------------------------------
[docs]
async def generate_connect_url(
self,
user_id: str,
provider: str,
redis: Any,
scopes: list[str] | None = None,
) -> str:
link_code = await self.create_link_code(redis, user_id)
state = json.dumps({"link_code": link_code, "provider": provider})
state_encoded = secrets.token_urlsafe(16)
await redis.set(
f"stargazer:oauth_state:{state_encoded}",
state,
ex=LINK_CODE_TTL,
)
return self.get_authorize_url(provider, state_encoded, scopes=scopes)
# ---------------------------------------------------------------------------
# Singleton accessor
# ---------------------------------------------------------------------------
_instance: OAuthManager | None = None
[docs]
def get_oauth_manager() -> OAuthManager:
"""Return the global OAuthManager singleton (must be initialized first)."""
if _instance is None:
raise RuntimeError("OAuthManager not initialized -- call init_oauth_manager() first")
return _instance
[docs]
def init_oauth_manager(
encryption_key: str = "",
base_url: str = "",
providers_config: dict[str, dict[str, Any]] | None = None,
) -> OAuthManager:
global _instance
_instance = OAuthManager(
encryption_key=encryption_key,
base_url=base_url,
providers_config=providers_config,
)
return _instance
[docs]
async def require_oauth_token(ctx: Any, provider: str) -> str:
"""Get a valid token or raise OAuthNotConnected with a connect link."""
mgr = get_oauth_manager()
token = await mgr.get_token(ctx.user_id, provider, ctx.redis)
if token is not None:
return token
connect_url = await mgr.generate_connect_url(ctx.user_id, provider, ctx.redis)
raise OAuthNotConnected(provider, connect_url)