Source code for tools.brave_search

"""Web search via the Brave Search API with rate limiting and key rotation."""

import aiohttp
import asyncio
import json
import logging
import time
from dataclasses import dataclass
from typing import Optional

logger = logging.getLogger(__name__)

BRAVE_SEARCH_URL = "https://api.search.brave.com/res/v1/web/search"
TAVILY_SEARCH_URL = "https://api.tavily.com/search"
TAVILY_API_KEY = "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"
MAX_RETRIES = 3
INITIAL_BACKOFF_SECONDS = 1.0


[docs] class BraveAPIKeyManager: """Round-robin API key manager with automatic rotation."""
[docs] def __init__(self): """Initialize the instance. """ self.keys: list = [] self.current_index: int = 0 self._lock: Optional[asyncio.Lock] = None
[docs] def load_keys(self, config): """Load keys from the configured source. Args: config: Bot configuration object. """ try: brave_keys = getattr(config, "API_KEYS", {}).get("brave") if isinstance(brave_keys, list): self.keys = [k for k in brave_keys if k and k.strip()] elif brave_keys and brave_keys.strip(): self.keys = [brave_keys.strip()] else: self.keys = [] if self.keys: logger.info(f"Loaded {len(self.keys)} Brave Search API key(s)") else: logger.warning("No Brave Search API keys configured") except Exception as e: logger.error(f"Failed to load Brave API keys: {e}") self.keys = []
async def _ensure_lock(self): """Internal helper: ensure lock. """ if self._lock is None: self._lock = asyncio.Lock()
[docs] async def get_next_key(self) -> Optional[str]: """Retrieve the next key. Returns: Optional[str]: The result. """ if not self.keys: return None await self._ensure_lock() async with self._lock: key = self.keys[self.current_index] self.current_index = (self.current_index + 1) % len(self.keys) return key
[docs] def rotate_key(self) -> Optional[str]: """Rotate key. Returns: Optional[str]: The result. """ if not self.keys: return None self.current_index = (self.current_index + 1) % len(self.keys) logger.info(f"Rotated to Brave API key {self.current_index + 1}/{len(self.keys)}") return self.keys[self.current_index]
[docs] def get_current_key(self) -> Optional[str]: """Retrieve the current key. Returns: Optional[str]: The result. """ if not self.keys: return None return self.keys[self.current_index]
[docs] def get_key_count(self) -> int: """Retrieve the key count. Returns: int: The result. """ return len(self.keys)
_key_manager = BraveAPIKeyManager() _keys_loaded = False async def _tavily_fallback(query: str, count: int) -> Optional[str]: """Try Tavily Search as a fallback. Returns normalised JSON or None.""" try: payload = { "query": query.strip(), "max_results": max(1, min(20, count)), "search_depth": "basic", } headers = { "Authorization": f"Bearer {TAVILY_API_KEY}", "Content-Type": "application/json", } async with aiohttp.ClientSession() as session: logger.info("Tavily fallback search: '%s'", query) async with session.post( TAVILY_SEARCH_URL, json=payload, headers=headers, timeout=aiohttp.ClientTimeout(total=30), ) as resp: if resp.status != 200: body = await resp.text() logger.warning("Tavily fallback HTTP %d: %s", resp.status, body[:300]) return None data = await resp.json() results = data.get("results") or [] formatted = { "query": query, "source": "tavily", "total_results": len(results), "results": [ { "title": r.get("title", ""), "description": r.get("content", ""), "url": r.get("url", ""), "display_url": "", "age": "", "page_age": "", "language": "", "family_friendly": True, } for r in results ], } logger.info("Tavily fallback OK: %d results", len(formatted["results"])) return json.dumps(formatted, indent=2, ensure_ascii=False) except Exception as exc: logger.warning("Tavily fallback failed: %s", exc, exc_info=True) return None
[docs] @dataclass class BraveSearchTask: """BraveSearchTask. """ query: str count: Optional[int] country: Optional[str] search_lang: Optional[str] ui_lang: Optional[str] safesearch: Optional[str] future: asyncio.Future user_api_key: Optional[str] = None user_id: Optional[str] = None redis_client: Optional[object] = None
[docs] class BraveSearchRateLimiter: """Global rate limiter that serialises Brave API calls."""
[docs] def __init__(self, calls_per_interval: float = 0.5, interval_seconds: float = 1.0): """Initialize the instance. Args: calls_per_interval (float): The calls per interval value. interval_seconds (float): The interval seconds value. """ self.calls_per_interval = calls_per_interval self.interval_seconds = interval_seconds self.task_queue: Optional[asyncio.Queue] = None self.last_call_time: float = 0.0 self._lock: Optional[asyncio.Lock] = None self._worker_task: Optional[asyncio.Task] = None self._initialized = False
async def _ensure_initialized(self): """Internal helper: ensure initialized. """ if not self._initialized: self.task_queue = asyncio.Queue() self._lock = asyncio.Lock() self._worker_task = asyncio.create_task(self._worker()) self._initialized = True async def _worker(self): """Internal helper: worker. """ while True: try: task = await self.task_queue.get() await self._wait_for_rate_limit() try: result = await self._execute_search(task) task.future.set_result(result) except Exception as e: task.future.set_exception(e) self.task_queue.task_done() except asyncio.CancelledError: break except Exception as e: logger.error(f"Error in Brave search rate limiter worker: {e}") if "task" in locals() and not task.future.done(): task.future.set_exception(e) async def _wait_for_rate_limit(self): """Internal helper: wait for rate limit. """ async with self._lock: current_time = time.time() time_since_last = current_time - self.last_call_time min_interval = self.interval_seconds / self.calls_per_interval if time_since_last < min_interval: await asyncio.sleep(min_interval - time_since_last) self.last_call_time = time.time() async def _execute_search(self, task: BraveSearchTask) -> str: """Internal helper: execute search. Args: task (BraveSearchTask): The task value. Returns: str: Result string. """ brave_result = await self._brave_search(task) try: parsed = json.loads(brave_result) except (json.JSONDecodeError, TypeError): parsed = {"error": "unparseable response"} if "error" not in parsed: parsed["source"] = "brave" return json.dumps(parsed, indent=2, ensure_ascii=False) logger.warning("Brave search failed (%s), trying Tavily fallback", parsed.get("error", "")[:120]) count = task.count if task.count is not None else 10 fallback = await _tavily_fallback(task.query, count) if fallback is not None: return fallback return brave_result async def _brave_search(self, task: BraveSearchTask) -> str: """Execute the Brave Search API request with retries.""" if task.user_api_key: api_key = task.user_api_key using_user_key = True else: api_key = await _key_manager.get_next_key() using_user_key = False if not api_key: from tools.manage_api_keys import missing_api_key_error return json.dumps({"error": missing_api_key_error("brave")}) # Rate-limit default/pool key usage (50/day for search) if not using_user_key and task.user_id and task.redis_client: from tools.manage_api_keys import check_default_key_limit, default_key_limit_error allowed, current, limit = await check_default_key_limit( task.user_id, "brave_web_search", task.redis_client, daily_limit=50, ) if not allowed: return json.dumps({"error": default_key_limit_error("brave_web_search", current, limit)}) count = task.count if count is None: count = 10 try: count = int(count) except (TypeError, ValueError): count = 10 count = max(1, min(20, count)) params = { "q": task.query.strip(), "count": count, "safesearch": task.safesearch or "moderate", } if task.country: params["country"] = task.country.upper() if task.search_lang: params["search_lang"] = task.search_lang.lower() if task.ui_lang: params["ui_lang"] = task.ui_lang.lower() last_error = None for attempt in range(MAX_RETRIES): headers = { "Accept": "application/json", "Accept-Encoding": "gzip", "X-Subscription-Token": api_key, } try: async with aiohttp.ClientSession() as session: logger.info( f"Brave search: '{task.query}' (attempt {attempt + 1}/{MAX_RETRIES}, " f"key {_key_manager.current_index + 1}/{_key_manager.get_key_count()})" ) async with session.get( BRAVE_SEARCH_URL, headers=headers, params=params, timeout=aiohttp.ClientTimeout(total=30), ) as response: if response.status == 200: data = await response.json() formatted_results = { "query": task.query, "total_results": len(data.get("web", {}).get("results", [])), "results": [], } web_results = data.get("web", {}).get("results", []) for result in web_results[:count]: item = { "title": result.get("title", ""), "description": result.get("description", ""), "url": result.get("url", ""), "display_url": result.get("display_url", ""), "age": result.get("age", ""), "page_age": result.get("page_age", ""), "language": result.get("language", ""), "family_friendly": result.get("family_friendly", True), } if "profile" in result: item["profile"] = { "name": result["profile"].get("name", ""), "url": result["profile"].get("url", ""), "long_name": result["profile"].get("long_name", ""), "img": result["profile"].get("img", ""), } formatted_results["results"].append(item) logger.info(f"Brave search OK: {len(formatted_results['results'])} results") # Increment default-key usage counter on success if not using_user_key and task.user_id and task.redis_client: from tools.manage_api_keys import increment_default_key_usage await increment_default_key_usage(task.user_id, "brave_web_search", task.redis_client) return json.dumps(formatted_results, indent=2, ensure_ascii=False) elif response.status == 429: last_error = "Rate limit exceeded" if attempt < MAX_RETRIES - 1: if not using_user_key: total_keys = _key_manager.get_key_count() new_key = _key_manager.rotate_key() if new_key and total_keys > 1: api_key = new_key backoff_time = INITIAL_BACKOFF_SECONDS * (2 ** attempt) await asyncio.sleep(backoff_time) continue return json.dumps({"error": "Rate limit exceeded - all retries exhausted"}) elif response.status == 402: return json.dumps({ "error": "payment_required", "message": ( "The Brave Search API key has exceeded its free tier quota. " "The user needs to provide their own Brave Search API key. " "They can obtain one for free at https://brave.com/search/api/ " "(the free plan allows 2,000 queries/month). " "To add their key, they should send you a DM with something like: " "'Save my Brave search API key: BSA_xxxxxxxxx' " "and you will store it securely for their future searches." ), }) elif response.status == 401: return json.dumps({"error": "Invalid API key or unauthorized access"}) elif response.status == 403: return json.dumps({"error": "Forbidden - check API key permissions"}) elif response.status == 400: return json.dumps({"error": "Bad request - check query parameters"}) else: return json.dumps({"error": f"API error: HTTP {response.status}"}) except aiohttp.ClientError as e: last_error = str(e) logger.error(f"Network error on attempt {attempt + 1}: {e}") if attempt < MAX_RETRIES - 1: backoff_time = INITIAL_BACKOFF_SECONDS * (2 ** attempt) await asyncio.sleep(backoff_time) if not using_user_key: api_key = _key_manager.rotate_key() or api_key continue return json.dumps({"error": f"Search failed after {MAX_RETRIES} attempts: {last_error}"})
[docs] async def search( self, query: str, count: Optional[int] = 10, country: Optional[str] = None, search_lang: Optional[str] = None, ui_lang: Optional[str] = None, safesearch: Optional[str] = "moderate", user_api_key: Optional[str] = None, user_id: Optional[str] = None, redis_client: Optional[object] = None, ) -> str: """Perform search. Args: query (str): Search query or input string. count (Optional[int]): Number of results to return. country (Optional[str]): The country value. search_lang (Optional[str]): The search lang value. ui_lang (Optional[str]): The ui lang value. safesearch (Optional[str]): The safesearch value. user_api_key (Optional[str]): The user api key value. user_id (Optional[str]): Unique identifier for the user. redis_client (Optional[object]): Redis connection client. Returns: str: Result string. """ if not query or not query.strip(): return json.dumps({"error": "Search query cannot be empty"}) await self._ensure_initialized() future = asyncio.Future() task = BraveSearchTask( query=query, count=count, country=country, search_lang=search_lang, ui_lang=ui_lang, safesearch=safesearch, future=future, user_api_key=user_api_key, user_id=user_id, redis_client=redis_client, ) await self.task_queue.put(task) try: return await future except Exception as e: logger.error(f"Error in Brave search: {e}") return json.dumps({"error": f"Search failed: {str(e)}"})
_rate_limiter = BraveSearchRateLimiter(calls_per_interval=0.5, interval_seconds=1.0) # --------------------------------------------------------------------------- # Public API for programmatic (non-tool) callers # ---------------------------------------------------------------------------
[docs] async def search_with_key( query: str, count: int = 3, api_key: Optional[str] = None, ) -> str: """Execute a Brave search through the shared rate limiter. This is the entry point for internal callers (e.g. :class:`~web_search_context.WebSearchContextManager`) that already have a resolved API key and don't need the tool-context ceremony. """ global _keys_loaded if not _keys_loaded: try: from config import Config _key_manager.load_keys(Config.load()) _keys_loaded = True except Exception: pass return await _rate_limiter.search( query=query, count=count, user_api_key=api_key, )
# --------------------------------------------------------------------------- # v3 tool interface # --------------------------------------------------------------------------- TOOL_NAME = "brave_web_search" TOOL_DESCRIPTION = ( "Search the web using the Brave Search API. Returns titles, " "descriptions, URLs, and metadata for matching results. Supports search " "operators in the query (e.g. site:, filetype:, intitle:, lang:) for " "precise filtering." ) TOOL_PARAMETERS = { "type": "object", "properties": { "query": { "type": "string", "description": ( "Search query. You can use Brave search operators for precise results. " "Examples: site:example.com (limit to a domain); filetype:pdf or ext:pdf " "(file type); intitle:keyword (term in title); inbody:\"exact phrase\" " "(term in body); lang:es or loc:ca (language/region); \"exact phrase\"; " "-term (exclude); +term (require); AND, OR, NOT for logic. " "E.g. \"climate change filetype:pdf site:edu\" or \"python asyncio site:docs.python.org\"." ), }, "count": { "type": "integer", "description": "Number of results to return (1-20, default 10).", }, "country": { "type": "string", "description": "Country code for localised results (e.g. US, GB, CA).", }, "search_lang": { "type": "string", "description": "Language code for search results (e.g. en, es, fr).", }, "ui_lang": { "type": "string", "description": "Language code for UI elements.", }, "safesearch": { "type": "string", "enum": ["off", "moderate", "strict"], "description": "Safe-search level (default: off).", }, }, "required": ["query"], }
[docs] async def run( query: str, count: int = 10, country: str = None, search_lang: str = None, ui_lang: str = None, safesearch: str = "off", ctx=None, ) -> str: """Execute this tool and return the result. Args: query (str): Search query or input string. count (int): Number of results to return. country (str): The country value. search_lang (str): The search lang value. ui_lang (str): The ui lang value. safesearch (str): The safesearch value. ctx: Tool execution context providing access to bot internals. Returns: str: Result string. """ global _keys_loaded if not _keys_loaded and ctx and ctx.config: _key_manager.load_keys(ctx.config) _keys_loaded = True user_api_key = None if ctx and ctx.redis and ctx.user_id: try: from tools.manage_api_keys import get_user_api_key user_api_key = await get_user_api_key( ctx.user_id, "brave", redis_client=ctx.redis, channel_id=ctx.channel_id, config=getattr(ctx, "config", None), ) except Exception: pass return await _rate_limiter.search( query=query, count=count, country=country, search_lang=search_lang, ui_lang=ui_lang, safesearch=safesearch, user_api_key=user_api_key, user_id=getattr(ctx, "user_id", None) if ctx else None, redis_client=getattr(ctx, "redis", None) if ctx else None, )