Source code for tools.cisa_kev_tools

"""Search CISA's Known Exploited Vulnerabilities (KEV) catalog.

Uses the official JSON feed from https://www.cisa.gov/known-exploited-vulnerabilities
"""

from __future__ import annotations

import asyncio
import json
import logging
import re
import time
from datetime import date
from typing import Any, TYPE_CHECKING

import aiohttp

if TYPE_CHECKING:
    from tool_context import ToolContext

logger = logging.getLogger(__name__)

CISA_KEV_JSON_URL = (
    "https://www.cisa.gov/sites/default/files/feeds/known_exploited_vulnerabilities.json"
)
CACHE_TTL_SEC = 3600
DEFAULT_LIMIT = 50
MAX_LIMIT = 200
FETCH_TIMEOUT_SEC = 60

# Keys permitted on each vulnerability record in field_filters (feed schema).
_VULN_FIELD_KEYS = frozenset({
    "cveID",
    "vendorProject",
    "product",
    "vulnerabilityName",
    "dateAdded",
    "shortDescription",
    "requiredAction",
    "dueDate",
    "knownRansomwareCampaignUse",
    "notes",
    "cwes",
})

_cache_lock = asyncio.Lock()
_cache_payload: dict[str, Any] | None = None
_cache_time: float = 0.0

TOOL_NAME = "search_cisa_kev_catalog"
TOOL_DESCRIPTION = (
    "Search the U.S. CISA Known Exploited Vulnerabilities (KEV) catalog — "
    "https://www.cisa.gov/known-exploited-vulnerabilities — using the official "
    "public JSON feed. Filter by CVE, vendor, product, text fields, CWE, date ranges, "
    "or arbitrary allowed fields via field_filters (keys: cveID, vendorProject, "
    "product, vulnerabilityName, dateAdded, shortDescription, requiredAction, "
    "dueDate, knownRansomwareCampaignUse, notes, cwes). "
    "String matching is case-insensitive substring except cveID / cve_id (exact after "
    "CVE normalization). Date bounds use ISO YYYY-MM-DD on dateAdded and dueDate. "
    "Results are cached about 1 hour; set force_refresh true to refetch immediately."
)

_PARAM_PROPS: dict[str, dict[str, Any]] = {
    "cve_id": {
        "type": "string",
        "description": "Exact CVE after normalization (e.g. CVE-2024-1234 or 2024-1234).",
    },
    "vendor_project": {
        "type": "string",
        "description": "Case-insensitive substring match on vendorProject.",
    },
    "product": {
        "type": "string",
        "description": "Case-insensitive substring match on product.",
    },
    "vulnerability_name": {
        "type": "string",
        "description": "Case-insensitive substring match on vulnerabilityName.",
    },
    "short_description": {
        "type": "string",
        "description": "Case-insensitive substring match on shortDescription.",
    },
    "required_action": {
        "type": "string",
        "description": "Case-insensitive substring match on requiredAction.",
    },
    "notes": {
        "type": "string",
        "description": "Case-insensitive substring match on notes.",
    },
    "known_ransomware_campaign_use": {
        "type": "string",
        "description": "Case-insensitive substring on knownRansomwareCampaignUse (e.g. Known, Unknown).",
    },
    "cwe": {
        "type": "string",
        "description": "Match if any entry in cwes equals or contains this (e.g. CWE-502).",
    },
    "date_added_on_or_after": {
        "type": "string",
        "description": "Inclusive lower bound on dateAdded (YYYY-MM-DD).",
    },
    "date_added_on_or_before": {
        "type": "string",
        "description": "Inclusive upper bound on dateAdded (YYYY-MM-DD).",
    },
    "due_on_or_after": {
        "type": "string",
        "description": "Inclusive lower bound on dueDate (YYYY-MM-DD).",
    },
    "due_on_or_before": {
        "type": "string",
        "description": "Inclusive upper bound on dueDate (YYYY-MM-DD).",
    },
    "field_filters": {
        "type": "object",
        "description": (
            "Extra filters: keys must be feed field names; values are strings. "
            "Substring match for text fields; cveID is exact (normalized); "
            "cwes matches if any listed CWE contains the string."
        ),
        "additionalProperties": {"type": "string"},
    },
    "limit": {
        "type": "integer",
        "description": f"Max vulnerabilities to return (default {DEFAULT_LIMIT}, max {MAX_LIMIT}).",
    },
    "force_refresh": {
        "type": "boolean",
        "description": "If true, bypass cache and download the latest feed.",
    },
}

TOOL_PARAMETERS = {
    "type": "object",
    "properties": _PARAM_PROPS,
    "required": [],
}


def _normalize_cve(raw: str) -> str:
    t = raw.strip().upper()
    if t.startswith("CVE-"):
        return t
    if re.fullmatch(r"\d{4}-\d+", t):
        return "CVE-" + t
    return t


def _parse_iso_date(label: str, s: str) -> date:
    s = s.strip()
    try:
        y, m, d = (int(p) for p in s.split("-", 2))
        return date(y, m, d)
    except (ValueError, AttributeError) as e:
        raise ValueError(f"{label} must be YYYY-MM-DD, got {s!r}") from e


def _vuln_date_iso(field: str, vuln: dict[str, Any]) -> date | None:
    raw = vuln.get(field)
    if not raw or not isinstance(raw, str):
        return None
    part = raw.strip()[:10]
    try:
        return _parse_iso_date(field, part)
    except ValueError:
        return None


def _substr_match(hay: Any, needle: str) -> bool:
    if hay is None:
        return False
    return needle.lower() in str(hay).lower()


def _cve_exact(vuln: dict[str, Any], normalized: str) -> bool:
    cid = vuln.get("cveID")
    if not isinstance(cid, str):
        return False
    return _normalize_cve(cid) == normalized


def _cwe_match(vuln: dict[str, Any], needle: str) -> bool:
    cwes = vuln.get("cwes")
    if not isinstance(cwes, list):
        return False
    nl = needle.lower()
    for c in cwes:
        if not isinstance(c, str):
            continue
        cl = c.lower()
        if nl == cl or nl in cl or cl in nl:
            return True
    return False


def _field_filter_value(vuln: dict[str, Any], key: str, value: str) -> bool:
    if key == "cveID":
        return _cve_exact(vuln, _normalize_cve(value))
    if key == "cwes":
        return _cwe_match(vuln, value)
    return _substr_match(vuln.get(key), value)


async def _fetch_json() -> dict[str, Any]:
    timeout = aiohttp.ClientTimeout(total=FETCH_TIMEOUT_SEC)
    headers = {"User-Agent": "Stargazer-KEV-Tool/1.0 (+https://www.cisa.gov/known-exploited-vulnerabilities)"}
    async with aiohttp.ClientSession(headers=headers) as session:
        async with session.get(CISA_KEV_JSON_URL, timeout=timeout) as resp:
            if resp.status != 200:
                text = (await resp.text())[:2000]
                raise RuntimeError(f"CISA KEV HTTP {resp.status}: {text}")
            return await resp.json(content_type=None)


async def _get_catalog(force_refresh: bool) -> dict[str, Any]:
    global _cache_payload, _cache_time
    async with _cache_lock:
        now = time.time()
        if (
            not force_refresh
            and _cache_payload is not None
            and now - _cache_time < CACHE_TTL_SEC
        ):
            return _cache_payload
        payload = await _fetch_json()
        _cache_payload = payload
        _cache_time = now
        return payload


def _validate_field_filters(ff: dict[str, Any] | None) -> str | None:
    if not ff:
        return None
    if not isinstance(ff, dict):
        return "field_filters must be an object mapping field names to strings."
    bad = [k for k in ff if k not in _VULN_FIELD_KEYS]
    if bad:
        return (
            "Unknown field_filters keys: "
            + ", ".join(sorted(bad))
            + ". Allowed: "
            + ", ".join(sorted(_VULN_FIELD_KEYS))
        )
    for k, v in ff.items():
        if not isinstance(v, str):
            return f"field_filters[{k!r}] must be a string."
    return None


[docs] async def run( cve_id: str = "", vendor_project: str = "", product: str = "", vulnerability_name: str = "", short_description: str = "", required_action: str = "", notes: str = "", known_ransomware_campaign_use: str = "", cwe: str = "", date_added_on_or_after: str = "", date_added_on_or_before: str = "", due_on_or_after: str = "", due_on_or_before: str = "", field_filters: dict[str, Any] | None = None, limit: int = DEFAULT_LIMIT, force_refresh: bool = False, ctx: ToolContext | None = None, ) -> str: """Search the KEV catalog and return JSON results.""" del ctx # unused err = _validate_field_filters(field_filters) if err: return json.dumps({"success": False, "error": err}) if limit < 1: limit = 1 limit = min(limit, MAX_LIMIT) bounds: dict[str, date | None] = {} for label, field_key, raw in ( ("date_added_on_or_after", "lower_da", date_added_on_or_after), ("date_added_on_or_before", "upper_da", date_added_on_or_before), ("due_on_or_after", "lower_dd", due_on_or_after), ("due_on_or_before", "upper_dd", due_on_or_before), ): if raw and str(raw).strip(): try: bounds[field_key] = _parse_iso_date(label, str(raw)) except ValueError as e: return json.dumps({"success": False, "error": str(e)}) try: payload = await _get_catalog(force_refresh) except Exception as e: logger.exception("CISA KEV fetch failed") return json.dumps({"success": False, "error": f"Failed to fetch KEV catalog: {e}"}) vulns = payload.get("vulnerabilities") if not isinstance(vulns, list): return json.dumps({ "success": False, "error": "Unexpected feed shape: missing vulnerabilities array.", }) matches: list[dict[str, Any]] = [] for v in vulns: if not isinstance(v, dict): continue ok = True if cve_id and str(cve_id).strip(): if not _cve_exact(v, _normalize_cve(str(cve_id))): ok = False if ok and vendor_project and str(vendor_project).strip(): if not _substr_match(v.get("vendorProject"), str(vendor_project).strip()): ok = False if ok and product and str(product).strip(): if not _substr_match(v.get("product"), str(product).strip()): ok = False if ok and vulnerability_name and str(vulnerability_name).strip(): if not _substr_match(v.get("vulnerabilityName"), str(vulnerability_name).strip()): ok = False if ok and short_description and str(short_description).strip(): if not _substr_match(v.get("shortDescription"), str(short_description).strip()): ok = False if ok and required_action and str(required_action).strip(): if not _substr_match(v.get("requiredAction"), str(required_action).strip()): ok = False if ok and notes and str(notes).strip(): if not _substr_match(v.get("notes"), str(notes).strip()): ok = False if ok and known_ransomware_campaign_use and str(known_ransomware_campaign_use).strip(): if not _substr_match( v.get("knownRansomwareCampaignUse"), str(known_ransomware_campaign_use).strip(), ): ok = False if ok and cwe and str(cwe).strip(): if not _cwe_match(v, str(cwe).strip()): ok = False if ok and bounds.get("lower_da") is not None: d = _vuln_date_iso("dateAdded", v) if d is None or d < bounds["lower_da"]: # type: ignore[operator] ok = False if ok and bounds.get("upper_da") is not None: d = _vuln_date_iso("dateAdded", v) if d is None or d > bounds["upper_da"]: # type: ignore[operator] ok = False if ok and bounds.get("lower_dd") is not None: d = _vuln_date_iso("dueDate", v) if d is None or d < bounds["lower_dd"]: # type: ignore[operator] ok = False if ok and bounds.get("upper_dd") is not None: d = _vuln_date_iso("dueDate", v) if d is None or d > bounds["upper_dd"]: # type: ignore[operator] ok = False if ok and field_filters: for fk, fv in field_filters.items(): if not _field_filter_value(v, fk, fv): ok = False break if ok: matches.append(v) total_matches = len(matches) truncated = total_matches > limit out_vulns = matches[:limit] return json.dumps( { "success": True, "catalogVersion": payload.get("catalogVersion"), "dateReleased": payload.get("dateReleased"), "total_in_catalog": payload.get("count"), "match_count": total_matches, "limit": limit, "truncated": truncated, "vulnerabilities": out_vulns, }, indent=2, ensure_ascii=False, )