Files
Bjorn/action_scheduler.py

1238 lines
45 KiB
Python

# action_scheduler.py
# Smart Action Scheduler for Bjorn - queue-only implementation
# Handles trigger evaluation, requirements checking, and queue management.
#
# Invariants we enforce:
# - At most ONE "active" row per (action_name, mac_address, COALESCE(port,0))
# where active ∈ {'scheduled','pending','running'}.
# - Retries for failed entries are coordinated by cleanup_queue() (with backoff)
# and never compete with trigger-based enqueues.
#
# Runtime knobs (from shared.py):
# shared_data.retry_success_actions : bool (default False)
# shared_data.retry_failed_actions : bool (default True)
#
# These take precedence over cooldown / rate-limit for NON-interval triggers.
from __future__ import annotations
import json
import time
import threading
from datetime import datetime, timedelta, timezone
from typing import Any, Dict, List, Optional, Tuple
from init_shared import shared_data
from logger import Logger
logger = Logger(name="action_scheduler.py")
# ---------- UTC helpers (match SQLite's UTC CURRENT_TIMESTAMP) ----------
def _utcnow() -> datetime:
"""Naive UTC datetime to compare with SQLite TEXT timestamps (UTC)."""
return datetime.now(timezone.utc).replace(tzinfo=None)
def _utcnow_str() -> str:
"""UTC 'YYYY-MM-DD HH:MM:SS' string to compare against SQLite TEXT."""
return _utcnow().strftime("%Y-%m-%d %H:%M:%S")
def _db_ts(dt: datetime) -> str:
"""Format any datetime as 'YYYY-MM-DD HH:MM:SS' (UTC expected)."""
return dt.strftime("%Y-%m-%d %H:%M:%S")
# Service → fallback ports (used when port_services table has nothing for a host)
SERVICE_PORTS: Dict[str, List[str]] = {
"ssh": ["22"],
"http": ["80", "8080"],
"https": ["443"],
"smb": ["445"],
"ftp": ["21"],
"telnet": ["23"],
"mysql": ["3306"],
"mssql": ["1433"],
"postgres": ["5432"],
"rdp": ["3389"],
"vnc": ["5900"],
}
class ActionScheduler:
"""
Smart scheduler that evaluates triggers and enqueues actions.
Does NOT execute actions - that's the orchestrator's job.
"""
def __init__(self, shared_data_):
self.shared_data = shared_data_
self.db = shared_data_.db
# Controller MAC for global actions
self.ctrl_mac = (self.shared_data.get_raspberry_mac() or "__GLOBAL__").lower()
self._ensure_host_exists(self.ctrl_mac)
# Runtime flags
self.running = True
self.check_interval = 5 # seconds between iterations
# Action definition cache
self._action_definitions: Dict[str, Dict[str, Any]] = {}
self._last_cache_refresh = 0.0
self._cache_ttl = 60.0 # seconds
# Memory for global actions
self._last_global_runs: Dict[str, float] = {}
# Actions Studio last source type
self._last_source_is_studio: Optional[bool] = None
# Enforce DB invariants (idempotent)
self._ensure_db_invariants()
logger.info("ActionScheduler initialized")
# --------------------------------------------------------------------- loop
def run(self):
"""Main scheduler loop."""
logger.info("ActionScheduler starting main loop")
while self.running and not self.shared_data.orchestrator_should_exit:
try:
# Refresh action cache if needed
self._refresh_cache_if_needed()
# 1) Promote scheduled actions that are due
self._promote_scheduled_to_pending()
# 2) Publish next scheduled occurrences for interval actions
self._publish_all_upcoming()
# 3) Evaluate global on_start actions
self._evaluate_global_actions()
# 4) Evaluate per-host triggers
self.evaluate_all_triggers()
# 5) Queue maintenance
self.cleanup_queue()
self.update_priorities()
time.sleep(self.check_interval)
except Exception as e:
logger.error(f"Error in scheduler loop: {e}")
time.sleep(self.check_interval)
logger.info("ActionScheduler stopped")
def stop(self):
"""Stop the scheduler."""
logger.info("Stopping ActionScheduler...")
self.running = False
# --------------------------------------------------------------- definitions
# ---------- replace this method ----------
def _refresh_cache_if_needed(self):
"""Refresh action definitions cache if expired or source flipped."""
now = time.time()
use_studio = bool(getattr(self.shared_data, "use_actions_studio", False))
# Refresh if TTL expired or the source changed (actions ↔ studio_actions)
if (now - self._last_cache_refresh > self._cache_ttl) or (self._last_source_is_studio != use_studio):
self._refresh_action_cache(use_studio=use_studio)
self._last_cache_refresh = now
self._last_source_is_studio = use_studio
# ---------- replace this method ----------
def _refresh_action_cache(self, use_studio: Optional[bool] = None):
"""Reload action definitions from database, from 'actions' or 'studio' view."""
if use_studio is None:
use_studio = bool(getattr(self.shared_data, "use_actions_studio", False))
try:
if use_studio:
# Primary: studio
actions = self.db.list_studio_actions()
source = "studio"
else:
# Primary: plain actions
actions = self.db.list_actions()
source = "actions"
# Build cache (expect same action schema: b_class, b_trigger, b_action, etc.)
self._action_definitions = {a["b_class"]: a for a in actions}
logger.info(f"Refreshed action cache from '{source}': {len(self._action_definitions)} actions")
except AttributeError as e:
# Fallback if the chosen method isn't available on the DB adapter
if use_studio and hasattr(self.db, "list_actions"):
logger.warning(f"DB has no list_studio_actions(); falling back to list_actions(): {e}")
try:
actions = self.db.list_actions()
self._action_definitions = {a["b_class"]: a for a in actions}
logger.info(f"Refreshed action cache from 'actions' (fallback): {len(self._action_definitions)} actions")
return
except Exception as ee:
logger.error(f"Fallback list_actions() failed: {ee}")
else:
logger.error(f"Action cache refresh failed (no suitable DB method): {e}")
except Exception as e:
logger.error(f"Failed to refresh action cache: {e}")
# ------------------------------------------------------------------ helpers
def _ensure_db_invariants(self):
"""
Create a partial UNIQUE index that forbids more than one active entry
for the same (action_name, mac_address, COALESCE(port,0)).
"""
try:
self.db.execute(
"""
CREATE UNIQUE INDEX IF NOT EXISTS uq_action_active
ON action_queue(action_name, mac_address, COALESCE(port,0))
WHERE status IN ('scheduled','pending','running')
"""
)
except Exception as e:
# If the SQLite build does not support partial/expression indexes,
# we still have app-level guards (NOT EXISTS inserts). But this
# index is recommended to make the invariant bulletproof.
logger.warning(f"Could not create unique partial index (fallback to app-level guards): {e}")
def _promote_scheduled_to_pending(self):
"""Promote due scheduled actions to pending status."""
try:
promoted = self.db.promote_due_scheduled_to_pending()
if promoted:
logger.debug(f"Promoted {promoted} scheduled action(s) to pending")
except Exception as e:
logger.error(f"Failed to promote scheduled actions: {e}")
def _ensure_host_exists(self, mac: str):
"""Ensure host exists in database (idempotent)."""
if not mac:
return
try:
self.db.execute(
"""
INSERT INTO hosts (mac_address, alive, updated_at)
VALUES (?, 0, CURRENT_TIMESTAMP)
ON CONFLICT(mac_address) DO UPDATE SET
updated_at = CURRENT_TIMESTAMP
""",
(mac,),
)
except Exception:
pass
# ---------------------------------------------------------- interval logic
def _parse_interval_seconds(self, trigger: str) -> int:
"""Parse interval from trigger string 'on_interval:SECONDS'."""
if not trigger or not trigger.startswith("on_interval:"):
return 0
try:
return max(0, int(trigger.split(":", 1)[1] or 0))
except Exception:
return 0
def _publish_all_upcoming(self):
"""
Publish next scheduled occurrence for all interval actions.
NOTE: By design, the runtime flags do not cancel interval publishing.
"""
# Global interval actions
for action in self._action_definitions.values():
if (action.get("b_action") or "normal") != "global":
continue
if int(action.get("b_enabled", 1) or 1) != 1:
continue
trigger = (action.get("b_trigger") or "").strip()
interval = self._parse_interval_seconds(trigger)
if interval <= 0:
continue
self._publish_next_schedule_for_global(action, interval)
# Per-host interval actions
try:
hosts = self.db.get_all_hosts()
except Exception:
hosts = []
for host in hosts:
if not host.get("alive"):
continue
mac = host.get("mac_address") or ""
if not mac:
continue
for action in self._action_definitions.values():
if (action.get("b_action") or "normal") == "global":
continue
if int(action.get("b_enabled", 1) or 1) != 1:
continue
trigger = (action.get("b_trigger") or "").strip()
interval = self._parse_interval_seconds(trigger)
if interval <= 0:
continue
self._publish_next_schedule_for_host(host, action, interval)
def _publish_next_schedule_for_global(self, action_def: Dict[str, Any], interval: int):
"""Publish next scheduled occurrence for a global action."""
try:
action_name = action_def["b_class"]
mac = self.ctrl_mac
# Already active?
active = self.db.query(
"""
SELECT 1 FROM action_queue
WHERE action_name=? AND mac_address=?
AND status IN ('scheduled','pending','running')
LIMIT 1
""",
(action_name, mac),
)
if active:
return
# Next occurrence immediately after last completion, else now (UTC)
last = self._get_last_global_execution_time(action_name)
next_run = _utcnow() if not last else (last + timedelta(seconds=interval))
scheduled_for = _db_ts(next_run)
inserted = self.db.ensure_scheduled_occurrence(
action_name=action_name,
next_run_at=scheduled_for,
mac=mac,
ip="0.0.0.0",
priority=int(action_def.get("b_priority", 40) or 40),
trigger="scheduler",
tags=action_def.get("b_tags", []),
metadata={"interval": interval, "is_global": True},
max_retries=int(action_def.get("b_max_retries", 3) or 3),
)
if inserted:
logger.debug(f"Scheduled global '{action_name}' at {scheduled_for}")
except Exception as e:
logger.error(f"Failed to publish global schedule: {e}")
def _publish_next_schedule_for_host(self, host: Dict[str, Any], action_def: Dict[str, Any], interval: int):
"""Publish next scheduled occurrence for a per-host action."""
try:
mac = host.get("mac_address") or ""
if not mac:
return
self._ensure_host_exists(mac)
action_name = action_def["b_class"]
# Already active?
active = self.db.query(
"""
SELECT 1 FROM action_queue
WHERE action_name=? AND mac_address=?
AND status IN ('scheduled','pending','running')
LIMIT 1
""",
(action_name, mac),
)
if active:
return
# Next occurrence immediately after last completion, else now (UTC)
last = self._get_last_execution_time(mac, action_name)
next_run = _utcnow() if not last else (last + timedelta(seconds=interval))
scheduled_for = _db_ts(next_run)
inserted = self.db.ensure_scheduled_occurrence(
action_name=action_name,
next_run_at=scheduled_for,
mac=mac,
ip=(host.get("ips") or "").split(";")[0] if host.get("ips") else "",
priority=int(action_def.get("b_priority", 40) or 40),
trigger="scheduler",
tags=action_def.get("b_tags", []),
metadata={"interval": interval, "is_global": False},
max_retries=int(action_def.get("b_max_retries", 3) or 3),
)
if inserted:
logger.debug(f"Scheduled '{action_name}' for {mac} at {scheduled_for}")
except Exception as e:
logger.error(f"Failed to publish host schedule: {e}")
# ------------------------------------------------------------ global start
def _evaluate_global_actions(self):
"""Evaluate and queue global actions with on_start trigger."""
self._globals_lock = getattr(self, "_globals_lock", threading.Lock())
with self._globals_lock:
try:
for action in self._action_definitions.values():
if (action.get("b_action") or "normal") != "global":
continue
if int(action.get("b_enabled", 1)) != 1:
continue
trigger = (action.get("b_trigger") or "").strip()
if trigger != "on_start":
continue
action_name = action["b_class"]
# Already executed at least once?
last = self._get_last_global_execution_time(action_name)
if last is not None:
continue
# Already queued?
existing = self.db.query(
"""
SELECT 1 FROM action_queue
WHERE action_name=? AND status IN ('scheduled','pending','running')
LIMIT 1
""",
(action_name,),
)
if existing:
continue
# Queue the action
self._queue_global_action(action)
self._last_global_runs[action_name] = time.time()
logger.info(f"Queued global action {action_name}")
except Exception as e:
logger.error(f"Error evaluating global actions: {e}")
def _queue_global_action(self, action_def: Dict[str, Any]):
"""Queue a global action for execution (idempotent insert)."""
action_name = action_def["b_class"]
mac = self.ctrl_mac
ip = "0.0.0.0"
timeout = int(action_def.get("b_timeout", 300) or 300)
expires_at = _db_ts(_utcnow() + timedelta(seconds=timeout))
metadata = {
"trigger": action_def.get("b_trigger", ""),
"requirements": action_def.get("b_requires", ""),
"timeout": timeout,
"is_global": True,
}
try:
self._ensure_host_exists(mac)
# Guard with NOT EXISTS to avoid races
self.db.execute(
"""
INSERT INTO action_queue (
action_name, mac_address, ip, port, hostname, service,
priority, status, max_retries, expires_at,
trigger_source, tags, metadata
)
SELECT ?, ?, ?, NULL, NULL, NULL,
?, 'pending', ?, ?, ?, ?, ?
WHERE NOT EXISTS (
SELECT 1 FROM action_queue
WHERE action_name=? AND mac_address=? AND COALESCE(port,0)=0
AND status IN ('scheduled','pending','running')
)
""",
(
action_name,
mac,
ip,
int(action_def.get("b_priority", 50) or 50),
int(action_def.get("b_max_retries", 3) or 3),
expires_at,
action_def.get("b_trigger", ""),
json.dumps(action_def.get("b_tags", [])),
json.dumps(metadata),
action_name,
mac,
),
)
except Exception as e:
logger.error(f"Failed to queue global action {action_name}: {e}")
# ------------------------------------------------------------- host path
def evaluate_all_triggers(self):
"""Evaluate triggers for all hosts."""
hosts = self.db.get_all_hosts() # include dead hosts for on_leave trigger
for host in hosts:
mac = host["mac_address"]
for action_name, action_def in self._action_definitions.items():
# Skip global actions
if (action_def.get("b_action") or "normal") == "global":
continue
# Skip disabled actions
if not int(action_def.get("b_enabled", 1)):
continue
trigger = (action_def.get("b_trigger") or "").strip()
if not trigger:
continue
# Skip interval triggers (handled elsewhere)
if trigger.startswith("on_interval:"):
continue
# Evaluate trigger
if not evaluate_trigger(trigger, host, action_def):
continue
# Evaluate requirements
requires = action_def.get("b_requires", "")
if requires and not evaluate_requirements(requires, host, action_def):
continue
# Resolve target port/service
target_port, target_service = self._resolve_target_port_service(mac, host, action_def)
# Decide if we should enqueue
if not self._should_queue_action(mac, action_name, action_def, target_port):
continue
# Queue the action
self._queue_action(host, action_def, target_port, target_service)
logger.info(f"Queued {action_name} for {mac} (port={target_port}, service={target_service})")
def _resolve_target_port_service(
self, mac: str, host: Dict[str, Any], action_def: Dict[str, Any]
) -> Tuple[Optional[int], Optional[str]]:
"""Resolve target port and service for action (service wins over port when present)."""
ports = _normalize_ports(host.get("ports"))
target_port: Optional[int] = None
target_service: Optional[str] = None
# Try b_service first
if action_def.get("b_service"):
try:
services = (
json.loads(action_def["b_service"])
if isinstance(action_def["b_service"], str)
else action_def["b_service"]
)
except Exception:
services = []
if services:
for svc in services:
row = self.db.query(
"SELECT port FROM port_services "
"WHERE mac_address=? AND state='open' AND LOWER(service)=? "
"ORDER BY last_seen DESC LIMIT 1",
(mac, str(svc).lower()),
)
if row:
target_port = int(row[0]["port"])
target_service = str(svc).lower()
break
# Fallback to b_port
if target_port is None and action_def.get("b_port"):
if str(action_def["b_port"]) in ports:
target_port = int(action_def["b_port"])
return target_port, target_service
# ----------------------------------------------------- re-queue policy core
def _get_last_status(self, mac: str, action_name: str, target_port: Optional[int]) -> Optional[str]:
"""
Return last known status for (mac, action, port), considering the
chronological fields (completed_at > started_at > scheduled_for > created_at).
"""
self_port = 0 if target_port is None else int(target_port)
row = self.db.query(
"""
SELECT status
FROM action_queue
WHERE mac_address=? AND action_name=? AND COALESCE(port,0)=?
ORDER BY datetime(COALESCE(completed_at, started_at, scheduled_for, created_at)) DESC
LIMIT 1
""",
(mac, action_name, self_port),
)
return row[0]["status"] if row else None
def _should_queue_action(
self, mac: str, action_name: str, action_def: Dict[str, Any], target_port: Optional[int]
) -> bool:
"""
Decide if we should enqueue a new job.
Evaluation order:
0) no duplicate active job
1) runtime flags (retry_success_actions / retry_failed_actions)
1-bis) do NOT enqueue if a retryable failed exists (let cleanup_queue() handle it)
2) cooldown
3) rate limit
"""
self_port = 0 if target_port is None else int(target_port)
# 0) Duplicate protection (active)
existing = self.db.query(
"""
SELECT 1 FROM action_queue
WHERE mac_address=? AND action_name=? AND COALESCE(port,0)=?
AND status IN ('scheduled','pending','running')
LIMIT 1
""",
(mac, action_name, self_port),
)
if existing:
return False
# 1) Runtime flags take precedence
allow_success = bool(getattr(self.shared_data, "retry_success_actions", False))
allow_failed = bool(getattr(self.shared_data, "retry_failed_actions", True))
last_status = self._get_last_status(mac, action_name, target_port)
if last_status == "success" and not allow_success:
return False
if last_status == "failed" and not allow_failed:
return False
# 1-bis) If a retryable failed exists, let cleanup_queue() requeue it (avoid duplicates)
if allow_failed:
retryable = self.db.query(
"""
SELECT 1 FROM action_queue
WHERE mac_address=? AND action_name=? AND COALESCE(port,0)=?
AND status='failed'
AND retry_count < max_retries
AND COALESCE(error_message,'') != 'expired'
LIMIT 1
""",
(mac, action_name, self_port),
)
if retryable:
return False
# 2) Cooldown (UTC)
cooldown = int(action_def.get("b_cooldown", 0) or 0)
if cooldown > 0:
last_exec = self._get_last_execution_time(mac, action_name)
if last_exec and (_utcnow() - last_exec).total_seconds() < cooldown:
return False
# 3) Rate limit (UTC)
rate_limit = (action_def.get("b_rate_limit") or "").strip()
if rate_limit and not self._check_rate_limit(mac, action_name, rate_limit):
return False
return True
def _queue_action(
self, host: Dict[str, Any], action_def: Dict[str, Any], target_port: Optional[int], target_service: Optional[str]
):
"""Queue action for execution (idempotent insert with NOT EXISTS guard)."""
action_name = action_def["b_class"]
mac = host["mac_address"]
timeout = int(action_def.get("b_timeout", 300) or 300)
expires_at = _db_ts(_utcnow() + timedelta(seconds=timeout))
self_port = 0 if target_port is None else int(target_port)
metadata = {
"trigger": action_def.get("b_trigger", ""),
"requirements": action_def.get("b_requires", ""),
"is_global": False,
"timeout": timeout,
"ports_snapshot": host.get("ports") or "",
}
try:
self.db.execute(
"""
INSERT INTO action_queue (
action_name, mac_address, ip, port, hostname, service,
priority, status, max_retries, expires_at,
trigger_source, tags, metadata
)
SELECT ?, ?, ?, ?, ?, ?,
?, 'pending', ?, ?, ?, ?, ?
WHERE NOT EXISTS (
SELECT 1 FROM action_queue
WHERE mac_address=? AND action_name=? AND COALESCE(port,0)=?
AND status IN ('scheduled','pending','running')
)
""",
(
action_name,
mac,
(host.get("ips") or "").split(";")[0] if host.get("ips") else "",
target_port,
(host.get("hostnames") or "").split(";")[0] if host.get("hostnames") else "",
target_service,
int(action_def.get("b_priority", 50) or 50),
int(action_def.get("b_max_retries", 3) or 3),
expires_at,
action_def.get("b_trigger", ""),
json.dumps(action_def.get("b_tags", [])),
json.dumps(metadata),
mac,
action_name,
self_port,
),
)
except Exception as e:
logger.error(f"Failed to queue {action_name} for {mac}: {e}")
# ------------------------------------------------------------- last times
def _get_last_execution_time(self, mac: str, action_name: str) -> Optional[datetime]:
"""Get last execution time (DB read only; naive UTC)."""
row = self.db.query(
"""
SELECT completed_at FROM action_queue
WHERE mac_address=? AND action_name=? AND status IN ('success','failed')
ORDER BY completed_at DESC
LIMIT 1
""",
(mac, action_name),
)
if row and row[0].get("completed_at"):
try:
return datetime.fromisoformat(row[0]["completed_at"])
except Exception:
return None
return None
def _get_last_global_execution_time(self, action_name: str) -> Optional[datetime]:
"""Get last global action execution time (naive UTC)."""
row = self.db.query(
"""
SELECT completed_at FROM action_queue
WHERE action_name=? AND status IN ('success','failed')
ORDER BY completed_at DESC
LIMIT 1
""",
(action_name,),
)
if row and row[0].get("completed_at"):
try:
return datetime.fromisoformat(row[0]["completed_at"])
except Exception:
return None
return None
# ------------------------------------------------------------- constraints
def _check_rate_limit(self, mac: str, action_name: str, rate_limit: str) -> bool:
"""
Check "X/SECONDS" rate-limit (count based on created_at).
Returns True if action is allowed to queue.
"""
try:
max_count, period = rate_limit.split("/")
max_count = int(max_count)
period = int(period)
since = _db_ts(_utcnow() - timedelta(seconds=period))
count = self.db.query(
"""
SELECT COUNT(*) AS c FROM action_queue
WHERE mac_address=? AND action_name=? AND created_at >= ?
""",
(mac, action_name, since),
)[0]["c"]
return int(count) < max_count
except Exception:
# Invalid format -> do not block
return True
# -------------------------------------------------------------- maintenance
def cleanup_queue(self):
"""Clean up queue: timeouts, retries, purge old entries."""
try:
now_iso = _utcnow_str()
# 1) Expire pending actions
self.db.execute(
"""
UPDATE action_queue
SET status='failed',
completed_at=CURRENT_TIMESTAMP,
error_message=COALESCE(error_message,'expired')
WHERE status='pending'
AND expires_at IS NOT NULL
AND expires_at < ?
""",
(now_iso,),
)
# 2) Timeout running actions
self.db.execute(
"""
UPDATE action_queue
SET status='failed',
completed_at=CURRENT_TIMESTAMP,
error_message=COALESCE(error_message,'timeout')
WHERE status='running'
AND started_at IS NOT NULL
AND datetime(started_at, '+' || COALESCE(
CAST(json_extract(metadata, '$.timeout') AS INTEGER), 900
) || ' seconds') <= datetime('now')
"""
)
# 3) Retry failed actions with exponential backoff
if bool(getattr(self.shared_data, "retry_failed_actions", True)):
# Only if no active job exists for the same (action, mac, port)
self.db.execute(
"""
UPDATE action_queue AS a
SET status='pending',
retry_count = retry_count + 1,
scheduled_for = datetime(
'now',
'+' || (
CASE
WHEN (60 * (1 << retry_count)) > 900 THEN 900
ELSE (60 * (1 << retry_count))
END
) || ' seconds'
),
error_message = NULL,
started_at = NULL,
completed_at = NULL
WHERE a.status='failed'
AND a.retry_count < a.max_retries
AND COALESCE(a.error_message,'') != 'expired'
AND NOT EXISTS (
SELECT 1 FROM action_queue b
WHERE b.mac_address=a.mac_address
AND b.action_name=a.action_name
AND COALESCE(b.port,0)=COALESCE(a.port,0)
AND b.status IN ('scheduled','pending','running')
)
"""
)
# 4) Purge old completed entries
old_date = _db_ts(_utcnow() - timedelta(days=7))
self.db.execute(
"""
DELETE FROM action_queue
WHERE status IN ('success','failed','cancelled','expired')
AND completed_at < ?
""",
(old_date,),
)
except Exception as e:
logger.error(f"Failed to cleanup queue: {e}")
def update_priorities(self):
"""Boost priority for actions waiting too long (anti-starvation)."""
try:
self.db.execute(
"""
UPDATE action_queue
SET priority = MIN(100, priority + 1)
WHERE status='pending'
AND julianday('now') - julianday(created_at) > 0.0417
"""
)
except Exception as e:
logger.error(f"Failed to update priorities: {e}")
# =================================================================== helpers ==
def _normalize_ports(raw) -> List[str]:
"""Normalize ports to list of strings."""
if not raw:
return []
if isinstance(raw, list):
return [str(p).split("/")[0] for p in raw if p is not None and str(p) != ""]
if isinstance(raw, int):
return [str(raw)]
if isinstance(raw, str):
s = raw.strip()
if not s:
return []
if s.startswith("[") and s.endswith("]"):
try:
arr = json.loads(s)
return [str(p).split("/")[0] for p in arr]
except Exception:
pass
if ";" in s:
return [p.strip().split("/")[0] for p in s.split(";") if p.strip()]
return [s.split("/")[0]]
return [str(raw)]
def _has_open_service(mac: str, svc: str, host: Dict[str, Any]) -> bool:
"""Check if service is open for host (port_services first, then fallback list)."""
svc = (svc or "").lower().strip()
# Check port_services table first
rows = shared_data.db.query(
"SELECT 1 FROM port_services WHERE mac_address=? AND state='open' AND LOWER(service)=? LIMIT 1",
(mac, svc),
)
if rows:
return True
# Fallback to known port numbers
ports = set(_normalize_ports(host.get("ports")))
for p in SERVICE_PORTS.get(svc, []):
if p in ports:
return True
return False
def _last_presence_event_for_mac(mac: str) -> Optional[str]:
"""Get last presence event for MAC (PresenceJoin/PresenceLeave)."""
rows = shared_data.db.query(
"""
SELECT action_name
FROM action_queue
WHERE mac_address=?
AND action_name IN ('PresenceJoin','PresenceLeave')
ORDER BY datetime(COALESCE(completed_at, started_at, scheduled_for, created_at)) DESC
LIMIT 1
""",
(mac,),
)
return rows[0]["action_name"] if rows else None
# --------------------------------------------------------------- trigger eval --
def evaluate_trigger(trigger: str, host: Dict[str, Any], action_def: Dict[str, Any]) -> bool:
"""
Evaluate trigger condition for host.
Supported triggers:
- on_start, on_host_alive, on_host_dead
- on_port_change, on_new_port:PORT
- on_service:SERVICE, on_web_service
- on_success:ACTION, on_failure:ACTION
- on_cred_found:SERVICE
- on_mac_is:MAC, on_essid_is:ESSID, on_ip_is:IP
- on_has_cve[:CVE], on_has_cpe[:CPE]
- on_all:[...], on_any:[...]
"""
try:
mac = host["mac_address"]
s = (trigger or "").strip()
if not s:
return False
# Combined triggers
if s.startswith("on_all:"):
try:
arr = json.loads(s.split(":", 1)[1])
except Exception:
return False
return all(evaluate_trigger(t, host, action_def) for t in arr)
if s.startswith("on_any:"):
try:
arr = json.loads(s.split(":", 1)[1])
except Exception:
return False
return any(evaluate_trigger(t, host, action_def) for t in arr)
# Skip interval triggers
if s.startswith("on_interval:"):
return False
# Parse trigger name and parameter
if ":" in s:
name, param = s.split(":", 1)
name = name.strip()
param = (param or "").strip()
else:
name, param = s, ""
# Aliases
if name == "on_alive":
name = "on_host_alive"
if name == "on_dead":
name = "on_host_dead"
# Join/Leave events
if name == "on_join":
if not bool(host.get("alive")):
return False
last = _last_presence_event_for_mac(mac)
return last != "PresenceJoin"
if name == "on_leave":
if bool(host.get("alive")):
return False
last = _last_presence_event_for_mac(mac)
return last != "PresenceLeave"
# Basic triggers
if name == "on_start" or name == "on_new_host":
r = shared_data.db.query(
"""
SELECT 1 FROM action_queue
WHERE mac_address=? AND action_name=?
AND status IN ('success','failed')
LIMIT 1
""",
(mac, action_def["b_class"]),
)
return not bool(r)
if name == "on_host_alive":
return bool(host.get("alive"))
if name == "on_host_dead":
return not bool(host.get("alive"))
# Skip port/service triggers for dead hosts
if not bool(host.get("alive")) and name in {"on_service", "on_web_service", "on_new_port", "on_port_change"}:
return False
# Port triggers
if name == "on_port_change":
cur = set(_normalize_ports(host.get("ports")))
prev = set(_normalize_ports(host.get("previous_ports")))
return cur != prev
if name == "on_new_port":
port = str(param)
cur = set(_normalize_ports(host.get("ports")))
prev = set(_normalize_ports(host.get("previous_ports")))
return port in cur and port not in prev
# Action status triggers
if name == "on_success":
parent = param
r = shared_data.db.query(
"""
SELECT 1 FROM action_queue
WHERE mac_address=? AND action_name=? AND status='success'
ORDER BY completed_at DESC LIMIT 1
""",
(mac, parent),
)
return bool(r)
if name == "on_failure":
parent = param
r = shared_data.db.query(
"""
SELECT 1 FROM action_queue
WHERE mac_address=? AND action_name=? AND status='failed'
ORDER BY completed_at DESC LIMIT 1
""",
(mac, parent),
)
return bool(r)
# Service triggers
if name == "on_cred_found":
service = param.lower()
r = shared_data.db.query(
"SELECT 1 FROM creds WHERE mac_address=? AND LOWER(service)=? LIMIT 1",
(mac, service),
)
return bool(r)
if name == "on_service":
return _has_open_service(mac, param, host)
if name == "on_web_service":
return _has_open_service(mac, "http", host) or _has_open_service(mac, "https", host)
# Identity triggers
if name == "on_mac_is":
return str(mac).lower() == param.lower()
if name == "on_essid_is":
return (host.get("essid") or "") == param
if name == "on_ip_is":
ips = (host.get("ips") or "").split(";") if host.get("ips") else []
return param in ips
# Vulnerability triggers
if name == "on_has_cve":
if not param:
r = shared_data.db.query(
"SELECT 1 FROM vulnerabilities WHERE mac_address=? AND is_active=1 LIMIT 1",
(mac,),
)
return bool(r)
r = shared_data.db.query(
"SELECT 1 FROM vulnerabilities WHERE mac_address=? AND vuln_id=? AND is_active=1 LIMIT 1",
(mac, param),
)
return bool(r)
if name == "on_has_cpe":
if not param:
r = shared_data.db.query(
"SELECT 1 FROM detected_software WHERE mac_address=? AND is_active=1 LIMIT 1",
(mac,),
)
return bool(r)
r = shared_data.db.query(
"SELECT 1 FROM detected_software WHERE mac_address=? AND cpe=? AND is_active=1 LIMIT 1",
(mac, param),
)
return bool(r)
# Unknown trigger
logger.debug(f"Unknown trigger: {name}")
return False
except Exception as e:
logger.error(f"Error evaluating trigger '{trigger}': {e}")
return False
# ---------------------------------------------------------- requirements eval --
def evaluate_requirements(requires: Any, host: Dict[str, Any], action_def: Dict[str, Any]) -> bool:
"""Evaluate requirements for action."""
if requires is None:
return True
# Already an object
if isinstance(requires, (dict, list)):
return evaluate_requirements_object(requires, host, action_def)
s = str(requires).strip()
if not s:
return True
# JSON string
if (s.startswith("{") and s.endswith("}")) or (s.startswith("[") and s.endswith("]")):
try:
obj = json.loads(s)
return evaluate_requirements_object(obj, host, action_def)
except Exception:
pass
# Legacy "Action:status" format
if ":" in s:
a, st = s.split(":", 1)
obj = {"action": a.strip(), "status": st.strip()}
return evaluate_requirements_object(obj, host, action_def)
return True
def evaluate_requirements_object(req: Any, host: Dict[str, Any], action_def: Dict[str, Any]) -> bool:
"""
Evaluate requirements object.
Supported:
- {"all": [...]} / {"any": [...]} / {"not": {...}}
- {"action": "ACTION", "status": "STATUS", "scope": "host|global"}
- {"has_port": PORT}
- {"has_cred": "SERVICE"}
- {"has_cve": "CVE"}
- {"has_cpe": "CPE"}
- {"mac_is": "MAC"}
- {"essid_is": "ESSID"}
- {"service_is_open": "SERVICE"}
"""
mac = host["mac_address"]
# Combinators
if isinstance(req, dict) and "all" in req:
return all(evaluate_requirements_object(x, host, action_def) for x in (req.get("all") or []))
if isinstance(req, dict) and "any" in req:
return any(evaluate_requirements_object(x, host, action_def) for x in (req.get("any") or []))
if isinstance(req, dict) and "not" in req:
return not evaluate_requirements_object(req.get("not"), host, action_def)
# Atomic requirements
if isinstance(req, dict):
if "action" in req:
action = str(req.get("action") or "").strip()
status = str(req.get("status") or "success").strip()
scope = str(req.get("scope") or "host").strip().lower()
if scope == "global":
r = shared_data.db.query(
"""
SELECT 1 FROM action_queue
WHERE action_name=? AND status=?
ORDER BY completed_at DESC LIMIT 1
""",
(action, status),
)
return bool(r)
# Host scope
r = shared_data.db.query(
"""
SELECT 1 FROM action_queue
WHERE mac_address=? AND action_name=? AND status=?
ORDER BY completed_at DESC LIMIT 1
""",
(mac, action, status),
)
return bool(r)
if "has_port" in req:
want = str(req.get("has_port"))
return want in set(_normalize_ports(host.get("ports")))
if "has_cred" in req:
svc = str(req.get("has_cred") or "").lower()
r = shared_data.db.query(
"SELECT 1 FROM creds WHERE mac_address=? AND LOWER(service)=? LIMIT 1",
(mac, svc),
)
return bool(r)
if "has_cve" in req:
cve = str(req.get("has_cve") or "")
r = shared_data.db.query(
"SELECT 1 FROM vulnerabilities WHERE mac_address=? AND vuln_id=? AND is_active=1 LIMIT 1",
(mac, cve),
)
return bool(r)
if "has_cpe" in req:
cpe = str(req.get("has_cpe") or "")
r = shared_data.db.query(
"SELECT 1 FROM detected_software WHERE mac_address=? AND cpe=? AND is_active=1 LIMIT 1",
(mac, cpe),
)
return bool(r)
if "mac_is" in req:
return str(mac).lower() == str(req.get("mac_is") or "").lower()
if "essid_is" in req:
return (host.get("essid") or "") == str(req.get("essid_is") or "")
if "service_is_open" in req:
svc = str(req.get("service_is_open") or "").lower()
return _has_open_service(mac, svc, host)
# Default: truthy
return bool(req)