Source code for tools.import_mcp_tool

"""Import an MCP tool from GitHub and convert it to a native Stargazer tool.

Clones the target repository, scans for MCP tool definitions, sends the
source to Cursor for conversion and then a second pass for security
auditing, and finally installs the result into the bot's tool registry
with embeddings and documentation.

Requires the UNSANDBOXED_EXEC privilege.
"""

from __future__ import annotations

import ast
import asyncio
import json
import logging
import os
import re
import shutil
import subprocess
import tempfile
from pathlib import Path
from typing import TYPE_CHECKING

from tools.alter_privileges import has_privilege, PRIVILEGES

if TYPE_CHECKING:
    from tool_context import ToolContext

logger = logging.getLogger(__name__)

TOOL_NAME = "import_mcp_tool"
TOOL_DESCRIPTION = (
    "Import an MCP (Model Context Protocol) tool from a GitHub repository and "
    "convert it into a native Stargazer tool. Clones the repo, uses Cursor AI "
    "to strip all MCP protocol boilerplate and rewrite the tool in the bot's "
    "native format, runs a security/privacy audit pass, then installs the tool "
    "with automatic registry reload, classifier embeddings, and documentation. "
    "Requires UNSANDBOXED_EXEC privilege. Accepts any GitHub repo URL containing "
    "an MCP server (Python FastMCP, raw mcp SDK, or TypeScript MCP servers)."
)

CURSOR_API_KEY = (
    "crsr_cc13d4b85df021a45f0d147b45784bf9285317816b227760510d130ebd49ff8b"
)
DEFAULT_MODEL = "gpt-5.3-codex-spark-preview"
CURSOR_TIMEOUT = 3600

MCP_INDICATORS = [
    "fastmcp", "FastMCP", "mcp.server", "mcp.types",
    "@mcp.tool", "@server.tool", "McpServer",
    "CallToolRequestSchema", "ListToolsRequestSchema",
    "from mcp", "import mcp",
]

SOURCE_EXTENSIONS = {
    ".py", ".ts", ".js", ".mjs", ".tsx", ".jsx",
}

CONFIG_FILENAMES = {
    "pyproject.toml", "package.json", "requirements.txt",
    "setup.py", "setup.cfg", "Cargo.toml",
}

MAX_CONTEXT_CHARS = 80_000
MAX_FILE_CHARS = 8_000
MAX_FILES = 15

TOOL_PARAMETERS = {
    "type": "object",
    "properties": {
        "github_url": {
            "type": "string",
            "description": (
                "GitHub URL of the MCP tool repository "
                "(e.g. 'https://github.com/user/mcp-weather')."
            ),
        },
        "tool_name": {
            "type": "string",
            "description": (
                "Override name for the generated tool file (without .py). "
                "If omitted, auto-detected from the repo name."
            ),
        },
        "model": {
            "type": "string",
            "description": (
                f"Cursor model to use for conversion. Default: {DEFAULT_MODEL}."
            ),
        },
    },
    "required": ["github_url"],
}


# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------

def _normalise_github_url(url: str) -> str:
    """Ensure the URL is a clonable HTTPS GitHub URL."""
    url = url.strip().rstrip("/")
    url = re.sub(r"#.*$", "", url)
    url = re.sub(r"/tree/[^/]+.*$", "", url)
    if not url.endswith(".git"):
        url += ".git"
    if url.startswith("git@"):
        url = url.replace("git@github.com:", "https://github.com/")
    return url


def _derive_tool_name(github_url: str) -> str:
    """Derive a snake_case tool name from the repo URL."""
    match = re.search(r"github\.com/[^/]+/([^/.]+)", github_url)
    name = match.group(1) if match else "imported_tool"
    name = re.sub(r"^mcp[-_]?", "", name, flags=re.IGNORECASE)
    name = re.sub(r"[-_]?mcp$", "", name, flags=re.IGNORECASE)
    name = re.sub(r"[^a-zA-Z0-9]", "_", name).strip("_").lower()
    return name or "imported_tool"


async def _clone_repo(url: str, dest: str) -> tuple[bool, str]:
    """Shallow-clone a GitHub repo.  Returns (success, message)."""
    proc = await asyncio.create_subprocess_exec(
        "git", "clone", "--depth", "1", url, dest,
        stdout=subprocess.PIPE, stderr=subprocess.PIPE,
    )
    _, stderr = await asyncio.wait_for(proc.communicate(), timeout=120)
    if proc.returncode != 0:
        return False, stderr.decode(errors="replace").strip()
    return True, ""


def _scan_repo(repo_dir: str) -> tuple[list[dict], list[dict]]:
    """Walk the repo for MCP source files and config files.

    Returns (source_files, config_files) where each entry is
    {"path": relative_path, "content": str, "mcp_score": int}.
    """
    source_files: list[dict] = []
    config_files: list[dict] = []

    for root, dirs, files in os.walk(repo_dir):
        dirs[:] = [
            d for d in dirs
            if d not in {"node_modules", ".git", "__pycache__", "venv", ".venv", "dist", "build"}
        ]
        for fname in files:
            fpath = os.path.join(root, fname)
            rel = os.path.relpath(fpath, repo_dir)
            ext = os.path.splitext(fname)[1].lower()

            if fname in CONFIG_FILENAMES:
                try:
                    content = Path(fpath).read_text(errors="replace")[:MAX_FILE_CHARS]
                    config_files.append({"path": rel, "content": content, "mcp_score": 0})
                except OSError:
                    pass
                continue

            if ext not in SOURCE_EXTENSIONS:
                continue

            try:
                content = Path(fpath).read_text(errors="replace")
            except OSError:
                continue

            score = sum(1 for ind in MCP_INDICATORS if ind in content)
            if score == 0 and len(source_files) >= 5:
                continue

            source_files.append({
                "path": rel,
                "content": content[:MAX_FILE_CHARS],
                "mcp_score": score,
            })

    source_files.sort(key=lambda f: f["mcp_score"], reverse=True)
    return source_files[:MAX_FILES], config_files


def _build_context_block(source_files: list[dict], config_files: list[dict]) -> str:
    """Build a single text block with all collected source context."""
    parts: list[str] = []
    total = 0

    for f in config_files:
        chunk = f"--- {f['path']} ---\n{f['content']}\n"
        if total + len(chunk) > MAX_CONTEXT_CHARS:
            break
        parts.append(chunk)
        total += len(chunk)

    for f in source_files:
        chunk = f"--- {f['path']} (MCP signals: {f['mcp_score']}) ---\n{f['content']}\n"
        if total + len(chunk) > MAX_CONTEXT_CHARS:
            break
        parts.append(chunk)
        total += len(chunk)

    return "\n".join(parts)


_REFERENCE_TOOL = '''\
"""Tool: example_tool
Example of the target single-tool format.
"""
import json
import aiohttp

TOOL_NAME = "example_tool"
TOOL_DESCRIPTION = "One-paragraph description the LLM reads at selection time."
TOOL_PARAMETERS = {
    "type": "object",
    "properties": {
        "query": {
            "type": "string",
            "description": "What to search for.",
        },
        "limit": {
            "type": "integer",
            "description": "Max results (default 5).",
        },
    },
    "required": ["query"],
}

async def run(query: str, limit: int = 5) -> str:
    """Core logic — must be async, must return str."""
    async with aiohttp.ClientSession() as session:
        async with session.get("https://api.example.com/search",
                               params={"q": query, "n": limit}) as resp:
            data = await resp.json()
    return json.dumps({"results": data})
'''

_MULTI_TOOL_REFERENCE = '''\
"""Multi-tool file format (use when the MCP server defines multiple tools)."""
import json

async def _search(query: str, limit: int = 5) -> str:
    return json.dumps({"results": []})

async def _lookup(id: str) -> str:
    return json.dumps({"item": {}})

TOOLS = [
    {
        "name": "example_search",
        "description": "Search for items.",
        "parameters": {
            "type": "object",
            "properties": {
                "query": {"type": "string", "description": "Search query."},
                "limit": {"type": "integer", "description": "Max results."},
            },
            "required": ["query"],
        },
        "handler": _search,
    },
    {
        "name": "example_lookup",
        "description": "Look up an item by ID.",
        "parameters": {
            "type": "object",
            "properties": {
                "id": {"type": "string", "description": "Item ID."},
            },
            "required": ["id"],
        },
        "handler": _lookup,
    },
]
'''


def _build_conversion_prompt(
    tool_name: str,
    context_block: str,
) -> str:
    return f"""\
You are converting an MCP (Model Context Protocol) tool into a native Stargazer bot tool.

## Source MCP code

{context_block}

## Target format

The output must be a SINGLE Python file written to: tools/{tool_name}.py

### Single-tool format (if the MCP server has one tool):

{_REFERENCE_TOOL}

### Multi-tool format (if the MCP server defines multiple tools):

{_MULTI_TOOL_REFERENCE}

## Conversion rules

1. Extract every tool defined in the MCP server.
2. Convert parameter schemas to JSON Schema (type/properties/required/description).
3. Preserve the core business logic.  Rewrite MCP-specific I/O (TextContent, etc.) to return plain str (JSON-serialised where appropriate).
4. All handler functions MUST be async and MUST return str.
5. Parameter names in the function signature must match the JSON Schema properties.
6. For non-standard dependencies, wrap imports in try/except ImportError with a feature flag:
   try:
       import some_lib
       SOME_LIB_AVAILABLE = True
   except ImportError:
       SOME_LIB_AVAILABLE = False
7. If the tool needs to make HTTP requests, prefer aiohttp or httpx (both already installed).
8. Strip ALL MCP protocol machinery: no FastMCP, no Server, no TextContent, no JSON-RPC, no mcp imports.
9. Include a module docstring.  TOOL_DESCRIPTION must be comprehensive and self-contained.
10. If the MCP server is TypeScript/JavaScript, rewrite the logic in Python.
11. Do NOT include any test code or if __name__ == "__main__" blocks.

Write the file now to tools/{tool_name}.py.  Do not create any other files.
"""


def _build_audit_prompt(tool_name: str) -> str:
    return f"""\
You are a security auditor.  Review the file tools/{tool_name}.py which was \
just generated by converting an untrusted MCP tool from a public GitHub repo.

This code is about to run INSIDE a production chatbot process with access to \
Redis, API keys, user data, and the full filesystem.  Your job is to find and \
FIX every security and privacy flaw.

## Checklist — flag and fix ALL of the following:

1. DATA EXFILTRATION: Remove/rewrite any outbound HTTP/socket calls that \
send bot internals, user data, API keys, tokens, Redis contents, filesystem \
paths, or environment variables to external endpoints.  Calls to the tool's \
own legitimate API are fine; calls to unknown/hardcoded suspicious URLs are not.

2. CODE INJECTION: Remove or rewrite any use of eval(), exec(), compile(), \
__import__(), importlib dynamic imports, pickle.loads, yaml.unsafe_load, \
subprocess (unless essential — if so, add a comment explaining why), \
os.system, ctypes, or similar.

3. FILESYSTEM ABUSE: Ensure the tool does not read/write outside its expected \
scope.  No reading ~/.ssh, /etc/passwd, config.yaml, .env, or credential \
files.  No open() on user-controllable paths without validation.

4. NETWORK ABUSE: No listening sockets, no reverse shells, no connecting to \
hardcoded attacker-controlled endpoints.

5. CREDENTIAL HARVESTING: Ensure no parameters, return values, or side \
effects expose or log API keys, tokens, passwords, or Redis data.

6. DEPENDENCY POISONING: Flag any pip install, subprocess calls that install \
packages, or runtime dependency fetching.  Dependencies should be documented \
in comments, not silently installed.

7. PRIVILEGE ESCALATION: Ensure the tool does not access ctx.redis, \
ctx.config, ctx.adapter, or ctx.tool_registry unless genuinely needed.  If it \
does, verify the access is scoped and safe.

8. OBFUSCATION: Flag base64-encoded strings, hex-encoded payloads, ROT13, \
or any encoded blobs that could hide malicious logic.  Decode and inspect.

## Output format

After making your fixes, add a MODULE-LEVEL constant at the very end of the \
file (after all other code) in exactly this format:

_AUDIT_REPORT = {{
    "issues_found": <int>,
    "issues_fixed": <int>,
    "details": [
        "<one-line description of each finding and fix>",
    ],
}}

If the code is clean, set issues_found and issues_fixed to 0 and details to \
["No security issues found."].

Fix everything in-place in tools/{tool_name}.py now.
"""


async def _run_cursor(
    prompt: str,
    model: str,
    cwd: str,
    timeout: int = CURSOR_TIMEOUT,
) -> tuple[bool, dict | None, str]:
    """Invoke cursor-agent.  Returns (success, response_json, stderr)."""
    cmd = [
        "cursor-agent",
        "--api-key", CURSOR_API_KEY,
        "--model", model,
        "--output-format", "json",
        "--yolo",
        "-p", prompt,
    ]
    proc = await asyncio.create_subprocess_exec(
        *cmd,
        stdout=subprocess.PIPE,
        stderr=subprocess.PIPE,
        cwd=cwd,
    )
    stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=timeout)
    stdout_text = stdout.decode(errors="replace") if stdout else ""
    stderr_text = stderr.decode(errors="replace") if stderr else ""

    response = None
    if stdout_text.strip():
        try:
            response = json.loads(stdout_text)
        except json.JSONDecodeError:
            response = {"raw_output": stdout_text[:4000]}

    return proc.returncode == 0, response, stderr_text


def _validate_tool_file(filepath: str) -> tuple[bool, list[str]]:
    """Parse and validate a generated tool file.  Returns (valid, errors)."""
    errors: list[str] = []
    try:
        source = Path(filepath).read_text()
    except OSError as exc:
        return False, [f"Cannot read file: {exc}"]

    try:
        tree = ast.parse(source)
    except SyntaxError as exc:
        return False, [f"Syntax error: {exc}"]

    has_registration = False
    for node in ast.walk(tree):
        if isinstance(node, ast.Assign):
            for target in node.targets:
                if isinstance(target, ast.Name) and target.id in ("TOOL_NAME", "TOOLS"):
                    has_registration = True

    if not has_registration:
        errors.append("No TOOL_NAME or TOOLS assignment found")

    has_func = any(
        isinstance(n, (ast.FunctionDef, ast.AsyncFunctionDef))
        for n in ast.walk(tree)
    )
    if not has_func:
        errors.append("No function definitions found")

    return len(errors) == 0, errors


def _extract_audit_report(filepath: str) -> dict:
    """Try to extract the _AUDIT_REPORT dict from the generated file."""
    default = {"issues_found": -1, "issues_fixed": -1, "details": ["Audit report not found in file."]}
    try:
        source = Path(filepath).read_text()
        tree = ast.parse(source)
    except Exception:
        return default

    for node in ast.walk(tree):
        if isinstance(node, ast.Assign):
            for target in node.targets:
                if isinstance(target, ast.Name) and target.id == "_AUDIT_REPORT":
                    try:
                        return ast.literal_eval(node.value)
                    except Exception:
                        return default
    return default


def _extract_tool_names(filepath: str) -> list[str]:
    """Extract registered tool names from a generated file."""
    names: list[str] = []
    try:
        source = Path(filepath).read_text()
        tree = ast.parse(source)
    except Exception:
        return names

    for node in ast.walk(tree):
        if isinstance(node, ast.Assign):
            for target in node.targets:
                if not isinstance(target, ast.Name):
                    continue
                if target.id == "TOOL_NAME":
                    try:
                        names.append(ast.literal_eval(node.value))
                    except Exception:
                        pass
                elif target.id == "TOOLS":
                    try:
                        tools_list = ast.literal_eval(node.value)
                        if isinstance(tools_list, list):
                            for t in tools_list:
                                if isinstance(t, dict) and "name" in t:
                                    names.append(t["name"])
                    except Exception:
                        pass
    return names


# ---------------------------------------------------------------------------
# Main entry point
# ---------------------------------------------------------------------------

[docs] async def run( github_url: str, tool_name: str = "", model: str = "", ctx: ToolContext | None = None, ) -> str: user_id = getattr(ctx, "user_id", "") or "" redis = getattr(ctx, "redis", None) config = getattr(ctx, "config", None) if not await has_privilege(redis, user_id, PRIVILEGES["UNSANDBOXED_EXEC"], config): return json.dumps({ "success": False, "error": "Requires UNSANDBOXED_EXEC privilege.", }) if not github_url or not github_url.strip(): return json.dumps({"success": False, "error": "Empty github_url."}) clone_url = _normalise_github_url(github_url) tool_name = (tool_name.strip() if tool_name else "") or _derive_tool_name(github_url) tool_name = re.sub(r"[^a-zA-Z0-9_]", "_", tool_name).lower().strip("_") or "imported_tool" selected_model = (model.strip() if model else "") or DEFAULT_MODEL bot_dir = os.getcwd() tool_filepath = os.path.join(bot_dir, "tools", f"{tool_name}.py") if os.path.exists(tool_filepath): return json.dumps({ "success": False, "error": f"tools/{tool_name}.py already exists. Choose a different name or delete it first.", }) tmp_dir = tempfile.mkdtemp(prefix="mcp_import_") try: # ---- 1. Clone ---- logger.info("Cloning %s into %s", clone_url, tmp_dir) clone_dest = os.path.join(tmp_dir, "repo") ok, err_msg = await _clone_repo(clone_url, clone_dest) if not ok: return json.dumps({ "success": False, "error": f"git clone failed: {err_msg}", }) # ---- 2. Scan ---- source_files, config_files = _scan_repo(clone_dest) mcp_files = [f for f in source_files if f["mcp_score"] > 0] if not mcp_files: return json.dumps({ "success": False, "error": "No MCP tool definitions found in the repository.", "files_scanned": [f["path"] for f in source_files], }) context_block = _build_context_block(source_files, config_files) logger.info( "Scanned repo: %d source files (%d with MCP signals), %d config files, %d chars context", len(source_files), len(mcp_files), len(config_files), len(context_block), ) # ---- 3. Conversion via Cursor ---- conversion_prompt = _build_conversion_prompt(tool_name, context_block) logger.info("Running Cursor conversion pass (model=%s)...", selected_model) try: conv_ok, conv_resp, conv_stderr = await _run_cursor( conversion_prompt, selected_model, bot_dir, ) except asyncio.TimeoutError: return json.dumps({ "success": False, "error": "Cursor conversion timed out.", }) if not conv_ok: return json.dumps({ "success": False, "error": "Cursor conversion failed.", "cursor_response": conv_resp, "stderr": conv_stderr[:2000], }) # ---- 4. Verify conversion output ---- if not os.path.exists(tool_filepath): return json.dumps({ "success": False, "error": f"Cursor did not create tools/{tool_name}.py.", "cursor_response": conv_resp, }) valid, validation_errors = _validate_tool_file(tool_filepath) if not valid: os.remove(tool_filepath) return json.dumps({ "success": False, "error": "Generated tool failed validation.", "validation_errors": validation_errors, "cursor_response": conv_resp, }) # ---- 5. Security audit via Cursor ---- audit_prompt = _build_audit_prompt(tool_name) logger.info("Running Cursor security audit pass...") try: audit_ok, audit_resp, audit_stderr = await _run_cursor( audit_prompt, selected_model, bot_dir, ) except asyncio.TimeoutError: logger.error("Security audit timed out — deleting tool file (fail closed)") if os.path.exists(tool_filepath): os.remove(tool_filepath) return json.dumps({ "success": False, "error": "Security audit timed out. Tool file deleted (fail closed).", }) if not audit_ok: logger.error("Security audit Cursor pass failed — deleting tool file") if os.path.exists(tool_filepath): os.remove(tool_filepath) return json.dumps({ "success": False, "error": "Security audit Cursor pass failed. Tool file deleted (fail closed).", "cursor_response": audit_resp, "stderr": audit_stderr[:2000], }) # Re-validate after audit modifications valid, validation_errors = _validate_tool_file(tool_filepath) if not valid: if os.path.exists(tool_filepath): os.remove(tool_filepath) return json.dumps({ "success": False, "error": "Tool file invalid after security audit pass.", "validation_errors": validation_errors, }) audit_report = _extract_audit_report(tool_filepath) logger.info( "Audit complete: %d found, %d fixed", audit_report.get("issues_found", -1), audit_report.get("issues_fixed", -1), ) # ---- 6. Extract tool names ---- registered_names = _extract_tool_names(tool_filepath) if not registered_names: registered_names = [tool_name] # ---- 7. Reload tool registry ---- reload_success = False registry = getattr(ctx, "tool_registry", None) if ctx else None if registry is not None and config is not None: try: from tool_loader import load_tools old_permissions = dict(registry._permissions) registry._tools.clear() registry.invalidate_cache() load_tools(getattr(config, "tools_dir", "tools"), registry) registry._permissions = old_permissions reload_success = True logger.info("Tool registry reloaded after importing %s", tool_name) except Exception as exc: logger.error("Registry reload failed: %s", exc, exc_info=True) # ---- 8. Refresh embeddings ---- embedding_success = False if registered_names: try: from classifiers.refresh_tool_embeddings import refresh_tool_embeddings tools_dir = getattr(config, "tools_dir", "tools") if config else "tools" embedding_success = await refresh_tool_embeddings( tool_names=registered_names, tools_dir=tools_dir, ) if embedding_success: logger.info("Embeddings refreshed for: %s", ", ".join(registered_names)) except Exception as exc: logger.error("Embedding refresh failed: %s", exc, exc_info=True) # ---- 9. Generate docs ---- docs_generated = False try: from tools.write_python_tool import _generate_docs_stub module_name = tool_name docs_generated = _generate_docs_stub(module_name) except Exception as exc: logger.error("Docs generation failed: %s", exc, exc_info=True) return json.dumps({ "success": True, "tool_file": f"tools/{tool_name}.py", "tool_names": registered_names, "source_repo": github_url, "reload_success": reload_success, "embedding_success": embedding_success, "docs_generated": docs_generated, "audit": { "passed": audit_report.get("issues_found", -1) >= 0, "issues_found": audit_report.get("issues_found", 0), "issues_fixed": audit_report.get("issues_fixed", 0), "details": audit_report.get("details", []), }, "mcp_files_found": [f["path"] for f in mcp_files], }) except Exception as exc: logger.error("import_mcp_tool failed: %s", exc, exc_info=True) if os.path.exists(tool_filepath): os.remove(tool_filepath) return json.dumps({"success": False, "error": f"Unexpected error: {exc}"}) finally: shutil.rmtree(tmp_dir, ignore_errors=True)