Files
Bjorn/script_scheduler.py
infinition b0584a1a8e feat: Add login page with dynamic RGB effects and password toggle functionality
feat: Implement package management utilities with JSON endpoints for listing and uninstalling packages

feat: Create plugin management utilities with endpoints for listing, configuring, and installing plugins

feat: Develop schedule and trigger management utilities with CRUD operations for schedules and triggers
2026-03-19 00:40:04 +01:00

392 lines
14 KiB
Python

"""script_scheduler.py - Background daemon for scheduled scripts and conditional triggers."""
import json
import threading
import time
import subprocess
import os
import logging
from datetime import datetime, timedelta
from typing import Optional, Dict, Any, List
from logger import Logger
logger = Logger(name="script_scheduler", level=logging.DEBUG)
def evaluate_conditions(node: dict, db) -> bool:
"""Recursively evaluate a condition tree (AND/OR groups + leaf conditions)."""
if not node or not isinstance(node, dict):
return False
node_type = node.get("type", "condition")
if node_type == "group":
op = node.get("op", "AND").upper()
children = node.get("children", [])
if not children:
return True
results = [evaluate_conditions(c, db) for c in children]
return all(results) if op == "AND" else any(results)
# Leaf condition
source = node.get("source", "")
if source == "action_result":
return _eval_action_result(node, db)
elif source == "hosts_with_port":
return _eval_hosts_with_port(node, db)
elif source == "hosts_alive":
return _eval_hosts_alive(node, db)
elif source == "cred_found":
return _eval_cred_found(node, db)
elif source == "has_vuln":
return _eval_has_vuln(node, db)
elif source == "db_count":
return _eval_db_count(node, db)
elif source == "time_after":
return _eval_time_after(node)
elif source == "time_before":
return _eval_time_before(node)
logger.warning(f"Unknown condition source: {source}")
return False
def _compare(actual, check, expected):
"""Generic numeric comparison."""
try:
actual = float(actual)
expected = float(expected)
except (ValueError, TypeError):
return str(actual) == str(expected)
if check == "eq": return actual == expected
if check == "neq": return actual != expected
if check == "gt": return actual > expected
if check == "lt": return actual < expected
if check == "gte": return actual >= expected
if check == "lte": return actual <= expected
return False
def _eval_action_result(node, db):
"""Check last result of a specific action in the action_queue."""
action = node.get("action", "")
check = node.get("check", "eq")
value = node.get("value", "success")
row = db.query_one(
"SELECT status FROM action_queue WHERE action_name=? ORDER BY updated_at DESC LIMIT 1",
(action,)
)
if not row:
return False
return _compare(row["status"], check, value)
def _eval_hosts_with_port(node, db):
"""Count alive hosts with a specific port open."""
port = str(node.get("port", ""))
check = node.get("check", "gt")
value = node.get("value", 0)
# ports column is semicolon-separated
rows = db.query(
"SELECT COUNT(1) c FROM hosts WHERE alive=1 AND (ports LIKE ? OR ports LIKE ? OR ports LIKE ? OR ports=?)",
(f"{port};%", f"%;{port};%", f"%;{port}", port)
)
count = rows[0]["c"] if rows else 0
return _compare(count, check, value)
def _eval_hosts_alive(node, db):
"""Count alive hosts."""
check = node.get("check", "gt")
value = node.get("value", 0)
row = db.query_one("SELECT COUNT(1) c FROM hosts WHERE alive=1")
count = row["c"] if row else 0
return _compare(count, check, value)
def _eval_cred_found(node, db):
"""Check if credentials exist for a service."""
service = node.get("service", "")
row = db.query_one("SELECT COUNT(1) c FROM creds WHERE service=?", (service,))
return (row["c"] if row else 0) > 0
def _eval_has_vuln(node, db):
"""Check if any vulnerabilities exist."""
row = db.query_one("SELECT COUNT(1) c FROM vulnerabilities WHERE active=1")
return (row["c"] if row else 0) > 0
def _eval_db_count(node, db):
"""Count rows in a whitelisted table with simple conditions."""
ALLOWED_TABLES = {"hosts", "creds", "vulnerabilities", "action_queue", "services"}
table = node.get("table", "")
if table not in ALLOWED_TABLES:
logger.warning(f"db_count: table '{table}' not in whitelist")
return False
where = node.get("where", {})
check = node.get("check", "gt")
value = node.get("value", 0)
# Build parameterized WHERE clause
conditions = []
params = []
for k, v in where.items():
# Only allow simple alphanumeric column names
if k.isalnum():
conditions.append(f"{k}=?")
params.append(v)
sql = f"SELECT COUNT(1) c FROM {table}"
if conditions:
sql += " WHERE " + " AND ".join(conditions)
row = db.query_one(sql, tuple(params))
count = row["c"] if row else 0
return _compare(count, check, value)
def _eval_time_after(node):
"""Check if current time is after a given hour:minute."""
hour = int(node.get("hour", 0))
minute = int(node.get("minute", 0))
now = datetime.now()
return (now.hour, now.minute) >= (hour, minute)
def _eval_time_before(node):
"""Check if current time is before a given hour:minute."""
hour = int(node.get("hour", 23))
minute = int(node.get("minute", 59))
now = datetime.now()
return (now.hour, now.minute) < (hour, minute)
class ScriptSchedulerDaemon(threading.Thread):
"""Lightweight 30s tick daemon for script schedules and conditional triggers."""
MAX_PENDING_EVENTS = 100
MAX_CONCURRENT_SCRIPTS = 4
def __init__(self, shared_data):
super().__init__(daemon=True, name="ScriptScheduler")
self.shared_data = shared_data
self.db = shared_data.db
self._stop = threading.Event()
self.check_interval = 30
self._pending_action_events = []
self._events_lock = threading.Lock()
self._active_threads = 0
self._threads_lock = threading.Lock()
def run(self):
logger.info("ScriptSchedulerDaemon started (30s tick)")
# Initial delay to let the system boot
if self._stop.wait(10):
return
while not self._stop.is_set():
try:
self._check_schedules()
self._check_triggers()
except Exception as e:
logger.error(f"Scheduler tick error: {e}")
self._stop.wait(self.check_interval)
logger.info("ScriptSchedulerDaemon stopped")
def stop(self):
self._stop.set()
def notify_action_complete(self, action_name: str, mac: str, success: bool):
"""Called from orchestrator when an action finishes. Queues an event for next tick."""
with self._events_lock:
if len(self._pending_action_events) >= self.MAX_PENDING_EVENTS:
self._pending_action_events.pop(0)
self._pending_action_events.append({
"action": action_name,
"mac": mac,
"success": success,
})
def _check_schedules(self):
"""Query due schedules and fire each in a separate thread."""
try:
due = self.db.get_due_schedules()
except Exception as e:
logger.error(f"Failed to query due schedules: {e}")
return
for sched in due:
sched_id = sched["id"]
script_name = sched["script_name"]
args = sched.get("args", "") or ""
# Check conditions if any
conditions_raw = sched.get("conditions")
if conditions_raw:
try:
conditions = json.loads(conditions_raw) if isinstance(conditions_raw, str) else conditions_raw
if conditions and not evaluate_conditions(conditions, self.db):
logger.debug(f"Schedule {sched_id} conditions not met, skipping")
continue
except Exception as e:
logger.warning(f"Schedule {sched_id} condition eval failed: {e}")
# Respect concurrency limit
with self._threads_lock:
if self._active_threads >= self.MAX_CONCURRENT_SCRIPTS:
logger.debug(f"Skipping schedule {sched_id}: max concurrent scripts reached")
continue
logger.info(f"Firing scheduled script: {script_name} (schedule={sched_id})")
self.db.mark_schedule_run(sched_id, "running")
threading.Thread(
target=self._run_with_tracking,
args=(sched_id, script_name, args),
daemon=True
).start()
def _run_with_tracking(self, sched_id: int, script_name: str, args: str):
"""Thread wrapper that tracks active count for concurrency limiting."""
with self._threads_lock:
self._active_threads += 1
try:
self._execute_scheduled(sched_id, script_name, args)
finally:
with self._threads_lock:
self._active_threads = max(0, self._active_threads - 1)
def _execute_scheduled(self, sched_id: int, script_name: str, args: str):
"""Run the script and record result. When sched_id is 0 (trigger-fired), skip schedule updates."""
process = None
try:
# Look up the action in DB to determine format and path
action = None
for a in self.db.list_actions():
if a["b_class"] == script_name or a["b_module"] == script_name:
action = a
break
if not action:
if sched_id > 0:
self.db.mark_schedule_run(sched_id, "error", f"Action {script_name} not found")
return
module_name = action["b_module"]
script_path = os.path.join(self.shared_data.actions_dir, f"{module_name}.py")
if not os.path.exists(script_path):
if sched_id > 0:
self.db.mark_schedule_run(sched_id, "error", f"Script file not found: {script_path}")
return
# Detect format for custom scripts
from web_utils.script_utils import _detect_script_format
is_custom = module_name.startswith("custom/")
fmt = _detect_script_format(script_path) if is_custom else "bjorn"
# Build command
env = dict(os.environ)
env["PYTHONUNBUFFERED"] = "1"
env["BJORN_EMBEDDED"] = "1"
if fmt == "free":
cmd = ["sudo", "python3", "-u", script_path]
else:
runner_path = os.path.join(self.shared_data.current_dir, "action_runner.py")
cmd = ["sudo", "python3", "-u", runner_path, module_name, action["b_class"]]
if args:
cmd.extend(args.split())
process = subprocess.Popen(
cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
universal_newlines=True, env=env, cwd=self.shared_data.current_dir
)
# Wait for completion
stdout, _ = process.communicate(timeout=3600) # 1h max
exit_code = process.returncode
if exit_code == 0:
if sched_id > 0:
self.db.mark_schedule_run(sched_id, "success")
logger.info(f"Scheduled script {script_name} completed successfully")
else:
last_lines = (stdout or "").strip().split('\n')[-3:]
error_msg = '\n'.join(last_lines) if last_lines else f"Exit code {exit_code}"
if sched_id > 0:
self.db.mark_schedule_run(sched_id, "error", error_msg)
logger.warning(f"Scheduled script {script_name} failed (code={exit_code})")
except subprocess.TimeoutExpired:
if process:
process.kill()
process.wait()
if sched_id > 0:
self.db.mark_schedule_run(sched_id, "error", "Timeout (1h)")
logger.error(f"Scheduled script {script_name} timed out")
except Exception as e:
if sched_id > 0:
self.db.mark_schedule_run(sched_id, "error", str(e))
logger.error(f"Error executing scheduled script {script_name}: {e}")
finally:
# Ensure subprocess resources are released
if process:
try:
if process.stdout:
process.stdout.close()
if process.poll() is None:
process.kill()
process.wait()
except Exception:
pass
def _check_triggers(self):
"""Evaluate conditions for active triggers."""
try:
triggers = self.db.get_active_triggers()
except Exception as e:
logger.error(f"Failed to query triggers: {e}")
return
for trig in triggers:
trig_id = trig["id"]
try:
if self.db.is_trigger_on_cooldown(trig_id):
continue
conditions = trig.get("conditions", "")
if isinstance(conditions, str):
conditions = json.loads(conditions)
if not conditions:
continue
if evaluate_conditions(conditions, self.db):
# Respect concurrency limit
with self._threads_lock:
if self._active_threads >= self.MAX_CONCURRENT_SCRIPTS:
logger.debug(f"Skipping trigger {trig_id}: max concurrent scripts")
continue
script_name = trig["script_name"]
args = trig.get("args", "") or ""
logger.info(f"Trigger '{trig['trigger_name']}' fired -> {script_name}")
self.db.mark_trigger_fired(trig_id)
threading.Thread(
target=self._run_with_tracking,
args=(0, script_name, args),
daemon=True
).start()
except Exception as e:
logger.warning(f"Trigger {trig_id} eval error: {e}")
# Clear consumed events
with self._events_lock:
self._pending_action_events.clear()