mirror of
https://github.com/infinition/Bjorn.git
synced 2026-03-19 02:00:24 +00:00
feat: Implement package management utilities with JSON endpoints for listing and uninstalling packages feat: Create plugin management utilities with endpoints for listing, configuring, and installing plugins feat: Develop schedule and trigger management utilities with CRUD operations for schedules and triggers
392 lines
14 KiB
Python
392 lines
14 KiB
Python
"""script_scheduler.py - Background daemon for scheduled scripts and conditional triggers."""
|
|
|
|
import json
|
|
import threading
|
|
import time
|
|
import subprocess
|
|
import os
|
|
import logging
|
|
from datetime import datetime, timedelta
|
|
from typing import Optional, Dict, Any, List
|
|
|
|
from logger import Logger
|
|
|
|
logger = Logger(name="script_scheduler", level=logging.DEBUG)
|
|
|
|
|
|
def evaluate_conditions(node: dict, db) -> bool:
|
|
"""Recursively evaluate a condition tree (AND/OR groups + leaf conditions)."""
|
|
if not node or not isinstance(node, dict):
|
|
return False
|
|
|
|
node_type = node.get("type", "condition")
|
|
|
|
if node_type == "group":
|
|
op = node.get("op", "AND").upper()
|
|
children = node.get("children", [])
|
|
if not children:
|
|
return True
|
|
results = [evaluate_conditions(c, db) for c in children]
|
|
return all(results) if op == "AND" else any(results)
|
|
|
|
# Leaf condition
|
|
source = node.get("source", "")
|
|
|
|
if source == "action_result":
|
|
return _eval_action_result(node, db)
|
|
elif source == "hosts_with_port":
|
|
return _eval_hosts_with_port(node, db)
|
|
elif source == "hosts_alive":
|
|
return _eval_hosts_alive(node, db)
|
|
elif source == "cred_found":
|
|
return _eval_cred_found(node, db)
|
|
elif source == "has_vuln":
|
|
return _eval_has_vuln(node, db)
|
|
elif source == "db_count":
|
|
return _eval_db_count(node, db)
|
|
elif source == "time_after":
|
|
return _eval_time_after(node)
|
|
elif source == "time_before":
|
|
return _eval_time_before(node)
|
|
|
|
logger.warning(f"Unknown condition source: {source}")
|
|
return False
|
|
|
|
|
|
def _compare(actual, check, expected):
|
|
"""Generic numeric comparison."""
|
|
try:
|
|
actual = float(actual)
|
|
expected = float(expected)
|
|
except (ValueError, TypeError):
|
|
return str(actual) == str(expected)
|
|
|
|
if check == "eq": return actual == expected
|
|
if check == "neq": return actual != expected
|
|
if check == "gt": return actual > expected
|
|
if check == "lt": return actual < expected
|
|
if check == "gte": return actual >= expected
|
|
if check == "lte": return actual <= expected
|
|
return False
|
|
|
|
|
|
def _eval_action_result(node, db):
|
|
"""Check last result of a specific action in the action_queue."""
|
|
action = node.get("action", "")
|
|
check = node.get("check", "eq")
|
|
value = node.get("value", "success")
|
|
row = db.query_one(
|
|
"SELECT status FROM action_queue WHERE action_name=? ORDER BY updated_at DESC LIMIT 1",
|
|
(action,)
|
|
)
|
|
if not row:
|
|
return False
|
|
return _compare(row["status"], check, value)
|
|
|
|
|
|
def _eval_hosts_with_port(node, db):
|
|
"""Count alive hosts with a specific port open."""
|
|
port = str(node.get("port", ""))
|
|
check = node.get("check", "gt")
|
|
value = node.get("value", 0)
|
|
# ports column is semicolon-separated
|
|
rows = db.query(
|
|
"SELECT COUNT(1) c FROM hosts WHERE alive=1 AND (ports LIKE ? OR ports LIKE ? OR ports LIKE ? OR ports=?)",
|
|
(f"{port};%", f"%;{port};%", f"%;{port}", port)
|
|
)
|
|
count = rows[0]["c"] if rows else 0
|
|
return _compare(count, check, value)
|
|
|
|
|
|
def _eval_hosts_alive(node, db):
|
|
"""Count alive hosts."""
|
|
check = node.get("check", "gt")
|
|
value = node.get("value", 0)
|
|
row = db.query_one("SELECT COUNT(1) c FROM hosts WHERE alive=1")
|
|
count = row["c"] if row else 0
|
|
return _compare(count, check, value)
|
|
|
|
|
|
def _eval_cred_found(node, db):
|
|
"""Check if credentials exist for a service."""
|
|
service = node.get("service", "")
|
|
row = db.query_one("SELECT COUNT(1) c FROM creds WHERE service=?", (service,))
|
|
return (row["c"] if row else 0) > 0
|
|
|
|
|
|
def _eval_has_vuln(node, db):
|
|
"""Check if any vulnerabilities exist."""
|
|
row = db.query_one("SELECT COUNT(1) c FROM vulnerabilities WHERE active=1")
|
|
return (row["c"] if row else 0) > 0
|
|
|
|
|
|
def _eval_db_count(node, db):
|
|
"""Count rows in a whitelisted table with simple conditions."""
|
|
ALLOWED_TABLES = {"hosts", "creds", "vulnerabilities", "action_queue", "services"}
|
|
table = node.get("table", "")
|
|
if table not in ALLOWED_TABLES:
|
|
logger.warning(f"db_count: table '{table}' not in whitelist")
|
|
return False
|
|
|
|
where = node.get("where", {})
|
|
check = node.get("check", "gt")
|
|
value = node.get("value", 0)
|
|
|
|
# Build parameterized WHERE clause
|
|
conditions = []
|
|
params = []
|
|
for k, v in where.items():
|
|
# Only allow simple alphanumeric column names
|
|
if k.isalnum():
|
|
conditions.append(f"{k}=?")
|
|
params.append(v)
|
|
|
|
sql = f"SELECT COUNT(1) c FROM {table}"
|
|
if conditions:
|
|
sql += " WHERE " + " AND ".join(conditions)
|
|
|
|
row = db.query_one(sql, tuple(params))
|
|
count = row["c"] if row else 0
|
|
return _compare(count, check, value)
|
|
|
|
|
|
def _eval_time_after(node):
|
|
"""Check if current time is after a given hour:minute."""
|
|
hour = int(node.get("hour", 0))
|
|
minute = int(node.get("minute", 0))
|
|
now = datetime.now()
|
|
return (now.hour, now.minute) >= (hour, minute)
|
|
|
|
|
|
def _eval_time_before(node):
|
|
"""Check if current time is before a given hour:minute."""
|
|
hour = int(node.get("hour", 23))
|
|
minute = int(node.get("minute", 59))
|
|
now = datetime.now()
|
|
return (now.hour, now.minute) < (hour, minute)
|
|
|
|
|
|
class ScriptSchedulerDaemon(threading.Thread):
|
|
"""Lightweight 30s tick daemon for script schedules and conditional triggers."""
|
|
|
|
MAX_PENDING_EVENTS = 100
|
|
MAX_CONCURRENT_SCRIPTS = 4
|
|
|
|
def __init__(self, shared_data):
|
|
super().__init__(daemon=True, name="ScriptScheduler")
|
|
self.shared_data = shared_data
|
|
self.db = shared_data.db
|
|
self._stop = threading.Event()
|
|
self.check_interval = 30
|
|
self._pending_action_events = []
|
|
self._events_lock = threading.Lock()
|
|
self._active_threads = 0
|
|
self._threads_lock = threading.Lock()
|
|
|
|
def run(self):
|
|
logger.info("ScriptSchedulerDaemon started (30s tick)")
|
|
# Initial delay to let the system boot
|
|
if self._stop.wait(10):
|
|
return
|
|
while not self._stop.is_set():
|
|
try:
|
|
self._check_schedules()
|
|
self._check_triggers()
|
|
except Exception as e:
|
|
logger.error(f"Scheduler tick error: {e}")
|
|
self._stop.wait(self.check_interval)
|
|
logger.info("ScriptSchedulerDaemon stopped")
|
|
|
|
def stop(self):
|
|
self._stop.set()
|
|
|
|
def notify_action_complete(self, action_name: str, mac: str, success: bool):
|
|
"""Called from orchestrator when an action finishes. Queues an event for next tick."""
|
|
with self._events_lock:
|
|
if len(self._pending_action_events) >= self.MAX_PENDING_EVENTS:
|
|
self._pending_action_events.pop(0)
|
|
self._pending_action_events.append({
|
|
"action": action_name,
|
|
"mac": mac,
|
|
"success": success,
|
|
})
|
|
|
|
def _check_schedules(self):
|
|
"""Query due schedules and fire each in a separate thread."""
|
|
try:
|
|
due = self.db.get_due_schedules()
|
|
except Exception as e:
|
|
logger.error(f"Failed to query due schedules: {e}")
|
|
return
|
|
|
|
for sched in due:
|
|
sched_id = sched["id"]
|
|
script_name = sched["script_name"]
|
|
args = sched.get("args", "") or ""
|
|
|
|
# Check conditions if any
|
|
conditions_raw = sched.get("conditions")
|
|
if conditions_raw:
|
|
try:
|
|
conditions = json.loads(conditions_raw) if isinstance(conditions_raw, str) else conditions_raw
|
|
if conditions and not evaluate_conditions(conditions, self.db):
|
|
logger.debug(f"Schedule {sched_id} conditions not met, skipping")
|
|
continue
|
|
except Exception as e:
|
|
logger.warning(f"Schedule {sched_id} condition eval failed: {e}")
|
|
|
|
# Respect concurrency limit
|
|
with self._threads_lock:
|
|
if self._active_threads >= self.MAX_CONCURRENT_SCRIPTS:
|
|
logger.debug(f"Skipping schedule {sched_id}: max concurrent scripts reached")
|
|
continue
|
|
|
|
logger.info(f"Firing scheduled script: {script_name} (schedule={sched_id})")
|
|
self.db.mark_schedule_run(sched_id, "running")
|
|
|
|
threading.Thread(
|
|
target=self._run_with_tracking,
|
|
args=(sched_id, script_name, args),
|
|
daemon=True
|
|
).start()
|
|
|
|
def _run_with_tracking(self, sched_id: int, script_name: str, args: str):
|
|
"""Thread wrapper that tracks active count for concurrency limiting."""
|
|
with self._threads_lock:
|
|
self._active_threads += 1
|
|
try:
|
|
self._execute_scheduled(sched_id, script_name, args)
|
|
finally:
|
|
with self._threads_lock:
|
|
self._active_threads = max(0, self._active_threads - 1)
|
|
|
|
def _execute_scheduled(self, sched_id: int, script_name: str, args: str):
|
|
"""Run the script and record result. When sched_id is 0 (trigger-fired), skip schedule updates."""
|
|
process = None
|
|
try:
|
|
# Look up the action in DB to determine format and path
|
|
action = None
|
|
for a in self.db.list_actions():
|
|
if a["b_class"] == script_name or a["b_module"] == script_name:
|
|
action = a
|
|
break
|
|
|
|
if not action:
|
|
if sched_id > 0:
|
|
self.db.mark_schedule_run(sched_id, "error", f"Action {script_name} not found")
|
|
return
|
|
|
|
module_name = action["b_module"]
|
|
script_path = os.path.join(self.shared_data.actions_dir, f"{module_name}.py")
|
|
|
|
if not os.path.exists(script_path):
|
|
if sched_id > 0:
|
|
self.db.mark_schedule_run(sched_id, "error", f"Script file not found: {script_path}")
|
|
return
|
|
|
|
# Detect format for custom scripts
|
|
from web_utils.script_utils import _detect_script_format
|
|
is_custom = module_name.startswith("custom/")
|
|
fmt = _detect_script_format(script_path) if is_custom else "bjorn"
|
|
|
|
# Build command
|
|
env = dict(os.environ)
|
|
env["PYTHONUNBUFFERED"] = "1"
|
|
env["BJORN_EMBEDDED"] = "1"
|
|
|
|
if fmt == "free":
|
|
cmd = ["sudo", "python3", "-u", script_path]
|
|
else:
|
|
runner_path = os.path.join(self.shared_data.current_dir, "action_runner.py")
|
|
cmd = ["sudo", "python3", "-u", runner_path, module_name, action["b_class"]]
|
|
if args:
|
|
cmd.extend(args.split())
|
|
|
|
process = subprocess.Popen(
|
|
cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
|
|
universal_newlines=True, env=env, cwd=self.shared_data.current_dir
|
|
)
|
|
|
|
# Wait for completion
|
|
stdout, _ = process.communicate(timeout=3600) # 1h max
|
|
exit_code = process.returncode
|
|
|
|
if exit_code == 0:
|
|
if sched_id > 0:
|
|
self.db.mark_schedule_run(sched_id, "success")
|
|
logger.info(f"Scheduled script {script_name} completed successfully")
|
|
else:
|
|
last_lines = (stdout or "").strip().split('\n')[-3:]
|
|
error_msg = '\n'.join(last_lines) if last_lines else f"Exit code {exit_code}"
|
|
if sched_id > 0:
|
|
self.db.mark_schedule_run(sched_id, "error", error_msg)
|
|
logger.warning(f"Scheduled script {script_name} failed (code={exit_code})")
|
|
|
|
except subprocess.TimeoutExpired:
|
|
if process:
|
|
process.kill()
|
|
process.wait()
|
|
if sched_id > 0:
|
|
self.db.mark_schedule_run(sched_id, "error", "Timeout (1h)")
|
|
logger.error(f"Scheduled script {script_name} timed out")
|
|
except Exception as e:
|
|
if sched_id > 0:
|
|
self.db.mark_schedule_run(sched_id, "error", str(e))
|
|
logger.error(f"Error executing scheduled script {script_name}: {e}")
|
|
finally:
|
|
# Ensure subprocess resources are released
|
|
if process:
|
|
try:
|
|
if process.stdout:
|
|
process.stdout.close()
|
|
if process.poll() is None:
|
|
process.kill()
|
|
process.wait()
|
|
except Exception:
|
|
pass
|
|
|
|
def _check_triggers(self):
|
|
"""Evaluate conditions for active triggers."""
|
|
try:
|
|
triggers = self.db.get_active_triggers()
|
|
except Exception as e:
|
|
logger.error(f"Failed to query triggers: {e}")
|
|
return
|
|
|
|
for trig in triggers:
|
|
trig_id = trig["id"]
|
|
try:
|
|
if self.db.is_trigger_on_cooldown(trig_id):
|
|
continue
|
|
|
|
conditions = trig.get("conditions", "")
|
|
if isinstance(conditions, str):
|
|
conditions = json.loads(conditions)
|
|
|
|
if not conditions:
|
|
continue
|
|
|
|
if evaluate_conditions(conditions, self.db):
|
|
# Respect concurrency limit
|
|
with self._threads_lock:
|
|
if self._active_threads >= self.MAX_CONCURRENT_SCRIPTS:
|
|
logger.debug(f"Skipping trigger {trig_id}: max concurrent scripts")
|
|
continue
|
|
|
|
script_name = trig["script_name"]
|
|
args = trig.get("args", "") or ""
|
|
logger.info(f"Trigger '{trig['trigger_name']}' fired -> {script_name}")
|
|
self.db.mark_trigger_fired(trig_id)
|
|
|
|
threading.Thread(
|
|
target=self._run_with_tracking,
|
|
args=(0, script_name, args),
|
|
daemon=True
|
|
).start()
|
|
except Exception as e:
|
|
logger.warning(f"Trigger {trig_id} eval error: {e}")
|
|
|
|
# Clear consumed events
|
|
with self._events_lock:
|
|
self._pending_action_events.clear()
|