# action_scheduler.py # Smart Action Scheduler for Bjorn - queue-only implementation # Handles trigger evaluation, requirements checking, and queue management. # # Invariants we enforce: # - At most ONE "active" row per (action_name, mac_address, COALESCE(port,0)) # where active ∈ {'scheduled','pending','running'}. # - Retries for failed entries are coordinated by cleanup_queue() (with backoff) # and never compete with trigger-based enqueues. # # Runtime knobs (from shared.py): # shared_data.retry_success_actions : bool (default False) # shared_data.retry_failed_actions : bool (default True) # # These take precedence over cooldown / rate-limit for NON-interval triggers. from __future__ import annotations import json import time import threading from datetime import datetime, timedelta, timezone from typing import Any, Dict, List, Optional, Tuple from init_shared import shared_data from logger import Logger logger = Logger(name="action_scheduler.py") # ---------- UTC helpers (match SQLite's UTC CURRENT_TIMESTAMP) ---------- def _utcnow() -> datetime: """Naive UTC datetime to compare with SQLite TEXT timestamps (UTC).""" return datetime.now(timezone.utc).replace(tzinfo=None) def _utcnow_str() -> str: """UTC 'YYYY-MM-DD HH:MM:SS' string to compare against SQLite TEXT.""" return _utcnow().strftime("%Y-%m-%d %H:%M:%S") def _db_ts(dt: datetime) -> str: """Format any datetime as 'YYYY-MM-DD HH:MM:SS' (UTC expected).""" return dt.strftime("%Y-%m-%d %H:%M:%S") # Service → fallback ports (used when port_services table has nothing for a host) SERVICE_PORTS: Dict[str, List[str]] = { "ssh": ["22"], "http": ["80", "8080"], "https": ["443"], "smb": ["445"], "ftp": ["21"], "telnet": ["23"], "mysql": ["3306"], "mssql": ["1433"], "postgres": ["5432"], "rdp": ["3389"], "vnc": ["5900"], } class ActionScheduler: """ Smart scheduler that evaluates triggers and enqueues actions. Does NOT execute actions - that's the orchestrator's job. """ def __init__(self, shared_data_): self.shared_data = shared_data_ self.db = shared_data_.db # Controller MAC for global actions self.ctrl_mac = (self.shared_data.get_raspberry_mac() or "__GLOBAL__").lower() self._ensure_host_exists(self.ctrl_mac) # Runtime flags self.running = True self.check_interval = 5 # seconds between iterations # Action definition cache self._action_definitions: Dict[str, Dict[str, Any]] = {} self._last_cache_refresh = 0.0 self._cache_ttl = 60.0 # seconds # Memory for global actions self._last_global_runs: Dict[str, float] = {} # Actions Studio last source type self._last_source_is_studio: Optional[bool] = None # Enforce DB invariants (idempotent) self._ensure_db_invariants() logger.info("ActionScheduler initialized") # --------------------------------------------------------------------- loop def run(self): """Main scheduler loop.""" logger.info("ActionScheduler starting main loop") while self.running and not self.shared_data.orchestrator_should_exit: try: # Refresh action cache if needed self._refresh_cache_if_needed() # 1) Promote scheduled actions that are due self._promote_scheduled_to_pending() # 2) Publish next scheduled occurrences for interval actions self._publish_all_upcoming() # 3) Evaluate global on_start actions self._evaluate_global_actions() # 4) Evaluate per-host triggers self.evaluate_all_triggers() # 5) Queue maintenance self.cleanup_queue() self.update_priorities() time.sleep(self.check_interval) except Exception as e: logger.error(f"Error in scheduler loop: {e}") time.sleep(self.check_interval) logger.info("ActionScheduler stopped") def stop(self): """Stop the scheduler.""" logger.info("Stopping ActionScheduler...") self.running = False # --------------------------------------------------------------- definitions # ---------- replace this method ---------- def _refresh_cache_if_needed(self): """Refresh action definitions cache if expired or source flipped.""" now = time.time() use_studio = bool(getattr(self.shared_data, "use_actions_studio", False)) # Refresh if TTL expired or the source changed (actions ↔ studio_actions) if (now - self._last_cache_refresh > self._cache_ttl) or (self._last_source_is_studio != use_studio): self._refresh_action_cache(use_studio=use_studio) self._last_cache_refresh = now self._last_source_is_studio = use_studio # ---------- replace this method ---------- def _refresh_action_cache(self, use_studio: Optional[bool] = None): """Reload action definitions from database, from 'actions' or 'studio' view.""" if use_studio is None: use_studio = bool(getattr(self.shared_data, "use_actions_studio", False)) try: if use_studio: # Primary: studio actions = self.db.list_studio_actions() source = "studio" else: # Primary: plain actions actions = self.db.list_actions() source = "actions" # Build cache (expect same action schema: b_class, b_trigger, b_action, etc.) self._action_definitions = {a["b_class"]: a for a in actions} logger.info(f"Refreshed action cache from '{source}': {len(self._action_definitions)} actions") except AttributeError as e: # Fallback if the chosen method isn't available on the DB adapter if use_studio and hasattr(self.db, "list_actions"): logger.warning(f"DB has no list_studio_actions(); falling back to list_actions(): {e}") try: actions = self.db.list_actions() self._action_definitions = {a["b_class"]: a for a in actions} logger.info(f"Refreshed action cache from 'actions' (fallback): {len(self._action_definitions)} actions") return except Exception as ee: logger.error(f"Fallback list_actions() failed: {ee}") else: logger.error(f"Action cache refresh failed (no suitable DB method): {e}") except Exception as e: logger.error(f"Failed to refresh action cache: {e}") # ------------------------------------------------------------------ helpers def _ensure_db_invariants(self): """ Create a partial UNIQUE index that forbids more than one active entry for the same (action_name, mac_address, COALESCE(port,0)). """ try: self.db.execute( """ CREATE UNIQUE INDEX IF NOT EXISTS uq_action_active ON action_queue(action_name, mac_address, COALESCE(port,0)) WHERE status IN ('scheduled','pending','running') """ ) except Exception as e: # If the SQLite build does not support partial/expression indexes, # we still have app-level guards (NOT EXISTS inserts). But this # index is recommended to make the invariant bulletproof. logger.warning(f"Could not create unique partial index (fallback to app-level guards): {e}") def _promote_scheduled_to_pending(self): """Promote due scheduled actions to pending status.""" try: promoted = self.db.promote_due_scheduled_to_pending() if promoted: logger.debug(f"Promoted {promoted} scheduled action(s) to pending") except Exception as e: logger.error(f"Failed to promote scheduled actions: {e}") def _ensure_host_exists(self, mac: str): """Ensure host exists in database (idempotent).""" if not mac: return try: self.db.execute( """ INSERT INTO hosts (mac_address, alive, updated_at) VALUES (?, 0, CURRENT_TIMESTAMP) ON CONFLICT(mac_address) DO UPDATE SET updated_at = CURRENT_TIMESTAMP """, (mac,), ) except Exception: pass # ---------------------------------------------------------- interval logic def _parse_interval_seconds(self, trigger: str) -> int: """Parse interval from trigger string 'on_interval:SECONDS'.""" if not trigger or not trigger.startswith("on_interval:"): return 0 try: return max(0, int(trigger.split(":", 1)[1] or 0)) except Exception: return 0 def _publish_all_upcoming(self): """ Publish next scheduled occurrence for all interval actions. NOTE: By design, the runtime flags do not cancel interval publishing. """ # Global interval actions for action in self._action_definitions.values(): if (action.get("b_action") or "normal") != "global": continue if int(action.get("b_enabled", 1) or 1) != 1: continue trigger = (action.get("b_trigger") or "").strip() interval = self._parse_interval_seconds(trigger) if interval <= 0: continue self._publish_next_schedule_for_global(action, interval) # Per-host interval actions try: hosts = self.db.get_all_hosts() except Exception: hosts = [] for host in hosts: if not host.get("alive"): continue mac = host.get("mac_address") or "" if not mac: continue for action in self._action_definitions.values(): if (action.get("b_action") or "normal") == "global": continue if int(action.get("b_enabled", 1) or 1) != 1: continue trigger = (action.get("b_trigger") or "").strip() interval = self._parse_interval_seconds(trigger) if interval <= 0: continue self._publish_next_schedule_for_host(host, action, interval) def _publish_next_schedule_for_global(self, action_def: Dict[str, Any], interval: int): """Publish next scheduled occurrence for a global action.""" try: action_name = action_def["b_class"] mac = self.ctrl_mac # Already active? active = self.db.query( """ SELECT 1 FROM action_queue WHERE action_name=? AND mac_address=? AND status IN ('scheduled','pending','running') LIMIT 1 """, (action_name, mac), ) if active: return # Next occurrence immediately after last completion, else now (UTC) last = self._get_last_global_execution_time(action_name) next_run = _utcnow() if not last else (last + timedelta(seconds=interval)) scheduled_for = _db_ts(next_run) inserted = self.db.ensure_scheduled_occurrence( action_name=action_name, next_run_at=scheduled_for, mac=mac, ip="0.0.0.0", priority=int(action_def.get("b_priority", 40) or 40), trigger="scheduler", tags=action_def.get("b_tags", []), metadata={"interval": interval, "is_global": True}, max_retries=int(action_def.get("b_max_retries", 3) or 3), ) if inserted: logger.debug(f"Scheduled global '{action_name}' at {scheduled_for}") except Exception as e: logger.error(f"Failed to publish global schedule: {e}") def _publish_next_schedule_for_host(self, host: Dict[str, Any], action_def: Dict[str, Any], interval: int): """Publish next scheduled occurrence for a per-host action.""" try: mac = host.get("mac_address") or "" if not mac: return self._ensure_host_exists(mac) action_name = action_def["b_class"] # Already active? active = self.db.query( """ SELECT 1 FROM action_queue WHERE action_name=? AND mac_address=? AND status IN ('scheduled','pending','running') LIMIT 1 """, (action_name, mac), ) if active: return # Next occurrence immediately after last completion, else now (UTC) last = self._get_last_execution_time(mac, action_name) next_run = _utcnow() if not last else (last + timedelta(seconds=interval)) scheduled_for = _db_ts(next_run) inserted = self.db.ensure_scheduled_occurrence( action_name=action_name, next_run_at=scheduled_for, mac=mac, ip=(host.get("ips") or "").split(";")[0] if host.get("ips") else "", priority=int(action_def.get("b_priority", 40) or 40), trigger="scheduler", tags=action_def.get("b_tags", []), metadata={"interval": interval, "is_global": False}, max_retries=int(action_def.get("b_max_retries", 3) or 3), ) if inserted: logger.debug(f"Scheduled '{action_name}' for {mac} at {scheduled_for}") except Exception as e: logger.error(f"Failed to publish host schedule: {e}") # ------------------------------------------------------------ global start def _evaluate_global_actions(self): """Evaluate and queue global actions with on_start trigger.""" self._globals_lock = getattr(self, "_globals_lock", threading.Lock()) with self._globals_lock: try: for action in self._action_definitions.values(): if (action.get("b_action") or "normal") != "global": continue if int(action.get("b_enabled", 1)) != 1: continue trigger = (action.get("b_trigger") or "").strip() if trigger != "on_start": continue action_name = action["b_class"] # Already executed at least once? last = self._get_last_global_execution_time(action_name) if last is not None: continue # Already queued? existing = self.db.query( """ SELECT 1 FROM action_queue WHERE action_name=? AND status IN ('scheduled','pending','running') LIMIT 1 """, (action_name,), ) if existing: continue # Queue the action self._queue_global_action(action) self._last_global_runs[action_name] = time.time() logger.info(f"Queued global action {action_name}") except Exception as e: logger.error(f"Error evaluating global actions: {e}") def _queue_global_action(self, action_def: Dict[str, Any]): """Queue a global action for execution (idempotent insert).""" action_name = action_def["b_class"] mac = self.ctrl_mac ip = "0.0.0.0" timeout = int(action_def.get("b_timeout", 300) or 300) expires_at = _db_ts(_utcnow() + timedelta(seconds=timeout)) metadata = { "trigger": action_def.get("b_trigger", ""), "requirements": action_def.get("b_requires", ""), "timeout": timeout, "is_global": True, } try: self._ensure_host_exists(mac) # Guard with NOT EXISTS to avoid races self.db.execute( """ INSERT INTO action_queue ( action_name, mac_address, ip, port, hostname, service, priority, status, max_retries, expires_at, trigger_source, tags, metadata ) SELECT ?, ?, ?, NULL, NULL, NULL, ?, 'pending', ?, ?, ?, ?, ? WHERE NOT EXISTS ( SELECT 1 FROM action_queue WHERE action_name=? AND mac_address=? AND COALESCE(port,0)=0 AND status IN ('scheduled','pending','running') ) """, ( action_name, mac, ip, int(action_def.get("b_priority", 50) or 50), int(action_def.get("b_max_retries", 3) or 3), expires_at, action_def.get("b_trigger", ""), json.dumps(action_def.get("b_tags", [])), json.dumps(metadata), action_name, mac, ), ) except Exception as e: logger.error(f"Failed to queue global action {action_name}: {e}") # ------------------------------------------------------------- host path def evaluate_all_triggers(self): """Evaluate triggers for all hosts.""" hosts = self.db.get_all_hosts() # include dead hosts for on_leave trigger for host in hosts: mac = host["mac_address"] for action_name, action_def in self._action_definitions.items(): # Skip global actions if (action_def.get("b_action") or "normal") == "global": continue # Skip disabled actions if not int(action_def.get("b_enabled", 1)): continue trigger = (action_def.get("b_trigger") or "").strip() if not trigger: continue # Skip interval triggers (handled elsewhere) if trigger.startswith("on_interval:"): continue # Evaluate trigger if not evaluate_trigger(trigger, host, action_def): continue # Evaluate requirements requires = action_def.get("b_requires", "") if requires and not evaluate_requirements(requires, host, action_def): continue # Resolve target port/service target_port, target_service = self._resolve_target_port_service(mac, host, action_def) # Decide if we should enqueue if not self._should_queue_action(mac, action_name, action_def, target_port): continue # Queue the action self._queue_action(host, action_def, target_port, target_service) logger.info(f"Queued {action_name} for {mac} (port={target_port}, service={target_service})") def _resolve_target_port_service( self, mac: str, host: Dict[str, Any], action_def: Dict[str, Any] ) -> Tuple[Optional[int], Optional[str]]: """Resolve target port and service for action (service wins over port when present).""" ports = _normalize_ports(host.get("ports")) target_port: Optional[int] = None target_service: Optional[str] = None # Try b_service first if action_def.get("b_service"): try: services = ( json.loads(action_def["b_service"]) if isinstance(action_def["b_service"], str) else action_def["b_service"] ) except Exception: services = [] if services: for svc in services: row = self.db.query( "SELECT port FROM port_services " "WHERE mac_address=? AND state='open' AND LOWER(service)=? " "ORDER BY last_seen DESC LIMIT 1", (mac, str(svc).lower()), ) if row: target_port = int(row[0]["port"]) target_service = str(svc).lower() break # Fallback to b_port if target_port is None and action_def.get("b_port"): if str(action_def["b_port"]) in ports: target_port = int(action_def["b_port"]) return target_port, target_service # ----------------------------------------------------- re-queue policy core def _get_last_status(self, mac: str, action_name: str, target_port: Optional[int]) -> Optional[str]: """ Return last known status for (mac, action, port), considering the chronological fields (completed_at > started_at > scheduled_for > created_at). """ self_port = 0 if target_port is None else int(target_port) row = self.db.query( """ SELECT status FROM action_queue WHERE mac_address=? AND action_name=? AND COALESCE(port,0)=? ORDER BY datetime(COALESCE(completed_at, started_at, scheduled_for, created_at)) DESC LIMIT 1 """, (mac, action_name, self_port), ) return row[0]["status"] if row else None def _should_queue_action( self, mac: str, action_name: str, action_def: Dict[str, Any], target_port: Optional[int] ) -> bool: """ Decide if we should enqueue a new job. Evaluation order: 0) no duplicate active job 1) runtime flags (retry_success_actions / retry_failed_actions) 1-bis) do NOT enqueue if a retryable failed exists (let cleanup_queue() handle it) 2) cooldown 3) rate limit """ self_port = 0 if target_port is None else int(target_port) # 0) Duplicate protection (active) existing = self.db.query( """ SELECT 1 FROM action_queue WHERE mac_address=? AND action_name=? AND COALESCE(port,0)=? AND status IN ('scheduled','pending','running') LIMIT 1 """, (mac, action_name, self_port), ) if existing: return False # 1) Runtime flags take precedence allow_success = bool(getattr(self.shared_data, "retry_success_actions", False)) allow_failed = bool(getattr(self.shared_data, "retry_failed_actions", True)) last_status = self._get_last_status(mac, action_name, target_port) if last_status == "success" and not allow_success: return False if last_status == "failed" and not allow_failed: return False # 1-bis) If a retryable failed exists, let cleanup_queue() requeue it (avoid duplicates) if allow_failed: retryable = self.db.query( """ SELECT 1 FROM action_queue WHERE mac_address=? AND action_name=? AND COALESCE(port,0)=? AND status='failed' AND retry_count < max_retries AND COALESCE(error_message,'') != 'expired' LIMIT 1 """, (mac, action_name, self_port), ) if retryable: return False # 2) Cooldown (UTC) cooldown = int(action_def.get("b_cooldown", 0) or 0) if cooldown > 0: last_exec = self._get_last_execution_time(mac, action_name) if last_exec and (_utcnow() - last_exec).total_seconds() < cooldown: return False # 3) Rate limit (UTC) rate_limit = (action_def.get("b_rate_limit") or "").strip() if rate_limit and not self._check_rate_limit(mac, action_name, rate_limit): return False return True def _queue_action( self, host: Dict[str, Any], action_def: Dict[str, Any], target_port: Optional[int], target_service: Optional[str] ): """Queue action for execution (idempotent insert with NOT EXISTS guard).""" action_name = action_def["b_class"] mac = host["mac_address"] timeout = int(action_def.get("b_timeout", 300) or 300) expires_at = _db_ts(_utcnow() + timedelta(seconds=timeout)) self_port = 0 if target_port is None else int(target_port) metadata = { "trigger": action_def.get("b_trigger", ""), "requirements": action_def.get("b_requires", ""), "is_global": False, "timeout": timeout, "ports_snapshot": host.get("ports") or "", } try: self.db.execute( """ INSERT INTO action_queue ( action_name, mac_address, ip, port, hostname, service, priority, status, max_retries, expires_at, trigger_source, tags, metadata ) SELECT ?, ?, ?, ?, ?, ?, ?, 'pending', ?, ?, ?, ?, ? WHERE NOT EXISTS ( SELECT 1 FROM action_queue WHERE mac_address=? AND action_name=? AND COALESCE(port,0)=? AND status IN ('scheduled','pending','running') ) """, ( action_name, mac, (host.get("ips") or "").split(";")[0] if host.get("ips") else "", target_port, (host.get("hostnames") or "").split(";")[0] if host.get("hostnames") else "", target_service, int(action_def.get("b_priority", 50) or 50), int(action_def.get("b_max_retries", 3) or 3), expires_at, action_def.get("b_trigger", ""), json.dumps(action_def.get("b_tags", [])), json.dumps(metadata), mac, action_name, self_port, ), ) except Exception as e: logger.error(f"Failed to queue {action_name} for {mac}: {e}") # ------------------------------------------------------------- last times def _get_last_execution_time(self, mac: str, action_name: str) -> Optional[datetime]: """Get last execution time (DB read only; naive UTC).""" row = self.db.query( """ SELECT completed_at FROM action_queue WHERE mac_address=? AND action_name=? AND status IN ('success','failed') ORDER BY completed_at DESC LIMIT 1 """, (mac, action_name), ) if row and row[0].get("completed_at"): try: return datetime.fromisoformat(row[0]["completed_at"]) except Exception: return None return None def _get_last_global_execution_time(self, action_name: str) -> Optional[datetime]: """Get last global action execution time (naive UTC).""" row = self.db.query( """ SELECT completed_at FROM action_queue WHERE action_name=? AND status IN ('success','failed') ORDER BY completed_at DESC LIMIT 1 """, (action_name,), ) if row and row[0].get("completed_at"): try: return datetime.fromisoformat(row[0]["completed_at"]) except Exception: return None return None # ------------------------------------------------------------- constraints def _check_rate_limit(self, mac: str, action_name: str, rate_limit: str) -> bool: """ Check "X/SECONDS" rate-limit (count based on created_at). Returns True if action is allowed to queue. """ try: max_count, period = rate_limit.split("/") max_count = int(max_count) period = int(period) since = _db_ts(_utcnow() - timedelta(seconds=period)) count = self.db.query( """ SELECT COUNT(*) AS c FROM action_queue WHERE mac_address=? AND action_name=? AND created_at >= ? """, (mac, action_name, since), )[0]["c"] return int(count) < max_count except Exception: # Invalid format -> do not block return True # -------------------------------------------------------------- maintenance def cleanup_queue(self): """Clean up queue: timeouts, retries, purge old entries.""" try: now_iso = _utcnow_str() # 1) Expire pending actions self.db.execute( """ UPDATE action_queue SET status='failed', completed_at=CURRENT_TIMESTAMP, error_message=COALESCE(error_message,'expired') WHERE status='pending' AND expires_at IS NOT NULL AND expires_at < ? """, (now_iso,), ) # 2) Timeout running actions self.db.execute( """ UPDATE action_queue SET status='failed', completed_at=CURRENT_TIMESTAMP, error_message=COALESCE(error_message,'timeout') WHERE status='running' AND started_at IS NOT NULL AND datetime(started_at, '+' || COALESCE( CAST(json_extract(metadata, '$.timeout') AS INTEGER), 900 ) || ' seconds') <= datetime('now') """ ) # 3) Retry failed actions with exponential backoff if bool(getattr(self.shared_data, "retry_failed_actions", True)): # Only if no active job exists for the same (action, mac, port) self.db.execute( """ UPDATE action_queue AS a SET status='pending', retry_count = retry_count + 1, scheduled_for = datetime( 'now', '+' || ( CASE WHEN (60 * (1 << retry_count)) > 900 THEN 900 ELSE (60 * (1 << retry_count)) END ) || ' seconds' ), error_message = NULL, started_at = NULL, completed_at = NULL WHERE a.status='failed' AND a.retry_count < a.max_retries AND COALESCE(a.error_message,'') != 'expired' AND NOT EXISTS ( SELECT 1 FROM action_queue b WHERE b.mac_address=a.mac_address AND b.action_name=a.action_name AND COALESCE(b.port,0)=COALESCE(a.port,0) AND b.status IN ('scheduled','pending','running') ) """ ) # 4) Purge old completed entries old_date = _db_ts(_utcnow() - timedelta(days=7)) self.db.execute( """ DELETE FROM action_queue WHERE status IN ('success','failed','cancelled','expired') AND completed_at < ? """, (old_date,), ) except Exception as e: logger.error(f"Failed to cleanup queue: {e}") def update_priorities(self): """Boost priority for actions waiting too long (anti-starvation).""" try: self.db.execute( """ UPDATE action_queue SET priority = MIN(100, priority + 1) WHERE status='pending' AND julianday('now') - julianday(created_at) > 0.0417 """ ) except Exception as e: logger.error(f"Failed to update priorities: {e}") # =================================================================== helpers == def _normalize_ports(raw) -> List[str]: """Normalize ports to list of strings.""" if not raw: return [] if isinstance(raw, list): return [str(p).split("/")[0] for p in raw if p is not None and str(p) != ""] if isinstance(raw, int): return [str(raw)] if isinstance(raw, str): s = raw.strip() if not s: return [] if s.startswith("[") and s.endswith("]"): try: arr = json.loads(s) return [str(p).split("/")[0] for p in arr] except Exception: pass if ";" in s: return [p.strip().split("/")[0] for p in s.split(";") if p.strip()] return [s.split("/")[0]] return [str(raw)] def _has_open_service(mac: str, svc: str, host: Dict[str, Any]) -> bool: """Check if service is open for host (port_services first, then fallback list).""" svc = (svc or "").lower().strip() # Check port_services table first rows = shared_data.db.query( "SELECT 1 FROM port_services WHERE mac_address=? AND state='open' AND LOWER(service)=? LIMIT 1", (mac, svc), ) if rows: return True # Fallback to known port numbers ports = set(_normalize_ports(host.get("ports"))) for p in SERVICE_PORTS.get(svc, []): if p in ports: return True return False def _last_presence_event_for_mac(mac: str) -> Optional[str]: """Get last presence event for MAC (PresenceJoin/PresenceLeave).""" rows = shared_data.db.query( """ SELECT action_name FROM action_queue WHERE mac_address=? AND action_name IN ('PresenceJoin','PresenceLeave') ORDER BY datetime(COALESCE(completed_at, started_at, scheduled_for, created_at)) DESC LIMIT 1 """, (mac,), ) return rows[0]["action_name"] if rows else None # --------------------------------------------------------------- trigger eval -- def evaluate_trigger(trigger: str, host: Dict[str, Any], action_def: Dict[str, Any]) -> bool: """ Evaluate trigger condition for host. Supported triggers: - on_start, on_host_alive, on_host_dead - on_port_change, on_new_port:PORT - on_service:SERVICE, on_web_service - on_success:ACTION, on_failure:ACTION - on_cred_found:SERVICE - on_mac_is:MAC, on_essid_is:ESSID, on_ip_is:IP - on_has_cve[:CVE], on_has_cpe[:CPE] - on_all:[...], on_any:[...] """ try: mac = host["mac_address"] s = (trigger or "").strip() if not s: return False # Combined triggers if s.startswith("on_all:"): try: arr = json.loads(s.split(":", 1)[1]) except Exception: return False return all(evaluate_trigger(t, host, action_def) for t in arr) if s.startswith("on_any:"): try: arr = json.loads(s.split(":", 1)[1]) except Exception: return False return any(evaluate_trigger(t, host, action_def) for t in arr) # Skip interval triggers if s.startswith("on_interval:"): return False # Parse trigger name and parameter if ":" in s: name, param = s.split(":", 1) name = name.strip() param = (param or "").strip() else: name, param = s, "" # Aliases if name == "on_alive": name = "on_host_alive" if name == "on_dead": name = "on_host_dead" # Join/Leave events if name == "on_join": if not bool(host.get("alive")): return False last = _last_presence_event_for_mac(mac) return last != "PresenceJoin" if name == "on_leave": if bool(host.get("alive")): return False last = _last_presence_event_for_mac(mac) return last != "PresenceLeave" # Basic triggers if name == "on_start" or name == "on_new_host": r = shared_data.db.query( """ SELECT 1 FROM action_queue WHERE mac_address=? AND action_name=? AND status IN ('success','failed') LIMIT 1 """, (mac, action_def["b_class"]), ) return not bool(r) if name == "on_host_alive": return bool(host.get("alive")) if name == "on_host_dead": return not bool(host.get("alive")) # Skip port/service triggers for dead hosts if not bool(host.get("alive")) and name in {"on_service", "on_web_service", "on_new_port", "on_port_change"}: return False # Port triggers if name == "on_port_change": cur = set(_normalize_ports(host.get("ports"))) prev = set(_normalize_ports(host.get("previous_ports"))) return cur != prev if name == "on_new_port": port = str(param) cur = set(_normalize_ports(host.get("ports"))) prev = set(_normalize_ports(host.get("previous_ports"))) return port in cur and port not in prev # Action status triggers if name == "on_success": parent = param r = shared_data.db.query( """ SELECT 1 FROM action_queue WHERE mac_address=? AND action_name=? AND status='success' ORDER BY completed_at DESC LIMIT 1 """, (mac, parent), ) return bool(r) if name == "on_failure": parent = param r = shared_data.db.query( """ SELECT 1 FROM action_queue WHERE mac_address=? AND action_name=? AND status='failed' ORDER BY completed_at DESC LIMIT 1 """, (mac, parent), ) return bool(r) # Service triggers if name == "on_cred_found": service = param.lower() r = shared_data.db.query( "SELECT 1 FROM creds WHERE mac_address=? AND LOWER(service)=? LIMIT 1", (mac, service), ) return bool(r) if name == "on_service": return _has_open_service(mac, param, host) if name == "on_web_service": return _has_open_service(mac, "http", host) or _has_open_service(mac, "https", host) # Identity triggers if name == "on_mac_is": return str(mac).lower() == param.lower() if name == "on_essid_is": return (host.get("essid") or "") == param if name == "on_ip_is": ips = (host.get("ips") or "").split(";") if host.get("ips") else [] return param in ips # Vulnerability triggers if name == "on_has_cve": if not param: r = shared_data.db.query( "SELECT 1 FROM vulnerabilities WHERE mac_address=? AND is_active=1 LIMIT 1", (mac,), ) return bool(r) r = shared_data.db.query( "SELECT 1 FROM vulnerabilities WHERE mac_address=? AND vuln_id=? AND is_active=1 LIMIT 1", (mac, param), ) return bool(r) if name == "on_has_cpe": if not param: r = shared_data.db.query( "SELECT 1 FROM detected_software WHERE mac_address=? AND is_active=1 LIMIT 1", (mac,), ) return bool(r) r = shared_data.db.query( "SELECT 1 FROM detected_software WHERE mac_address=? AND cpe=? AND is_active=1 LIMIT 1", (mac, param), ) return bool(r) # Unknown trigger logger.debug(f"Unknown trigger: {name}") return False except Exception as e: logger.error(f"Error evaluating trigger '{trigger}': {e}") return False # ---------------------------------------------------------- requirements eval -- def evaluate_requirements(requires: Any, host: Dict[str, Any], action_def: Dict[str, Any]) -> bool: """Evaluate requirements for action.""" if requires is None: return True # Already an object if isinstance(requires, (dict, list)): return evaluate_requirements_object(requires, host, action_def) s = str(requires).strip() if not s: return True # JSON string if (s.startswith("{") and s.endswith("}")) or (s.startswith("[") and s.endswith("]")): try: obj = json.loads(s) return evaluate_requirements_object(obj, host, action_def) except Exception: pass # Legacy "Action:status" format if ":" in s: a, st = s.split(":", 1) obj = {"action": a.strip(), "status": st.strip()} return evaluate_requirements_object(obj, host, action_def) return True def evaluate_requirements_object(req: Any, host: Dict[str, Any], action_def: Dict[str, Any]) -> bool: """ Evaluate requirements object. Supported: - {"all": [...]} / {"any": [...]} / {"not": {...}} - {"action": "ACTION", "status": "STATUS", "scope": "host|global"} - {"has_port": PORT} - {"has_cred": "SERVICE"} - {"has_cve": "CVE"} - {"has_cpe": "CPE"} - {"mac_is": "MAC"} - {"essid_is": "ESSID"} - {"service_is_open": "SERVICE"} """ mac = host["mac_address"] # Combinators if isinstance(req, dict) and "all" in req: return all(evaluate_requirements_object(x, host, action_def) for x in (req.get("all") or [])) if isinstance(req, dict) and "any" in req: return any(evaluate_requirements_object(x, host, action_def) for x in (req.get("any") or [])) if isinstance(req, dict) and "not" in req: return not evaluate_requirements_object(req.get("not"), host, action_def) # Atomic requirements if isinstance(req, dict): if "action" in req: action = str(req.get("action") or "").strip() status = str(req.get("status") or "success").strip() scope = str(req.get("scope") or "host").strip().lower() if scope == "global": r = shared_data.db.query( """ SELECT 1 FROM action_queue WHERE action_name=? AND status=? ORDER BY completed_at DESC LIMIT 1 """, (action, status), ) return bool(r) # Host scope r = shared_data.db.query( """ SELECT 1 FROM action_queue WHERE mac_address=? AND action_name=? AND status=? ORDER BY completed_at DESC LIMIT 1 """, (mac, action, status), ) return bool(r) if "has_port" in req: want = str(req.get("has_port")) return want in set(_normalize_ports(host.get("ports"))) if "has_cred" in req: svc = str(req.get("has_cred") or "").lower() r = shared_data.db.query( "SELECT 1 FROM creds WHERE mac_address=? AND LOWER(service)=? LIMIT 1", (mac, svc), ) return bool(r) if "has_cve" in req: cve = str(req.get("has_cve") or "") r = shared_data.db.query( "SELECT 1 FROM vulnerabilities WHERE mac_address=? AND vuln_id=? AND is_active=1 LIMIT 1", (mac, cve), ) return bool(r) if "has_cpe" in req: cpe = str(req.get("has_cpe") or "") r = shared_data.db.query( "SELECT 1 FROM detected_software WHERE mac_address=? AND cpe=? AND is_active=1 LIMIT 1", (mac, cpe), ) return bool(r) if "mac_is" in req: return str(mac).lower() == str(req.get("mac_is") or "").lower() if "essid_is" in req: return (host.get("essid") or "") == str(req.get("essid_is") or "") if "service_is_open" in req: svc = str(req.get("service_is_open") or "").lower() return _has_open_service(mac, svc, host) # Default: truthy return bool(req)