"""Shared bulk agentic KG extraction: Redis message collect, chunking, LLM run.
Used by :mod:`scripts.kg_bulk_dump_and_extract` and the scheduled incremental
task in :mod:`background_tasks`.
"""
from __future__ import annotations
import json
import logging
import os
from collections import defaultdict
from dataclasses import dataclass
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Awaitable, Callable, Literal
from config import Config
from kg_agentic_extraction import (
KgBulkLlmClient,
create_kg_bulk_gemini_pool_client,
create_kg_bulk_openrouter_client,
messages_for_agentic_token_estimate,
prefetch_speaker_kg_context,
run_agentic_kg_extraction_chunk,
)
from knowledge_graph import KnowledgeGraphManager
from message_cache import MessageCache
from openrouter_client import OpenRouterClient
logger = logging.getLogger(__name__)
CROSS_CHANNEL_SCOPE = "cross_channel_bulk"
KG_AGENTIC_BULK_LAST_TS_HASH = "stargazer:kg_agentic_bulk:last_ts"
_HASH_MSG_FIELDS = (
"user_id",
"user_name",
"platform",
"channel_id",
"text",
"timestamp",
"message_id",
"reply_to_id",
)
_ZSET_BATCH = 5000
_HMGET_BATCH = 400
CursorBootstrap = Literal["full", "latest"]
[docs]
def cursor_field(platform: str, channel_id: str) -> str:
return f"{platform}:{channel_id}"
[docs]
def redis_ssl_kwargs_for_bulk(
cfg: Config, *, redis_no_verify: bool,
) -> dict[str, Any]:
kwargs: dict[str, Any] = dict(cfg.redis_ssl_kwargs())
url_l = (cfg.redis_url or "").strip().lower()
if redis_no_verify and url_l.startswith("rediss://"):
import ssl as _ssl
kwargs["ssl_cert_reqs"] = _ssl.CERT_NONE
kwargs["ssl_check_hostname"] = False
return kwargs
[docs]
async def scan_channel_zset_keys(redis: Any) -> list[str]:
keys: list[str] = []
async for key in redis.scan_iter(match="channel_msgs:*", count=500):
k = key.decode() if isinstance(key, bytes) else key
keys.append(k)
keys.sort()
return keys
def _parse_zset_key(zset_key: str) -> tuple[str, str] | None:
parts = zset_key.split(":", 2)
if len(parts) < 3:
return None
return parts[1], parts[2]
[docs]
async def cursor_hget(redis: Any, platform: str, channel_id: str) -> float | None:
raw = await redis.hget(
KG_AGENTIC_BULK_LAST_TS_HASH,
cursor_field(platform, channel_id),
)
if raw is None:
return None
s = raw.decode() if isinstance(raw, bytes) else raw
try:
return float(s)
except (TypeError, ValueError):
return None
[docs]
async def cursor_hset(
redis: Any,
platform: str,
channel_id: str,
ts: float,
) -> None:
await redis.hset(
KG_AGENTIC_BULK_LAST_TS_HASH,
cursor_field(platform, channel_id),
str(ts),
)
[docs]
async def fetch_messages_for_zset(
redis: Any,
zset_key: str,
) -> list[dict[str, Any]]:
out: list[dict[str, Any]] = []
offset = 0
while True:
raw_keys: list[Any] = await redis.zrange(
zset_key, offset, offset + _ZSET_BATCH - 1,
)
if not raw_keys:
break
msg_keys = [k.decode() if isinstance(k, bytes) else k for k in raw_keys]
out.extend(await _hmget_message_dicts(redis, zset_key, msg_keys))
offset += len(raw_keys)
if len(raw_keys) < _ZSET_BATCH:
break
return out
[docs]
async def fetch_messages_for_zset_after(
redis: Any,
zset_key: str,
after_ts_exclusive: float | None,
) -> list[dict[str, Any]]:
if after_ts_exclusive is None:
return await fetch_messages_for_zset(redis, zset_key)
out: list[dict[str, Any]] = []
offset = 0
min_score = f"({after_ts_exclusive}"
while True:
raw_keys: list[Any] = await redis.zrangebyscore(
zset_key,
min_score,
"+inf",
start=offset,
num=_ZSET_BATCH,
)
if not raw_keys:
break
msg_keys = [k.decode() if isinstance(k, bytes) else k for k in raw_keys]
out.extend(await _hmget_message_dicts(redis, zset_key, msg_keys))
offset += len(raw_keys)
if len(raw_keys) < _ZSET_BATCH:
break
return out
async def _hmget_message_dicts(
redis: Any,
zset_key: str,
msg_keys: list[str],
) -> list[dict[str, Any]]:
out: list[dict[str, Any]] = []
for i in range(0, len(msg_keys), _HMGET_BATCH):
batch = msg_keys[i: i + _HMGET_BATCH]
pipe = redis.pipeline()
for mk in batch:
pipe.hmget(mk, *_HASH_MSG_FIELDS)
rows = await pipe.execute()
for mk, values in zip(batch, rows):
if not values or all(v is None for v in values):
continue
mapping = dict(zip(_HASH_MSG_FIELDS, values))
msg = {
k: (v.decode() if isinstance(v, bytes) else v)
for k, v in mapping.items()
}
msg["timestamp"] = float(msg.get("timestamp") or 0)
msg["redis_msg_key"] = mk
msg["zset_key"] = zset_key
out.append(msg)
return out
[docs]
async def collect_messages_from_redis(
redis: Any,
*,
incremental: bool = False,
cursor_bootstrap: CursorBootstrap = "full",
channel_pairs: list[tuple[str, str]] | None = None,
) -> tuple[list[dict[str, Any]], int]:
"""Load messages from Redis channel zsets.
*channel_pairs*: if set, only these ``(platform, channel_id)`` zsets;
otherwise all ``channel_msgs:*`` keys.
When *incremental* is true, uses :data:`KG_AGENTIC_BULK_LAST_TS_HASH`
per channel with exclusive lower bound. *cursor_bootstrap* ``latest``
seeds missing cursors to the current zset max without returning backlog
messages.
"""
if channel_pairs is not None:
zkeys = [f"channel_msgs:{p}:{c}" for p, c in channel_pairs]
n_scanned = len(zkeys)
allowed = set(channel_pairs)
else:
zkeys = await scan_channel_zset_keys(redis)
n_scanned = len(zkeys)
allowed = None
all_msgs: list[dict[str, Any]] = []
for zk in zkeys:
parsed = _parse_zset_key(zk)
if not parsed:
continue
platform, channel_id = parsed
if allowed is not None and (platform, channel_id) not in allowed:
continue
after_ts: float | None = None
if incremental:
if cursor_bootstrap == "latest":
await bootstrap_latest_cursor_no_extract(
redis, zk, platform, channel_id,
)
stored = await cursor_hget(redis, platform, channel_id)
after_ts = stored
try:
if incremental and after_ts is not None:
part = await fetch_messages_for_zset_after(redis, zk, after_ts)
elif incremental and after_ts is None:
part = await fetch_messages_for_zset(redis, zk)
else:
part = await fetch_messages_for_zset(redis, zk)
all_msgs.extend(part)
except Exception:
logger.warning("Failed zset %s", zk, exc_info=True)
all_msgs.sort(
key=lambda m: (
m["timestamp"],
str(m.get("platform", "")),
str(m.get("channel_id", "")),
str(m.get("message_id", "")),
),
)
return all_msgs, n_scanned
def _unique_channel_pairs(keys: list[tuple[str, str]]) -> list[tuple[str, str]]:
return sorted(set(keys))
def _unique_speaker_pairs(
keys: list[tuple[str, str]],
) -> list[tuple[str, str]]:
seen: dict[str, str] = {}
for uid, name in keys:
u = (uid or "").strip()
if not u:
continue
if u not in seen:
seen[u] = (name or "").strip() or "?"
return sorted(seen.items(), key=lambda x: x[0])
def _subset_channel_metadata(
full: dict[str, dict[str, str]] | None,
pairs: list[tuple[str, str]],
) -> dict[str, dict[str, str]] | None:
if not full:
return None
out: dict[str, dict[str, str]] = {}
for plat, cid in pairs:
k = f"{plat}:{cid}"
if k in full:
out[k] = dict(full[k])
return out or None
def _max_ts_per_channel(
line_channel_keys: list[tuple[str, str]],
line_timestamps: list[float],
) -> dict[tuple[str, str], float]:
acc: dict[tuple[str, str], float] = {}
for k, t in zip(line_channel_keys, line_timestamps):
if k not in acc or t > acc[k]:
acc[k] = t
return acc
def _speaker_prefetch_placeholder(max_chars: int) -> str:
if max_chars <= 0:
return ""
header = (
"## Existing knowledge graph (speakers — prefetch)\n"
"_Placeholder sizing for token budget._\n\n"
)
pad = max(0, int(max_chars) - len(header))
return header + ("·" * pad)
[docs]
async def token_count_conversation(
counter: KgBulkLlmClient,
conversation_text: str,
*,
channel_scope: str,
reserve: int,
config: Config | None = None,
chunk_channel_pairs: list[tuple[str, str]] | None = None,
chunk_speaker_pairs: list[tuple[str, str]] | None = None,
speaker_kg_prefetch: str = "",
channel_metadata: dict[str, dict[str, str]] | None = None,
) -> int:
msgs = messages_for_agentic_token_estimate(
conversation_text,
channel_id=channel_scope,
chunk_index=0,
config=config,
chunk_channel_pairs=chunk_channel_pairs,
chunk_speaker_pairs=chunk_speaker_pairs,
speaker_kg_prefetch=speaker_kg_prefetch,
channel_metadata=channel_metadata,
)
n = await counter.count_input_tokens(msgs)
if n is not None:
return int(n) + reserve
return max(1, len(conversation_text) // 3) + reserve
[docs]
async def chunk_message_lines(
lines: list[str],
line_channel_keys: list[tuple[str, str]],
line_speaker_keys: list[tuple[str, str]],
line_timestamps: list[float],
counter: KgBulkLlmClient,
*,
max_tokens: int,
channel_scope: str,
reserve: int,
config: Config | None = None,
speaker_prefetch_placeholder: str = "",
channel_metadata: dict[str, dict[str, str]] | None = None,
) -> list[
tuple[
list[str],
list[tuple[str, str]],
list[tuple[str, str]],
float,
float,
dict[tuple[str, str], float],
]
]:
"""Pack lines into token-budget chunks with per-chunk time + cursor hints."""
chunks: list[
tuple[
list[str],
list[tuple[str, str]],
list[tuple[str, str]],
float,
float,
dict[tuple[str, str], float],
]
] = []
i = 0
nlines = len(lines)
if len(line_channel_keys) != nlines or len(line_speaker_keys) != nlines:
raise ValueError("line_*_keys length must match lines")
if len(line_timestamps) != nlines:
raise ValueError("line_timestamps length must match lines")
while i < nlines:
low, high = 1, nlines - i
best = 0
while low <= high:
mid = (low + high) // 2
block = lines[i: i + mid]
text = "\n".join(block)
slice_keys = line_channel_keys[i: i + mid]
slice_spk = line_speaker_keys[i: i + mid]
pairs = _unique_channel_pairs(slice_keys)
spairs = _unique_speaker_pairs(slice_spk)
meta_slice = _subset_channel_metadata(channel_metadata, pairs)
tc = await token_count_conversation(
counter,
text,
channel_scope=channel_scope,
reserve=reserve,
config=config,
chunk_channel_pairs=pairs,
chunk_speaker_pairs=spairs,
speaker_kg_prefetch=speaker_prefetch_placeholder,
channel_metadata=meta_slice,
)
if tc <= max_tokens:
best = mid
low = mid + 1
else:
high = mid - 1
if best == 0:
best = 1
slice_keys = line_channel_keys[i: i + best]
slice_spk = line_speaker_keys[i: i + best]
slice_ts = line_timestamps[i: i + best]
ts_lo = min(slice_ts) if slice_ts else 0.0
ts_hi = max(slice_ts) if slice_ts else 0.0
cmap = _max_ts_per_channel(slice_keys, slice_ts)
chunks.append(
(
lines[i: i + best],
_unique_channel_pairs(slice_keys),
_unique_speaker_pairs(slice_spk),
ts_lo,
ts_hi,
cmap,
),
)
i += best
return chunks
[docs]
@dataclass
class KgBulkPipelineParams:
out_dir: Path
dump_only: bool = False
dry_run_chunks: bool = False
dry_run_llm: bool = False
chunk_tokens: int = 250_000
token_reserve: int = 100_000
max_messages: int = 0
chunks_max: int = 0
per_channel: bool = False
resume_from_chunk: int = 0
max_tool_rounds: int = 48
bulk_llm_backend: str = "gemini"
fetch_channel_metadata: bool = False
channel_metadata_ttl_days: float = 7.0
discord_platform_type: str = "discord"
prefetch_speaker_kg: bool = False
prefetch_max_speakers: int = 8
prefetch_hits_per_speaker: int = 3
prefetch_max_chars: int = 400_000
prefetch_min_score: float = 0.0
redis_no_verify: bool = True
incremental: bool = False
[docs]
async def run_agentic_bulk_pipeline(
cfg: Config,
redis: Any,
messages: list[dict[str, Any]],
n_zsets_scanned: int,
params: KgBulkPipelineParams,
*,
load_channel_metadata: Callable[
[],
Awaitable[dict[str, dict[str, str]]],
]
| None = None,
) -> None:
"""Chunk (optional disk artifacts), run agentic KG extraction, update cursors.
*load_channel_metadata*: optional async callable returning Discord (or other)
channel metadata dict; only invoked when *params.fetch_channel_metadata*.
After each successful extraction chunk, updates incremental cursors
per channel when *params.incremental* is true.
"""
out_dir = Path(params.out_dir).resolve()
out_dir.mkdir(parents=True, exist_ok=True)
msgs = messages
if params.max_messages and params.max_messages > 0:
msgs = msgs[: params.max_messages]
jsonl_path = out_dir / "messages.jsonl"
with jsonl_path.open("w", encoding="utf-8") as fj:
for m in msgs:
rec = {k: v for k, v in m.items() if k != "embedding"}
rec["llm_style_line"] = format_llm_style_line(rec)
fj.write(json.dumps(rec, ensure_ascii=False, default=str) + "\n")
lines = [format_llm_style_line(m) for m in msgs]
line_channel_keys = [
(str(m.get("platform", "")), str(m.get("channel_id", ""))) for m in msgs
]
line_speaker_keys = [
(str(m.get("user_id", "")), str(m.get("user_name", ""))) for m in msgs
]
line_timestamps = [float(m.get("timestamp") or 0) for m in msgs]
platforms_channels = sorted(
{f"{m.get('platform')}:{m.get('channel_id')}" for m in msgs},
)
zsets_in_msgs = len(
{
str(m.get("zset_key", "") or "")
for m in msgs
if m.get("zset_key")
},
)
meta = {
"message_count": len(msgs),
"zset_count": zsets_in_msgs,
"zset_count_redis_scanned": n_zsets_scanned,
"time_min": min((m["timestamp"] for m in msgs), default=0),
"time_max": max((m["timestamp"] for m in msgs), default=0),
"platforms_channels": platforms_channels,
}
(out_dir / "manifest.json").write_text(
json.dumps(meta, indent=2, default=str),
encoding="utf-8",
)
if params.dump_only:
logger.info("Dump-only: wrote %s and manifest.", jsonl_path)
return
global_channel_metadata: dict[str, dict[str, str]] = {}
if params.fetch_channel_metadata and load_channel_metadata is not None:
global_channel_metadata = await load_channel_metadata()
bulk_backend = (params.bulk_llm_backend or "gemini").strip().lower()
if bulk_backend == "gemini":
token_counter = create_kg_bulk_gemini_pool_client(max_tool_rounds=1)
else:
token_counter = create_kg_bulk_openrouter_client(
cfg.api_key,
gemini_api_key=cfg.gemini_api_key or "",
max_tool_rounds=1,
)
scope = CROSS_CHANNEL_SCOPE
reserve = int(params.token_reserve)
if params.prefetch_speaker_kg:
reserve += max(0, int(params.prefetch_max_chars) // 3)
if params.fetch_channel_metadata:
n_disc = sum(
1
for x in platforms_channels
if x.startswith("discord:") or x.startswith("discord-self:")
)
reserve += min(120_000, 2_000 * max(1, n_disc))
max_tok = int(params.chunk_tokens)
prefetch_placeholder = ""
if params.prefetch_speaker_kg:
prefetch_placeholder = _speaker_prefetch_placeholder(
int(params.prefetch_max_chars),
)
chunk_groups: list[
tuple[
str,
list[str],
list[tuple[str, str]],
list[tuple[str, str]],
float,
float,
dict[tuple[str, str], float],
]
] = []
if params.per_channel:
grouped: dict[
tuple[str, str],
list[
tuple[
str,
tuple[str, str],
tuple[str, str],
float,
]
],
] = defaultdict(list)
for m in msgs:
k = (str(m.get("platform", "")), str(m.get("channel_id", "")))
sp = (str(m.get("user_id", "")), str(m.get("user_name", "")))
grouped[k].append(
(
format_llm_style_line(m),
k,
sp,
float(m.get("timestamp") or 0),
),
)
for (plat, cid), triples in sorted(grouped.items()):
ch_scope = f"{plat}:{cid}" if plat or cid else CROSS_CHANNEL_SCOPE
ls = [t[0] for t in triples]
lck = [t[1] for t in triples]
lsk = [t[2] for t in triples]
lts = [t[3] for t in triples]
subchunks = await chunk_message_lines(
ls,
lck,
lsk,
lts,
token_counter,
max_tokens=max_tok,
channel_scope=ch_scope,
reserve=reserve,
config=cfg,
speaker_prefetch_placeholder=prefetch_placeholder,
channel_metadata=global_channel_metadata or None,
)
for bl, bp, spk, ts_lo, ts_hi, cmap in subchunks:
chunk_groups.append((ch_scope, bl, bp, spk, ts_lo, ts_hi, cmap))
else:
subchunks = await chunk_message_lines(
lines,
line_channel_keys,
line_speaker_keys,
line_timestamps,
token_counter,
max_tokens=max_tok,
channel_scope=scope,
reserve=reserve,
config=cfg,
speaker_prefetch_placeholder=prefetch_placeholder,
channel_metadata=global_channel_metadata or None,
)
for bl, bp, spk, ts_lo, ts_hi, cmap in subchunks:
chunk_groups.append((scope, bl, bp, spk, ts_lo, ts_hi, cmap))
cm = int(params.chunks_max or 0)
if cm > 0:
chunk_groups = chunk_groups[:cm]
logger.info(
"--chunks-max=%d: using first %d chunk(s)",
cm,
len(chunk_groups),
)
chunks_dir = out_dir / "chunks"
chunks_dir.mkdir(exist_ok=True)
for idx, (
ch_scope,
chunk_lines,
chunk_pairs,
chunk_spk,
ts_lo,
ts_hi,
_cmap,
) in enumerate(chunk_groups):
text = "\n".join(chunk_lines)
t0 = datetime.now(timezone.utc).isoformat()
t_start = (
datetime.fromtimestamp(ts_lo, tz=timezone.utc).isoformat()
if ts_lo > 0
else ""
)
t_end = (
datetime.fromtimestamp(ts_hi, tz=timezone.utc).isoformat()
if ts_hi > 0
else ""
)
meta_rows = _subset_channel_metadata(
global_channel_metadata or None,
list(chunk_pairs),
)
piece = {
"index": idx,
"channel_scope": ch_scope,
"channels_in_chunk": [list(p) for p in chunk_pairs],
"speakers_in_chunk": [list(p) for p in chunk_spk],
"channel_metadata": meta_rows if meta_rows is not None else {},
"line_count": len(chunk_lines),
"char_count": len(text),
"written_at": t0,
"time_range_hint_utc": [t_start, t_end],
"conversation_text": text,
}
(chunks_dir / f"chunk_{idx:05d}.json").write_text(
json.dumps(piece, ensure_ascii=False, indent=2),
encoding="utf-8",
)
if params.dry_run_chunks:
logger.info(
"Dry-run: wrote %d chunk files; skipping LLM.", len(chunk_groups),
)
await token_counter.close()
return
if params.dry_run_llm:
logger.info(
"Dry-run LLM: running agentic extraction without persisting to FalkorDB.",
)
or_for_kg = OpenRouterClient(
api_key=cfg.api_key,
model=cfg.model,
base_url=cfg.llm_base_url,
gemini_api_key=cfg.gemini_api_key or "",
)
cache2 = MessageCache(
redis_url=cfg.redis_url,
openrouter_client=or_for_kg,
embedding_model=cfg.embedding_model,
ssl_kwargs=redis_ssl_kwargs_for_bulk(
cfg, redis_no_verify=params.redis_no_verify,
),
)
kg = KnowledgeGraphManager(
redis_client=cache2.redis_client,
openrouter=or_for_kg,
embedding_model=cfg.embedding_model,
admin_user_ids=set(cfg.admin_user_ids) if cfg.admin_user_ids else None,
)
await kg.ensure_indexes()
if bulk_backend == "gemini":
bulk_client = create_kg_bulk_gemini_pool_client(
max_tool_rounds=params.max_tool_rounds,
)
else:
bulk_client = create_kg_bulk_openrouter_client(
cfg.api_key,
gemini_api_key=cfg.gemini_api_key or "",
max_tool_rounds=params.max_tool_rounds,
)
run_log: list[dict[str, Any]] = []
start_idx = max(0, int(params.resume_from_chunk))
pc_hint = ", ".join(platforms_channels)
try:
for idx, (
ch_scope,
chunk_lines,
chunk_pairs,
chunk_spk,
ts_lo,
ts_hi,
cmap,
) in enumerate(chunk_groups):
if idx < start_idx:
continue
text = "\n".join(chunk_lines)
t_start = (
datetime.fromtimestamp(ts_lo, tz=timezone.utc).isoformat()
if ts_lo > 0
else ""
)
t_end = (
datetime.fromtimestamp(ts_hi, tz=timezone.utc).isoformat()
if ts_hi > 0
else ""
)
meta_slice = _subset_channel_metadata(
global_channel_metadata or None,
list(chunk_pairs),
)
prefetch_text = ""
if params.prefetch_speaker_kg:
prefetch_text = await prefetch_speaker_kg_context(
kg,
list(chunk_spk),
max_speakers=int(params.prefetch_max_speakers),
hits_per_speaker=int(params.prefetch_hits_per_speaker),
min_score=float(params.prefetch_min_score),
)
stats = await run_agentic_kg_extraction_chunk(
conversation_text=text,
channel_id=ch_scope,
kg_manager=kg,
bulk_client=bulk_client,
user_id="000000000000",
chunk_index=idx,
time_start_iso=t_start,
time_end_iso=t_end,
platforms_channels=pc_hint,
config=cfg,
chunk_channel_pairs=list(chunk_pairs),
chunk_speaker_pairs=list(chunk_spk),
speaker_kg_prefetch=prefetch_text,
channel_metadata=meta_slice,
persist_extraction=not params.dry_run_llm,
)
stats["chunk_index"] = idx
stats["channel_scope"] = ch_scope
run_log.append(stats)
logger.info("Chunk %d scope=%s stats=%s", idx, ch_scope, stats)
err = int(stats.get("errors", 0) or 0)
parse_err = bool(stats.get("parse_error"))
if (
params.incremental
and not params.dry_run_llm
and err == 0
and not parse_err
and cmap
):
for (plat, cid), tmax in cmap.items():
if tmax <= 0:
continue
await cursor_hset(redis, plat, cid, tmax)
(out_dir / "extraction_run.json").write_text(
json.dumps(run_log, indent=2, default=str),
encoding="utf-8",
)
finally:
await bulk_client.close()
await cache2.close()
await token_counter.close()
[docs]
def resolve_bulk_backend(cli_value: str | None) -> str:
v = (cli_value or os.environ.get("KG_BULK_LLM_BACKEND", "") or "gemini")
return v.strip().lower()