Files
Bjorn/mcp_server.py
infinition b759ab6d4b Add LLM configuration and MCP server management UI and backend functionality
- Implemented a new SPA page for LLM Bridge and MCP Server settings in `llm-config.js`.
- Added functionality for managing LLM and MCP configurations, including toggling, saving settings, and testing connections.
- Created HTTP endpoints in `llm_utils.py` for handling LLM chat, status checks, and MCP server configuration.
- Integrated model fetching from LaRuche and Ollama backends.
- Enhanced error handling and logging for better debugging and user feedback.
2026-03-16 20:33:22 +01:00

334 lines
12 KiB
Python

# mcp_server.py
# Model Context Protocol server for Bjorn.
# Exposes Bjorn's database and actions as MCP tools consumable by any MCP client
# (Claude Desktop, custom agents, etc.).
#
# Transport: HTTP SSE (default, port configurable) or stdio.
# Requires: pip install mcp
# Gracefully no-ops if mcp is not installed.
import json
import threading
import time
from typing import Any, Dict, List, Optional
from logger import Logger
logger = Logger(name="mcp_server.py", level=20)
# ---------------------------------------------------------------------------
# Lazy shared_data import (avoids circular imports at module level)
# ---------------------------------------------------------------------------
_shared_data = None
def _sd():
global _shared_data
if _shared_data is None:
from init_shared import shared_data
_shared_data = shared_data
return _shared_data
def _tool_allowed(name: str) -> bool:
allowed = _sd().config.get("mcp_allowed_tools", [])
return name in allowed
# ---------------------------------------------------------------------------
# Tool implementations (pure functions, no MCP deps)
# ---------------------------------------------------------------------------
def _impl_get_hosts(alive_only: bool = True) -> str:
try:
sql = "SELECT ip, mac, hostname, os, alive, ports_open FROM hosts"
if alive_only:
sql += " WHERE alive=1"
sql += " ORDER BY ip"
rows = _sd().db.query(sql, ())
result = [dict(r) for r in rows] if rows else []
return json.dumps(result, default=str)
except Exception as e:
return json.dumps({"error": str(e)})
def _impl_get_vulnerabilities(host_ip: Optional[str] = None, limit: int = 100) -> str:
try:
if host_ip:
sql = ("SELECT v.ip, v.port, v.cve_id, v.severity, v.description "
"FROM vulnerabilities v WHERE v.ip=? ORDER BY v.severity DESC LIMIT ?")
rows = _sd().db.query(sql, (host_ip, limit))
else:
sql = ("SELECT v.ip, v.port, v.cve_id, v.severity, v.description "
"FROM vulnerabilities v ORDER BY v.severity DESC LIMIT ?")
rows = _sd().db.query(sql, (limit,))
return json.dumps([dict(r) for r in rows] if rows else [], default=str)
except Exception as e:
return json.dumps({"error": str(e)})
def _impl_get_credentials(service: Optional[str] = None, limit: int = 100) -> str:
try:
if service:
sql = ("SELECT ip, port, service, username, password, found_at "
"FROM credentials WHERE service=? ORDER BY found_at DESC LIMIT ?")
rows = _sd().db.query(sql, (service, limit))
else:
sql = ("SELECT ip, port, service, username, password, found_at "
"FROM credentials ORDER BY found_at DESC LIMIT ?")
rows = _sd().db.query(sql, (limit,))
return json.dumps([dict(r) for r in rows] if rows else [], default=str)
except Exception as e:
return json.dumps({"error": str(e)})
def _impl_get_action_history(limit: int = 50, action_name: Optional[str] = None) -> str:
try:
if action_name:
sql = ("SELECT action_name, target_ip, status, result, started_at, finished_at "
"FROM action_history WHERE action_name=? ORDER BY started_at DESC LIMIT ?")
rows = _sd().db.query(sql, (action_name, limit))
else:
sql = ("SELECT action_name, target_ip, status, result, started_at, finished_at "
"FROM action_history ORDER BY started_at DESC LIMIT ?")
rows = _sd().db.query(sql, (limit,))
return json.dumps([dict(r) for r in rows] if rows else [], default=str)
except Exception as e:
return json.dumps({"error": str(e)})
def _impl_get_status() -> str:
try:
sd = _sd()
return json.dumps({
"operation_mode": sd.operation_mode,
"active_action": getattr(sd, "active_action", None),
"bjorn_status": getattr(sd, "bjorn_status_text", "IDLE"),
"bjorn_says": getattr(sd, "bjorn_says", ""),
"hosts_discovered": getattr(sd, "target_count", 0),
"vulnerabilities": getattr(sd, "vuln_count", 0),
"credentials": getattr(sd, "cred_count", 0),
"current_ip": getattr(sd, "current_ip", ""),
"current_ssid": getattr(sd, "current_ssid", ""),
})
except Exception as e:
return json.dumps({"error": str(e)})
_MCP_PRIORITY = 80 # Higher than scheduler default (40) and queue_action default (50)
def _impl_run_action(action_name: str, target_ip: str, target_mac: str = "") -> str:
"""Queue a Bjorn action with MCP priority boost. Returns queue confirmation."""
try:
sd = _sd()
# Resolve MAC from IP if not supplied
mac = target_mac or ""
if not mac and target_ip:
try:
row = sd.db.query_one(
"SELECT mac_address FROM hosts WHERE ip=? LIMIT 1", (target_ip,)
)
if row:
mac = row["mac_address"]
except Exception:
pass
sd.db.queue_action(
action_name=action_name,
mac=mac,
ip=target_ip,
priority=_MCP_PRIORITY,
trigger="mcp",
metadata={"decision_method": "mcp", "decision_origin": "mcp"},
)
# Wake the orchestrator immediately (it sleeps up to 5 s when idle)
try:
sd.queue_event.set()
except Exception:
pass
return json.dumps({
"status": "queued",
"action": action_name,
"target": target_ip,
"priority": _MCP_PRIORITY,
})
except Exception as e:
return json.dumps({"error": str(e)})
def _impl_query_db(sql: str, params: Optional[List] = None) -> str:
"""Run a read-only SELECT query. Non-SELECT statements are rejected."""
try:
stripped = sql.strip().upper()
if not stripped.startswith("SELECT"):
return json.dumps({"error": "Only SELECT queries are allowed."})
rows = _sd().db.query(sql, tuple(params or []))
return json.dumps([dict(r) for r in rows] if rows else [], default=str)
except Exception as e:
return json.dumps({"error": str(e)})
# ---------------------------------------------------------------------------
# MCP Server setup (requires `pip install mcp`)
# ---------------------------------------------------------------------------
def _build_mcp_server():
"""Build and return a FastMCP server instance, or None if mcp not available."""
try:
from mcp.server.fastmcp import FastMCP
except ImportError:
logger.warning("mcp package not installed — MCP server disabled. "
"Run: pip install mcp")
return None
mcp = FastMCP(
name="bjorn",
version="1.0.0",
instructions=(
"Bjorn is a Raspberry Pi network security tool. "
"Use these tools to query discovered hosts, vulnerabilities, credentials, "
"and action history, or to queue new actions."
),
)
# ---- Tool registrations ----------------------------------------
@mcp.tool()
def get_hosts(alive_only: bool = True) -> str:
"""Return all network hosts discovered by Bjorn's scanner.
Set alive_only=false to include hosts that are currently offline."""
if not _tool_allowed("get_hosts"):
return json.dumps({"error": "Tool disabled in Bjorn MCP config."})
return _impl_get_hosts(alive_only)
@mcp.tool()
def get_vulnerabilities(host_ip: str = "", limit: int = 100) -> str:
"""Return discovered vulnerabilities. Optionally filter by host_ip."""
if not _tool_allowed("get_vulnerabilities"):
return json.dumps({"error": "Tool disabled in Bjorn MCP config."})
return _impl_get_vulnerabilities(host_ip or None, limit)
@mcp.tool()
def get_credentials(service: str = "", limit: int = 100) -> str:
"""Return captured credentials. Optionally filter by service (ssh, ftp, smb…)."""
if not _tool_allowed("get_credentials"):
return json.dumps({"error": "Tool disabled in Bjorn MCP config."})
return _impl_get_credentials(service or None, limit)
@mcp.tool()
def get_action_history(limit: int = 50, action_name: str = "") -> str:
"""Return the history of executed actions, most recent first."""
if not _tool_allowed("get_action_history"):
return json.dumps({"error": "Tool disabled in Bjorn MCP config."})
return _impl_get_action_history(limit, action_name or None)
@mcp.tool()
def get_status() -> str:
"""Return Bjorn's current operational status, counters, and active action."""
if not _tool_allowed("get_status"):
return json.dumps({"error": "Tool disabled in Bjorn MCP config."})
return _impl_get_status()
@mcp.tool()
def run_action(action_name: str, target_ip: str, target_mac: str = "") -> str:
"""Queue a Bjorn action (e.g. ssh_bruteforce) against target_ip.
The action will be executed by Bjorn's orchestrator."""
if not _tool_allowed("run_action"):
return json.dumps({"error": "Tool disabled in Bjorn MCP config."})
return _impl_run_action(action_name, target_ip, target_mac)
@mcp.tool()
def query_db(sql: str, params: str = "[]") -> str:
"""Run a read-only SELECT query against Bjorn's SQLite database.
params must be a JSON array of bind parameters."""
if not _tool_allowed("query_db"):
return json.dumps({"error": "Tool disabled in Bjorn MCP config."})
try:
p = json.loads(params)
except Exception:
p = []
return _impl_query_db(sql, p)
return mcp
# ---------------------------------------------------------------------------
# Server lifecycle
# ---------------------------------------------------------------------------
_server_thread: Optional[threading.Thread] = None
_mcp_instance = None
def start(block: bool = False) -> bool:
"""
Start the MCP server in a daemon thread.
Args:
block: If True, run in the calling thread (for stdio mode).
Returns:
True if started successfully, False otherwise.
"""
global _server_thread, _mcp_instance
sd = _sd()
if not sd.config.get("mcp_enabled", False):
logger.debug("MCP server disabled in config (mcp_enabled=False)")
return False
mcp = _build_mcp_server()
if mcp is None:
return False
_mcp_instance = mcp
transport = sd.config.get("mcp_transport", "http")
port = int(sd.config.get("mcp_port", 8765))
def _run():
try:
if transport == "stdio":
logger.info("MCP server starting (stdio transport)")
mcp.run(transport="stdio")
else:
logger.info(f"MCP server starting (HTTP SSE transport, port {port})")
# FastMCP HTTP SSE — runs uvicorn internally
mcp.run(transport="sse", port=port)
except Exception as e:
logger.error(f"MCP server error: {e}")
if block:
_run()
return True
_server_thread = threading.Thread(target=_run, daemon=True, name="MCPServer")
_server_thread.start()
logger.info(f"MCP server thread started (transport={transport})")
return True
def stop() -> None:
"""Signal MCP server to stop (best-effort — FastMCP handles cleanup)."""
global _server_thread
if _server_thread and _server_thread.is_alive():
logger.info("MCP server thread stopping (daemon — will exit with process)")
_server_thread = None
def is_running() -> bool:
return _server_thread is not None and _server_thread.is_alive()
def server_status() -> Dict[str, Any]:
sd = _sd()
return {
"enabled": sd.config.get("mcp_enabled", False),
"running": is_running(),
"transport": sd.config.get("mcp_transport", "http"),
"port": sd.config.get("mcp_port", 8765),
"allowed_tools": sd.config.get("mcp_allowed_tools", []),
}