Source code for kg_bulk_runner

"""Shared bulk agentic KG extraction: Redis message collect, chunking, LLM run.

Used by :mod:`scripts.kg_bulk_dump_and_extract` and the scheduled incremental
task in :mod:`background_tasks`.
"""

from __future__ import annotations

import asyncio
import jsonutil as json
import logging
import os
import time as _time
from collections import defaultdict
from dataclasses import dataclass
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Awaitable, Callable, Literal

from config import Config
from kg_agentic_extraction import (
    KgBulkLlmClient,
    create_kg_bulk_gemini_pool_client,
    create_kg_bulk_openrouter_client,
    messages_for_agentic_token_estimate,
    prefetch_speaker_kg_context,
    run_agentic_kg_extraction_chunk,
)
from knowledge_graph import KnowledgeGraphManager
from message_cache import MessageCache
from openrouter_client import OpenRouterClient

logger = logging.getLogger(__name__)

CROSS_CHANNEL_SCOPE = "cross_channel_bulk"

KG_AGENTIC_BULK_LAST_TS_HASH = "stargazer:kg_agentic_bulk:last_ts"

_HASH_MSG_FIELDS = (
    "user_id",
    "user_name",
    "platform",
    "channel_id",
    "text",
    "timestamp",
    "message_id",
    "reply_to_id",
)
_ZSET_BATCH = 5000
_HMGET_BATCH = 400

CursorBootstrap = Literal["full", "latest"]


[docs] def cursor_field(platform: str, channel_id: str) -> str: """Build the per-channel field name used inside the incremental cursor hash. Composes the ``{platform}:{channel_id}`` key under which the last-processed timestamp for a channel is stored in :data:`KG_AGENTIC_BULK_LAST_TS_HASH`. Note that *channel_id* may itself contain colons, so this is a left-anchored composition rather than a reversible encoding. This is called by :func:`cursor_hget`, :func:`cursor_hset`, and :func:`bootstrap_latest_cursor_no_extract` to address the cursor hash, and is exercised directly by ``tests/test_kg_bulk_cursors.py``. Args: platform: Platform identifier (e.g. ``"discord"``, ``"discord-self"``). channel_id: Channel identifier within that platform. Returns: str: The ``"{platform}:{channel_id}"`` hash field name. """ return f"{platform}:{channel_id}"
[docs] def redis_ssl_kwargs_for_bulk( cfg: Config, *, redis_no_verify: bool, ) -> dict[str, Any]: """Build Redis SSL connection kwargs for the bulk KG pipeline. Starts from the project-wide TLS settings produced by :meth:`Config.redis_ssl_kwargs` and, only when *redis_no_verify* is set and the URL uses the ``rediss://`` scheme, downgrades certificate verification so that self-signed or hostname-mismatched broker certificates are accepted. This is intended for managed/ephemeral bulk runs, not the long-lived services. The returned mapping is fed to :class:`message_cache.MessageCache` so its Redis client can connect. This is called from :func:`run_agentic_bulk_pipeline` (building the ``cache2`` client) and from the standalone ``scripts/kg_bulk_dump_and_extract.py`` entry point. Args: cfg: Loaded :class:`config.Config` providing ``redis_url`` and base SSL kwargs. redis_no_verify: When true, disable cert/hostname verification for ``rediss://`` URLs (sets ``ssl_cert_reqs=CERT_NONE`` and ``ssl_check_hostname=False``). Returns: dict[str, Any]: SSL keyword arguments suitable for a redis-py async client. """ kwargs: dict[str, Any] = dict(cfg.redis_ssl_kwargs()) url_l = (cfg.redis_url or "").strip().lower() if redis_no_verify and url_l.startswith("rediss://"): import ssl as _ssl kwargs["ssl_cert_reqs"] = _ssl.CERT_NONE kwargs["ssl_check_hostname"] = False return kwargs
[docs] def format_llm_style_line(m: dict[str, Any]) -> str: """Render one message dict as a single human/LLM-readable transcript line. Produces a deterministic line carrying the UTC timestamp, platform/channel, speaker name and id, message id, optional reply-to id, and text, e.g. ``[<iso>] [platform:discord channel:123] Alice (42) [Message ID: 9] : hi``. Missing fields degrade gracefully (empty strings, ``"?"`` for an unknown name) and an unparseable timestamp falls back to the Unix epoch. These lines are what the bulk extractor feeds to the LLM and what the token packer measures. This is a pure formatter with no I/O. It is called from :func:`run_agentic_bulk_pipeline` to populate the ``llm_style_line`` field of each ``messages.jsonl`` record and to build the flat ``lines`` list (and the per-channel grouped lines) that :func:`chunk_message_lines` packs. Args: m: A message mapping using the fields in :data:`_HASH_MSG_FIELDS` (``timestamp``, ``platform``, ``channel_id``, ``user_id``, ``user_name``, ``message_id``, ``reply_to_id``, ``text``). Returns: str: The formatted single-line representation of the message. """ try: ts_dt = datetime.fromtimestamp( float(m.get("timestamp") or 0), tz=timezone.utc, ).isoformat() except (TypeError, ValueError, OSError): ts_dt = "1970-01-01T00:00:00+00:00" plat = str(m.get("platform", "") or "") ch = str(m.get("channel_id", "") or "") uid = str(m.get("user_id", "") or "") uname = str(m.get("user_name", "") or "?") mid = str(m.get("message_id", "") or "") rid = str(m.get("reply_to_id", "") or "") text = str(m.get("text", "") or "") prefix = ( f"[{ts_dt}] [platform:{plat} channel:{ch}] {uname} ({uid})" f" [Message ID: {mid}]" ) if rid: prefix += f" [Replying to: {rid}]" return prefix + " : " + text
[docs] async def scan_channel_zset_keys(redis: Any) -> list[str]: """Scan Redis for all per-channel message-index sorted-set keys. Iterates the keyspace with ``SCAN MATCH channel_msgs:* COUNT 500``, decodes any bytes keys, and returns them sorted for deterministic processing order. Each matched key is a sorted set mapping message keys to timestamp scores for one ``(platform, channel)``. This is called by :func:`collect_messages_from_redis` when no explicit *channel_pairs* filter is supplied, to discover every channel to ingest. Args: redis: Async redis-py client. Returns: list[str]: Sorted list of ``channel_msgs:{platform}:{channel_id}`` keys. """ keys: list[str] = [] async for key in redis.scan_iter(match="channel_msgs:*", count=500): k = key.decode() if isinstance(key, bytes) else key keys.append(k) keys.sort() return keys
def _parse_zset_key(zset_key: str) -> tuple[str, str] | None: """Split a ``channel_msgs:{platform}:{channel_id}`` key into its parts. Uses a max-two-split so that channel ids containing colons stay intact in the final segment, and returns ``None`` when the key has fewer than the expected three segments (the prefix, platform, and channel id). This is called by :func:`collect_messages_from_redis` to recover the ``(platform, channel_id)`` pair for each scanned zset, and is covered by ``tests/test_kg_bulk_cursors.py``. Args: zset_key: A key of the form ``channel_msgs:{platform}:{channel_id}``. Returns: tuple[str, str] | None: ``(platform, channel_id)``, or ``None`` if the key is malformed. """ parts = zset_key.split(":", 2) if len(parts) < 3: return None return parts[1], parts[2]
[docs] async def cursor_hget(redis: Any, platform: str, channel_id: str) -> float | None: """Read the last-processed timestamp cursor for one channel. Performs ``HGET`` on :data:`KG_AGENTIC_BULK_LAST_TS_HASH` using the field from :func:`cursor_field`, decoding bytes and coercing to ``float``. Returns ``None`` when no cursor exists yet or the stored value cannot be parsed, which the caller treats as "process the full backlog". This is called from :func:`collect_messages_from_redis` during incremental runs to determine the exclusive lower-bound timestamp for each channel. Args: redis: Async redis-py client. platform: Platform identifier for the channel. channel_id: Channel identifier. Returns: float | None: The stored cursor timestamp, or ``None`` if unset/unparsable. """ raw = await redis.hget( KG_AGENTIC_BULK_LAST_TS_HASH, cursor_field(platform, channel_id), ) if raw is None: return None s = raw.decode() if isinstance(raw, bytes) else raw try: return float(s) except (TypeError, ValueError): return None
[docs] async def cursor_hset( redis: Any, platform: str, channel_id: str, ts: float, ) -> None: """Persist the last-processed timestamp cursor for one channel. Writes ``str(ts)`` via ``HSET`` into :data:`KG_AGENTIC_BULK_LAST_TS_HASH` under the :func:`cursor_field` field, advancing the incremental watermark so future runs skip already-extracted messages. This is called from :func:`run_agentic_bulk_pipeline` after a chunk extracts cleanly (no errors and no parse error), once per channel using that chunk's per-channel maximum timestamp. Args: redis: Async redis-py client. platform: Platform identifier for the channel. channel_id: Channel identifier. ts: New cursor timestamp (epoch seconds) to store. """ await redis.hset( KG_AGENTIC_BULK_LAST_TS_HASH, cursor_field(platform, channel_id), str(ts), )
[docs] async def bootstrap_latest_cursor_no_extract( redis: Any, zset_key: str, platform: str, channel_id: str, ) -> None: """Seed a missing incremental cursor to the channel's current high-water mark. Implements the ``cursor_bootstrap="latest"`` behavior: when a channel has no stored cursor yet, this sets it to the newest message timestamp currently in the zset so the first incremental run skips the entire historical backlog and only ever extracts messages that arrive afterward. If a cursor already exists, it is a no-op, preserving any prior watermark. It reads the top zset score with ``ZREVRANGE`` on the ``channel_msgs:{platform}:{channel_id}`` key, then writes the cursor with ``HSET`` into :data:`KG_AGENTIC_BULK_LAST_TS_HASH` under the :func:`cursor_field` field, and logs the bootstrap. It is called by :func:`collect_messages_from_redis` for each channel when an incremental run requests ``cursor_bootstrap="latest"``. Args: redis: Async redis-py client. zset_key: The ``channel_msgs:{platform}:{channel_id}`` sorted-set key. platform: Platform identifier for the channel. channel_id: Channel identifier. """ field = cursor_field(platform, channel_id) existing = await redis.hget(KG_AGENTIC_BULK_LAST_TS_HASH, field) if existing is not None: return top = await redis.zrevrange(zset_key, 0, 0, withscores=True) if not top: return _member, score = top[0] await redis.hset(KG_AGENTIC_BULK_LAST_TS_HASH, field, str(float(score))) logger.info( "KG bulk incremental: bootstrap=latest for %s:%s cursor=%s", platform, channel_id, score, )
[docs] async def fetch_messages_for_zset( redis: Any, zset_key: str, ) -> list[dict[str, Any]]: """Fetch and hydrate every message recorded in one channel's zset. Pages through the entire sorted set in :data:`_ZSET_BATCH`-sized ``ZRANGE`` slices, gathering the member message keys and hydrating each into a full message dict via :func:`_hmget_message_dicts`. Returns the union across all pages (ordering is not guaranteed here; the top-level caller re-sorts). This is called by :func:`fetch_messages_for_zset_after` when there is no lower-bound timestamp, and directly by :func:`collect_messages_from_redis` for non-incremental (full backlog) channels. Args: redis: Async redis-py client. zset_key: The ``channel_msgs:{platform}:{channel_id}`` sorted-set key. Returns: list[dict[str, Any]]: Hydrated message dicts for the whole channel. """ out: list[dict[str, Any]] = [] offset = 0 while True: raw_keys: list[Any] = await redis.zrange( zset_key, offset, offset + _ZSET_BATCH - 1, ) if not raw_keys: break msg_keys = [k.decode() if isinstance(k, bytes) else k for k in raw_keys] out.extend(await _hmget_message_dicts(redis, zset_key, msg_keys)) offset += len(raw_keys) if len(raw_keys) < _ZSET_BATCH: break return out
[docs] async def fetch_messages_for_zset_after( redis: Any, zset_key: str, after_ts_exclusive: float | None, ) -> list[dict[str, Any]]: """Fetch and hydrate channel messages newer than a cursor timestamp. When *after_ts_exclusive* is ``None`` this delegates to :func:`fetch_messages_for_zset` (full backlog). Otherwise it pages through the zset with ``ZRANGEBYSCORE`` using the exclusive lower bound ``(ts`` up to ``+inf`` in :data:`_ZSET_BATCH`-sized windows, hydrating each page through :func:`_hmget_message_dicts`. This is the incremental read path that honors the per-channel cursor. This is called by :func:`collect_messages_from_redis` for incremental runs on channels that already have a stored cursor. Args: redis: Async redis-py client. zset_key: The ``channel_msgs:{platform}:{channel_id}`` sorted-set key. after_ts_exclusive: Exclusive lower-bound timestamp; ``None`` means no bound (return everything). Returns: list[dict[str, Any]]: Hydrated message dicts strictly newer than the bound. """ if after_ts_exclusive is None: return await fetch_messages_for_zset(redis, zset_key) out: list[dict[str, Any]] = [] offset = 0 min_score = f"({after_ts_exclusive}" while True: raw_keys: list[Any] = await redis.zrangebyscore( zset_key, min_score, "+inf", start=offset, num=_ZSET_BATCH, ) if not raw_keys: break msg_keys = [k.decode() if isinstance(k, bytes) else k for k in raw_keys] out.extend(await _hmget_message_dicts(redis, zset_key, msg_keys)) offset += len(raw_keys) if len(raw_keys) < _ZSET_BATCH: break return out
async def _hmget_message_dicts( redis: Any, zset_key: str, msg_keys: list[str], ) -> list[dict[str, Any]]: """Hydrate message hash keys into message dicts via pipelined ``HMGET``. Reads the fields in :data:`_HASH_MSG_FIELDS` for each message key in :data:`_HMGET_BATCH`-sized pipelines, decoding bytes values. Keys whose hash is entirely missing (all values ``None``) are skipped (e.g. expired messages). Each surviving dict gets ``timestamp`` coerced to ``float`` and is annotated with bookkeeping fields ``redis_msg_key`` (its source hash key) and ``zset_key`` (the owning channel index). This is called by both :func:`fetch_messages_for_zset` and :func:`fetch_messages_for_zset_after` to turn zset members into usable message dicts. Args: redis: Async redis-py client. zset_key: The owning channel index key, stamped onto each result dict. msg_keys: Message hash keys to hydrate. Returns: list[dict[str, Any]]: Hydrated message dicts (missing keys omitted). """ out: list[dict[str, Any]] = [] for i in range(0, len(msg_keys), _HMGET_BATCH): batch = msg_keys[i : i + _HMGET_BATCH] pipe = redis.pipeline() for mk in batch: pipe.hmget(mk, *_HASH_MSG_FIELDS) rows = await pipe.execute() for mk, values in zip(batch, rows): if not values or all(v is None for v in values): continue mapping = dict(zip(_HASH_MSG_FIELDS, values)) msg = { k: (v.decode() if isinstance(v, bytes) else v) for k, v in mapping.items() } msg["timestamp"] = float(msg.get("timestamp") or 0) msg["redis_msg_key"] = mk msg["zset_key"] = zset_key out.append(msg) return out
[docs] async def collect_messages_from_redis( redis: Any, *, incremental: bool = False, cursor_bootstrap: CursorBootstrap = "full", channel_pairs: list[tuple[str, str]] | None = None, ) -> tuple[list[dict[str, Any]], int]: """Load messages from Redis channel zsets. *channel_pairs*: if set, only these ``(platform, channel_id)`` zsets; otherwise all ``channel_msgs:*`` keys. When *incremental* is true, uses :data:`KG_AGENTIC_BULK_LAST_TS_HASH` per channel with exclusive lower bound. *cursor_bootstrap* ``latest`` seeds missing cursors to the current zset max without returning backlog messages. """ if channel_pairs is not None: zkeys = [f"channel_msgs:{p}:{c}" for p, c in channel_pairs] n_scanned = len(zkeys) allowed = set(channel_pairs) else: zkeys = await scan_channel_zset_keys(redis) n_scanned = len(zkeys) allowed = None all_msgs: list[dict[str, Any]] = [] for zk in zkeys: parsed = _parse_zset_key(zk) if not parsed: continue platform, channel_id = parsed if allowed is not None and (platform, channel_id) not in allowed: continue after_ts: float | None = None if incremental: if cursor_bootstrap == "latest": await bootstrap_latest_cursor_no_extract( redis, zk, platform, channel_id, ) stored = await cursor_hget(redis, platform, channel_id) after_ts = stored try: if incremental and after_ts is not None: part = await fetch_messages_for_zset_after(redis, zk, after_ts) elif incremental and after_ts is None: part = await fetch_messages_for_zset(redis, zk) else: part = await fetch_messages_for_zset(redis, zk) all_msgs.extend(part) except Exception: logger.warning("Failed zset %s", zk, exc_info=True) all_msgs.sort( key=lambda m: ( m["timestamp"], str(m.get("platform", "")), str(m.get("channel_id", "")), str(m.get("message_id", "")), ), ) return all_msgs, n_scanned
def _unique_channel_pairs(keys: list[tuple[str, str]]) -> list[tuple[str, str]]: """Deduplicate and sort a list of ``(platform, channel_id)`` pairs. This is a pure helper called from :func:`chunk_message_lines` to summarize which channels appear in a candidate or final chunk, used both for the token estimate and to record the chunk's channel set. Args: keys: ``(platform, channel_id)`` pairs, possibly with duplicates. Returns: list[tuple[str, str]]: The distinct pairs in sorted order. """ return sorted(set(keys)) def _unique_speaker_pairs( keys: list[tuple[str, str]], ) -> list[tuple[str, str]]: """Collapse ``(user_id, user_name)`` pairs to one name per distinct speaker. Trims ids and names, drops entries with an empty id, and for each id keeps the first non-empty name seen (falling back to ``"?"``). The result is sorted by user id so chunk speaker sets are stable. This is a pure helper called from :func:`chunk_message_lines` to build the per-chunk speaker roster used for token estimation and for downstream speaker KG prefetch. Args: keys: ``(user_id, user_name)`` pairs, possibly with duplicates or blanks. Returns: list[tuple[str, str]]: Distinct ``(user_id, user_name)`` pairs sorted by id. """ seen: dict[str, str] = {} for uid, name in keys: u = (uid or "").strip() if not u: continue if u not in seen: seen[u] = (name or "").strip() or "?" return sorted(seen.items(), key=lambda x: x[0]) def _subset_channel_metadata( full: dict[str, dict[str, str]] | None, pairs: list[tuple[str, str]], ) -> dict[str, dict[str, str]] | None: """Narrow a channel-metadata map to only the channels in *pairs*. Given the global metadata keyed by ``"{platform}:{channel_id}"``, returns a shallow copy containing just the entries whose channel appears in *pairs*, or ``None`` when the input is empty or nothing matches (so callers can omit the section entirely). This is a pure helper used in :func:`chunk_message_lines` (to size token budgets for a candidate slice) and twice in :func:`run_agentic_bulk_pipeline` (to embed metadata in each chunk artifact and to pass per-chunk metadata into extraction). Args: full: The full metadata map keyed by ``"{platform}:{channel_id}"``, or ``None``. pairs: ``(platform, channel_id)`` pairs to retain. Returns: dict[str, dict[str, str]] | None: The filtered metadata, or ``None`` when empty. """ if not full: return None out: dict[str, dict[str, str]] = {} for plat, cid in pairs: k = f"{plat}:{cid}" if k in full: out[k] = dict(full[k]) return out or None def _max_ts_per_channel( line_channel_keys: list[tuple[str, str]], line_timestamps: list[float], ) -> dict[tuple[str, str], float]: """Compute the maximum timestamp seen per channel across parallel lists. Walks the zipped ``(channel_key, timestamp)`` stream and keeps the largest timestamp for each ``(platform, channel_id)``. The result is the chunk-boundary cursor map: after a chunk extracts cleanly, these per-channel maxima become the new incremental cursors. This is a pure helper called from :func:`chunk_message_lines` to attach a ``cmap`` to each emitted chunk; that map is later consumed by :func:`run_agentic_bulk_pipeline` to call :func:`cursor_hset`. It is also covered by ``tests/test_kg_bulk_cursors.py``. Args: line_channel_keys: Per-line ``(platform, channel_id)`` keys. line_timestamps: Per-line timestamps, aligned with *line_channel_keys*. Returns: dict[tuple[str, str], float]: Max timestamp per ``(platform, channel_id)``. """ acc: dict[tuple[str, str], float] = {} for k, t in zip(line_channel_keys, line_timestamps): if k not in acc or t > acc[k]: acc[k] = t return acc def _speaker_prefetch_placeholder(max_chars: int) -> str: """Build a fixed-size filler string approximating the speaker-KG prefetch. When speaker KG prefetch is enabled, the real prefetch text is only computed per chunk at extraction time, but the chunk packer must reserve room for it up front. This returns a header plus middot padding totaling about *max_chars* characters so token estimation during packing reflects the eventual prompt size. Returns ``""`` when *max_chars* is non-positive. This is a pure helper called once in :func:`run_agentic_bulk_pipeline` to produce the ``speaker_prefetch_placeholder`` passed through to :func:`chunk_message_lines` / :func:`token_count_conversation`. Args: max_chars: Target placeholder size in characters (the prefetch char budget). Returns: str: A placeholder string of roughly *max_chars* characters, or empty. """ if max_chars <= 0: return "" header = ( "## Existing knowledge graph (speakers — prefetch)\n" "_Placeholder sizing for token budget._\n\n" ) pad = max(0, int(max_chars) - len(header)) return header + ("·" * pad)
[docs] async def token_count_conversation( counter: KgBulkLlmClient, conversation_text: str, *, channel_scope: str, reserve: int, config: Config | None = None, chunk_channel_pairs: list[tuple[str, str]] | None = None, chunk_speaker_pairs: list[tuple[str, str]] | None = None, speaker_kg_prefetch: str = "", channel_metadata: dict[str, dict[str, str]] | None = None, ) -> int: """Estimate the input-token cost of a candidate conversation chunk. Reconstructs the full extraction prompt for *conversation_text* via :func:`kg_agentic_extraction.messages_for_agentic_token_estimate` (including channel/speaker pairs, speaker-KG prefetch placeholder, and channel metadata), then asks the bulk LLM client to count its input tokens. If the client cannot count, it falls back to a rough ``len(text) // 3`` heuristic. The configured *reserve* (headroom for tool rounds and output) is always added. This makes a network/SDK call through ``counter.count_input_tokens`` and is called from the binary-search inner loop of :func:`chunk_message_lines` to decide how many lines fit within a chunk's token budget. Args: counter: Bulk LLM client used purely for token counting. conversation_text: The joined transcript lines for the candidate chunk. channel_scope: Scope label for the prompt (a channel id or :data:`CROSS_CHANNEL_SCOPE`). reserve: Token headroom always added to the estimate. config: Optional :class:`config.Config` influencing prompt assembly. chunk_channel_pairs: Channels represented in the chunk. chunk_speaker_pairs: Speakers represented in the chunk. speaker_kg_prefetch: Placeholder/real speaker-KG prefetch text to size. channel_metadata: Per-channel metadata subset for the chunk. Returns: int: Estimated total input tokens including *reserve*. """ msgs = messages_for_agentic_token_estimate( conversation_text, channel_id=channel_scope, chunk_index=0, config=config, chunk_channel_pairs=chunk_channel_pairs, chunk_speaker_pairs=chunk_speaker_pairs, speaker_kg_prefetch=speaker_kg_prefetch, channel_metadata=channel_metadata, ) n = await counter.count_input_tokens(msgs) if n is not None: return int(n) + reserve return max(1, len(conversation_text) // 3) + reserve
[docs] async def chunk_message_lines( lines: list[str], line_channel_keys: list[tuple[str, str]], line_speaker_keys: list[tuple[str, str]], line_timestamps: list[float], counter: KgBulkLlmClient, *, max_tokens: int, channel_scope: str, reserve: int, config: Config | None = None, speaker_prefetch_placeholder: str = "", channel_metadata: dict[str, dict[str, str]] | None = None, ) -> list[ tuple[ list[str], list[tuple[str, str]], list[tuple[str, str]], float, float, dict[tuple[str, str], float], ] ]: """Greedily pack transcript lines into token-budget extraction chunks. Walks the line list in order and, for each chunk, binary-searches the largest contiguous run of lines whose reconstructed extraction prompt still fits under *max_tokens*. Because lines stay contiguous and in timestamp order, each chunk carries clean per-channel time boundaries, which lets the incremental pipeline advance cursors safely only after a chunk extracts without error. Always emits at least one line per chunk so an oversized single line cannot stall the loop. Each candidate slice is priced by :func:`token_count_conversation` (which calls the bulk LLM client's token counter), summarized via :func:`_unique_channel_pairs` and :func:`_unique_speaker_pairs`, and narrowed with :func:`_subset_channel_metadata`; the chosen slice also gets a per-channel max-timestamp map from :func:`_max_ts_per_channel` for cursor advancement. This is called from :func:`run_agentic_bulk_pipeline` (once for the cross-channel path, once per channel for the per-channel path) and is exercised by ``tests/test_kg_bulk_cursors.py``. Args: lines: Formatted transcript lines to pack, in chronological order. line_channel_keys: Per-line ``(platform, channel_id)`` keys, aligned with *lines*. line_speaker_keys: Per-line ``(user_id, user_name)`` keys, aligned with *lines*. line_timestamps: Per-line epoch-second timestamps, aligned with *lines*. counter: Bulk LLM client used only for token counting. max_tokens: Maximum estimated input tokens allowed per chunk. channel_scope: Scope label passed into prompt assembly (a channel id or :data:`CROSS_CHANNEL_SCOPE`). reserve: Token headroom added to every estimate. config: Optional :class:`config.Config` influencing prompt assembly. speaker_prefetch_placeholder: Filler text sizing the speaker-KG prefetch section during estimation. channel_metadata: Full per-channel metadata map, subset per candidate. Returns: A list of chunk tuples, each ``(lines, channel_pairs, speaker_pairs, ts_lo, ts_hi, cursor_map)``: the chunk's lines, its distinct channel and speaker pairs, its low/high timestamps, and the per-channel max-timestamp cursor map. Raises: ValueError: If *line_channel_keys*, *line_speaker_keys*, or *line_timestamps* do not all match the length of *lines*. """ chunks: list[ tuple[ list[str], list[tuple[str, str]], list[tuple[str, str]], float, float, dict[tuple[str, str], float], ] ] = [] i = 0 nlines = len(lines) if len(line_channel_keys) != nlines or len(line_speaker_keys) != nlines: raise ValueError("line_*_keys length must match lines") if len(line_timestamps) != nlines: raise ValueError("line_timestamps length must match lines") while i < nlines: low, high = 1, nlines - i best = 0 while low <= high: mid = (low + high) // 2 block = lines[i : i + mid] text = "\n".join(block) slice_keys = line_channel_keys[i : i + mid] slice_spk = line_speaker_keys[i : i + mid] pairs = _unique_channel_pairs(slice_keys) spairs = _unique_speaker_pairs(slice_spk) meta_slice = _subset_channel_metadata(channel_metadata, pairs) tc = await token_count_conversation( counter, text, channel_scope=channel_scope, reserve=reserve, config=config, chunk_channel_pairs=pairs, chunk_speaker_pairs=spairs, speaker_kg_prefetch=speaker_prefetch_placeholder, channel_metadata=meta_slice, ) if tc <= max_tokens: best = mid low = mid + 1 else: high = mid - 1 if best == 0: best = 1 slice_keys = line_channel_keys[i : i + best] slice_spk = line_speaker_keys[i : i + best] slice_ts = line_timestamps[i : i + best] ts_lo = min(slice_ts) if slice_ts else 0.0 ts_hi = max(slice_ts) if slice_ts else 0.0 cmap = _max_ts_per_channel(slice_keys, slice_ts) chunks.append( ( lines[i : i + best], _unique_channel_pairs(slice_keys), _unique_speaker_pairs(slice_spk), ts_lo, ts_hi, cmap, ), ) i += best return chunks
[docs] @dataclass class KgBulkPipelineParams: """Bundle of tuning knobs for one bulk agentic KG extraction run. Carries everything :func:`run_agentic_bulk_pipeline` needs that is not the runtime state (config, redis, messages): the output directory, dry-run and dump-only switches, token budgets (chunk size and reserve), chunking mode (per-channel vs cross-channel), resume/limit bounds, LLM backend selection, channel-metadata fetching, speaker-KG prefetch sizing, and the incremental cursor flag. Instances are constructed by ``scripts/kg_bulk_dump_and_extract.py`` (from parsed CLI args) and by ``background_tasks.py`` (for the scheduled incremental run), then passed straight into :func:`run_agentic_bulk_pipeline`. """ out_dir: Path dump_only: bool = False dry_run_chunks: bool = False dry_run_llm: bool = False chunk_tokens: int = 250_000 token_reserve: int = 100_000 max_messages: int = 0 chunks_max: int = 0 per_channel: bool = False resume_from_chunk: int = 0 max_tool_rounds: int = 48 bulk_llm_backend: str = "gemini" fetch_channel_metadata: bool = False channel_metadata_ttl_days: float = 7.0 discord_platform_type: str = "discord" prefetch_speaker_kg: bool = False prefetch_max_speakers: int = 8 prefetch_hits_per_speaker: int = 3 prefetch_max_chars: int = 400_000 prefetch_min_score: float = 0.0 redis_no_verify: bool = True incremental: bool = False
[docs] async def run_agentic_bulk_pipeline( cfg: Config, redis: Any, messages: list[dict[str, Any]], n_zsets_scanned: int, params: KgBulkPipelineParams, *, load_channel_metadata: ( Callable[ [], Awaitable[dict[str, dict[str, str]]], ] | None ) = None, ) -> None: """Chunk (optional disk artifacts), run agentic KG extraction, update cursors. *load_channel_metadata*: optional async callable returning Discord (or other) channel metadata dict; only invoked when *params.fetch_channel_metadata*. After each successful extraction chunk, updates incremental cursors per channel when *params.incremental* is true. """ out_dir = Path(params.out_dir).resolve() out_dir.mkdir(parents=True, exist_ok=True) _t0 = _time.monotonic() msgs = messages if params.max_messages and params.max_messages > 0: msgs = msgs[: params.max_messages] jsonl_path = out_dir / "messages.jsonl" with jsonl_path.open("w", encoding="utf-8") as fj: for m in msgs: rec = {k: v for k, v in m.items() if k != "embedding"} rec["llm_style_line"] = format_llm_style_line(rec) fj.write(json.dumps(rec, ensure_ascii=False, default=str) + "\n") lines = [format_llm_style_line(m) for m in msgs] line_channel_keys = [ (str(m.get("platform", "")), str(m.get("channel_id", ""))) for m in msgs ] line_speaker_keys = [ (str(m.get("user_id", "")), str(m.get("user_name", ""))) for m in msgs ] line_timestamps = [float(m.get("timestamp") or 0) for m in msgs] platforms_channels = sorted( {f"{m.get('platform')}:{m.get('channel_id')}" for m in msgs}, ) zsets_in_msgs = len( {str(m.get("zset_key", "") or "") for m in msgs if m.get("zset_key")}, ) meta = { "message_count": len(msgs), "zset_count": zsets_in_msgs, "zset_count_redis_scanned": n_zsets_scanned, "time_min": min((m["timestamp"] for m in msgs), default=0), "time_max": max((m["timestamp"] for m in msgs), default=0), "platforms_channels": platforms_channels, } (out_dir / "manifest.json").write_text( json.dumps(meta, indent=2, default=str), encoding="utf-8", ) if params.dump_only: logger.info("Dump-only: wrote %s and manifest.", jsonl_path) try: from observability import publish_debug_event asyncio.create_task( publish_debug_event( "kg_bulk_incremental", "kg_bulk_runner", phase="pipeline", status="dump_only", duration_ms=(_time.monotonic() - _t0) * 1000, preview=f"channels={len(platforms_channels)} messages={len(msgs)} out_dir={out_dir}", payload={ "channels_processed": len(platforms_channels), "messages_processed": len(msgs), "n_zsets": n_zsets_scanned, "out_dir": str(out_dir), }, ), name="obs_kg_bulk", ) except Exception: pass return global_channel_metadata: dict[str, dict[str, str]] = {} if params.fetch_channel_metadata and load_channel_metadata is not None: global_channel_metadata = await load_channel_metadata() bulk_backend = (params.bulk_llm_backend or "gemini").strip().lower() if bulk_backend == "gemini": token_counter = create_kg_bulk_gemini_pool_client(max_tool_rounds=1) else: token_counter = create_kg_bulk_openrouter_client( cfg.api_key, gemini_api_key=cfg.gemini_api_key or "", max_tool_rounds=1, top_p=cfg.top_p, ) scope = CROSS_CHANNEL_SCOPE reserve = int(params.token_reserve) if params.prefetch_speaker_kg: reserve += max(0, int(params.prefetch_max_chars) // 3) if params.fetch_channel_metadata: n_disc = sum( 1 for x in platforms_channels if x.startswith("discord:") or x.startswith("discord-self:") ) reserve += min(120_000, 2_000 * max(1, n_disc)) max_tok = int(params.chunk_tokens) prefetch_placeholder = "" if params.prefetch_speaker_kg: prefetch_placeholder = _speaker_prefetch_placeholder( int(params.prefetch_max_chars), ) chunk_groups: list[ tuple[ str, list[str], list[tuple[str, str]], list[tuple[str, str]], float, float, dict[tuple[str, str], float], ] ] = [] if params.per_channel: grouped: dict[ tuple[str, str], list[ tuple[ str, tuple[str, str], tuple[str, str], float, ] ], ] = defaultdict(list) for m in msgs: k = (str(m.get("platform", "")), str(m.get("channel_id", ""))) sp = (str(m.get("user_id", "")), str(m.get("user_name", ""))) grouped[k].append( ( format_llm_style_line(m), k, sp, float(m.get("timestamp") or 0), ), ) for (plat, cid), triples in sorted(grouped.items()): ch_scope = f"{plat}:{cid}" if plat or cid else CROSS_CHANNEL_SCOPE ls = [t[0] for t in triples] lck = [t[1] for t in triples] lsk = [t[2] for t in triples] lts = [t[3] for t in triples] subchunks = await chunk_message_lines( ls, lck, lsk, lts, token_counter, max_tokens=max_tok, channel_scope=ch_scope, reserve=reserve, config=cfg, speaker_prefetch_placeholder=prefetch_placeholder, channel_metadata=global_channel_metadata or None, ) for bl, bp, spk, ts_lo, ts_hi, cmap in subchunks: chunk_groups.append((ch_scope, bl, bp, spk, ts_lo, ts_hi, cmap)) else: subchunks = await chunk_message_lines( lines, line_channel_keys, line_speaker_keys, line_timestamps, token_counter, max_tokens=max_tok, channel_scope=scope, reserve=reserve, config=cfg, speaker_prefetch_placeholder=prefetch_placeholder, channel_metadata=global_channel_metadata or None, ) for bl, bp, spk, ts_lo, ts_hi, cmap in subchunks: chunk_groups.append((scope, bl, bp, spk, ts_lo, ts_hi, cmap)) cm = int(params.chunks_max or 0) if cm > 0: chunk_groups = chunk_groups[:cm] logger.info( "--chunks-max=%d: using first %d chunk(s)", cm, len(chunk_groups), ) chunks_dir = out_dir / "chunks" chunks_dir.mkdir(exist_ok=True) for idx, ( ch_scope, chunk_lines, chunk_pairs, chunk_spk, ts_lo, ts_hi, _cmap, ) in enumerate(chunk_groups): text = "\n".join(chunk_lines) t0 = datetime.now(timezone.utc).isoformat() t_start = ( datetime.fromtimestamp(ts_lo, tz=timezone.utc).isoformat() if ts_lo > 0 else "" ) t_end = ( datetime.fromtimestamp(ts_hi, tz=timezone.utc).isoformat() if ts_hi > 0 else "" ) meta_rows = _subset_channel_metadata( global_channel_metadata or None, list(chunk_pairs), ) piece = { "index": idx, "channel_scope": ch_scope, "channels_in_chunk": [list(p) for p in chunk_pairs], "speakers_in_chunk": [list(p) for p in chunk_spk], "channel_metadata": meta_rows if meta_rows is not None else {}, "line_count": len(chunk_lines), "char_count": len(text), "written_at": t0, "time_range_hint_utc": [t_start, t_end], "conversation_text": text, } (chunks_dir / f"chunk_{idx:05d}.json").write_text( json.dumps(piece, ensure_ascii=False, indent=2), encoding="utf-8", ) if params.dry_run_chunks: logger.info( "Dry-run: wrote %d chunk files; skipping LLM.", len(chunk_groups), ) await token_counter.close() return if params.dry_run_llm: logger.info( "Dry-run LLM: running agentic extraction without persisting to FalkorDB.", ) or_for_kg = OpenRouterClient( api_key=cfg.api_key, model=cfg.model, base_url=cfg.llm_base_url, gemini_api_key=cfg.gemini_api_key or "", top_p=cfg.top_p, ) cache2 = MessageCache( redis_url=cfg.redis_url, openrouter_client=or_for_kg, embedding_model=cfg.embedding_model, ssl_kwargs=redis_ssl_kwargs_for_bulk( cfg, redis_no_verify=params.redis_no_verify, ), ) kg = KnowledgeGraphManager( redis_client=cache2.redis_client, openrouter=or_for_kg, embedding_model=cfg.embedding_model, admin_user_ids=set(cfg.admin_user_ids) if cfg.admin_user_ids else None, ) await kg.ensure_indexes() if bulk_backend == "gemini": bulk_client = create_kg_bulk_gemini_pool_client( max_tool_rounds=params.max_tool_rounds, ) else: bulk_client = create_kg_bulk_openrouter_client( cfg.api_key, gemini_api_key=cfg.gemini_api_key or "", max_tool_rounds=params.max_tool_rounds, top_p=cfg.top_p, ) run_log: list[dict[str, Any]] = [] start_idx = max(0, int(params.resume_from_chunk)) pc_hint = ", ".join(platforms_channels) try: for idx, ( ch_scope, chunk_lines, chunk_pairs, chunk_spk, ts_lo, ts_hi, cmap, ) in enumerate(chunk_groups): if idx < start_idx: continue text = "\n".join(chunk_lines) t_start = ( datetime.fromtimestamp(ts_lo, tz=timezone.utc).isoformat() if ts_lo > 0 else "" ) t_end = ( datetime.fromtimestamp(ts_hi, tz=timezone.utc).isoformat() if ts_hi > 0 else "" ) meta_slice = _subset_channel_metadata( global_channel_metadata or None, list(chunk_pairs), ) prefetch_text = "" if params.prefetch_speaker_kg: prefetch_text = await prefetch_speaker_kg_context( kg, list(chunk_spk), max_speakers=int(params.prefetch_max_speakers), hits_per_speaker=int(params.prefetch_hits_per_speaker), min_score=float(params.prefetch_min_score), ) stats = await run_agentic_kg_extraction_chunk( conversation_text=text, channel_id=ch_scope, kg_manager=kg, bulk_client=bulk_client, user_id="000000000000", chunk_index=idx, time_start_iso=t_start, time_end_iso=t_end, platforms_channels=pc_hint, config=cfg, chunk_channel_pairs=list(chunk_pairs), chunk_speaker_pairs=list(chunk_spk), speaker_kg_prefetch=prefetch_text, channel_metadata=meta_slice, persist_extraction=not params.dry_run_llm, ) stats["chunk_index"] = idx stats["channel_scope"] = ch_scope run_log.append(stats) logger.info("Chunk %d scope=%s stats=%s", idx, ch_scope, stats) err = int(stats.get("errors", 0) or 0) parse_err = bool(stats.get("parse_error")) if ( params.incremental and not params.dry_run_llm and err == 0 and not parse_err and cmap ): for (plat, cid), tmax in cmap.items(): if tmax <= 0: continue await cursor_hset(redis, plat, cid, tmax) (out_dir / "extraction_run.json").write_text( json.dumps(run_log, indent=2, default=str), encoding="utf-8", ) finally: await bulk_client.close() await cache2.close() await token_counter.close() try: from observability import publish_debug_event _channels_done = len( {e.get("channel_scope", "") for e in run_log if isinstance(e, dict)} ) asyncio.create_task( publish_debug_event( "kg_bulk_incremental", "kg_bulk_runner", phase="pipeline", status="ok", duration_ms=(_time.monotonic() - _t0) * 1000, preview=f"channels={_channels_done} messages={len(msgs)} out_dir={out_dir}", payload={ "channels_processed": _channels_done, "messages_processed": len(msgs), "n_zsets": n_zsets_scanned, "out_dir": str(out_dir), "chunks_run": len(run_log), }, ), name="obs_kg_bulk", ) except Exception: pass
[docs] def resolve_bulk_backend(cli_value: str | None) -> str: """Resolve which LLM backend the bulk pipeline should use. Applies the precedence CLI argument, then the ``KG_BULK_LLM_BACKEND`` environment variable, then the ``"gemini"`` default, normalizing the result to lowercase. The chosen value selects between the Gemini pool and OpenRouter clients in :func:`run_agentic_bulk_pipeline`. This reads ``os.environ`` and is called by ``scripts/kg_bulk_dump_and_extract.py`` (passing the parsed ``--bulk-llm-backend`` arg) and by ``background_tasks.py`` (passing ``None``, so it falls back to env/default for the scheduled incremental run). Args: cli_value: Explicit backend from the CLI, or ``None`` to defer to env/default. Returns: str: The lowercased backend name (e.g. ``"gemini"`` or ``"openrouter"``). """ v = cli_value or os.environ.get("KG_BULK_LLM_BACKEND", "") or "gemini" return v.strip().lower()