Source code for tools.cisa_kev_tools

"""Search CISA's Known Exploited Vulnerabilities (KEV) catalog.

Uses the official JSON feed from https://www.cisa.gov/known-exploited-vulnerabilities
"""

from __future__ import annotations

import asyncio
import jsonutil as json
import logging
import re
import time
from datetime import date
from typing import Any, TYPE_CHECKING

import aiohttp

if TYPE_CHECKING:
    from tool_context import ToolContext

logger = logging.getLogger(__name__)

CISA_KEV_JSON_URL = "https://www.cisa.gov/sites/default/files/feeds/known_exploited_vulnerabilities.json"
CACHE_TTL_SEC = 3600
DEFAULT_LIMIT = 50
MAX_LIMIT = 200
FETCH_TIMEOUT_SEC = 60

# Keys permitted on each vulnerability record in field_filters (feed schema).
_VULN_FIELD_KEYS = frozenset(
    {
        "cveID",
        "vendorProject",
        "product",
        "vulnerabilityName",
        "dateAdded",
        "shortDescription",
        "requiredAction",
        "dueDate",
        "knownRansomwareCampaignUse",
        "notes",
        "cwes",
    }
)

_cache_lock = asyncio.Lock()
_cache_payload: dict[str, Any] | None = None
_cache_time: float = 0.0

TOOL_NAME = "search_cisa_kev_catalog"
TOOL_DESCRIPTION = (
    "Search the U.S. CISA Known Exploited Vulnerabilities (KEV) catalog — "
    "https://www.cisa.gov/known-exploited-vulnerabilities — using the official "
    "public JSON feed. Filter by CVE, vendor, product, text fields, CWE, date ranges, "
    "or arbitrary allowed fields via field_filters (keys: cveID, vendorProject, "
    "product, vulnerabilityName, dateAdded, shortDescription, requiredAction, "
    "dueDate, knownRansomwareCampaignUse, notes, cwes). "
    "String matching is case-insensitive substring except cveID / cve_id (exact after "
    "CVE normalization). Date bounds use ISO YYYY-MM-DD on dateAdded and dueDate. "
    "Results are cached about 1 hour; set force_refresh true to refetch immediately."
)

_PARAM_PROPS: dict[str, dict[str, Any]] = {
    "cve_id": {
        "type": "string",
        "description": "Exact CVE after normalization (e.g. CVE-2024-1234 or 2024-1234).",
    },
    "vendor_project": {
        "type": "string",
        "description": "Case-insensitive substring match on vendorProject.",
    },
    "product": {
        "type": "string",
        "description": "Case-insensitive substring match on product.",
    },
    "vulnerability_name": {
        "type": "string",
        "description": "Case-insensitive substring match on vulnerabilityName.",
    },
    "short_description": {
        "type": "string",
        "description": "Case-insensitive substring match on shortDescription.",
    },
    "required_action": {
        "type": "string",
        "description": "Case-insensitive substring match on requiredAction.",
    },
    "notes": {
        "type": "string",
        "description": "Case-insensitive substring match on notes.",
    },
    "known_ransomware_campaign_use": {
        "type": "string",
        "description": "Case-insensitive substring on knownRansomwareCampaignUse (e.g. Known, Unknown).",
    },
    "cwe": {
        "type": "string",
        "description": "Match if any entry in cwes equals or contains this (e.g. CWE-502).",
    },
    "date_added_on_or_after": {
        "type": "string",
        "description": "Inclusive lower bound on dateAdded (YYYY-MM-DD).",
    },
    "date_added_on_or_before": {
        "type": "string",
        "description": "Inclusive upper bound on dateAdded (YYYY-MM-DD).",
    },
    "due_on_or_after": {
        "type": "string",
        "description": "Inclusive lower bound on dueDate (YYYY-MM-DD).",
    },
    "due_on_or_before": {
        "type": "string",
        "description": "Inclusive upper bound on dueDate (YYYY-MM-DD).",
    },
    "field_filters": {
        "type": "object",
        "description": (
            "Extra filters: keys must be feed field names; values are strings. "
            "Substring match for text fields; cveID is exact (normalized); "
            "cwes matches if any listed CWE contains the string."
        ),
        "additionalProperties": {"type": "string"},
    },
    "limit": {
        "type": "integer",
        "description": f"Max vulnerabilities to return (default {DEFAULT_LIMIT}, max {MAX_LIMIT}).",
    },
    "force_refresh": {
        "type": "boolean",
        "description": "If true, bypass cache and download the latest feed.",
    },
}

TOOL_PARAMETERS = {
    "type": "object",
    "properties": _PARAM_PROPS,
    "required": [],
}


def _normalize_cve(raw: str) -> str:
    """Normalize a CVE identifier to canonical ``CVE-YYYY-NNNN`` form.

    Upper-cases and trims the input, returns it unchanged if it already starts
    with ``CVE-``, prepends ``CVE-`` to a bare ``YYYY-NNNN`` token, and
    otherwise returns the trimmed/upper-cased value as-is so callers can still
    compare it. This is purely string manipulation with no I/O or side effects.

    Interactions: called by :func:`_cve_exact` (both on the user-supplied needle
    and on each record's ``cveID``) and by :func:`_field_filter_value` when the
    filter key is ``cveID``; also invoked directly in :func:`run` to normalize
    the ``cve_id`` argument before exact matching.

    Called by no external module; used only within this file by ``_cve_exact``,
    ``_field_filter_value``, and ``run``.

    Args:
        raw: A raw CVE string such as ``"CVE-2024-1234"`` or ``"2024-1234"``.

    Returns:
        str: The normalized identifier (canonical ``CVE-...`` form when
        recognizable, otherwise the stripped/upper-cased input).
    """
    t = raw.strip().upper()
    if t.startswith("CVE-"):
        return t
    if re.fullmatch(r"\d{4}-\d+", t):
        return "CVE-" + t
    return t


def _parse_iso_date(label: str, s: str) -> date:
    """Parse an ISO ``YYYY-MM-DD`` string into a :class:`datetime.date`.

    Splits on ``-`` into at most three integer components and constructs a
    ``date``; on any parse failure it raises a ``ValueError`` whose message
    includes the supplied ``label`` so the caller can identify which argument
    was malformed. No I/O or side effects.

    Interactions: invoked by :func:`_vuln_date_iso` (where the resulting
    ``ValueError`` is caught and swallowed to mean "unparseable record date")
    and by :func:`run` when validating the four user-supplied date-bound
    arguments (where the ``ValueError`` is surfaced to the user as an error
    JSON payload).

    Called by no external module; used only within this file by
    ``_vuln_date_iso`` and ``run``.

    Args:
        label: Human-readable name of the field being parsed, embedded into the
            error message (e.g. ``"date_added_on_or_after"`` or ``"dueDate"``).
        s: The date string to parse, expected as ``YYYY-MM-DD``.

    Returns:
        date: The parsed calendar date.

    Raises:
        ValueError: If ``s`` is not a valid ``YYYY-MM-DD`` date.
    """
    s = s.strip()
    try:
        y, m, d = (int(p) for p in s.split("-", 2))
        return date(y, m, d)
    except (ValueError, AttributeError) as e:
        raise ValueError(f"{label} must be YYYY-MM-DD, got {s!r}") from e


def _vuln_date_iso(field: str, vuln: dict[str, Any]) -> date | None:
    """Extract and parse a date field from one KEV vulnerability record.

    Reads ``vuln[field]``, returning ``None`` if it is missing or not a string;
    otherwise it takes the leading 10 characters (the ``YYYY-MM-DD`` prefix) and
    parses them. Parse failures are caught and converted to ``None`` so a single
    malformed record date never aborts a search.

    Interactions: delegates to :func:`_parse_iso_date` for the actual parsing
    and catches the ``ValueError`` it may raise. Called by :func:`run` for the
    ``dateAdded`` and ``dueDate`` bound comparisons.

    Called by no external module; used only within this file by ``run``.

    Args:
        field: The record key to read, e.g. ``"dateAdded"`` or ``"dueDate"``.
        vuln: A single vulnerability record from the feed's
            ``vulnerabilities`` array.

    Returns:
        date | None: The parsed date, or ``None`` if the field is absent,
        non-string, or unparseable.
    """
    raw = vuln.get(field)
    if not raw or not isinstance(raw, str):
        return None
    part = raw.strip()[:10]
    try:
        return _parse_iso_date(field, part)
    except ValueError:
        return None


def _substr_match(hay: Any, needle: str) -> bool:
    """Return whether ``needle`` is a case-insensitive substring of ``hay``.

    Treats a ``None`` haystack as a non-match; otherwise it stringifies ``hay``
    and performs a lower-cased substring test. Pure, side-effect-free.

    Interactions: this is the core text-matching primitive. Called by
    :func:`_field_filter_value` for non-CVE/non-CWE field filters, and directly
    by :func:`run` for the ``vendor_project``, ``product``,
    ``vulnerability_name``, ``short_description``, ``required_action``,
    ``notes``, and ``known_ransomware_campaign_use`` arguments.

    Called by no external module; used only within this file by
    ``_field_filter_value`` and ``run``.

    Args:
        hay: The value to search within (any type; coerced via ``str()``).
        needle: The substring to look for (matched case-insensitively).

    Returns:
        bool: ``True`` if ``needle`` occurs in ``hay`` ignoring case, else
        ``False`` (including when ``hay`` is ``None``).
    """
    if hay is None:
        return False
    return needle.lower() in str(hay).lower()


def _cve_exact(vuln: dict[str, Any], normalized: str) -> bool:
    """Return whether a record's ``cveID`` exactly equals a normalized CVE.

    Reads ``vuln["cveID"]``, returns ``False`` if it is missing or not a string,
    and otherwise normalizes the record's CVE and compares it for equality with
    the already-normalized ``normalized`` argument. Pure, side-effect-free.

    Interactions: calls :func:`_normalize_cve` on the record value so both sides
    of the comparison are canonicalized. Called by :func:`_field_filter_value`
    for the ``cveID`` filter key, and directly by :func:`run` for the ``cve_id``
    argument (which the caller normalizes before passing in).

    Called by no external module; used only within this file by
    ``_field_filter_value`` and ``run``.

    Args:
        vuln: A single vulnerability record from the feed.
        normalized: The target CVE id, already passed through
            :func:`_normalize_cve` by the caller.

    Returns:
        bool: ``True`` if the record's normalized ``cveID`` equals
        ``normalized``, else ``False``.
    """
    cid = vuln.get("cveID")
    if not isinstance(cid, str):
        return False
    return _normalize_cve(cid) == normalized


def _cwe_match(vuln: dict[str, Any], needle: str) -> bool:
    """Return whether any CWE on a record relates to ``needle``.

    Reads the record's ``cwes`` list (returning ``False`` if it is absent or not
    a list) and, for each string entry, matches when the lower-cased needle and
    CWE are equal or either contains the other. This containment-in-either-
    direction logic lets ``CWE-502`` match ``CWE-502`` and lets a partial token
    match a longer label. Pure, side-effect-free.

    Interactions: called by :func:`_field_filter_value` for the ``cwes`` filter
    key and directly by :func:`run` for the ``cwe`` argument.

    Called by no external module; used only within this file by
    ``_field_filter_value`` and ``run``.

    Args:
        vuln: A single vulnerability record from the feed.
        needle: The CWE token to match (e.g. ``"CWE-502"``), case-insensitive.

    Returns:
        bool: ``True`` if any listed CWE equals or overlaps ``needle``, else
        ``False``.
    """
    cwes = vuln.get("cwes")
    if not isinstance(cwes, list):
        return False
    nl = needle.lower()
    for c in cwes:
        if not isinstance(c, str):
            continue
        cl = c.lower()
        if nl == cl or nl in cl or cl in nl:
            return True
    return False


def _field_filter_value(vuln: dict[str, Any], key: str, value: str) -> bool:
    """Evaluate one ``field_filters`` entry against a vulnerability record.

    Dispatches by field key: ``cveID`` uses exact (normalized) matching,
    ``cwes`` uses CWE overlap matching, and every other key falls back to
    case-insensitive substring matching on the record's value. Pure,
    side-effect-free.

    Interactions: delegates to :func:`_cve_exact` (after :func:`_normalize_cve`),
    :func:`_cwe_match`, or :func:`_substr_match` depending on ``key``. Called by
    :func:`run` once per ``field_filters`` entry; all entries must pass (logical
    AND) for a record to be included.

    Called by no external module; used only within this file by ``run``.

    Args:
        vuln: A single vulnerability record from the feed.
        key: A feed field name; one of the keys in ``_VULN_FIELD_KEYS`` (already
            validated by :func:`_validate_field_filters`).
        value: The string the field is filtered against.

    Returns:
        bool: ``True`` if the record satisfies this single filter, else
        ``False``.
    """
    if key == "cveID":
        return _cve_exact(vuln, _normalize_cve(value))
    if key == "cwes":
        return _cwe_match(vuln, value)
    return _substr_match(vuln.get(key), value)


async def _fetch_json() -> dict[str, Any]:
    """Download and decode the CISA KEV catalog JSON feed over HTTPS.

    Opens a fresh :class:`aiohttp.ClientSession` with a descriptive User-Agent,
    issues a ``GET`` against :data:`CISA_KEV_JSON_URL` bounded by a
    :data:`FETCH_TIMEOUT_SEC` total timeout, and decodes the body as JSON
    (ignoring the server's content-type). The only side effect is the outbound
    network request; no caching is performed here.

    Interactions: makes a live HTTP request to ``cisa.gov``. Called exclusively
    by :func:`_get_catalog`, which wraps it with caching and a lock.

    Called by no external module; used only within this file by
    ``_get_catalog``.

    Returns:
        dict[str, Any]: The parsed feed payload (with keys such as
        ``catalogVersion``, ``count``, and ``vulnerabilities``).

    Raises:
        RuntimeError: If the HTTP response status is not ``200`` (the message
            includes the status and a truncated body).
        aiohttp.ClientError: On transport-level failures.
        asyncio.TimeoutError: If the request exceeds the configured timeout.
    """
    timeout = aiohttp.ClientTimeout(total=FETCH_TIMEOUT_SEC)
    headers = {
        "User-Agent": "Stargazer-KEV-Tool/1.0 (+https://www.cisa.gov/known-exploited-vulnerabilities)"
    }
    async with aiohttp.ClientSession(headers=headers) as session:
        async with session.get(CISA_KEV_JSON_URL, timeout=timeout) as resp:
            if resp.status != 200:
                text = (await resp.text())[:2000]
                raise RuntimeError(f"CISA KEV HTTP {resp.status}: {text}")
            return await resp.json(content_type=None)


async def _get_catalog(force_refresh: bool) -> dict[str, Any]:
    """Return the KEV catalog payload, using a process-wide cached copy.

    Serializes access with the module-level :data:`_cache_lock` so concurrent
    searches do not trigger duplicate downloads. Returns the cached
    :data:`_cache_payload` when it exists, ``force_refresh`` is false, and it is
    younger than :data:`CACHE_TTL_SEC`; otherwise it fetches a fresh feed and
    repopulates the cache.

    Interactions: reads and writes the module globals ``_cache_payload`` and
    ``_cache_time`` under ``_cache_lock``, and calls :func:`_fetch_json` to
    perform the network download on a cache miss or forced refresh. Called by
    :func:`run` to obtain the catalog before filtering.

    Called by no external module; used only within this file by ``run``.

    Args:
        force_refresh: If ``True``, bypass the cached payload and always refetch.

    Returns:
        dict[str, Any]: The catalog payload, either from cache or freshly
        downloaded.

    Raises:
        RuntimeError: Propagated from :func:`_fetch_json` on a non-200 response.
        aiohttp.ClientError: Propagated from :func:`_fetch_json` on transport
            errors.
        asyncio.TimeoutError: Propagated from :func:`_fetch_json` on timeout.
    """
    global _cache_payload, _cache_time
    async with _cache_lock:
        now = time.time()
        if (
            not force_refresh
            and _cache_payload is not None
            and now - _cache_time < CACHE_TTL_SEC
        ):
            return _cache_payload
        payload = await _fetch_json()
        _cache_payload = payload
        _cache_time = now
        return payload


def _validate_field_filters(ff: dict[str, Any] | None) -> str | None:
    """Validate the ``field_filters`` argument and return an error message or None.

    Accepts ``None``/empty as valid (no extra filters). Otherwise it checks that
    ``ff`` is a dict, that every key is a known feed field in
    :data:`_VULN_FIELD_KEYS`, and that every value is a string, returning a
    human-readable error string on the first violation. Returning ``None`` means
    the filters are well-formed. Pure, side-effect-free.

    Interactions: consults the :data:`_VULN_FIELD_KEYS` allow-list. Called by
    :func:`run` before any catalog work; a non-``None`` return is serialized
    straight into the tool's failure JSON.

    Called by no external module; used only within this file by ``run``.

    Args:
        ff: The user-supplied ``field_filters`` mapping (or ``None``).

    Returns:
        str | None: An error description if the filters are invalid, otherwise
        ``None``.
    """
    if not ff:
        return None
    if not isinstance(ff, dict):
        return "field_filters must be an object mapping field names to strings."
    bad = [k for k in ff if k not in _VULN_FIELD_KEYS]
    if bad:
        return (
            "Unknown field_filters keys: "
            + ", ".join(sorted(bad))
            + ". Allowed: "
            + ", ".join(sorted(_VULN_FIELD_KEYS))
        )
    for k, v in ff.items():
        if not isinstance(v, str):
            return f"field_filters[{k!r}] must be a string."
    return None


[docs] async def run( cve_id: str = "", vendor_project: str = "", product: str = "", vulnerability_name: str = "", short_description: str = "", required_action: str = "", notes: str = "", known_ransomware_campaign_use: str = "", cwe: str = "", date_added_on_or_after: str = "", date_added_on_or_before: str = "", due_on_or_after: str = "", due_on_or_before: str = "", field_filters: dict[str, Any] | None = None, limit: int = DEFAULT_LIMIT, force_refresh: bool = False, ctx: ToolContext | None = None, ) -> str: """Search the CISA Known Exploited Vulnerabilities catalog and return JSON. The single entry point for the ``search_cisa_kev_catalog`` tool. It validates the request, loads the (cached) KEV feed, applies every supplied filter as a logical AND across the catalog, and returns the matching vulnerability records as a JSON document. Filtering covers exact CVE match, case-insensitive substring matches on the text fields, CWE overlap, inclusive ISO date bounds on ``dateAdded`` and ``dueDate``, and an open-ended ``field_filters`` map; the result count is capped at ``limit`` (itself clamped to :data:`MAX_LIMIT`) with a ``truncated`` flag when more matched than were returned. It first checks ``field_filters`` via :func:`_validate_field_filters`, parses the four date bounds with :func:`_parse_iso_date`, then obtains the catalog through :func:`_get_catalog` (which serves a process-wide cached copy or, on a miss or ``force_refresh``, fetches the live feed over HTTPS via :func:`_fetch_json`). Per-record matching delegates to :func:`_cve_exact` (with :func:`_normalize_cve`), :func:`_substr_match`, :func:`_cwe_match`, :func:`_vuln_date_iso`, and :func:`_field_filter_value`. The only external side effect is the possible network fetch; the ``ctx`` argument is unused. Dispatched by the tool loader as the module's ``run`` handler; not called directly elsewhere. Args: cve_id (str): Exact CVE id (normalized before matching). vendor_project (str): Substring filter on ``vendorProject``. product (str): Substring filter on ``product``. vulnerability_name (str): Substring filter on ``vulnerabilityName``. short_description (str): Substring filter on ``shortDescription``. required_action (str): Substring filter on ``requiredAction``. notes (str): Substring filter on ``notes``. known_ransomware_campaign_use (str): Substring filter on ``knownRansomwareCampaignUse``. cwe (str): CWE token matched against the record's ``cwes`` list. date_added_on_or_after (str): Inclusive lower ``dateAdded`` bound (``YYYY-MM-DD``). date_added_on_or_before (str): Inclusive upper ``dateAdded`` bound. due_on_or_after (str): Inclusive lower ``dueDate`` bound. due_on_or_before (str): Inclusive upper ``dueDate`` bound. field_filters (dict[str, Any] | None): Extra feed-field filters (keys restricted to :data:`_VULN_FIELD_KEYS`). limit (int): Maximum records to return (clamped to ``[1, MAX_LIMIT]``). force_refresh (bool): If ``True``, bypass the cache and refetch the feed. ctx (ToolContext | None): Tool context; accepted but unused. Returns: str: A JSON string with catalog metadata, the total ``match_count``, a ``truncated`` flag, and the (capped) ``vulnerabilities`` list; or a JSON error payload when validation fails or the feed cannot be fetched. """ del ctx # unused err = _validate_field_filters(field_filters) if err: return json.dumps({"success": False, "error": err}) if limit < 1: limit = 1 limit = min(limit, MAX_LIMIT) bounds: dict[str, date | None] = {} for label, field_key, raw in ( ("date_added_on_or_after", "lower_da", date_added_on_or_after), ("date_added_on_or_before", "upper_da", date_added_on_or_before), ("due_on_or_after", "lower_dd", due_on_or_after), ("due_on_or_before", "upper_dd", due_on_or_before), ): if raw and str(raw).strip(): try: bounds[field_key] = _parse_iso_date(label, str(raw)) except ValueError as e: return json.dumps({"success": False, "error": str(e)}) try: payload = await _get_catalog(force_refresh) except Exception as e: logger.exception("CISA KEV fetch failed") return json.dumps( {"success": False, "error": f"Failed to fetch KEV catalog: {e}"} ) vulns = payload.get("vulnerabilities") if not isinstance(vulns, list): return json.dumps( { "success": False, "error": "Unexpected feed shape: missing vulnerabilities array.", } ) matches: list[dict[str, Any]] = [] for v in vulns: if not isinstance(v, dict): continue ok = True if cve_id and str(cve_id).strip(): if not _cve_exact(v, _normalize_cve(str(cve_id))): ok = False if ok and vendor_project and str(vendor_project).strip(): if not _substr_match(v.get("vendorProject"), str(vendor_project).strip()): ok = False if ok and product and str(product).strip(): if not _substr_match(v.get("product"), str(product).strip()): ok = False if ok and vulnerability_name and str(vulnerability_name).strip(): if not _substr_match( v.get("vulnerabilityName"), str(vulnerability_name).strip() ): ok = False if ok and short_description and str(short_description).strip(): if not _substr_match( v.get("shortDescription"), str(short_description).strip() ): ok = False if ok and required_action and str(required_action).strip(): if not _substr_match(v.get("requiredAction"), str(required_action).strip()): ok = False if ok and notes and str(notes).strip(): if not _substr_match(v.get("notes"), str(notes).strip()): ok = False if ( ok and known_ransomware_campaign_use and str(known_ransomware_campaign_use).strip() ): if not _substr_match( v.get("knownRansomwareCampaignUse"), str(known_ransomware_campaign_use).strip(), ): ok = False if ok and cwe and str(cwe).strip(): if not _cwe_match(v, str(cwe).strip()): ok = False if ok and bounds.get("lower_da") is not None: d = _vuln_date_iso("dateAdded", v) if d is None or d < bounds["lower_da"]: # type: ignore[operator] ok = False if ok and bounds.get("upper_da") is not None: d = _vuln_date_iso("dateAdded", v) if d is None or d > bounds["upper_da"]: # type: ignore[operator] ok = False if ok and bounds.get("lower_dd") is not None: d = _vuln_date_iso("dueDate", v) if d is None or d < bounds["lower_dd"]: # type: ignore[operator] ok = False if ok and bounds.get("upper_dd") is not None: d = _vuln_date_iso("dueDate", v) if d is None or d > bounds["upper_dd"]: # type: ignore[operator] ok = False if ok and field_filters: for fk, fv in field_filters.items(): if not _field_filter_value(v, fk, fv): ok = False break if ok: matches.append(v) total_matches = len(matches) truncated = total_matches > limit out_vulns = matches[:limit] return json.dumps( { "success": True, "catalogVersion": payload.get("catalogVersion"), "dateReleased": payload.get("dateReleased"), "total_in_catalog": payload.get("count"), "match_count": total_matches, "limit": limit, "truncated": truncated, "vulnerabilities": out_vulns, }, indent=2, ensure_ascii=False, )