mirror of
https://github.com/infinition/Bjorn.git
synced 2025-12-12 15:44:58 +00:00
1238 lines
45 KiB
Python
1238 lines
45 KiB
Python
# action_scheduler.py
|
|
# Smart Action Scheduler for Bjorn - queue-only implementation
|
|
# Handles trigger evaluation, requirements checking, and queue management.
|
|
#
|
|
# Invariants we enforce:
|
|
# - At most ONE "active" row per (action_name, mac_address, COALESCE(port,0))
|
|
# where active ∈ {'scheduled','pending','running'}.
|
|
# - Retries for failed entries are coordinated by cleanup_queue() (with backoff)
|
|
# and never compete with trigger-based enqueues.
|
|
#
|
|
# Runtime knobs (from shared.py):
|
|
# shared_data.retry_success_actions : bool (default False)
|
|
# shared_data.retry_failed_actions : bool (default True)
|
|
#
|
|
# These take precedence over cooldown / rate-limit for NON-interval triggers.
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import time
|
|
import threading
|
|
from datetime import datetime, timedelta, timezone
|
|
from typing import Any, Dict, List, Optional, Tuple
|
|
|
|
from init_shared import shared_data
|
|
from logger import Logger
|
|
|
|
logger = Logger(name="action_scheduler.py")
|
|
|
|
# ---------- UTC helpers (match SQLite's UTC CURRENT_TIMESTAMP) ----------
|
|
def _utcnow() -> datetime:
|
|
"""Naive UTC datetime to compare with SQLite TEXT timestamps (UTC)."""
|
|
return datetime.now(timezone.utc).replace(tzinfo=None)
|
|
|
|
def _utcnow_str() -> str:
|
|
"""UTC 'YYYY-MM-DD HH:MM:SS' string to compare against SQLite TEXT."""
|
|
return _utcnow().strftime("%Y-%m-%d %H:%M:%S")
|
|
|
|
def _db_ts(dt: datetime) -> str:
|
|
"""Format any datetime as 'YYYY-MM-DD HH:MM:SS' (UTC expected)."""
|
|
return dt.strftime("%Y-%m-%d %H:%M:%S")
|
|
|
|
# Service → fallback ports (used when port_services table has nothing for a host)
|
|
SERVICE_PORTS: Dict[str, List[str]] = {
|
|
"ssh": ["22"],
|
|
"http": ["80", "8080"],
|
|
"https": ["443"],
|
|
"smb": ["445"],
|
|
"ftp": ["21"],
|
|
"telnet": ["23"],
|
|
"mysql": ["3306"],
|
|
"mssql": ["1433"],
|
|
"postgres": ["5432"],
|
|
"rdp": ["3389"],
|
|
"vnc": ["5900"],
|
|
}
|
|
|
|
|
|
class ActionScheduler:
|
|
"""
|
|
Smart scheduler that evaluates triggers and enqueues actions.
|
|
Does NOT execute actions - that's the orchestrator's job.
|
|
"""
|
|
|
|
def __init__(self, shared_data_):
|
|
self.shared_data = shared_data_
|
|
self.db = shared_data_.db
|
|
|
|
# Controller MAC for global actions
|
|
self.ctrl_mac = (self.shared_data.get_raspberry_mac() or "__GLOBAL__").lower()
|
|
self._ensure_host_exists(self.ctrl_mac)
|
|
|
|
# Runtime flags
|
|
self.running = True
|
|
self.check_interval = 5 # seconds between iterations
|
|
|
|
# Action definition cache
|
|
self._action_definitions: Dict[str, Dict[str, Any]] = {}
|
|
self._last_cache_refresh = 0.0
|
|
self._cache_ttl = 60.0 # seconds
|
|
|
|
# Memory for global actions
|
|
self._last_global_runs: Dict[str, float] = {}
|
|
# Actions Studio last source type
|
|
self._last_source_is_studio: Optional[bool] = None
|
|
# Enforce DB invariants (idempotent)
|
|
self._ensure_db_invariants()
|
|
|
|
logger.info("ActionScheduler initialized")
|
|
|
|
# --------------------------------------------------------------------- loop
|
|
|
|
def run(self):
|
|
"""Main scheduler loop."""
|
|
logger.info("ActionScheduler starting main loop")
|
|
while self.running and not self.shared_data.orchestrator_should_exit:
|
|
try:
|
|
# Refresh action cache if needed
|
|
self._refresh_cache_if_needed()
|
|
|
|
# 1) Promote scheduled actions that are due
|
|
self._promote_scheduled_to_pending()
|
|
|
|
# 2) Publish next scheduled occurrences for interval actions
|
|
self._publish_all_upcoming()
|
|
|
|
# 3) Evaluate global on_start actions
|
|
self._evaluate_global_actions()
|
|
|
|
# 4) Evaluate per-host triggers
|
|
self.evaluate_all_triggers()
|
|
|
|
# 5) Queue maintenance
|
|
self.cleanup_queue()
|
|
self.update_priorities()
|
|
|
|
time.sleep(self.check_interval)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in scheduler loop: {e}")
|
|
time.sleep(self.check_interval)
|
|
|
|
logger.info("ActionScheduler stopped")
|
|
|
|
def stop(self):
|
|
"""Stop the scheduler."""
|
|
logger.info("Stopping ActionScheduler...")
|
|
self.running = False
|
|
|
|
# --------------------------------------------------------------- definitions
|
|
|
|
# ---------- replace this method ----------
|
|
def _refresh_cache_if_needed(self):
|
|
"""Refresh action definitions cache if expired or source flipped."""
|
|
now = time.time()
|
|
use_studio = bool(getattr(self.shared_data, "use_actions_studio", False))
|
|
|
|
# Refresh if TTL expired or the source changed (actions ↔ studio_actions)
|
|
if (now - self._last_cache_refresh > self._cache_ttl) or (self._last_source_is_studio != use_studio):
|
|
self._refresh_action_cache(use_studio=use_studio)
|
|
self._last_cache_refresh = now
|
|
self._last_source_is_studio = use_studio
|
|
|
|
|
|
# ---------- replace this method ----------
|
|
def _refresh_action_cache(self, use_studio: Optional[bool] = None):
|
|
"""Reload action definitions from database, from 'actions' or 'studio' view."""
|
|
if use_studio is None:
|
|
use_studio = bool(getattr(self.shared_data, "use_actions_studio", False))
|
|
|
|
try:
|
|
if use_studio:
|
|
# Primary: studio
|
|
actions = self.db.list_studio_actions()
|
|
source = "studio"
|
|
else:
|
|
# Primary: plain actions
|
|
actions = self.db.list_actions()
|
|
source = "actions"
|
|
|
|
# Build cache (expect same action schema: b_class, b_trigger, b_action, etc.)
|
|
self._action_definitions = {a["b_class"]: a for a in actions}
|
|
logger.info(f"Refreshed action cache from '{source}': {len(self._action_definitions)} actions")
|
|
|
|
except AttributeError as e:
|
|
# Fallback if the chosen method isn't available on the DB adapter
|
|
if use_studio and hasattr(self.db, "list_actions"):
|
|
logger.warning(f"DB has no list_studio_actions(); falling back to list_actions(): {e}")
|
|
try:
|
|
actions = self.db.list_actions()
|
|
self._action_definitions = {a["b_class"]: a for a in actions}
|
|
logger.info(f"Refreshed action cache from 'actions' (fallback): {len(self._action_definitions)} actions")
|
|
return
|
|
except Exception as ee:
|
|
logger.error(f"Fallback list_actions() failed: {ee}")
|
|
else:
|
|
logger.error(f"Action cache refresh failed (no suitable DB method): {e}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to refresh action cache: {e}")
|
|
|
|
|
|
# ------------------------------------------------------------------ helpers
|
|
|
|
def _ensure_db_invariants(self):
|
|
"""
|
|
Create a partial UNIQUE index that forbids more than one active entry
|
|
for the same (action_name, mac_address, COALESCE(port,0)).
|
|
"""
|
|
try:
|
|
self.db.execute(
|
|
"""
|
|
CREATE UNIQUE INDEX IF NOT EXISTS uq_action_active
|
|
ON action_queue(action_name, mac_address, COALESCE(port,0))
|
|
WHERE status IN ('scheduled','pending','running')
|
|
"""
|
|
)
|
|
except Exception as e:
|
|
# If the SQLite build does not support partial/expression indexes,
|
|
# we still have app-level guards (NOT EXISTS inserts). But this
|
|
# index is recommended to make the invariant bulletproof.
|
|
logger.warning(f"Could not create unique partial index (fallback to app-level guards): {e}")
|
|
|
|
def _promote_scheduled_to_pending(self):
|
|
"""Promote due scheduled actions to pending status."""
|
|
try:
|
|
promoted = self.db.promote_due_scheduled_to_pending()
|
|
if promoted:
|
|
logger.debug(f"Promoted {promoted} scheduled action(s) to pending")
|
|
except Exception as e:
|
|
logger.error(f"Failed to promote scheduled actions: {e}")
|
|
|
|
def _ensure_host_exists(self, mac: str):
|
|
"""Ensure host exists in database (idempotent)."""
|
|
if not mac:
|
|
return
|
|
try:
|
|
self.db.execute(
|
|
"""
|
|
INSERT INTO hosts (mac_address, alive, updated_at)
|
|
VALUES (?, 0, CURRENT_TIMESTAMP)
|
|
ON CONFLICT(mac_address) DO UPDATE SET
|
|
updated_at = CURRENT_TIMESTAMP
|
|
""",
|
|
(mac,),
|
|
)
|
|
except Exception:
|
|
pass
|
|
|
|
# ---------------------------------------------------------- interval logic
|
|
|
|
def _parse_interval_seconds(self, trigger: str) -> int:
|
|
"""Parse interval from trigger string 'on_interval:SECONDS'."""
|
|
if not trigger or not trigger.startswith("on_interval:"):
|
|
return 0
|
|
try:
|
|
return max(0, int(trigger.split(":", 1)[1] or 0))
|
|
except Exception:
|
|
return 0
|
|
|
|
def _publish_all_upcoming(self):
|
|
"""
|
|
Publish next scheduled occurrence for all interval actions.
|
|
|
|
NOTE: By design, the runtime flags do not cancel interval publishing.
|
|
"""
|
|
# Global interval actions
|
|
for action in self._action_definitions.values():
|
|
if (action.get("b_action") or "normal") != "global":
|
|
continue
|
|
if int(action.get("b_enabled", 1) or 1) != 1:
|
|
continue
|
|
|
|
trigger = (action.get("b_trigger") or "").strip()
|
|
interval = self._parse_interval_seconds(trigger)
|
|
if interval <= 0:
|
|
continue
|
|
|
|
self._publish_next_schedule_for_global(action, interval)
|
|
|
|
# Per-host interval actions
|
|
try:
|
|
hosts = self.db.get_all_hosts()
|
|
except Exception:
|
|
hosts = []
|
|
|
|
for host in hosts:
|
|
if not host.get("alive"):
|
|
continue
|
|
|
|
mac = host.get("mac_address") or ""
|
|
if not mac:
|
|
continue
|
|
|
|
for action in self._action_definitions.values():
|
|
if (action.get("b_action") or "normal") == "global":
|
|
continue
|
|
if int(action.get("b_enabled", 1) or 1) != 1:
|
|
continue
|
|
|
|
trigger = (action.get("b_trigger") or "").strip()
|
|
interval = self._parse_interval_seconds(trigger)
|
|
if interval <= 0:
|
|
continue
|
|
|
|
self._publish_next_schedule_for_host(host, action, interval)
|
|
|
|
def _publish_next_schedule_for_global(self, action_def: Dict[str, Any], interval: int):
|
|
"""Publish next scheduled occurrence for a global action."""
|
|
try:
|
|
action_name = action_def["b_class"]
|
|
mac = self.ctrl_mac
|
|
|
|
# Already active?
|
|
active = self.db.query(
|
|
"""
|
|
SELECT 1 FROM action_queue
|
|
WHERE action_name=? AND mac_address=?
|
|
AND status IN ('scheduled','pending','running')
|
|
LIMIT 1
|
|
""",
|
|
(action_name, mac),
|
|
)
|
|
if active:
|
|
return
|
|
|
|
# Next occurrence immediately after last completion, else now (UTC)
|
|
last = self._get_last_global_execution_time(action_name)
|
|
next_run = _utcnow() if not last else (last + timedelta(seconds=interval))
|
|
scheduled_for = _db_ts(next_run)
|
|
|
|
inserted = self.db.ensure_scheduled_occurrence(
|
|
action_name=action_name,
|
|
next_run_at=scheduled_for,
|
|
mac=mac,
|
|
ip="0.0.0.0",
|
|
priority=int(action_def.get("b_priority", 40) or 40),
|
|
trigger="scheduler",
|
|
tags=action_def.get("b_tags", []),
|
|
metadata={"interval": interval, "is_global": True},
|
|
max_retries=int(action_def.get("b_max_retries", 3) or 3),
|
|
)
|
|
if inserted:
|
|
logger.debug(f"Scheduled global '{action_name}' at {scheduled_for}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to publish global schedule: {e}")
|
|
|
|
def _publish_next_schedule_for_host(self, host: Dict[str, Any], action_def: Dict[str, Any], interval: int):
|
|
"""Publish next scheduled occurrence for a per-host action."""
|
|
try:
|
|
mac = host.get("mac_address") or ""
|
|
if not mac:
|
|
return
|
|
|
|
self._ensure_host_exists(mac)
|
|
action_name = action_def["b_class"]
|
|
|
|
# Already active?
|
|
active = self.db.query(
|
|
"""
|
|
SELECT 1 FROM action_queue
|
|
WHERE action_name=? AND mac_address=?
|
|
AND status IN ('scheduled','pending','running')
|
|
LIMIT 1
|
|
""",
|
|
(action_name, mac),
|
|
)
|
|
if active:
|
|
return
|
|
|
|
# Next occurrence immediately after last completion, else now (UTC)
|
|
last = self._get_last_execution_time(mac, action_name)
|
|
next_run = _utcnow() if not last else (last + timedelta(seconds=interval))
|
|
scheduled_for = _db_ts(next_run)
|
|
|
|
inserted = self.db.ensure_scheduled_occurrence(
|
|
action_name=action_name,
|
|
next_run_at=scheduled_for,
|
|
mac=mac,
|
|
ip=(host.get("ips") or "").split(";")[0] if host.get("ips") else "",
|
|
priority=int(action_def.get("b_priority", 40) or 40),
|
|
trigger="scheduler",
|
|
tags=action_def.get("b_tags", []),
|
|
metadata={"interval": interval, "is_global": False},
|
|
max_retries=int(action_def.get("b_max_retries", 3) or 3),
|
|
)
|
|
if inserted:
|
|
logger.debug(f"Scheduled '{action_name}' for {mac} at {scheduled_for}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to publish host schedule: {e}")
|
|
|
|
# ------------------------------------------------------------ global start
|
|
|
|
def _evaluate_global_actions(self):
|
|
"""Evaluate and queue global actions with on_start trigger."""
|
|
self._globals_lock = getattr(self, "_globals_lock", threading.Lock())
|
|
|
|
with self._globals_lock:
|
|
try:
|
|
for action in self._action_definitions.values():
|
|
if (action.get("b_action") or "normal") != "global":
|
|
continue
|
|
if int(action.get("b_enabled", 1)) != 1:
|
|
continue
|
|
|
|
trigger = (action.get("b_trigger") or "").strip()
|
|
if trigger != "on_start":
|
|
continue
|
|
|
|
action_name = action["b_class"]
|
|
|
|
# Already executed at least once?
|
|
last = self._get_last_global_execution_time(action_name)
|
|
if last is not None:
|
|
continue
|
|
|
|
# Already queued?
|
|
existing = self.db.query(
|
|
"""
|
|
SELECT 1 FROM action_queue
|
|
WHERE action_name=? AND status IN ('scheduled','pending','running')
|
|
LIMIT 1
|
|
""",
|
|
(action_name,),
|
|
)
|
|
if existing:
|
|
continue
|
|
|
|
# Queue the action
|
|
self._queue_global_action(action)
|
|
self._last_global_runs[action_name] = time.time()
|
|
logger.info(f"Queued global action {action_name}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error evaluating global actions: {e}")
|
|
|
|
def _queue_global_action(self, action_def: Dict[str, Any]):
|
|
"""Queue a global action for execution (idempotent insert)."""
|
|
action_name = action_def["b_class"]
|
|
mac = self.ctrl_mac
|
|
ip = "0.0.0.0"
|
|
timeout = int(action_def.get("b_timeout", 300) or 300)
|
|
expires_at = _db_ts(_utcnow() + timedelta(seconds=timeout))
|
|
|
|
metadata = {
|
|
"trigger": action_def.get("b_trigger", ""),
|
|
"requirements": action_def.get("b_requires", ""),
|
|
"timeout": timeout,
|
|
"is_global": True,
|
|
}
|
|
|
|
try:
|
|
self._ensure_host_exists(mac)
|
|
# Guard with NOT EXISTS to avoid races
|
|
self.db.execute(
|
|
"""
|
|
INSERT INTO action_queue (
|
|
action_name, mac_address, ip, port, hostname, service,
|
|
priority, status, max_retries, expires_at,
|
|
trigger_source, tags, metadata
|
|
)
|
|
SELECT ?, ?, ?, NULL, NULL, NULL,
|
|
?, 'pending', ?, ?, ?, ?, ?
|
|
WHERE NOT EXISTS (
|
|
SELECT 1 FROM action_queue
|
|
WHERE action_name=? AND mac_address=? AND COALESCE(port,0)=0
|
|
AND status IN ('scheduled','pending','running')
|
|
)
|
|
""",
|
|
(
|
|
action_name,
|
|
mac,
|
|
ip,
|
|
int(action_def.get("b_priority", 50) or 50),
|
|
int(action_def.get("b_max_retries", 3) or 3),
|
|
expires_at,
|
|
action_def.get("b_trigger", ""),
|
|
json.dumps(action_def.get("b_tags", [])),
|
|
json.dumps(metadata),
|
|
action_name,
|
|
mac,
|
|
),
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Failed to queue global action {action_name}: {e}")
|
|
|
|
# ------------------------------------------------------------- host path
|
|
|
|
def evaluate_all_triggers(self):
|
|
"""Evaluate triggers for all hosts."""
|
|
hosts = self.db.get_all_hosts() # include dead hosts for on_leave trigger
|
|
for host in hosts:
|
|
mac = host["mac_address"]
|
|
|
|
for action_name, action_def in self._action_definitions.items():
|
|
# Skip global actions
|
|
if (action_def.get("b_action") or "normal") == "global":
|
|
continue
|
|
|
|
# Skip disabled actions
|
|
if not int(action_def.get("b_enabled", 1)):
|
|
continue
|
|
|
|
trigger = (action_def.get("b_trigger") or "").strip()
|
|
if not trigger:
|
|
continue
|
|
|
|
# Skip interval triggers (handled elsewhere)
|
|
if trigger.startswith("on_interval:"):
|
|
continue
|
|
|
|
# Evaluate trigger
|
|
if not evaluate_trigger(trigger, host, action_def):
|
|
continue
|
|
|
|
# Evaluate requirements
|
|
requires = action_def.get("b_requires", "")
|
|
if requires and not evaluate_requirements(requires, host, action_def):
|
|
continue
|
|
|
|
# Resolve target port/service
|
|
target_port, target_service = self._resolve_target_port_service(mac, host, action_def)
|
|
|
|
# Decide if we should enqueue
|
|
if not self._should_queue_action(mac, action_name, action_def, target_port):
|
|
continue
|
|
|
|
# Queue the action
|
|
self._queue_action(host, action_def, target_port, target_service)
|
|
logger.info(f"Queued {action_name} for {mac} (port={target_port}, service={target_service})")
|
|
|
|
def _resolve_target_port_service(
|
|
self, mac: str, host: Dict[str, Any], action_def: Dict[str, Any]
|
|
) -> Tuple[Optional[int], Optional[str]]:
|
|
"""Resolve target port and service for action (service wins over port when present)."""
|
|
ports = _normalize_ports(host.get("ports"))
|
|
target_port: Optional[int] = None
|
|
target_service: Optional[str] = None
|
|
|
|
# Try b_service first
|
|
if action_def.get("b_service"):
|
|
try:
|
|
services = (
|
|
json.loads(action_def["b_service"])
|
|
if isinstance(action_def["b_service"], str)
|
|
else action_def["b_service"]
|
|
)
|
|
except Exception:
|
|
services = []
|
|
|
|
if services:
|
|
for svc in services:
|
|
row = self.db.query(
|
|
"SELECT port FROM port_services "
|
|
"WHERE mac_address=? AND state='open' AND LOWER(service)=? "
|
|
"ORDER BY last_seen DESC LIMIT 1",
|
|
(mac, str(svc).lower()),
|
|
)
|
|
if row:
|
|
target_port = int(row[0]["port"])
|
|
target_service = str(svc).lower()
|
|
break
|
|
|
|
# Fallback to b_port
|
|
if target_port is None and action_def.get("b_port"):
|
|
if str(action_def["b_port"]) in ports:
|
|
target_port = int(action_def["b_port"])
|
|
|
|
return target_port, target_service
|
|
|
|
# ----------------------------------------------------- re-queue policy core
|
|
|
|
def _get_last_status(self, mac: str, action_name: str, target_port: Optional[int]) -> Optional[str]:
|
|
"""
|
|
Return last known status for (mac, action, port), considering the
|
|
chronological fields (completed_at > started_at > scheduled_for > created_at).
|
|
"""
|
|
self_port = 0 if target_port is None else int(target_port)
|
|
row = self.db.query(
|
|
"""
|
|
SELECT status
|
|
FROM action_queue
|
|
WHERE mac_address=? AND action_name=? AND COALESCE(port,0)=?
|
|
ORDER BY datetime(COALESCE(completed_at, started_at, scheduled_for, created_at)) DESC
|
|
LIMIT 1
|
|
""",
|
|
(mac, action_name, self_port),
|
|
)
|
|
return row[0]["status"] if row else None
|
|
|
|
def _should_queue_action(
|
|
self, mac: str, action_name: str, action_def: Dict[str, Any], target_port: Optional[int]
|
|
) -> bool:
|
|
"""
|
|
Decide if we should enqueue a new job.
|
|
|
|
Evaluation order:
|
|
0) no duplicate active job
|
|
1) runtime flags (retry_success_actions / retry_failed_actions)
|
|
1-bis) do NOT enqueue if a retryable failed exists (let cleanup_queue() handle it)
|
|
2) cooldown
|
|
3) rate limit
|
|
"""
|
|
self_port = 0 if target_port is None else int(target_port)
|
|
|
|
# 0) Duplicate protection (active)
|
|
existing = self.db.query(
|
|
"""
|
|
SELECT 1 FROM action_queue
|
|
WHERE mac_address=? AND action_name=? AND COALESCE(port,0)=?
|
|
AND status IN ('scheduled','pending','running')
|
|
LIMIT 1
|
|
""",
|
|
(mac, action_name, self_port),
|
|
)
|
|
if existing:
|
|
return False
|
|
|
|
# 1) Runtime flags take precedence
|
|
allow_success = bool(getattr(self.shared_data, "retry_success_actions", False))
|
|
allow_failed = bool(getattr(self.shared_data, "retry_failed_actions", True))
|
|
last_status = self._get_last_status(mac, action_name, target_port)
|
|
|
|
if last_status == "success" and not allow_success:
|
|
return False
|
|
if last_status == "failed" and not allow_failed:
|
|
return False
|
|
|
|
# 1-bis) If a retryable failed exists, let cleanup_queue() requeue it (avoid duplicates)
|
|
if allow_failed:
|
|
retryable = self.db.query(
|
|
"""
|
|
SELECT 1 FROM action_queue
|
|
WHERE mac_address=? AND action_name=? AND COALESCE(port,0)=?
|
|
AND status='failed'
|
|
AND retry_count < max_retries
|
|
AND COALESCE(error_message,'') != 'expired'
|
|
LIMIT 1
|
|
""",
|
|
(mac, action_name, self_port),
|
|
)
|
|
if retryable:
|
|
return False
|
|
|
|
# 2) Cooldown (UTC)
|
|
cooldown = int(action_def.get("b_cooldown", 0) or 0)
|
|
if cooldown > 0:
|
|
last_exec = self._get_last_execution_time(mac, action_name)
|
|
if last_exec and (_utcnow() - last_exec).total_seconds() < cooldown:
|
|
return False
|
|
|
|
# 3) Rate limit (UTC)
|
|
rate_limit = (action_def.get("b_rate_limit") or "").strip()
|
|
if rate_limit and not self._check_rate_limit(mac, action_name, rate_limit):
|
|
return False
|
|
|
|
return True
|
|
|
|
def _queue_action(
|
|
self, host: Dict[str, Any], action_def: Dict[str, Any], target_port: Optional[int], target_service: Optional[str]
|
|
):
|
|
"""Queue action for execution (idempotent insert with NOT EXISTS guard)."""
|
|
action_name = action_def["b_class"]
|
|
mac = host["mac_address"]
|
|
timeout = int(action_def.get("b_timeout", 300) or 300)
|
|
expires_at = _db_ts(_utcnow() + timedelta(seconds=timeout))
|
|
self_port = 0 if target_port is None else int(target_port)
|
|
|
|
metadata = {
|
|
"trigger": action_def.get("b_trigger", ""),
|
|
"requirements": action_def.get("b_requires", ""),
|
|
"is_global": False,
|
|
"timeout": timeout,
|
|
"ports_snapshot": host.get("ports") or "",
|
|
}
|
|
|
|
try:
|
|
self.db.execute(
|
|
"""
|
|
INSERT INTO action_queue (
|
|
action_name, mac_address, ip, port, hostname, service,
|
|
priority, status, max_retries, expires_at,
|
|
trigger_source, tags, metadata
|
|
)
|
|
SELECT ?, ?, ?, ?, ?, ?,
|
|
?, 'pending', ?, ?, ?, ?, ?
|
|
WHERE NOT EXISTS (
|
|
SELECT 1 FROM action_queue
|
|
WHERE mac_address=? AND action_name=? AND COALESCE(port,0)=?
|
|
AND status IN ('scheduled','pending','running')
|
|
)
|
|
""",
|
|
(
|
|
action_name,
|
|
mac,
|
|
(host.get("ips") or "").split(";")[0] if host.get("ips") else "",
|
|
target_port,
|
|
(host.get("hostnames") or "").split(";")[0] if host.get("hostnames") else "",
|
|
target_service,
|
|
int(action_def.get("b_priority", 50) or 50),
|
|
int(action_def.get("b_max_retries", 3) or 3),
|
|
expires_at,
|
|
action_def.get("b_trigger", ""),
|
|
json.dumps(action_def.get("b_tags", [])),
|
|
json.dumps(metadata),
|
|
mac,
|
|
action_name,
|
|
self_port,
|
|
),
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Failed to queue {action_name} for {mac}: {e}")
|
|
|
|
# ------------------------------------------------------------- last times
|
|
|
|
def _get_last_execution_time(self, mac: str, action_name: str) -> Optional[datetime]:
|
|
"""Get last execution time (DB read only; naive UTC)."""
|
|
row = self.db.query(
|
|
"""
|
|
SELECT completed_at FROM action_queue
|
|
WHERE mac_address=? AND action_name=? AND status IN ('success','failed')
|
|
ORDER BY completed_at DESC
|
|
LIMIT 1
|
|
""",
|
|
(mac, action_name),
|
|
)
|
|
if row and row[0].get("completed_at"):
|
|
try:
|
|
return datetime.fromisoformat(row[0]["completed_at"])
|
|
except Exception:
|
|
return None
|
|
return None
|
|
|
|
def _get_last_global_execution_time(self, action_name: str) -> Optional[datetime]:
|
|
"""Get last global action execution time (naive UTC)."""
|
|
row = self.db.query(
|
|
"""
|
|
SELECT completed_at FROM action_queue
|
|
WHERE action_name=? AND status IN ('success','failed')
|
|
ORDER BY completed_at DESC
|
|
LIMIT 1
|
|
""",
|
|
(action_name,),
|
|
)
|
|
if row and row[0].get("completed_at"):
|
|
try:
|
|
return datetime.fromisoformat(row[0]["completed_at"])
|
|
except Exception:
|
|
return None
|
|
return None
|
|
|
|
# ------------------------------------------------------------- constraints
|
|
|
|
def _check_rate_limit(self, mac: str, action_name: str, rate_limit: str) -> bool:
|
|
"""
|
|
Check "X/SECONDS" rate-limit (count based on created_at).
|
|
Returns True if action is allowed to queue.
|
|
"""
|
|
try:
|
|
max_count, period = rate_limit.split("/")
|
|
max_count = int(max_count)
|
|
period = int(period)
|
|
since = _db_ts(_utcnow() - timedelta(seconds=period))
|
|
|
|
count = self.db.query(
|
|
"""
|
|
SELECT COUNT(*) AS c FROM action_queue
|
|
WHERE mac_address=? AND action_name=? AND created_at >= ?
|
|
""",
|
|
(mac, action_name, since),
|
|
)[0]["c"]
|
|
|
|
return int(count) < max_count
|
|
except Exception:
|
|
# Invalid format -> do not block
|
|
return True
|
|
|
|
# -------------------------------------------------------------- maintenance
|
|
|
|
def cleanup_queue(self):
|
|
"""Clean up queue: timeouts, retries, purge old entries."""
|
|
try:
|
|
now_iso = _utcnow_str()
|
|
|
|
# 1) Expire pending actions
|
|
self.db.execute(
|
|
"""
|
|
UPDATE action_queue
|
|
SET status='failed',
|
|
completed_at=CURRENT_TIMESTAMP,
|
|
error_message=COALESCE(error_message,'expired')
|
|
WHERE status='pending'
|
|
AND expires_at IS NOT NULL
|
|
AND expires_at < ?
|
|
""",
|
|
(now_iso,),
|
|
)
|
|
|
|
# 2) Timeout running actions
|
|
self.db.execute(
|
|
"""
|
|
UPDATE action_queue
|
|
SET status='failed',
|
|
completed_at=CURRENT_TIMESTAMP,
|
|
error_message=COALESCE(error_message,'timeout')
|
|
WHERE status='running'
|
|
AND started_at IS NOT NULL
|
|
AND datetime(started_at, '+' || COALESCE(
|
|
CAST(json_extract(metadata, '$.timeout') AS INTEGER), 900
|
|
) || ' seconds') <= datetime('now')
|
|
"""
|
|
)
|
|
|
|
# 3) Retry failed actions with exponential backoff
|
|
if bool(getattr(self.shared_data, "retry_failed_actions", True)):
|
|
# Only if no active job exists for the same (action, mac, port)
|
|
self.db.execute(
|
|
"""
|
|
UPDATE action_queue AS a
|
|
SET status='pending',
|
|
retry_count = retry_count + 1,
|
|
scheduled_for = datetime(
|
|
'now',
|
|
'+' || (
|
|
CASE
|
|
WHEN (60 * (1 << retry_count)) > 900 THEN 900
|
|
ELSE (60 * (1 << retry_count))
|
|
END
|
|
) || ' seconds'
|
|
),
|
|
error_message = NULL,
|
|
started_at = NULL,
|
|
completed_at = NULL
|
|
WHERE a.status='failed'
|
|
AND a.retry_count < a.max_retries
|
|
AND COALESCE(a.error_message,'') != 'expired'
|
|
AND NOT EXISTS (
|
|
SELECT 1 FROM action_queue b
|
|
WHERE b.mac_address=a.mac_address
|
|
AND b.action_name=a.action_name
|
|
AND COALESCE(b.port,0)=COALESCE(a.port,0)
|
|
AND b.status IN ('scheduled','pending','running')
|
|
)
|
|
"""
|
|
)
|
|
|
|
# 4) Purge old completed entries
|
|
old_date = _db_ts(_utcnow() - timedelta(days=7))
|
|
self.db.execute(
|
|
"""
|
|
DELETE FROM action_queue
|
|
WHERE status IN ('success','failed','cancelled','expired')
|
|
AND completed_at < ?
|
|
""",
|
|
(old_date,),
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to cleanup queue: {e}")
|
|
|
|
def update_priorities(self):
|
|
"""Boost priority for actions waiting too long (anti-starvation)."""
|
|
try:
|
|
self.db.execute(
|
|
"""
|
|
UPDATE action_queue
|
|
SET priority = MIN(100, priority + 1)
|
|
WHERE status='pending'
|
|
AND julianday('now') - julianday(created_at) > 0.0417
|
|
"""
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Failed to update priorities: {e}")
|
|
|
|
|
|
# =================================================================== helpers ==
|
|
|
|
def _normalize_ports(raw) -> List[str]:
|
|
"""Normalize ports to list of strings."""
|
|
if not raw:
|
|
return []
|
|
if isinstance(raw, list):
|
|
return [str(p).split("/")[0] for p in raw if p is not None and str(p) != ""]
|
|
if isinstance(raw, int):
|
|
return [str(raw)]
|
|
if isinstance(raw, str):
|
|
s = raw.strip()
|
|
if not s:
|
|
return []
|
|
if s.startswith("[") and s.endswith("]"):
|
|
try:
|
|
arr = json.loads(s)
|
|
return [str(p).split("/")[0] for p in arr]
|
|
except Exception:
|
|
pass
|
|
if ";" in s:
|
|
return [p.strip().split("/")[0] for p in s.split(";") if p.strip()]
|
|
return [s.split("/")[0]]
|
|
return [str(raw)]
|
|
|
|
|
|
def _has_open_service(mac: str, svc: str, host: Dict[str, Any]) -> bool:
|
|
"""Check if service is open for host (port_services first, then fallback list)."""
|
|
svc = (svc or "").lower().strip()
|
|
|
|
# Check port_services table first
|
|
rows = shared_data.db.query(
|
|
"SELECT 1 FROM port_services WHERE mac_address=? AND state='open' AND LOWER(service)=? LIMIT 1",
|
|
(mac, svc),
|
|
)
|
|
if rows:
|
|
return True
|
|
|
|
# Fallback to known port numbers
|
|
ports = set(_normalize_ports(host.get("ports")))
|
|
for p in SERVICE_PORTS.get(svc, []):
|
|
if p in ports:
|
|
return True
|
|
return False
|
|
|
|
|
|
def _last_presence_event_for_mac(mac: str) -> Optional[str]:
|
|
"""Get last presence event for MAC (PresenceJoin/PresenceLeave)."""
|
|
rows = shared_data.db.query(
|
|
"""
|
|
SELECT action_name
|
|
FROM action_queue
|
|
WHERE mac_address=?
|
|
AND action_name IN ('PresenceJoin','PresenceLeave')
|
|
ORDER BY datetime(COALESCE(completed_at, started_at, scheduled_for, created_at)) DESC
|
|
LIMIT 1
|
|
""",
|
|
(mac,),
|
|
)
|
|
return rows[0]["action_name"] if rows else None
|
|
|
|
|
|
# --------------------------------------------------------------- trigger eval --
|
|
|
|
def evaluate_trigger(trigger: str, host: Dict[str, Any], action_def: Dict[str, Any]) -> bool:
|
|
"""
|
|
Evaluate trigger condition for host.
|
|
|
|
Supported triggers:
|
|
- on_start, on_host_alive, on_host_dead
|
|
- on_port_change, on_new_port:PORT
|
|
- on_service:SERVICE, on_web_service
|
|
- on_success:ACTION, on_failure:ACTION
|
|
- on_cred_found:SERVICE
|
|
- on_mac_is:MAC, on_essid_is:ESSID, on_ip_is:IP
|
|
- on_has_cve[:CVE], on_has_cpe[:CPE]
|
|
- on_all:[...], on_any:[...]
|
|
"""
|
|
try:
|
|
mac = host["mac_address"]
|
|
s = (trigger or "").strip()
|
|
if not s:
|
|
return False
|
|
|
|
# Combined triggers
|
|
if s.startswith("on_all:"):
|
|
try:
|
|
arr = json.loads(s.split(":", 1)[1])
|
|
except Exception:
|
|
return False
|
|
return all(evaluate_trigger(t, host, action_def) for t in arr)
|
|
|
|
if s.startswith("on_any:"):
|
|
try:
|
|
arr = json.loads(s.split(":", 1)[1])
|
|
except Exception:
|
|
return False
|
|
return any(evaluate_trigger(t, host, action_def) for t in arr)
|
|
|
|
# Skip interval triggers
|
|
if s.startswith("on_interval:"):
|
|
return False
|
|
|
|
# Parse trigger name and parameter
|
|
if ":" in s:
|
|
name, param = s.split(":", 1)
|
|
name = name.strip()
|
|
param = (param or "").strip()
|
|
else:
|
|
name, param = s, ""
|
|
|
|
# Aliases
|
|
if name == "on_alive":
|
|
name = "on_host_alive"
|
|
if name == "on_dead":
|
|
name = "on_host_dead"
|
|
|
|
# Join/Leave events
|
|
if name == "on_join":
|
|
if not bool(host.get("alive")):
|
|
return False
|
|
last = _last_presence_event_for_mac(mac)
|
|
return last != "PresenceJoin"
|
|
|
|
if name == "on_leave":
|
|
if bool(host.get("alive")):
|
|
return False
|
|
last = _last_presence_event_for_mac(mac)
|
|
return last != "PresenceLeave"
|
|
|
|
# Basic triggers
|
|
if name == "on_start" or name == "on_new_host":
|
|
r = shared_data.db.query(
|
|
"""
|
|
SELECT 1 FROM action_queue
|
|
WHERE mac_address=? AND action_name=?
|
|
AND status IN ('success','failed')
|
|
LIMIT 1
|
|
""",
|
|
(mac, action_def["b_class"]),
|
|
)
|
|
return not bool(r)
|
|
|
|
if name == "on_host_alive":
|
|
return bool(host.get("alive"))
|
|
|
|
if name == "on_host_dead":
|
|
return not bool(host.get("alive"))
|
|
|
|
# Skip port/service triggers for dead hosts
|
|
if not bool(host.get("alive")) and name in {"on_service", "on_web_service", "on_new_port", "on_port_change"}:
|
|
return False
|
|
|
|
# Port triggers
|
|
if name == "on_port_change":
|
|
cur = set(_normalize_ports(host.get("ports")))
|
|
prev = set(_normalize_ports(host.get("previous_ports")))
|
|
return cur != prev
|
|
|
|
if name == "on_new_port":
|
|
port = str(param)
|
|
cur = set(_normalize_ports(host.get("ports")))
|
|
prev = set(_normalize_ports(host.get("previous_ports")))
|
|
return port in cur and port not in prev
|
|
|
|
# Action status triggers
|
|
if name == "on_success":
|
|
parent = param
|
|
r = shared_data.db.query(
|
|
"""
|
|
SELECT 1 FROM action_queue
|
|
WHERE mac_address=? AND action_name=? AND status='success'
|
|
ORDER BY completed_at DESC LIMIT 1
|
|
""",
|
|
(mac, parent),
|
|
)
|
|
return bool(r)
|
|
|
|
if name == "on_failure":
|
|
parent = param
|
|
r = shared_data.db.query(
|
|
"""
|
|
SELECT 1 FROM action_queue
|
|
WHERE mac_address=? AND action_name=? AND status='failed'
|
|
ORDER BY completed_at DESC LIMIT 1
|
|
""",
|
|
(mac, parent),
|
|
)
|
|
return bool(r)
|
|
|
|
# Service triggers
|
|
if name == "on_cred_found":
|
|
service = param.lower()
|
|
r = shared_data.db.query(
|
|
"SELECT 1 FROM creds WHERE mac_address=? AND LOWER(service)=? LIMIT 1",
|
|
(mac, service),
|
|
)
|
|
return bool(r)
|
|
|
|
if name == "on_service":
|
|
return _has_open_service(mac, param, host)
|
|
|
|
if name == "on_web_service":
|
|
return _has_open_service(mac, "http", host) or _has_open_service(mac, "https", host)
|
|
|
|
# Identity triggers
|
|
if name == "on_mac_is":
|
|
return str(mac).lower() == param.lower()
|
|
|
|
if name == "on_essid_is":
|
|
return (host.get("essid") or "") == param
|
|
|
|
if name == "on_ip_is":
|
|
ips = (host.get("ips") or "").split(";") if host.get("ips") else []
|
|
return param in ips
|
|
|
|
# Vulnerability triggers
|
|
if name == "on_has_cve":
|
|
if not param:
|
|
r = shared_data.db.query(
|
|
"SELECT 1 FROM vulnerabilities WHERE mac_address=? AND is_active=1 LIMIT 1",
|
|
(mac,),
|
|
)
|
|
return bool(r)
|
|
r = shared_data.db.query(
|
|
"SELECT 1 FROM vulnerabilities WHERE mac_address=? AND vuln_id=? AND is_active=1 LIMIT 1",
|
|
(mac, param),
|
|
)
|
|
return bool(r)
|
|
|
|
if name == "on_has_cpe":
|
|
if not param:
|
|
r = shared_data.db.query(
|
|
"SELECT 1 FROM detected_software WHERE mac_address=? AND is_active=1 LIMIT 1",
|
|
(mac,),
|
|
)
|
|
return bool(r)
|
|
r = shared_data.db.query(
|
|
"SELECT 1 FROM detected_software WHERE mac_address=? AND cpe=? AND is_active=1 LIMIT 1",
|
|
(mac, param),
|
|
)
|
|
return bool(r)
|
|
|
|
# Unknown trigger
|
|
logger.debug(f"Unknown trigger: {name}")
|
|
return False
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error evaluating trigger '{trigger}': {e}")
|
|
return False
|
|
|
|
|
|
# ---------------------------------------------------------- requirements eval --
|
|
|
|
def evaluate_requirements(requires: Any, host: Dict[str, Any], action_def: Dict[str, Any]) -> bool:
|
|
"""Evaluate requirements for action."""
|
|
if requires is None:
|
|
return True
|
|
|
|
# Already an object
|
|
if isinstance(requires, (dict, list)):
|
|
return evaluate_requirements_object(requires, host, action_def)
|
|
|
|
s = str(requires).strip()
|
|
if not s:
|
|
return True
|
|
|
|
# JSON string
|
|
if (s.startswith("{") and s.endswith("}")) or (s.startswith("[") and s.endswith("]")):
|
|
try:
|
|
obj = json.loads(s)
|
|
return evaluate_requirements_object(obj, host, action_def)
|
|
except Exception:
|
|
pass
|
|
|
|
# Legacy "Action:status" format
|
|
if ":" in s:
|
|
a, st = s.split(":", 1)
|
|
obj = {"action": a.strip(), "status": st.strip()}
|
|
return evaluate_requirements_object(obj, host, action_def)
|
|
|
|
return True
|
|
|
|
|
|
def evaluate_requirements_object(req: Any, host: Dict[str, Any], action_def: Dict[str, Any]) -> bool:
|
|
"""
|
|
Evaluate requirements object.
|
|
|
|
Supported:
|
|
- {"all": [...]} / {"any": [...]} / {"not": {...}}
|
|
- {"action": "ACTION", "status": "STATUS", "scope": "host|global"}
|
|
- {"has_port": PORT}
|
|
- {"has_cred": "SERVICE"}
|
|
- {"has_cve": "CVE"}
|
|
- {"has_cpe": "CPE"}
|
|
- {"mac_is": "MAC"}
|
|
- {"essid_is": "ESSID"}
|
|
- {"service_is_open": "SERVICE"}
|
|
"""
|
|
mac = host["mac_address"]
|
|
|
|
# Combinators
|
|
if isinstance(req, dict) and "all" in req:
|
|
return all(evaluate_requirements_object(x, host, action_def) for x in (req.get("all") or []))
|
|
|
|
if isinstance(req, dict) and "any" in req:
|
|
return any(evaluate_requirements_object(x, host, action_def) for x in (req.get("any") or []))
|
|
|
|
if isinstance(req, dict) and "not" in req:
|
|
return not evaluate_requirements_object(req.get("not"), host, action_def)
|
|
|
|
# Atomic requirements
|
|
if isinstance(req, dict):
|
|
if "action" in req:
|
|
action = str(req.get("action") or "").strip()
|
|
status = str(req.get("status") or "success").strip()
|
|
scope = str(req.get("scope") or "host").strip().lower()
|
|
|
|
if scope == "global":
|
|
r = shared_data.db.query(
|
|
"""
|
|
SELECT 1 FROM action_queue
|
|
WHERE action_name=? AND status=?
|
|
ORDER BY completed_at DESC LIMIT 1
|
|
""",
|
|
(action, status),
|
|
)
|
|
return bool(r)
|
|
|
|
# Host scope
|
|
r = shared_data.db.query(
|
|
"""
|
|
SELECT 1 FROM action_queue
|
|
WHERE mac_address=? AND action_name=? AND status=?
|
|
ORDER BY completed_at DESC LIMIT 1
|
|
""",
|
|
(mac, action, status),
|
|
)
|
|
return bool(r)
|
|
|
|
if "has_port" in req:
|
|
want = str(req.get("has_port"))
|
|
return want in set(_normalize_ports(host.get("ports")))
|
|
|
|
if "has_cred" in req:
|
|
svc = str(req.get("has_cred") or "").lower()
|
|
r = shared_data.db.query(
|
|
"SELECT 1 FROM creds WHERE mac_address=? AND LOWER(service)=? LIMIT 1",
|
|
(mac, svc),
|
|
)
|
|
return bool(r)
|
|
|
|
if "has_cve" in req:
|
|
cve = str(req.get("has_cve") or "")
|
|
r = shared_data.db.query(
|
|
"SELECT 1 FROM vulnerabilities WHERE mac_address=? AND vuln_id=? AND is_active=1 LIMIT 1",
|
|
(mac, cve),
|
|
)
|
|
return bool(r)
|
|
|
|
if "has_cpe" in req:
|
|
cpe = str(req.get("has_cpe") or "")
|
|
r = shared_data.db.query(
|
|
"SELECT 1 FROM detected_software WHERE mac_address=? AND cpe=? AND is_active=1 LIMIT 1",
|
|
(mac, cpe),
|
|
)
|
|
return bool(r)
|
|
|
|
if "mac_is" in req:
|
|
return str(mac).lower() == str(req.get("mac_is") or "").lower()
|
|
|
|
if "essid_is" in req:
|
|
return (host.get("essid") or "") == str(req.get("essid_is") or "")
|
|
|
|
if "service_is_open" in req:
|
|
svc = str(req.get("service_is_open") or "").lower()
|
|
return _has_open_service(mac, svc, host)
|
|
|
|
# Default: truthy
|
|
return bool(req)
|