"""Jinja2-based system prompt renderer with SSTI hardening.
Loads a ``.j2`` template file once at startup and renders it on each call
with room-specific and tool-specific context variables. Uses a
:class:`~jinja2.sandbox.SandboxedEnvironment` and recursively sanitises
user-controllable values to prevent server-side template injection.
"""
from __future__ import annotations
import logging
import re
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
from jinja2.sandbox import SandboxedEnvironment
from jinja2 import FileSystemLoader
logger = logging.getLogger(__name__)
_JINJA_META_RE = re.compile(r"\{\{|\}\}|\{%|%\}|\{#|#\}")
[docs]
def sanitize_context(value: Any) -> Any:
"""Recursively strip Jinja2 metacharacters from user-controllable strings.
Replaces ``{{``, ``}}``, ``{%``, ``%}``, ``{#``, ``#}`` with
Unicode full-width lookalikes so they cannot be interpreted as
template syntax if an ``| tojson`` filter is ever omitted.
Non-string leaves (ints, floats, bools, ``None``) pass through
unchanged. Dicts and lists are walked recursively.
"""
if isinstance(value, str):
return _JINJA_META_RE.sub(_replace_meta, value)
if isinstance(value, dict):
return {k: sanitize_context(v) for k, v in value.items()}
if isinstance(value, (list, tuple)):
sanitized = [sanitize_context(v) for v in value]
return type(value)(sanitized)
return value
def _replace_meta(match: re.Match) -> str:
"""Map each Jinja2 metacharacter pair to a safe full-width substitute."""
return {
"{{": "\uff5b\uff5b",
"}}": "\uff5d\uff5d",
"{%": "\uff5b\uff05",
"%}": "\uff05\uff5d",
"{#": "\uff5b\uff03",
"#}": "\uff03\uff5d",
}[match.group()]
[docs]
class PromptRenderer:
"""Render a Jinja2 system-prompt template with per-request context.
Uses :class:`~jinja2.sandbox.SandboxedEnvironment` to prevent
template injection even if a caller accidentally passes unsanitised
user data.
Parameters
----------
template_path:
Path to the ``.j2`` template file (e.g. ``"system_prompt.j2"``).
default_extras:
Optional dict of variables injected into **every** render call
(e.g. the list of registered tools). Per-call context takes
precedence over these defaults.
"""
[docs]
def __init__(
self,
template_path: str | Path,
default_extras: dict[str, Any] | None = None,
) -> None:
"""Initialize the instance.
Args:
template_path (str | Path): The template path value.
default_extras (dict[str, Any] | None): The default extras value.
"""
path = Path(template_path)
if not path.exists():
raise FileNotFoundError(
f"System prompt template not found: {path}",
)
self._env = SandboxedEnvironment(
loader=FileSystemLoader(str(path.parent)),
keep_trailing_newline=True,
trim_blocks=True,
lstrip_blocks=True,
)
self._template = self._env.get_template(path.name)
self.default_extras: dict[str, Any] = default_extras or {}
logger.info("Loaded system prompt template from %s", path)
[docs]
def render(self, context: dict[str, Any] | None = None) -> str:
"""Render the template with the supplied *context*.
All values in *context* are recursively sanitised to strip
Jinja2 metacharacters before rendering.
The following keys are automatically injected if not already present:
* ``current_date`` -- today's date in ``YYYY-MM-DD`` format (UTC).
Keys from *default_extras* (set at init or later) are included but
can be overridden by *context*.
"""
ctx: dict[str, Any] = {}
ctx.update(sanitize_context(dict(self.default_extras)))
if context:
ctx.update(sanitize_context(context))
ctx.setdefault(
"current_date",
datetime.now(timezone.utc).strftime("%Y-%m-%d"),
)
return self._template.render(ctx)