"""Shared bulk agentic KG extraction: Redis message collect, chunking, LLM run.
Used by :mod:`scripts.kg_bulk_dump_and_extract` and the scheduled incremental
task in :mod:`background_tasks`.
"""
from __future__ import annotations
import asyncio
import jsonutil as json
import logging
import os
import time as _time
from collections import defaultdict
from dataclasses import dataclass
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Awaitable, Callable, Literal
from config import Config
from kg_agentic_extraction import (
KgBulkLlmClient,
create_kg_bulk_gemini_pool_client,
create_kg_bulk_openrouter_client,
messages_for_agentic_token_estimate,
prefetch_speaker_kg_context,
run_agentic_kg_extraction_chunk,
)
from knowledge_graph import KnowledgeGraphManager
from message_cache import MessageCache
from openrouter_client import OpenRouterClient
logger = logging.getLogger(__name__)
CROSS_CHANNEL_SCOPE = "cross_channel_bulk"
KG_AGENTIC_BULK_LAST_TS_HASH = "stargazer:kg_agentic_bulk:last_ts"
_HASH_MSG_FIELDS = (
"user_id",
"user_name",
"platform",
"channel_id",
"text",
"timestamp",
"message_id",
"reply_to_id",
)
_ZSET_BATCH = 5000
_HMGET_BATCH = 400
CursorBootstrap = Literal["full", "latest"]
[docs]
def cursor_field(platform: str, channel_id: str) -> str:
"""Build the per-channel field name used inside the incremental cursor hash.
Composes the ``{platform}:{channel_id}`` key under which the last-processed
timestamp for a channel is stored in :data:`KG_AGENTIC_BULK_LAST_TS_HASH`.
Note that *channel_id* may itself contain colons, so this is a left-anchored
composition rather than a reversible encoding.
This is called by :func:`cursor_hget`, :func:`cursor_hset`, and
:func:`bootstrap_latest_cursor_no_extract` to address the cursor hash, and is
exercised directly by ``tests/test_kg_bulk_cursors.py``.
Args:
platform: Platform identifier (e.g. ``"discord"``, ``"discord-self"``).
channel_id: Channel identifier within that platform.
Returns:
str: The ``"{platform}:{channel_id}"`` hash field name.
"""
return f"{platform}:{channel_id}"
[docs]
def redis_ssl_kwargs_for_bulk(
cfg: Config,
*,
redis_no_verify: bool,
) -> dict[str, Any]:
"""Build Redis SSL connection kwargs for the bulk KG pipeline.
Starts from the project-wide TLS settings produced by
:meth:`Config.redis_ssl_kwargs` and, only when *redis_no_verify* is set and
the URL uses the ``rediss://`` scheme, downgrades certificate verification so
that self-signed or hostname-mismatched broker certificates are accepted.
This is intended for managed/ephemeral bulk runs, not the long-lived
services.
The returned mapping is fed to :class:`message_cache.MessageCache` so its
Redis client can connect. This is called from
:func:`run_agentic_bulk_pipeline` (building the ``cache2`` client) and from
the standalone ``scripts/kg_bulk_dump_and_extract.py`` entry point.
Args:
cfg: Loaded :class:`config.Config` providing ``redis_url`` and base SSL
kwargs.
redis_no_verify: When true, disable cert/hostname verification for
``rediss://`` URLs (sets ``ssl_cert_reqs=CERT_NONE`` and
``ssl_check_hostname=False``).
Returns:
dict[str, Any]: SSL keyword arguments suitable for a redis-py async
client.
"""
kwargs: dict[str, Any] = dict(cfg.redis_ssl_kwargs())
url_l = (cfg.redis_url or "").strip().lower()
if redis_no_verify and url_l.startswith("rediss://"):
import ssl as _ssl
kwargs["ssl_cert_reqs"] = _ssl.CERT_NONE
kwargs["ssl_check_hostname"] = False
return kwargs
[docs]
async def scan_channel_zset_keys(redis: Any) -> list[str]:
"""Scan Redis for all per-channel message-index sorted-set keys.
Iterates the keyspace with ``SCAN MATCH channel_msgs:* COUNT 500``, decodes
any bytes keys, and returns them sorted for deterministic processing order.
Each matched key is a sorted set mapping message keys to timestamp scores for
one ``(platform, channel)``.
This is called by :func:`collect_messages_from_redis` when no explicit
*channel_pairs* filter is supplied, to discover every channel to ingest.
Args:
redis: Async redis-py client.
Returns:
list[str]: Sorted list of ``channel_msgs:{platform}:{channel_id}`` keys.
"""
keys: list[str] = []
async for key in redis.scan_iter(match="channel_msgs:*", count=500):
k = key.decode() if isinstance(key, bytes) else key
keys.append(k)
keys.sort()
return keys
def _parse_zset_key(zset_key: str) -> tuple[str, str] | None:
"""Split a ``channel_msgs:{platform}:{channel_id}`` key into its parts.
Uses a max-two-split so that channel ids containing colons stay intact in the
final segment, and returns ``None`` when the key has fewer than the expected
three segments (the prefix, platform, and channel id).
This is called by :func:`collect_messages_from_redis` to recover the
``(platform, channel_id)`` pair for each scanned zset, and is covered by
``tests/test_kg_bulk_cursors.py``.
Args:
zset_key: A key of the form ``channel_msgs:{platform}:{channel_id}``.
Returns:
tuple[str, str] | None: ``(platform, channel_id)``, or ``None`` if the
key is malformed.
"""
parts = zset_key.split(":", 2)
if len(parts) < 3:
return None
return parts[1], parts[2]
[docs]
async def cursor_hget(redis: Any, platform: str, channel_id: str) -> float | None:
"""Read the last-processed timestamp cursor for one channel.
Performs ``HGET`` on :data:`KG_AGENTIC_BULK_LAST_TS_HASH` using the field
from :func:`cursor_field`, decoding bytes and coercing to ``float``. Returns
``None`` when no cursor exists yet or the stored value cannot be parsed,
which the caller treats as "process the full backlog".
This is called from :func:`collect_messages_from_redis` during incremental
runs to determine the exclusive lower-bound timestamp for each channel.
Args:
redis: Async redis-py client.
platform: Platform identifier for the channel.
channel_id: Channel identifier.
Returns:
float | None: The stored cursor timestamp, or ``None`` if unset/unparsable.
"""
raw = await redis.hget(
KG_AGENTIC_BULK_LAST_TS_HASH,
cursor_field(platform, channel_id),
)
if raw is None:
return None
s = raw.decode() if isinstance(raw, bytes) else raw
try:
return float(s)
except (TypeError, ValueError):
return None
[docs]
async def cursor_hset(
redis: Any,
platform: str,
channel_id: str,
ts: float,
) -> None:
"""Persist the last-processed timestamp cursor for one channel.
Writes ``str(ts)`` via ``HSET`` into :data:`KG_AGENTIC_BULK_LAST_TS_HASH`
under the :func:`cursor_field` field, advancing the incremental watermark so
future runs skip already-extracted messages.
This is called from :func:`run_agentic_bulk_pipeline` after a chunk extracts
cleanly (no errors and no parse error), once per channel using that chunk's
per-channel maximum timestamp.
Args:
redis: Async redis-py client.
platform: Platform identifier for the channel.
channel_id: Channel identifier.
ts: New cursor timestamp (epoch seconds) to store.
"""
await redis.hset(
KG_AGENTIC_BULK_LAST_TS_HASH,
cursor_field(platform, channel_id),
str(ts),
)
[docs]
async def fetch_messages_for_zset(
redis: Any,
zset_key: str,
) -> list[dict[str, Any]]:
"""Fetch and hydrate every message recorded in one channel's zset.
Pages through the entire sorted set in :data:`_ZSET_BATCH`-sized ``ZRANGE``
slices, gathering the member message keys and hydrating each into a full
message dict via :func:`_hmget_message_dicts`. Returns the union across all
pages (ordering is not guaranteed here; the top-level caller re-sorts).
This is called by :func:`fetch_messages_for_zset_after` when there is no
lower-bound timestamp, and directly by :func:`collect_messages_from_redis`
for non-incremental (full backlog) channels.
Args:
redis: Async redis-py client.
zset_key: The ``channel_msgs:{platform}:{channel_id}`` sorted-set key.
Returns:
list[dict[str, Any]]: Hydrated message dicts for the whole channel.
"""
out: list[dict[str, Any]] = []
offset = 0
while True:
raw_keys: list[Any] = await redis.zrange(
zset_key,
offset,
offset + _ZSET_BATCH - 1,
)
if not raw_keys:
break
msg_keys = [k.decode() if isinstance(k, bytes) else k for k in raw_keys]
out.extend(await _hmget_message_dicts(redis, zset_key, msg_keys))
offset += len(raw_keys)
if len(raw_keys) < _ZSET_BATCH:
break
return out
[docs]
async def fetch_messages_for_zset_after(
redis: Any,
zset_key: str,
after_ts_exclusive: float | None,
) -> list[dict[str, Any]]:
"""Fetch and hydrate channel messages newer than a cursor timestamp.
When *after_ts_exclusive* is ``None`` this delegates to
:func:`fetch_messages_for_zset` (full backlog). Otherwise it pages through
the zset with ``ZRANGEBYSCORE`` using the exclusive lower bound ``(ts`` up to
``+inf`` in :data:`_ZSET_BATCH`-sized windows, hydrating each page through
:func:`_hmget_message_dicts`. This is the incremental read path that honors
the per-channel cursor.
This is called by :func:`collect_messages_from_redis` for incremental runs
on channels that already have a stored cursor.
Args:
redis: Async redis-py client.
zset_key: The ``channel_msgs:{platform}:{channel_id}`` sorted-set key.
after_ts_exclusive: Exclusive lower-bound timestamp; ``None`` means no
bound (return everything).
Returns:
list[dict[str, Any]]: Hydrated message dicts strictly newer than the
bound.
"""
if after_ts_exclusive is None:
return await fetch_messages_for_zset(redis, zset_key)
out: list[dict[str, Any]] = []
offset = 0
min_score = f"({after_ts_exclusive}"
while True:
raw_keys: list[Any] = await redis.zrangebyscore(
zset_key,
min_score,
"+inf",
start=offset,
num=_ZSET_BATCH,
)
if not raw_keys:
break
msg_keys = [k.decode() if isinstance(k, bytes) else k for k in raw_keys]
out.extend(await _hmget_message_dicts(redis, zset_key, msg_keys))
offset += len(raw_keys)
if len(raw_keys) < _ZSET_BATCH:
break
return out
async def _hmget_message_dicts(
redis: Any,
zset_key: str,
msg_keys: list[str],
) -> list[dict[str, Any]]:
"""Hydrate message hash keys into message dicts via pipelined ``HMGET``.
Reads the fields in :data:`_HASH_MSG_FIELDS` for each message key in
:data:`_HMGET_BATCH`-sized pipelines, decoding bytes values. Keys whose hash
is entirely missing (all values ``None``) are skipped (e.g. expired
messages). Each surviving dict gets ``timestamp`` coerced to ``float`` and is
annotated with bookkeeping fields ``redis_msg_key`` (its source hash key) and
``zset_key`` (the owning channel index).
This is called by both :func:`fetch_messages_for_zset` and
:func:`fetch_messages_for_zset_after` to turn zset members into usable
message dicts.
Args:
redis: Async redis-py client.
zset_key: The owning channel index key, stamped onto each result dict.
msg_keys: Message hash keys to hydrate.
Returns:
list[dict[str, Any]]: Hydrated message dicts (missing keys omitted).
"""
out: list[dict[str, Any]] = []
for i in range(0, len(msg_keys), _HMGET_BATCH):
batch = msg_keys[i : i + _HMGET_BATCH]
pipe = redis.pipeline()
for mk in batch:
pipe.hmget(mk, *_HASH_MSG_FIELDS)
rows = await pipe.execute()
for mk, values in zip(batch, rows):
if not values or all(v is None for v in values):
continue
mapping = dict(zip(_HASH_MSG_FIELDS, values))
msg = {
k: (v.decode() if isinstance(v, bytes) else v)
for k, v in mapping.items()
}
msg["timestamp"] = float(msg.get("timestamp") or 0)
msg["redis_msg_key"] = mk
msg["zset_key"] = zset_key
out.append(msg)
return out
[docs]
async def collect_messages_from_redis(
redis: Any,
*,
incremental: bool = False,
cursor_bootstrap: CursorBootstrap = "full",
channel_pairs: list[tuple[str, str]] | None = None,
) -> tuple[list[dict[str, Any]], int]:
"""Load messages from Redis channel zsets.
*channel_pairs*: if set, only these ``(platform, channel_id)`` zsets;
otherwise all ``channel_msgs:*`` keys.
When *incremental* is true, uses :data:`KG_AGENTIC_BULK_LAST_TS_HASH`
per channel with exclusive lower bound. *cursor_bootstrap* ``latest``
seeds missing cursors to the current zset max without returning backlog
messages.
"""
if channel_pairs is not None:
zkeys = [f"channel_msgs:{p}:{c}" for p, c in channel_pairs]
n_scanned = len(zkeys)
allowed = set(channel_pairs)
else:
zkeys = await scan_channel_zset_keys(redis)
n_scanned = len(zkeys)
allowed = None
all_msgs: list[dict[str, Any]] = []
for zk in zkeys:
parsed = _parse_zset_key(zk)
if not parsed:
continue
platform, channel_id = parsed
if allowed is not None and (platform, channel_id) not in allowed:
continue
after_ts: float | None = None
if incremental:
if cursor_bootstrap == "latest":
await bootstrap_latest_cursor_no_extract(
redis,
zk,
platform,
channel_id,
)
stored = await cursor_hget(redis, platform, channel_id)
after_ts = stored
try:
if incremental and after_ts is not None:
part = await fetch_messages_for_zset_after(redis, zk, after_ts)
elif incremental and after_ts is None:
part = await fetch_messages_for_zset(redis, zk)
else:
part = await fetch_messages_for_zset(redis, zk)
all_msgs.extend(part)
except Exception:
logger.warning("Failed zset %s", zk, exc_info=True)
all_msgs.sort(
key=lambda m: (
m["timestamp"],
str(m.get("platform", "")),
str(m.get("channel_id", "")),
str(m.get("message_id", "")),
),
)
return all_msgs, n_scanned
def _unique_channel_pairs(keys: list[tuple[str, str]]) -> list[tuple[str, str]]:
"""Deduplicate and sort a list of ``(platform, channel_id)`` pairs.
This is a pure helper called from :func:`chunk_message_lines` to summarize
which channels appear in a candidate or final chunk, used both for the token
estimate and to record the chunk's channel set.
Args:
keys: ``(platform, channel_id)`` pairs, possibly with duplicates.
Returns:
list[tuple[str, str]]: The distinct pairs in sorted order.
"""
return sorted(set(keys))
def _unique_speaker_pairs(
keys: list[tuple[str, str]],
) -> list[tuple[str, str]]:
"""Collapse ``(user_id, user_name)`` pairs to one name per distinct speaker.
Trims ids and names, drops entries with an empty id, and for each id keeps
the first non-empty name seen (falling back to ``"?"``). The result is sorted
by user id so chunk speaker sets are stable.
This is a pure helper called from :func:`chunk_message_lines` to build the
per-chunk speaker roster used for token estimation and for downstream speaker
KG prefetch.
Args:
keys: ``(user_id, user_name)`` pairs, possibly with duplicates or blanks.
Returns:
list[tuple[str, str]]: Distinct ``(user_id, user_name)`` pairs sorted by
id.
"""
seen: dict[str, str] = {}
for uid, name in keys:
u = (uid or "").strip()
if not u:
continue
if u not in seen:
seen[u] = (name or "").strip() or "?"
return sorted(seen.items(), key=lambda x: x[0])
def _subset_channel_metadata(
full: dict[str, dict[str, str]] | None,
pairs: list[tuple[str, str]],
) -> dict[str, dict[str, str]] | None:
"""Narrow a channel-metadata map to only the channels in *pairs*.
Given the global metadata keyed by ``"{platform}:{channel_id}"``, returns a
shallow copy containing just the entries whose channel appears in *pairs*, or
``None`` when the input is empty or nothing matches (so callers can omit the
section entirely).
This is a pure helper used in :func:`chunk_message_lines` (to size token
budgets for a candidate slice) and twice in :func:`run_agentic_bulk_pipeline`
(to embed metadata in each chunk artifact and to pass per-chunk metadata into
extraction).
Args:
full: The full metadata map keyed by ``"{platform}:{channel_id}"``, or
``None``.
pairs: ``(platform, channel_id)`` pairs to retain.
Returns:
dict[str, dict[str, str]] | None: The filtered metadata, or ``None`` when
empty.
"""
if not full:
return None
out: dict[str, dict[str, str]] = {}
for plat, cid in pairs:
k = f"{plat}:{cid}"
if k in full:
out[k] = dict(full[k])
return out or None
def _max_ts_per_channel(
line_channel_keys: list[tuple[str, str]],
line_timestamps: list[float],
) -> dict[tuple[str, str], float]:
"""Compute the maximum timestamp seen per channel across parallel lists.
Walks the zipped ``(channel_key, timestamp)`` stream and keeps the largest
timestamp for each ``(platform, channel_id)``. The result is the
chunk-boundary cursor map: after a chunk extracts cleanly, these per-channel
maxima become the new incremental cursors.
This is a pure helper called from :func:`chunk_message_lines` to attach a
``cmap`` to each emitted chunk; that map is later consumed by
:func:`run_agentic_bulk_pipeline` to call :func:`cursor_hset`. It is also
covered by ``tests/test_kg_bulk_cursors.py``.
Args:
line_channel_keys: Per-line ``(platform, channel_id)`` keys.
line_timestamps: Per-line timestamps, aligned with *line_channel_keys*.
Returns:
dict[tuple[str, str], float]: Max timestamp per ``(platform, channel_id)``.
"""
acc: dict[tuple[str, str], float] = {}
for k, t in zip(line_channel_keys, line_timestamps):
if k not in acc or t > acc[k]:
acc[k] = t
return acc
def _speaker_prefetch_placeholder(max_chars: int) -> str:
"""Build a fixed-size filler string approximating the speaker-KG prefetch.
When speaker KG prefetch is enabled, the real prefetch text is only computed
per chunk at extraction time, but the chunk packer must reserve room for it
up front. This returns a header plus middot padding totaling about
*max_chars* characters so token estimation during packing reflects the
eventual prompt size. Returns ``""`` when *max_chars* is non-positive.
This is a pure helper called once in :func:`run_agentic_bulk_pipeline` to
produce the ``speaker_prefetch_placeholder`` passed through to
:func:`chunk_message_lines` / :func:`token_count_conversation`.
Args:
max_chars: Target placeholder size in characters (the prefetch char
budget).
Returns:
str: A placeholder string of roughly *max_chars* characters, or empty.
"""
if max_chars <= 0:
return ""
header = (
"## Existing knowledge graph (speakers — prefetch)\n"
"_Placeholder sizing for token budget._\n\n"
)
pad = max(0, int(max_chars) - len(header))
return header + ("·" * pad)
[docs]
async def token_count_conversation(
counter: KgBulkLlmClient,
conversation_text: str,
*,
channel_scope: str,
reserve: int,
config: Config | None = None,
chunk_channel_pairs: list[tuple[str, str]] | None = None,
chunk_speaker_pairs: list[tuple[str, str]] | None = None,
speaker_kg_prefetch: str = "",
channel_metadata: dict[str, dict[str, str]] | None = None,
) -> int:
"""Estimate the input-token cost of a candidate conversation chunk.
Reconstructs the full extraction prompt for *conversation_text* via
:func:`kg_agentic_extraction.messages_for_agentic_token_estimate` (including
channel/speaker pairs, speaker-KG prefetch placeholder, and channel
metadata), then asks the bulk LLM client to count its input tokens. If the
client cannot count, it falls back to a rough ``len(text) // 3`` heuristic.
The configured *reserve* (headroom for tool rounds and output) is always
added.
This makes a network/SDK call through ``counter.count_input_tokens`` and is
called from the binary-search inner loop of :func:`chunk_message_lines` to
decide how many lines fit within a chunk's token budget.
Args:
counter: Bulk LLM client used purely for token counting.
conversation_text: The joined transcript lines for the candidate chunk.
channel_scope: Scope label for the prompt (a channel id or
:data:`CROSS_CHANNEL_SCOPE`).
reserve: Token headroom always added to the estimate.
config: Optional :class:`config.Config` influencing prompt assembly.
chunk_channel_pairs: Channels represented in the chunk.
chunk_speaker_pairs: Speakers represented in the chunk.
speaker_kg_prefetch: Placeholder/real speaker-KG prefetch text to size.
channel_metadata: Per-channel metadata subset for the chunk.
Returns:
int: Estimated total input tokens including *reserve*.
"""
msgs = messages_for_agentic_token_estimate(
conversation_text,
channel_id=channel_scope,
chunk_index=0,
config=config,
chunk_channel_pairs=chunk_channel_pairs,
chunk_speaker_pairs=chunk_speaker_pairs,
speaker_kg_prefetch=speaker_kg_prefetch,
channel_metadata=channel_metadata,
)
n = await counter.count_input_tokens(msgs)
if n is not None:
return int(n) + reserve
return max(1, len(conversation_text) // 3) + reserve
[docs]
async def chunk_message_lines(
lines: list[str],
line_channel_keys: list[tuple[str, str]],
line_speaker_keys: list[tuple[str, str]],
line_timestamps: list[float],
counter: KgBulkLlmClient,
*,
max_tokens: int,
channel_scope: str,
reserve: int,
config: Config | None = None,
speaker_prefetch_placeholder: str = "",
channel_metadata: dict[str, dict[str, str]] | None = None,
) -> list[
tuple[
list[str],
list[tuple[str, str]],
list[tuple[str, str]],
float,
float,
dict[tuple[str, str], float],
]
]:
"""Greedily pack transcript lines into token-budget extraction chunks.
Walks the line list in order and, for each chunk, binary-searches the largest
contiguous run of lines whose reconstructed extraction prompt still fits under
*max_tokens*. Because lines stay contiguous and in timestamp order, each chunk
carries clean per-channel time boundaries, which lets the incremental pipeline
advance cursors safely only after a chunk extracts without error. Always emits
at least one line per chunk so an oversized single line cannot stall the loop.
Each candidate slice is priced by :func:`token_count_conversation` (which calls
the bulk LLM client's token counter), summarized via :func:`_unique_channel_pairs`
and :func:`_unique_speaker_pairs`, and narrowed with
:func:`_subset_channel_metadata`; the chosen slice also gets a per-channel
max-timestamp map from :func:`_max_ts_per_channel` for cursor advancement. This
is called from :func:`run_agentic_bulk_pipeline` (once for the cross-channel
path, once per channel for the per-channel path) and is exercised by
``tests/test_kg_bulk_cursors.py``.
Args:
lines: Formatted transcript lines to pack, in chronological order.
line_channel_keys: Per-line ``(platform, channel_id)`` keys, aligned with
*lines*.
line_speaker_keys: Per-line ``(user_id, user_name)`` keys, aligned with
*lines*.
line_timestamps: Per-line epoch-second timestamps, aligned with *lines*.
counter: Bulk LLM client used only for token counting.
max_tokens: Maximum estimated input tokens allowed per chunk.
channel_scope: Scope label passed into prompt assembly (a channel id or
:data:`CROSS_CHANNEL_SCOPE`).
reserve: Token headroom added to every estimate.
config: Optional :class:`config.Config` influencing prompt assembly.
speaker_prefetch_placeholder: Filler text sizing the speaker-KG prefetch
section during estimation.
channel_metadata: Full per-channel metadata map, subset per candidate.
Returns:
A list of chunk tuples, each
``(lines, channel_pairs, speaker_pairs, ts_lo, ts_hi, cursor_map)``: the
chunk's lines, its distinct channel and speaker pairs, its low/high
timestamps, and the per-channel max-timestamp cursor map.
Raises:
ValueError: If *line_channel_keys*, *line_speaker_keys*, or
*line_timestamps* do not all match the length of *lines*.
"""
chunks: list[
tuple[
list[str],
list[tuple[str, str]],
list[tuple[str, str]],
float,
float,
dict[tuple[str, str], float],
]
] = []
i = 0
nlines = len(lines)
if len(line_channel_keys) != nlines or len(line_speaker_keys) != nlines:
raise ValueError("line_*_keys length must match lines")
if len(line_timestamps) != nlines:
raise ValueError("line_timestamps length must match lines")
while i < nlines:
low, high = 1, nlines - i
best = 0
while low <= high:
mid = (low + high) // 2
block = lines[i : i + mid]
text = "\n".join(block)
slice_keys = line_channel_keys[i : i + mid]
slice_spk = line_speaker_keys[i : i + mid]
pairs = _unique_channel_pairs(slice_keys)
spairs = _unique_speaker_pairs(slice_spk)
meta_slice = _subset_channel_metadata(channel_metadata, pairs)
tc = await token_count_conversation(
counter,
text,
channel_scope=channel_scope,
reserve=reserve,
config=config,
chunk_channel_pairs=pairs,
chunk_speaker_pairs=spairs,
speaker_kg_prefetch=speaker_prefetch_placeholder,
channel_metadata=meta_slice,
)
if tc <= max_tokens:
best = mid
low = mid + 1
else:
high = mid - 1
if best == 0:
best = 1
slice_keys = line_channel_keys[i : i + best]
slice_spk = line_speaker_keys[i : i + best]
slice_ts = line_timestamps[i : i + best]
ts_lo = min(slice_ts) if slice_ts else 0.0
ts_hi = max(slice_ts) if slice_ts else 0.0
cmap = _max_ts_per_channel(slice_keys, slice_ts)
chunks.append(
(
lines[i : i + best],
_unique_channel_pairs(slice_keys),
_unique_speaker_pairs(slice_spk),
ts_lo,
ts_hi,
cmap,
),
)
i += best
return chunks
[docs]
@dataclass
class KgBulkPipelineParams:
"""Bundle of tuning knobs for one bulk agentic KG extraction run.
Carries everything :func:`run_agentic_bulk_pipeline` needs that is not the
runtime state (config, redis, messages): the output directory, dry-run and
dump-only switches, token budgets (chunk size and reserve), chunking mode
(per-channel vs cross-channel), resume/limit bounds, LLM backend selection,
channel-metadata fetching, speaker-KG prefetch sizing, and the incremental
cursor flag. Instances are constructed by ``scripts/kg_bulk_dump_and_extract.py``
(from parsed CLI args) and by ``background_tasks.py`` (for the scheduled
incremental run), then passed straight into :func:`run_agentic_bulk_pipeline`.
"""
out_dir: Path
dump_only: bool = False
dry_run_chunks: bool = False
dry_run_llm: bool = False
chunk_tokens: int = 250_000
token_reserve: int = 100_000
max_messages: int = 0
chunks_max: int = 0
per_channel: bool = False
resume_from_chunk: int = 0
max_tool_rounds: int = 48
bulk_llm_backend: str = "gemini"
fetch_channel_metadata: bool = False
channel_metadata_ttl_days: float = 7.0
discord_platform_type: str = "discord"
prefetch_speaker_kg: bool = False
prefetch_max_speakers: int = 8
prefetch_hits_per_speaker: int = 3
prefetch_max_chars: int = 400_000
prefetch_min_score: float = 0.0
redis_no_verify: bool = True
incremental: bool = False
[docs]
async def run_agentic_bulk_pipeline(
cfg: Config,
redis: Any,
messages: list[dict[str, Any]],
n_zsets_scanned: int,
params: KgBulkPipelineParams,
*,
load_channel_metadata: (
Callable[
[],
Awaitable[dict[str, dict[str, str]]],
]
| None
) = None,
) -> None:
"""Chunk (optional disk artifacts), run agentic KG extraction, update cursors.
*load_channel_metadata*: optional async callable returning Discord (or other)
channel metadata dict; only invoked when *params.fetch_channel_metadata*.
After each successful extraction chunk, updates incremental cursors
per channel when *params.incremental* is true.
"""
out_dir = Path(params.out_dir).resolve()
out_dir.mkdir(parents=True, exist_ok=True)
_t0 = _time.monotonic()
msgs = messages
if params.max_messages and params.max_messages > 0:
msgs = msgs[: params.max_messages]
jsonl_path = out_dir / "messages.jsonl"
with jsonl_path.open("w", encoding="utf-8") as fj:
for m in msgs:
rec = {k: v for k, v in m.items() if k != "embedding"}
rec["llm_style_line"] = format_llm_style_line(rec)
fj.write(json.dumps(rec, ensure_ascii=False, default=str) + "\n")
lines = [format_llm_style_line(m) for m in msgs]
line_channel_keys = [
(str(m.get("platform", "")), str(m.get("channel_id", ""))) for m in msgs
]
line_speaker_keys = [
(str(m.get("user_id", "")), str(m.get("user_name", ""))) for m in msgs
]
line_timestamps = [float(m.get("timestamp") or 0) for m in msgs]
platforms_channels = sorted(
{f"{m.get('platform')}:{m.get('channel_id')}" for m in msgs},
)
zsets_in_msgs = len(
{str(m.get("zset_key", "") or "") for m in msgs if m.get("zset_key")},
)
meta = {
"message_count": len(msgs),
"zset_count": zsets_in_msgs,
"zset_count_redis_scanned": n_zsets_scanned,
"time_min": min((m["timestamp"] for m in msgs), default=0),
"time_max": max((m["timestamp"] for m in msgs), default=0),
"platforms_channels": platforms_channels,
}
(out_dir / "manifest.json").write_text(
json.dumps(meta, indent=2, default=str),
encoding="utf-8",
)
if params.dump_only:
logger.info("Dump-only: wrote %s and manifest.", jsonl_path)
try:
from observability import publish_debug_event
asyncio.create_task(
publish_debug_event(
"kg_bulk_incremental",
"kg_bulk_runner",
phase="pipeline",
status="dump_only",
duration_ms=(_time.monotonic() - _t0) * 1000,
preview=f"channels={len(platforms_channels)} messages={len(msgs)} out_dir={out_dir}",
payload={
"channels_processed": len(platforms_channels),
"messages_processed": len(msgs),
"n_zsets": n_zsets_scanned,
"out_dir": str(out_dir),
},
),
name="obs_kg_bulk",
)
except Exception:
pass
return
global_channel_metadata: dict[str, dict[str, str]] = {}
if params.fetch_channel_metadata and load_channel_metadata is not None:
global_channel_metadata = await load_channel_metadata()
bulk_backend = (params.bulk_llm_backend or "gemini").strip().lower()
if bulk_backend == "gemini":
token_counter = create_kg_bulk_gemini_pool_client(max_tool_rounds=1)
else:
token_counter = create_kg_bulk_openrouter_client(
cfg.api_key,
gemini_api_key=cfg.gemini_api_key or "",
max_tool_rounds=1,
top_p=cfg.top_p,
)
scope = CROSS_CHANNEL_SCOPE
reserve = int(params.token_reserve)
if params.prefetch_speaker_kg:
reserve += max(0, int(params.prefetch_max_chars) // 3)
if params.fetch_channel_metadata:
n_disc = sum(
1
for x in platforms_channels
if x.startswith("discord:") or x.startswith("discord-self:")
)
reserve += min(120_000, 2_000 * max(1, n_disc))
max_tok = int(params.chunk_tokens)
prefetch_placeholder = ""
if params.prefetch_speaker_kg:
prefetch_placeholder = _speaker_prefetch_placeholder(
int(params.prefetch_max_chars),
)
chunk_groups: list[
tuple[
str,
list[str],
list[tuple[str, str]],
list[tuple[str, str]],
float,
float,
dict[tuple[str, str], float],
]
] = []
if params.per_channel:
grouped: dict[
tuple[str, str],
list[
tuple[
str,
tuple[str, str],
tuple[str, str],
float,
]
],
] = defaultdict(list)
for m in msgs:
k = (str(m.get("platform", "")), str(m.get("channel_id", "")))
sp = (str(m.get("user_id", "")), str(m.get("user_name", "")))
grouped[k].append(
(
format_llm_style_line(m),
k,
sp,
float(m.get("timestamp") or 0),
),
)
for (plat, cid), triples in sorted(grouped.items()):
ch_scope = f"{plat}:{cid}" if plat or cid else CROSS_CHANNEL_SCOPE
ls = [t[0] for t in triples]
lck = [t[1] for t in triples]
lsk = [t[2] for t in triples]
lts = [t[3] for t in triples]
subchunks = await chunk_message_lines(
ls,
lck,
lsk,
lts,
token_counter,
max_tokens=max_tok,
channel_scope=ch_scope,
reserve=reserve,
config=cfg,
speaker_prefetch_placeholder=prefetch_placeholder,
channel_metadata=global_channel_metadata or None,
)
for bl, bp, spk, ts_lo, ts_hi, cmap in subchunks:
chunk_groups.append((ch_scope, bl, bp, spk, ts_lo, ts_hi, cmap))
else:
subchunks = await chunk_message_lines(
lines,
line_channel_keys,
line_speaker_keys,
line_timestamps,
token_counter,
max_tokens=max_tok,
channel_scope=scope,
reserve=reserve,
config=cfg,
speaker_prefetch_placeholder=prefetch_placeholder,
channel_metadata=global_channel_metadata or None,
)
for bl, bp, spk, ts_lo, ts_hi, cmap in subchunks:
chunk_groups.append((scope, bl, bp, spk, ts_lo, ts_hi, cmap))
cm = int(params.chunks_max or 0)
if cm > 0:
chunk_groups = chunk_groups[:cm]
logger.info(
"--chunks-max=%d: using first %d chunk(s)",
cm,
len(chunk_groups),
)
chunks_dir = out_dir / "chunks"
chunks_dir.mkdir(exist_ok=True)
for idx, (
ch_scope,
chunk_lines,
chunk_pairs,
chunk_spk,
ts_lo,
ts_hi,
_cmap,
) in enumerate(chunk_groups):
text = "\n".join(chunk_lines)
t0 = datetime.now(timezone.utc).isoformat()
t_start = (
datetime.fromtimestamp(ts_lo, tz=timezone.utc).isoformat()
if ts_lo > 0
else ""
)
t_end = (
datetime.fromtimestamp(ts_hi, tz=timezone.utc).isoformat()
if ts_hi > 0
else ""
)
meta_rows = _subset_channel_metadata(
global_channel_metadata or None,
list(chunk_pairs),
)
piece = {
"index": idx,
"channel_scope": ch_scope,
"channels_in_chunk": [list(p) for p in chunk_pairs],
"speakers_in_chunk": [list(p) for p in chunk_spk],
"channel_metadata": meta_rows if meta_rows is not None else {},
"line_count": len(chunk_lines),
"char_count": len(text),
"written_at": t0,
"time_range_hint_utc": [t_start, t_end],
"conversation_text": text,
}
(chunks_dir / f"chunk_{idx:05d}.json").write_text(
json.dumps(piece, ensure_ascii=False, indent=2),
encoding="utf-8",
)
if params.dry_run_chunks:
logger.info(
"Dry-run: wrote %d chunk files; skipping LLM.",
len(chunk_groups),
)
await token_counter.close()
return
if params.dry_run_llm:
logger.info(
"Dry-run LLM: running agentic extraction without persisting to FalkorDB.",
)
or_for_kg = OpenRouterClient(
api_key=cfg.api_key,
model=cfg.model,
base_url=cfg.llm_base_url,
gemini_api_key=cfg.gemini_api_key or "",
top_p=cfg.top_p,
)
cache2 = MessageCache(
redis_url=cfg.redis_url,
openrouter_client=or_for_kg,
embedding_model=cfg.embedding_model,
ssl_kwargs=redis_ssl_kwargs_for_bulk(
cfg,
redis_no_verify=params.redis_no_verify,
),
)
kg = KnowledgeGraphManager(
redis_client=cache2.redis_client,
openrouter=or_for_kg,
embedding_model=cfg.embedding_model,
admin_user_ids=set(cfg.admin_user_ids) if cfg.admin_user_ids else None,
)
await kg.ensure_indexes()
if bulk_backend == "gemini":
bulk_client = create_kg_bulk_gemini_pool_client(
max_tool_rounds=params.max_tool_rounds,
)
else:
bulk_client = create_kg_bulk_openrouter_client(
cfg.api_key,
gemini_api_key=cfg.gemini_api_key or "",
max_tool_rounds=params.max_tool_rounds,
top_p=cfg.top_p,
)
run_log: list[dict[str, Any]] = []
start_idx = max(0, int(params.resume_from_chunk))
pc_hint = ", ".join(platforms_channels)
try:
for idx, (
ch_scope,
chunk_lines,
chunk_pairs,
chunk_spk,
ts_lo,
ts_hi,
cmap,
) in enumerate(chunk_groups):
if idx < start_idx:
continue
text = "\n".join(chunk_lines)
t_start = (
datetime.fromtimestamp(ts_lo, tz=timezone.utc).isoformat()
if ts_lo > 0
else ""
)
t_end = (
datetime.fromtimestamp(ts_hi, tz=timezone.utc).isoformat()
if ts_hi > 0
else ""
)
meta_slice = _subset_channel_metadata(
global_channel_metadata or None,
list(chunk_pairs),
)
prefetch_text = ""
if params.prefetch_speaker_kg:
prefetch_text = await prefetch_speaker_kg_context(
kg,
list(chunk_spk),
max_speakers=int(params.prefetch_max_speakers),
hits_per_speaker=int(params.prefetch_hits_per_speaker),
min_score=float(params.prefetch_min_score),
)
stats = await run_agentic_kg_extraction_chunk(
conversation_text=text,
channel_id=ch_scope,
kg_manager=kg,
bulk_client=bulk_client,
user_id="000000000000",
chunk_index=idx,
time_start_iso=t_start,
time_end_iso=t_end,
platforms_channels=pc_hint,
config=cfg,
chunk_channel_pairs=list(chunk_pairs),
chunk_speaker_pairs=list(chunk_spk),
speaker_kg_prefetch=prefetch_text,
channel_metadata=meta_slice,
persist_extraction=not params.dry_run_llm,
)
stats["chunk_index"] = idx
stats["channel_scope"] = ch_scope
run_log.append(stats)
logger.info("Chunk %d scope=%s stats=%s", idx, ch_scope, stats)
err = int(stats.get("errors", 0) or 0)
parse_err = bool(stats.get("parse_error"))
if (
params.incremental
and not params.dry_run_llm
and err == 0
and not parse_err
and cmap
):
for (plat, cid), tmax in cmap.items():
if tmax <= 0:
continue
await cursor_hset(redis, plat, cid, tmax)
(out_dir / "extraction_run.json").write_text(
json.dumps(run_log, indent=2, default=str),
encoding="utf-8",
)
finally:
await bulk_client.close()
await cache2.close()
await token_counter.close()
try:
from observability import publish_debug_event
_channels_done = len(
{e.get("channel_scope", "") for e in run_log if isinstance(e, dict)}
)
asyncio.create_task(
publish_debug_event(
"kg_bulk_incremental",
"kg_bulk_runner",
phase="pipeline",
status="ok",
duration_ms=(_time.monotonic() - _t0) * 1000,
preview=f"channels={_channels_done} messages={len(msgs)} out_dir={out_dir}",
payload={
"channels_processed": _channels_done,
"messages_processed": len(msgs),
"n_zsets": n_zsets_scanned,
"out_dir": str(out_dir),
"chunks_run": len(run_log),
},
),
name="obs_kg_bulk",
)
except Exception:
pass
[docs]
def resolve_bulk_backend(cli_value: str | None) -> str:
"""Resolve which LLM backend the bulk pipeline should use.
Applies the precedence CLI argument, then the ``KG_BULK_LLM_BACKEND``
environment variable, then the ``"gemini"`` default, normalizing the result
to lowercase. The chosen value selects between the Gemini pool and
OpenRouter clients in :func:`run_agentic_bulk_pipeline`.
This reads ``os.environ`` and is called by ``scripts/kg_bulk_dump_and_extract.py``
(passing the parsed ``--bulk-llm-backend`` arg) and by ``background_tasks.py``
(passing ``None``, so it falls back to env/default for the scheduled
incremental run).
Args:
cli_value: Explicit backend from the CLI, or ``None`` to defer to env/default.
Returns:
str: The lowercased backend name (e.g. ``"gemini"`` or ``"openrouter"``).
"""
v = cli_value or os.environ.get("KG_BULK_LLM_BACKEND", "") or "gemini"
return v.strip().lower()