Source code for kg_bulk_runner

"""Shared bulk agentic KG extraction: Redis message collect, chunking, LLM run.

Used by :mod:`scripts.kg_bulk_dump_and_extract` and the scheduled incremental
task in :mod:`background_tasks`.
"""

from __future__ import annotations

import json
import logging
import os
from collections import defaultdict
from dataclasses import dataclass
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Awaitable, Callable, Literal

from config import Config
from kg_agentic_extraction import (
    KgBulkLlmClient,
    create_kg_bulk_gemini_pool_client,
    create_kg_bulk_openrouter_client,
    messages_for_agentic_token_estimate,
    prefetch_speaker_kg_context,
    run_agentic_kg_extraction_chunk,
)
from knowledge_graph import KnowledgeGraphManager
from message_cache import MessageCache
from openrouter_client import OpenRouterClient

logger = logging.getLogger(__name__)

CROSS_CHANNEL_SCOPE = "cross_channel_bulk"

KG_AGENTIC_BULK_LAST_TS_HASH = "stargazer:kg_agentic_bulk:last_ts"

_HASH_MSG_FIELDS = (
    "user_id",
    "user_name",
    "platform",
    "channel_id",
    "text",
    "timestamp",
    "message_id",
    "reply_to_id",
)
_ZSET_BATCH = 5000
_HMGET_BATCH = 400

CursorBootstrap = Literal["full", "latest"]


[docs] def cursor_field(platform: str, channel_id: str) -> str: return f"{platform}:{channel_id}"
[docs] def redis_ssl_kwargs_for_bulk( cfg: Config, *, redis_no_verify: bool, ) -> dict[str, Any]: kwargs: dict[str, Any] = dict(cfg.redis_ssl_kwargs()) url_l = (cfg.redis_url or "").strip().lower() if redis_no_verify and url_l.startswith("rediss://"): import ssl as _ssl kwargs["ssl_cert_reqs"] = _ssl.CERT_NONE kwargs["ssl_check_hostname"] = False return kwargs
[docs] def format_llm_style_line(m: dict[str, Any]) -> str: try: ts_dt = datetime.fromtimestamp( float(m.get("timestamp") or 0), tz=timezone.utc, ).isoformat() except (TypeError, ValueError, OSError): ts_dt = "1970-01-01T00:00:00+00:00" plat = str(m.get("platform", "") or "") ch = str(m.get("channel_id", "") or "") uid = str(m.get("user_id", "") or "") uname = str(m.get("user_name", "") or "?") mid = str(m.get("message_id", "") or "") rid = str(m.get("reply_to_id", "") or "") text = str(m.get("text", "") or "") prefix = ( f"[{ts_dt}] [platform:{plat} channel:{ch}] {uname} ({uid})" f" [Message ID: {mid}]" ) if rid: prefix += f" [Replying to: {rid}]" return prefix + " : " + text
[docs] async def scan_channel_zset_keys(redis: Any) -> list[str]: keys: list[str] = [] async for key in redis.scan_iter(match="channel_msgs:*", count=500): k = key.decode() if isinstance(key, bytes) else key keys.append(k) keys.sort() return keys
def _parse_zset_key(zset_key: str) -> tuple[str, str] | None: parts = zset_key.split(":", 2) if len(parts) < 3: return None return parts[1], parts[2]
[docs] async def cursor_hget(redis: Any, platform: str, channel_id: str) -> float | None: raw = await redis.hget( KG_AGENTIC_BULK_LAST_TS_HASH, cursor_field(platform, channel_id), ) if raw is None: return None s = raw.decode() if isinstance(raw, bytes) else raw try: return float(s) except (TypeError, ValueError): return None
[docs] async def cursor_hset( redis: Any, platform: str, channel_id: str, ts: float, ) -> None: await redis.hset( KG_AGENTIC_BULK_LAST_TS_HASH, cursor_field(platform, channel_id), str(ts), )
[docs] async def bootstrap_latest_cursor_no_extract( redis: Any, zset_key: str, platform: str, channel_id: str, ) -> None: """If no cursor, set it to current max zset score and skip backlog.""" field = cursor_field(platform, channel_id) existing = await redis.hget(KG_AGENTIC_BULK_LAST_TS_HASH, field) if existing is not None: return top = await redis.zrevrange(zset_key, 0, 0, withscores=True) if not top: return _member, score = top[0] await redis.hset(KG_AGENTIC_BULK_LAST_TS_HASH, field, str(float(score))) logger.info( "KG bulk incremental: bootstrap=latest for %s:%s cursor=%s", platform, channel_id, score, )
[docs] async def fetch_messages_for_zset( redis: Any, zset_key: str, ) -> list[dict[str, Any]]: out: list[dict[str, Any]] = [] offset = 0 while True: raw_keys: list[Any] = await redis.zrange( zset_key, offset, offset + _ZSET_BATCH - 1, ) if not raw_keys: break msg_keys = [k.decode() if isinstance(k, bytes) else k for k in raw_keys] out.extend(await _hmget_message_dicts(redis, zset_key, msg_keys)) offset += len(raw_keys) if len(raw_keys) < _ZSET_BATCH: break return out
[docs] async def fetch_messages_for_zset_after( redis: Any, zset_key: str, after_ts_exclusive: float | None, ) -> list[dict[str, Any]]: if after_ts_exclusive is None: return await fetch_messages_for_zset(redis, zset_key) out: list[dict[str, Any]] = [] offset = 0 min_score = f"({after_ts_exclusive}" while True: raw_keys: list[Any] = await redis.zrangebyscore( zset_key, min_score, "+inf", start=offset, num=_ZSET_BATCH, ) if not raw_keys: break msg_keys = [k.decode() if isinstance(k, bytes) else k for k in raw_keys] out.extend(await _hmget_message_dicts(redis, zset_key, msg_keys)) offset += len(raw_keys) if len(raw_keys) < _ZSET_BATCH: break return out
async def _hmget_message_dicts( redis: Any, zset_key: str, msg_keys: list[str], ) -> list[dict[str, Any]]: out: list[dict[str, Any]] = [] for i in range(0, len(msg_keys), _HMGET_BATCH): batch = msg_keys[i: i + _HMGET_BATCH] pipe = redis.pipeline() for mk in batch: pipe.hmget(mk, *_HASH_MSG_FIELDS) rows = await pipe.execute() for mk, values in zip(batch, rows): if not values or all(v is None for v in values): continue mapping = dict(zip(_HASH_MSG_FIELDS, values)) msg = { k: (v.decode() if isinstance(v, bytes) else v) for k, v in mapping.items() } msg["timestamp"] = float(msg.get("timestamp") or 0) msg["redis_msg_key"] = mk msg["zset_key"] = zset_key out.append(msg) return out
[docs] async def collect_messages_from_redis( redis: Any, *, incremental: bool = False, cursor_bootstrap: CursorBootstrap = "full", channel_pairs: list[tuple[str, str]] | None = None, ) -> tuple[list[dict[str, Any]], int]: """Load messages from Redis channel zsets. *channel_pairs*: if set, only these ``(platform, channel_id)`` zsets; otherwise all ``channel_msgs:*`` keys. When *incremental* is true, uses :data:`KG_AGENTIC_BULK_LAST_TS_HASH` per channel with exclusive lower bound. *cursor_bootstrap* ``latest`` seeds missing cursors to the current zset max without returning backlog messages. """ if channel_pairs is not None: zkeys = [f"channel_msgs:{p}:{c}" for p, c in channel_pairs] n_scanned = len(zkeys) allowed = set(channel_pairs) else: zkeys = await scan_channel_zset_keys(redis) n_scanned = len(zkeys) allowed = None all_msgs: list[dict[str, Any]] = [] for zk in zkeys: parsed = _parse_zset_key(zk) if not parsed: continue platform, channel_id = parsed if allowed is not None and (platform, channel_id) not in allowed: continue after_ts: float | None = None if incremental: if cursor_bootstrap == "latest": await bootstrap_latest_cursor_no_extract( redis, zk, platform, channel_id, ) stored = await cursor_hget(redis, platform, channel_id) after_ts = stored try: if incremental and after_ts is not None: part = await fetch_messages_for_zset_after(redis, zk, after_ts) elif incremental and after_ts is None: part = await fetch_messages_for_zset(redis, zk) else: part = await fetch_messages_for_zset(redis, zk) all_msgs.extend(part) except Exception: logger.warning("Failed zset %s", zk, exc_info=True) all_msgs.sort( key=lambda m: ( m["timestamp"], str(m.get("platform", "")), str(m.get("channel_id", "")), str(m.get("message_id", "")), ), ) return all_msgs, n_scanned
def _unique_channel_pairs(keys: list[tuple[str, str]]) -> list[tuple[str, str]]: return sorted(set(keys)) def _unique_speaker_pairs( keys: list[tuple[str, str]], ) -> list[tuple[str, str]]: seen: dict[str, str] = {} for uid, name in keys: u = (uid or "").strip() if not u: continue if u not in seen: seen[u] = (name or "").strip() or "?" return sorted(seen.items(), key=lambda x: x[0]) def _subset_channel_metadata( full: dict[str, dict[str, str]] | None, pairs: list[tuple[str, str]], ) -> dict[str, dict[str, str]] | None: if not full: return None out: dict[str, dict[str, str]] = {} for plat, cid in pairs: k = f"{plat}:{cid}" if k in full: out[k] = dict(full[k]) return out or None def _max_ts_per_channel( line_channel_keys: list[tuple[str, str]], line_timestamps: list[float], ) -> dict[tuple[str, str], float]: acc: dict[tuple[str, str], float] = {} for k, t in zip(line_channel_keys, line_timestamps): if k not in acc or t > acc[k]: acc[k] = t return acc def _speaker_prefetch_placeholder(max_chars: int) -> str: if max_chars <= 0: return "" header = ( "## Existing knowledge graph (speakers — prefetch)\n" "_Placeholder sizing for token budget._\n\n" ) pad = max(0, int(max_chars) - len(header)) return header + ("·" * pad)
[docs] async def token_count_conversation( counter: KgBulkLlmClient, conversation_text: str, *, channel_scope: str, reserve: int, config: Config | None = None, chunk_channel_pairs: list[tuple[str, str]] | None = None, chunk_speaker_pairs: list[tuple[str, str]] | None = None, speaker_kg_prefetch: str = "", channel_metadata: dict[str, dict[str, str]] | None = None, ) -> int: msgs = messages_for_agentic_token_estimate( conversation_text, channel_id=channel_scope, chunk_index=0, config=config, chunk_channel_pairs=chunk_channel_pairs, chunk_speaker_pairs=chunk_speaker_pairs, speaker_kg_prefetch=speaker_kg_prefetch, channel_metadata=channel_metadata, ) n = await counter.count_input_tokens(msgs) if n is not None: return int(n) + reserve return max(1, len(conversation_text) // 3) + reserve
[docs] async def chunk_message_lines( lines: list[str], line_channel_keys: list[tuple[str, str]], line_speaker_keys: list[tuple[str, str]], line_timestamps: list[float], counter: KgBulkLlmClient, *, max_tokens: int, channel_scope: str, reserve: int, config: Config | None = None, speaker_prefetch_placeholder: str = "", channel_metadata: dict[str, dict[str, str]] | None = None, ) -> list[ tuple[ list[str], list[tuple[str, str]], list[tuple[str, str]], float, float, dict[tuple[str, str], float], ] ]: """Pack lines into token-budget chunks with per-chunk time + cursor hints.""" chunks: list[ tuple[ list[str], list[tuple[str, str]], list[tuple[str, str]], float, float, dict[tuple[str, str], float], ] ] = [] i = 0 nlines = len(lines) if len(line_channel_keys) != nlines or len(line_speaker_keys) != nlines: raise ValueError("line_*_keys length must match lines") if len(line_timestamps) != nlines: raise ValueError("line_timestamps length must match lines") while i < nlines: low, high = 1, nlines - i best = 0 while low <= high: mid = (low + high) // 2 block = lines[i: i + mid] text = "\n".join(block) slice_keys = line_channel_keys[i: i + mid] slice_spk = line_speaker_keys[i: i + mid] pairs = _unique_channel_pairs(slice_keys) spairs = _unique_speaker_pairs(slice_spk) meta_slice = _subset_channel_metadata(channel_metadata, pairs) tc = await token_count_conversation( counter, text, channel_scope=channel_scope, reserve=reserve, config=config, chunk_channel_pairs=pairs, chunk_speaker_pairs=spairs, speaker_kg_prefetch=speaker_prefetch_placeholder, channel_metadata=meta_slice, ) if tc <= max_tokens: best = mid low = mid + 1 else: high = mid - 1 if best == 0: best = 1 slice_keys = line_channel_keys[i: i + best] slice_spk = line_speaker_keys[i: i + best] slice_ts = line_timestamps[i: i + best] ts_lo = min(slice_ts) if slice_ts else 0.0 ts_hi = max(slice_ts) if slice_ts else 0.0 cmap = _max_ts_per_channel(slice_keys, slice_ts) chunks.append( ( lines[i: i + best], _unique_channel_pairs(slice_keys), _unique_speaker_pairs(slice_spk), ts_lo, ts_hi, cmap, ), ) i += best return chunks
[docs] @dataclass class KgBulkPipelineParams: out_dir: Path dump_only: bool = False dry_run_chunks: bool = False dry_run_llm: bool = False chunk_tokens: int = 250_000 token_reserve: int = 100_000 max_messages: int = 0 chunks_max: int = 0 per_channel: bool = False resume_from_chunk: int = 0 max_tool_rounds: int = 48 bulk_llm_backend: str = "gemini" fetch_channel_metadata: bool = False channel_metadata_ttl_days: float = 7.0 discord_platform_type: str = "discord" prefetch_speaker_kg: bool = False prefetch_max_speakers: int = 8 prefetch_hits_per_speaker: int = 3 prefetch_max_chars: int = 400_000 prefetch_min_score: float = 0.0 redis_no_verify: bool = True incremental: bool = False
[docs] async def run_agentic_bulk_pipeline( cfg: Config, redis: Any, messages: list[dict[str, Any]], n_zsets_scanned: int, params: KgBulkPipelineParams, *, load_channel_metadata: Callable[ [], Awaitable[dict[str, dict[str, str]]], ] | None = None, ) -> None: """Chunk (optional disk artifacts), run agentic KG extraction, update cursors. *load_channel_metadata*: optional async callable returning Discord (or other) channel metadata dict; only invoked when *params.fetch_channel_metadata*. After each successful extraction chunk, updates incremental cursors per channel when *params.incremental* is true. """ out_dir = Path(params.out_dir).resolve() out_dir.mkdir(parents=True, exist_ok=True) msgs = messages if params.max_messages and params.max_messages > 0: msgs = msgs[: params.max_messages] jsonl_path = out_dir / "messages.jsonl" with jsonl_path.open("w", encoding="utf-8") as fj: for m in msgs: rec = {k: v for k, v in m.items() if k != "embedding"} rec["llm_style_line"] = format_llm_style_line(rec) fj.write(json.dumps(rec, ensure_ascii=False, default=str) + "\n") lines = [format_llm_style_line(m) for m in msgs] line_channel_keys = [ (str(m.get("platform", "")), str(m.get("channel_id", ""))) for m in msgs ] line_speaker_keys = [ (str(m.get("user_id", "")), str(m.get("user_name", ""))) for m in msgs ] line_timestamps = [float(m.get("timestamp") or 0) for m in msgs] platforms_channels = sorted( {f"{m.get('platform')}:{m.get('channel_id')}" for m in msgs}, ) zsets_in_msgs = len( { str(m.get("zset_key", "") or "") for m in msgs if m.get("zset_key") }, ) meta = { "message_count": len(msgs), "zset_count": zsets_in_msgs, "zset_count_redis_scanned": n_zsets_scanned, "time_min": min((m["timestamp"] for m in msgs), default=0), "time_max": max((m["timestamp"] for m in msgs), default=0), "platforms_channels": platforms_channels, } (out_dir / "manifest.json").write_text( json.dumps(meta, indent=2, default=str), encoding="utf-8", ) if params.dump_only: logger.info("Dump-only: wrote %s and manifest.", jsonl_path) return global_channel_metadata: dict[str, dict[str, str]] = {} if params.fetch_channel_metadata and load_channel_metadata is not None: global_channel_metadata = await load_channel_metadata() bulk_backend = (params.bulk_llm_backend or "gemini").strip().lower() if bulk_backend == "gemini": token_counter = create_kg_bulk_gemini_pool_client(max_tool_rounds=1) else: token_counter = create_kg_bulk_openrouter_client( cfg.api_key, gemini_api_key=cfg.gemini_api_key or "", max_tool_rounds=1, ) scope = CROSS_CHANNEL_SCOPE reserve = int(params.token_reserve) if params.prefetch_speaker_kg: reserve += max(0, int(params.prefetch_max_chars) // 3) if params.fetch_channel_metadata: n_disc = sum( 1 for x in platforms_channels if x.startswith("discord:") or x.startswith("discord-self:") ) reserve += min(120_000, 2_000 * max(1, n_disc)) max_tok = int(params.chunk_tokens) prefetch_placeholder = "" if params.prefetch_speaker_kg: prefetch_placeholder = _speaker_prefetch_placeholder( int(params.prefetch_max_chars), ) chunk_groups: list[ tuple[ str, list[str], list[tuple[str, str]], list[tuple[str, str]], float, float, dict[tuple[str, str], float], ] ] = [] if params.per_channel: grouped: dict[ tuple[str, str], list[ tuple[ str, tuple[str, str], tuple[str, str], float, ] ], ] = defaultdict(list) for m in msgs: k = (str(m.get("platform", "")), str(m.get("channel_id", ""))) sp = (str(m.get("user_id", "")), str(m.get("user_name", ""))) grouped[k].append( ( format_llm_style_line(m), k, sp, float(m.get("timestamp") or 0), ), ) for (plat, cid), triples in sorted(grouped.items()): ch_scope = f"{plat}:{cid}" if plat or cid else CROSS_CHANNEL_SCOPE ls = [t[0] for t in triples] lck = [t[1] for t in triples] lsk = [t[2] for t in triples] lts = [t[3] for t in triples] subchunks = await chunk_message_lines( ls, lck, lsk, lts, token_counter, max_tokens=max_tok, channel_scope=ch_scope, reserve=reserve, config=cfg, speaker_prefetch_placeholder=prefetch_placeholder, channel_metadata=global_channel_metadata or None, ) for bl, bp, spk, ts_lo, ts_hi, cmap in subchunks: chunk_groups.append((ch_scope, bl, bp, spk, ts_lo, ts_hi, cmap)) else: subchunks = await chunk_message_lines( lines, line_channel_keys, line_speaker_keys, line_timestamps, token_counter, max_tokens=max_tok, channel_scope=scope, reserve=reserve, config=cfg, speaker_prefetch_placeholder=prefetch_placeholder, channel_metadata=global_channel_metadata or None, ) for bl, bp, spk, ts_lo, ts_hi, cmap in subchunks: chunk_groups.append((scope, bl, bp, spk, ts_lo, ts_hi, cmap)) cm = int(params.chunks_max or 0) if cm > 0: chunk_groups = chunk_groups[:cm] logger.info( "--chunks-max=%d: using first %d chunk(s)", cm, len(chunk_groups), ) chunks_dir = out_dir / "chunks" chunks_dir.mkdir(exist_ok=True) for idx, ( ch_scope, chunk_lines, chunk_pairs, chunk_spk, ts_lo, ts_hi, _cmap, ) in enumerate(chunk_groups): text = "\n".join(chunk_lines) t0 = datetime.now(timezone.utc).isoformat() t_start = ( datetime.fromtimestamp(ts_lo, tz=timezone.utc).isoformat() if ts_lo > 0 else "" ) t_end = ( datetime.fromtimestamp(ts_hi, tz=timezone.utc).isoformat() if ts_hi > 0 else "" ) meta_rows = _subset_channel_metadata( global_channel_metadata or None, list(chunk_pairs), ) piece = { "index": idx, "channel_scope": ch_scope, "channels_in_chunk": [list(p) for p in chunk_pairs], "speakers_in_chunk": [list(p) for p in chunk_spk], "channel_metadata": meta_rows if meta_rows is not None else {}, "line_count": len(chunk_lines), "char_count": len(text), "written_at": t0, "time_range_hint_utc": [t_start, t_end], "conversation_text": text, } (chunks_dir / f"chunk_{idx:05d}.json").write_text( json.dumps(piece, ensure_ascii=False, indent=2), encoding="utf-8", ) if params.dry_run_chunks: logger.info( "Dry-run: wrote %d chunk files; skipping LLM.", len(chunk_groups), ) await token_counter.close() return if params.dry_run_llm: logger.info( "Dry-run LLM: running agentic extraction without persisting to FalkorDB.", ) or_for_kg = OpenRouterClient( api_key=cfg.api_key, model=cfg.model, base_url=cfg.llm_base_url, gemini_api_key=cfg.gemini_api_key or "", ) cache2 = MessageCache( redis_url=cfg.redis_url, openrouter_client=or_for_kg, embedding_model=cfg.embedding_model, ssl_kwargs=redis_ssl_kwargs_for_bulk( cfg, redis_no_verify=params.redis_no_verify, ), ) kg = KnowledgeGraphManager( redis_client=cache2.redis_client, openrouter=or_for_kg, embedding_model=cfg.embedding_model, admin_user_ids=set(cfg.admin_user_ids) if cfg.admin_user_ids else None, ) await kg.ensure_indexes() if bulk_backend == "gemini": bulk_client = create_kg_bulk_gemini_pool_client( max_tool_rounds=params.max_tool_rounds, ) else: bulk_client = create_kg_bulk_openrouter_client( cfg.api_key, gemini_api_key=cfg.gemini_api_key or "", max_tool_rounds=params.max_tool_rounds, ) run_log: list[dict[str, Any]] = [] start_idx = max(0, int(params.resume_from_chunk)) pc_hint = ", ".join(platforms_channels) try: for idx, ( ch_scope, chunk_lines, chunk_pairs, chunk_spk, ts_lo, ts_hi, cmap, ) in enumerate(chunk_groups): if idx < start_idx: continue text = "\n".join(chunk_lines) t_start = ( datetime.fromtimestamp(ts_lo, tz=timezone.utc).isoformat() if ts_lo > 0 else "" ) t_end = ( datetime.fromtimestamp(ts_hi, tz=timezone.utc).isoformat() if ts_hi > 0 else "" ) meta_slice = _subset_channel_metadata( global_channel_metadata or None, list(chunk_pairs), ) prefetch_text = "" if params.prefetch_speaker_kg: prefetch_text = await prefetch_speaker_kg_context( kg, list(chunk_spk), max_speakers=int(params.prefetch_max_speakers), hits_per_speaker=int(params.prefetch_hits_per_speaker), min_score=float(params.prefetch_min_score), ) stats = await run_agentic_kg_extraction_chunk( conversation_text=text, channel_id=ch_scope, kg_manager=kg, bulk_client=bulk_client, user_id="000000000000", chunk_index=idx, time_start_iso=t_start, time_end_iso=t_end, platforms_channels=pc_hint, config=cfg, chunk_channel_pairs=list(chunk_pairs), chunk_speaker_pairs=list(chunk_spk), speaker_kg_prefetch=prefetch_text, channel_metadata=meta_slice, persist_extraction=not params.dry_run_llm, ) stats["chunk_index"] = idx stats["channel_scope"] = ch_scope run_log.append(stats) logger.info("Chunk %d scope=%s stats=%s", idx, ch_scope, stats) err = int(stats.get("errors", 0) or 0) parse_err = bool(stats.get("parse_error")) if ( params.incremental and not params.dry_run_llm and err == 0 and not parse_err and cmap ): for (plat, cid), tmax in cmap.items(): if tmax <= 0: continue await cursor_hset(redis, plat, cid, tmax) (out_dir / "extraction_run.json").write_text( json.dumps(run_log, indent=2, default=str), encoding="utf-8", ) finally: await bulk_client.close() await cache2.close() await token_counter.close()
[docs] def resolve_bulk_backend(cli_value: str | None) -> str: v = (cli_value or os.environ.get("KG_BULK_LLM_BACKEND", "") or "gemini") return v.strip().lower()