mirror of
https://github.com/infinition/Bjorn.git
synced 2026-03-19 02:00:24 +00:00
feat: Add login page with dynamic RGB effects and password toggle functionality
feat: Implement package management utilities with JSON endpoints for listing and uninstalling packages feat: Create plugin management utilities with endpoints for listing, configuring, and installing plugins feat: Develop schedule and trigger management utilities with CRUD operations for schedules and triggers
This commit is contained in:
391
script_scheduler.py
Normal file
391
script_scheduler.py
Normal file
@@ -0,0 +1,391 @@
|
||||
"""script_scheduler.py - Background daemon for scheduled scripts and conditional triggers."""
|
||||
|
||||
import json
|
||||
import threading
|
||||
import time
|
||||
import subprocess
|
||||
import os
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, Dict, Any, List
|
||||
|
||||
from logger import Logger
|
||||
|
||||
logger = Logger(name="script_scheduler", level=logging.DEBUG)
|
||||
|
||||
|
||||
def evaluate_conditions(node: dict, db) -> bool:
|
||||
"""Recursively evaluate a condition tree (AND/OR groups + leaf conditions)."""
|
||||
if not node or not isinstance(node, dict):
|
||||
return False
|
||||
|
||||
node_type = node.get("type", "condition")
|
||||
|
||||
if node_type == "group":
|
||||
op = node.get("op", "AND").upper()
|
||||
children = node.get("children", [])
|
||||
if not children:
|
||||
return True
|
||||
results = [evaluate_conditions(c, db) for c in children]
|
||||
return all(results) if op == "AND" else any(results)
|
||||
|
||||
# Leaf condition
|
||||
source = node.get("source", "")
|
||||
|
||||
if source == "action_result":
|
||||
return _eval_action_result(node, db)
|
||||
elif source == "hosts_with_port":
|
||||
return _eval_hosts_with_port(node, db)
|
||||
elif source == "hosts_alive":
|
||||
return _eval_hosts_alive(node, db)
|
||||
elif source == "cred_found":
|
||||
return _eval_cred_found(node, db)
|
||||
elif source == "has_vuln":
|
||||
return _eval_has_vuln(node, db)
|
||||
elif source == "db_count":
|
||||
return _eval_db_count(node, db)
|
||||
elif source == "time_after":
|
||||
return _eval_time_after(node)
|
||||
elif source == "time_before":
|
||||
return _eval_time_before(node)
|
||||
|
||||
logger.warning(f"Unknown condition source: {source}")
|
||||
return False
|
||||
|
||||
|
||||
def _compare(actual, check, expected):
|
||||
"""Generic numeric comparison."""
|
||||
try:
|
||||
actual = float(actual)
|
||||
expected = float(expected)
|
||||
except (ValueError, TypeError):
|
||||
return str(actual) == str(expected)
|
||||
|
||||
if check == "eq": return actual == expected
|
||||
if check == "neq": return actual != expected
|
||||
if check == "gt": return actual > expected
|
||||
if check == "lt": return actual < expected
|
||||
if check == "gte": return actual >= expected
|
||||
if check == "lte": return actual <= expected
|
||||
return False
|
||||
|
||||
|
||||
def _eval_action_result(node, db):
|
||||
"""Check last result of a specific action in the action_queue."""
|
||||
action = node.get("action", "")
|
||||
check = node.get("check", "eq")
|
||||
value = node.get("value", "success")
|
||||
row = db.query_one(
|
||||
"SELECT status FROM action_queue WHERE action_name=? ORDER BY updated_at DESC LIMIT 1",
|
||||
(action,)
|
||||
)
|
||||
if not row:
|
||||
return False
|
||||
return _compare(row["status"], check, value)
|
||||
|
||||
|
||||
def _eval_hosts_with_port(node, db):
|
||||
"""Count alive hosts with a specific port open."""
|
||||
port = str(node.get("port", ""))
|
||||
check = node.get("check", "gt")
|
||||
value = node.get("value", 0)
|
||||
# ports column is semicolon-separated
|
||||
rows = db.query(
|
||||
"SELECT COUNT(1) c FROM hosts WHERE alive=1 AND (ports LIKE ? OR ports LIKE ? OR ports LIKE ? OR ports=?)",
|
||||
(f"{port};%", f"%;{port};%", f"%;{port}", port)
|
||||
)
|
||||
count = rows[0]["c"] if rows else 0
|
||||
return _compare(count, check, value)
|
||||
|
||||
|
||||
def _eval_hosts_alive(node, db):
|
||||
"""Count alive hosts."""
|
||||
check = node.get("check", "gt")
|
||||
value = node.get("value", 0)
|
||||
row = db.query_one("SELECT COUNT(1) c FROM hosts WHERE alive=1")
|
||||
count = row["c"] if row else 0
|
||||
return _compare(count, check, value)
|
||||
|
||||
|
||||
def _eval_cred_found(node, db):
|
||||
"""Check if credentials exist for a service."""
|
||||
service = node.get("service", "")
|
||||
row = db.query_one("SELECT COUNT(1) c FROM creds WHERE service=?", (service,))
|
||||
return (row["c"] if row else 0) > 0
|
||||
|
||||
|
||||
def _eval_has_vuln(node, db):
|
||||
"""Check if any vulnerabilities exist."""
|
||||
row = db.query_one("SELECT COUNT(1) c FROM vulnerabilities WHERE active=1")
|
||||
return (row["c"] if row else 0) > 0
|
||||
|
||||
|
||||
def _eval_db_count(node, db):
|
||||
"""Count rows in a whitelisted table with simple conditions."""
|
||||
ALLOWED_TABLES = {"hosts", "creds", "vulnerabilities", "action_queue", "services"}
|
||||
table = node.get("table", "")
|
||||
if table not in ALLOWED_TABLES:
|
||||
logger.warning(f"db_count: table '{table}' not in whitelist")
|
||||
return False
|
||||
|
||||
where = node.get("where", {})
|
||||
check = node.get("check", "gt")
|
||||
value = node.get("value", 0)
|
||||
|
||||
# Build parameterized WHERE clause
|
||||
conditions = []
|
||||
params = []
|
||||
for k, v in where.items():
|
||||
# Only allow simple alphanumeric column names
|
||||
if k.isalnum():
|
||||
conditions.append(f"{k}=?")
|
||||
params.append(v)
|
||||
|
||||
sql = f"SELECT COUNT(1) c FROM {table}"
|
||||
if conditions:
|
||||
sql += " WHERE " + " AND ".join(conditions)
|
||||
|
||||
row = db.query_one(sql, tuple(params))
|
||||
count = row["c"] if row else 0
|
||||
return _compare(count, check, value)
|
||||
|
||||
|
||||
def _eval_time_after(node):
|
||||
"""Check if current time is after a given hour:minute."""
|
||||
hour = int(node.get("hour", 0))
|
||||
minute = int(node.get("minute", 0))
|
||||
now = datetime.now()
|
||||
return (now.hour, now.minute) >= (hour, minute)
|
||||
|
||||
|
||||
def _eval_time_before(node):
|
||||
"""Check if current time is before a given hour:minute."""
|
||||
hour = int(node.get("hour", 23))
|
||||
minute = int(node.get("minute", 59))
|
||||
now = datetime.now()
|
||||
return (now.hour, now.minute) < (hour, minute)
|
||||
|
||||
|
||||
class ScriptSchedulerDaemon(threading.Thread):
|
||||
"""Lightweight 30s tick daemon for script schedules and conditional triggers."""
|
||||
|
||||
MAX_PENDING_EVENTS = 100
|
||||
MAX_CONCURRENT_SCRIPTS = 4
|
||||
|
||||
def __init__(self, shared_data):
|
||||
super().__init__(daemon=True, name="ScriptScheduler")
|
||||
self.shared_data = shared_data
|
||||
self.db = shared_data.db
|
||||
self._stop = threading.Event()
|
||||
self.check_interval = 30
|
||||
self._pending_action_events = []
|
||||
self._events_lock = threading.Lock()
|
||||
self._active_threads = 0
|
||||
self._threads_lock = threading.Lock()
|
||||
|
||||
def run(self):
|
||||
logger.info("ScriptSchedulerDaemon started (30s tick)")
|
||||
# Initial delay to let the system boot
|
||||
if self._stop.wait(10):
|
||||
return
|
||||
while not self._stop.is_set():
|
||||
try:
|
||||
self._check_schedules()
|
||||
self._check_triggers()
|
||||
except Exception as e:
|
||||
logger.error(f"Scheduler tick error: {e}")
|
||||
self._stop.wait(self.check_interval)
|
||||
logger.info("ScriptSchedulerDaemon stopped")
|
||||
|
||||
def stop(self):
|
||||
self._stop.set()
|
||||
|
||||
def notify_action_complete(self, action_name: str, mac: str, success: bool):
|
||||
"""Called from orchestrator when an action finishes. Queues an event for next tick."""
|
||||
with self._events_lock:
|
||||
if len(self._pending_action_events) >= self.MAX_PENDING_EVENTS:
|
||||
self._pending_action_events.pop(0)
|
||||
self._pending_action_events.append({
|
||||
"action": action_name,
|
||||
"mac": mac,
|
||||
"success": success,
|
||||
})
|
||||
|
||||
def _check_schedules(self):
|
||||
"""Query due schedules and fire each in a separate thread."""
|
||||
try:
|
||||
due = self.db.get_due_schedules()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to query due schedules: {e}")
|
||||
return
|
||||
|
||||
for sched in due:
|
||||
sched_id = sched["id"]
|
||||
script_name = sched["script_name"]
|
||||
args = sched.get("args", "") or ""
|
||||
|
||||
# Check conditions if any
|
||||
conditions_raw = sched.get("conditions")
|
||||
if conditions_raw:
|
||||
try:
|
||||
conditions = json.loads(conditions_raw) if isinstance(conditions_raw, str) else conditions_raw
|
||||
if conditions and not evaluate_conditions(conditions, self.db):
|
||||
logger.debug(f"Schedule {sched_id} conditions not met, skipping")
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.warning(f"Schedule {sched_id} condition eval failed: {e}")
|
||||
|
||||
# Respect concurrency limit
|
||||
with self._threads_lock:
|
||||
if self._active_threads >= self.MAX_CONCURRENT_SCRIPTS:
|
||||
logger.debug(f"Skipping schedule {sched_id}: max concurrent scripts reached")
|
||||
continue
|
||||
|
||||
logger.info(f"Firing scheduled script: {script_name} (schedule={sched_id})")
|
||||
self.db.mark_schedule_run(sched_id, "running")
|
||||
|
||||
threading.Thread(
|
||||
target=self._run_with_tracking,
|
||||
args=(sched_id, script_name, args),
|
||||
daemon=True
|
||||
).start()
|
||||
|
||||
def _run_with_tracking(self, sched_id: int, script_name: str, args: str):
|
||||
"""Thread wrapper that tracks active count for concurrency limiting."""
|
||||
with self._threads_lock:
|
||||
self._active_threads += 1
|
||||
try:
|
||||
self._execute_scheduled(sched_id, script_name, args)
|
||||
finally:
|
||||
with self._threads_lock:
|
||||
self._active_threads = max(0, self._active_threads - 1)
|
||||
|
||||
def _execute_scheduled(self, sched_id: int, script_name: str, args: str):
|
||||
"""Run the script and record result. When sched_id is 0 (trigger-fired), skip schedule updates."""
|
||||
process = None
|
||||
try:
|
||||
# Look up the action in DB to determine format and path
|
||||
action = None
|
||||
for a in self.db.list_actions():
|
||||
if a["b_class"] == script_name or a["b_module"] == script_name:
|
||||
action = a
|
||||
break
|
||||
|
||||
if not action:
|
||||
if sched_id > 0:
|
||||
self.db.mark_schedule_run(sched_id, "error", f"Action {script_name} not found")
|
||||
return
|
||||
|
||||
module_name = action["b_module"]
|
||||
script_path = os.path.join(self.shared_data.actions_dir, f"{module_name}.py")
|
||||
|
||||
if not os.path.exists(script_path):
|
||||
if sched_id > 0:
|
||||
self.db.mark_schedule_run(sched_id, "error", f"Script file not found: {script_path}")
|
||||
return
|
||||
|
||||
# Detect format for custom scripts
|
||||
from web_utils.script_utils import _detect_script_format
|
||||
is_custom = module_name.startswith("custom/")
|
||||
fmt = _detect_script_format(script_path) if is_custom else "bjorn"
|
||||
|
||||
# Build command
|
||||
env = dict(os.environ)
|
||||
env["PYTHONUNBUFFERED"] = "1"
|
||||
env["BJORN_EMBEDDED"] = "1"
|
||||
|
||||
if fmt == "free":
|
||||
cmd = ["sudo", "python3", "-u", script_path]
|
||||
else:
|
||||
runner_path = os.path.join(self.shared_data.current_dir, "action_runner.py")
|
||||
cmd = ["sudo", "python3", "-u", runner_path, module_name, action["b_class"]]
|
||||
if args:
|
||||
cmd.extend(args.split())
|
||||
|
||||
process = subprocess.Popen(
|
||||
cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
|
||||
universal_newlines=True, env=env, cwd=self.shared_data.current_dir
|
||||
)
|
||||
|
||||
# Wait for completion
|
||||
stdout, _ = process.communicate(timeout=3600) # 1h max
|
||||
exit_code = process.returncode
|
||||
|
||||
if exit_code == 0:
|
||||
if sched_id > 0:
|
||||
self.db.mark_schedule_run(sched_id, "success")
|
||||
logger.info(f"Scheduled script {script_name} completed successfully")
|
||||
else:
|
||||
last_lines = (stdout or "").strip().split('\n')[-3:]
|
||||
error_msg = '\n'.join(last_lines) if last_lines else f"Exit code {exit_code}"
|
||||
if sched_id > 0:
|
||||
self.db.mark_schedule_run(sched_id, "error", error_msg)
|
||||
logger.warning(f"Scheduled script {script_name} failed (code={exit_code})")
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
if process:
|
||||
process.kill()
|
||||
process.wait()
|
||||
if sched_id > 0:
|
||||
self.db.mark_schedule_run(sched_id, "error", "Timeout (1h)")
|
||||
logger.error(f"Scheduled script {script_name} timed out")
|
||||
except Exception as e:
|
||||
if sched_id > 0:
|
||||
self.db.mark_schedule_run(sched_id, "error", str(e))
|
||||
logger.error(f"Error executing scheduled script {script_name}: {e}")
|
||||
finally:
|
||||
# Ensure subprocess resources are released
|
||||
if process:
|
||||
try:
|
||||
if process.stdout:
|
||||
process.stdout.close()
|
||||
if process.poll() is None:
|
||||
process.kill()
|
||||
process.wait()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _check_triggers(self):
|
||||
"""Evaluate conditions for active triggers."""
|
||||
try:
|
||||
triggers = self.db.get_active_triggers()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to query triggers: {e}")
|
||||
return
|
||||
|
||||
for trig in triggers:
|
||||
trig_id = trig["id"]
|
||||
try:
|
||||
if self.db.is_trigger_on_cooldown(trig_id):
|
||||
continue
|
||||
|
||||
conditions = trig.get("conditions", "")
|
||||
if isinstance(conditions, str):
|
||||
conditions = json.loads(conditions)
|
||||
|
||||
if not conditions:
|
||||
continue
|
||||
|
||||
if evaluate_conditions(conditions, self.db):
|
||||
# Respect concurrency limit
|
||||
with self._threads_lock:
|
||||
if self._active_threads >= self.MAX_CONCURRENT_SCRIPTS:
|
||||
logger.debug(f"Skipping trigger {trig_id}: max concurrent scripts")
|
||||
continue
|
||||
|
||||
script_name = trig["script_name"]
|
||||
args = trig.get("args", "") or ""
|
||||
logger.info(f"Trigger '{trig['trigger_name']}' fired -> {script_name}")
|
||||
self.db.mark_trigger_fired(trig_id)
|
||||
|
||||
threading.Thread(
|
||||
target=self._run_with_tracking,
|
||||
args=(0, script_name, args),
|
||||
daemon=True
|
||||
).start()
|
||||
except Exception as e:
|
||||
logger.warning(f"Trigger {trig_id} eval error: {e}")
|
||||
|
||||
# Clear consumed events
|
||||
with self._events_lock:
|
||||
self._pending_action_events.clear()
|
||||
Reference in New Issue
Block a user