mirror of
https://github.com/infinition/Bjorn.git
synced 2026-03-08 05:51:59 +00:00
- Implemented methods for fetching AI stats, training history, and recent experiences. - Added functionality to set operation mode (MANUAL, AUTO, AI) with appropriate handling. - Included helper methods for querying the database and sending JSON responses. - Integrated model metadata extraction for visualization purposes.
1652 lines
62 KiB
Python
1652 lines
62 KiB
Python
# action_scheduler.py testsdd
|
|
# Smart Action Scheduler for Bjorn - queue-only implementation
|
|
# Handles trigger evaluation, requirements checking, and queue management.
|
|
#
|
|
# Invariants we enforce:
|
|
# - At most ONE "active" row per (action_name, mac_address, COALESCE(port,0))
|
|
# where active ∈ {'scheduled','pending','running'}.
|
|
# - Retries for failed entries are coordinated by cleanup_queue() (with backoff)
|
|
# and never compete with trigger-based enqueues.
|
|
#
|
|
# Runtime knobs (from shared.py):
|
|
# shared_data.retry_success_actions : bool (default False)
|
|
# shared_data.retry_failed_actions : bool (default True)
|
|
#
|
|
# These take precedence over cooldown / rate-limit for NON-interval triggers.
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import time
|
|
import threading
|
|
from datetime import datetime, timedelta, timezone
|
|
from typing import Any, Dict, List, Optional, Tuple
|
|
|
|
from init_shared import shared_data
|
|
from logger import Logger
|
|
from ai_engine import get_or_create_ai_engine
|
|
|
|
logger = Logger(name="action_scheduler.py")
|
|
|
|
# ---------- UTC helpers (match SQLite's UTC CURRENT_TIMESTAMP) ----------
|
|
def _utcnow() -> datetime:
|
|
"""Naive UTC datetime to compare with SQLite TEXT timestamps (UTC)."""
|
|
return datetime.now(timezone.utc).replace(tzinfo=None)
|
|
|
|
def _utcnow_str() -> str:
|
|
"""UTC 'YYYY-MM-DD HH:MM:SS' string to compare against SQLite TEXT."""
|
|
return _utcnow().strftime("%Y-%m-%d %H:%M:%S")
|
|
|
|
def _db_ts(dt: datetime) -> str:
|
|
"""Format any datetime as 'YYYY-MM-DD HH:MM:SS' (UTC expected)."""
|
|
return dt.strftime("%Y-%m-%d %H:%M:%S")
|
|
|
|
# Service → fallback ports (used when port_services table has nothing for a host)
|
|
SERVICE_PORTS: Dict[str, List[str]] = {
|
|
"ssh": ["22"],
|
|
"http": ["80", "8080"],
|
|
"https": ["443"],
|
|
"smb": ["445"],
|
|
"ftp": ["21"],
|
|
"telnet": ["23"],
|
|
"mysql": ["3306"],
|
|
"mssql": ["1433"],
|
|
"postgres": ["5432"],
|
|
"rdp": ["3389"],
|
|
"vnc": ["5900"],
|
|
}
|
|
|
|
|
|
class ActionScheduler:
|
|
"""
|
|
Smart scheduler that evaluates triggers and enqueues actions.
|
|
Does NOT execute actions - that's the orchestrator's job.
|
|
"""
|
|
|
|
def __init__(self, shared_data_):
|
|
self.shared_data = shared_data_
|
|
self.db = shared_data_.db
|
|
|
|
# Controller MAC for global actions
|
|
self.ctrl_mac = (self.shared_data.get_raspberry_mac() or "__GLOBAL__").lower()
|
|
self._ensure_host_exists(self.ctrl_mac)
|
|
|
|
# Runtime flags
|
|
self.running = True
|
|
self.check_interval = 5 # seconds between iterations
|
|
self._stop_event = threading.Event()
|
|
self._error_backoff = 1.0
|
|
|
|
# Action definition cache
|
|
self._action_definitions: Dict[str, Dict[str, Any]] = {}
|
|
self._last_cache_refresh = 0.0
|
|
self._cache_ttl = 60.0 # seconds
|
|
|
|
# Memory for global actions
|
|
self._last_global_runs: Dict[str, float] = {}
|
|
# Actions Studio last source type
|
|
self._last_source_is_studio: Optional[bool] = None
|
|
# Enforce DB invariants (idempotent)
|
|
self._ensure_db_invariants()
|
|
|
|
# Throttling for priorities
|
|
self._last_priority_update = 0.0
|
|
self._priority_update_interval = 60.0 # seconds
|
|
|
|
# Initialize AI engine for recommendations ONLY in AI mode.
|
|
# Uses singleton so model weights are loaded only once across the process.
|
|
self.ai_engine = None
|
|
if self.shared_data.operation_mode == "AI":
|
|
self.ai_engine = get_or_create_ai_engine(self.shared_data)
|
|
if self.ai_engine is None:
|
|
logger.info_throttled(
|
|
"AI engine unavailable in scheduler; continuing heuristic-only",
|
|
key="scheduler_ai_init_failed",
|
|
interval_s=300.0,
|
|
)
|
|
|
|
logger.info("ActionScheduler initialized")
|
|
|
|
# --------------------------------------------------------------------- loop
|
|
|
|
def run(self):
|
|
"""Main scheduler loop."""
|
|
logger.info("ActionScheduler starting main loop")
|
|
while self.running and not self.shared_data.orchestrator_should_exit:
|
|
try:
|
|
# If the user toggles AI mode at runtime, enable/disable AI engine without restart.
|
|
if self.shared_data.operation_mode == "AI" and self.ai_engine is None:
|
|
self.ai_engine = get_or_create_ai_engine(self.shared_data)
|
|
if self.ai_engine:
|
|
logger.info("Scheduler: AI engine enabled (singleton)")
|
|
else:
|
|
logger.info_throttled(
|
|
"Scheduler: AI engine unavailable; staying heuristic-only",
|
|
key="scheduler_ai_enable_failed",
|
|
interval_s=300.0,
|
|
)
|
|
elif self.shared_data.operation_mode != "AI" and self.ai_engine is not None:
|
|
self.ai_engine = None
|
|
|
|
# Refresh action cache if needed
|
|
self._refresh_cache_if_needed()
|
|
# Keep queue consistent with current enable/disable flags.
|
|
self._cancel_queued_disabled_actions()
|
|
|
|
# 1) Promote scheduled actions that are due
|
|
self._promote_scheduled_to_pending()
|
|
|
|
# 2) Publish next scheduled occurrences for interval actions
|
|
self._publish_all_upcoming()
|
|
|
|
# 3) Evaluate global on_start actions
|
|
self._evaluate_global_actions()
|
|
|
|
# 4) Evaluate per-host triggers
|
|
self.evaluate_all_triggers()
|
|
|
|
# 5) Queue maintenance
|
|
self.cleanup_queue()
|
|
self.update_priorities()
|
|
|
|
self._error_backoff = 1.0
|
|
if self._stop_event.wait(self.check_interval):
|
|
break
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in scheduler loop: {e}")
|
|
if self._stop_event.wait(self._error_backoff):
|
|
break
|
|
self._error_backoff = min(self._error_backoff * 2.0, 15.0)
|
|
|
|
logger.info("ActionScheduler stopped")
|
|
|
|
# ----------------------------------------------------------------- priorities
|
|
|
|
def update_priorities(self):
|
|
"""
|
|
Update priorities of pending actions.
|
|
1. Increase priority over time (starvation prevention) with MIN(100) cap.
|
|
2. [AI Mode] Boost priority of actions recommended by AI engine.
|
|
"""
|
|
now = time.time()
|
|
if now - self._last_priority_update < self._priority_update_interval:
|
|
return
|
|
|
|
try:
|
|
# 1. Anti-starvation aging: +1 per minute for actions waiting >1 hour.
|
|
# julianday is portable across all SQLite builds.
|
|
# MIN(100) cap prevents unbounded priority inflation.
|
|
affected = self.db.execute(
|
|
"""
|
|
UPDATE action_queue
|
|
SET priority = MIN(100, priority + 1)
|
|
WHERE status='pending'
|
|
AND julianday('now') - julianday(created_at) > 0.0417
|
|
"""
|
|
)
|
|
|
|
self._last_priority_update = now
|
|
|
|
if affected and affected > 0:
|
|
logger.debug(f"Aged {affected} pending actions in queue")
|
|
|
|
# 2. AI Recommendation Boost
|
|
if self.shared_data.operation_mode == "AI" and self.ai_engine:
|
|
self._apply_ai_priority_boost()
|
|
elif self.shared_data.operation_mode == "AI" and not self.ai_engine:
|
|
logger.warning("Operation mode is AI, but ai_engine is not initialized!")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to update priorities: {e}")
|
|
|
|
def _apply_ai_priority_boost(self):
|
|
"""Boost priority of actions recommended by AI engine."""
|
|
try:
|
|
if not self.ai_engine:
|
|
logger.warning("AI Boost skipped: ai_engine is None")
|
|
return
|
|
|
|
# Get list of unique hosts with pending actions
|
|
hosts = self.db.query("""
|
|
SELECT DISTINCT mac_address FROM action_queue
|
|
WHERE status='pending'
|
|
""")
|
|
|
|
if not hosts:
|
|
return
|
|
|
|
for row in hosts:
|
|
mac = row['mac_address']
|
|
if not mac:
|
|
continue
|
|
|
|
# Get available actions for this host
|
|
available = [
|
|
r['action_name'] for r in self.db.query("""
|
|
SELECT DISTINCT action_name FROM action_queue
|
|
WHERE mac_address=? AND status='pending'
|
|
""", (mac,))
|
|
]
|
|
|
|
if not available:
|
|
continue
|
|
|
|
# Get host context
|
|
host_data = self.db.get_host_by_mac(mac)
|
|
if not host_data:
|
|
continue
|
|
|
|
context = {
|
|
'mac': mac,
|
|
'hostname': (host_data.get('hostnames') or '').split(';')[0],
|
|
'ports': [
|
|
int(p) for p in (host_data.get('ports') or '').split(';')
|
|
if p.isdigit()
|
|
]
|
|
}
|
|
|
|
# Ask AI for recommendation
|
|
recommended_action, confidence, debug = self.ai_engine.choose_action(
|
|
host_context=context,
|
|
available_actions=available,
|
|
exploration_rate=0.0 # No exploration in scheduler
|
|
)
|
|
if not isinstance(debug, dict):
|
|
debug = {}
|
|
|
|
threshold = self._get_ai_confirm_threshold()
|
|
if recommended_action and confidence >= threshold: # Only boost if confident
|
|
# Boost recommended action
|
|
boost_amount = int(20 * confidence) # Scale boost by confidence
|
|
|
|
affected = self.db.execute("""
|
|
UPDATE action_queue
|
|
SET priority = priority + ?
|
|
WHERE mac_address=? AND action_name=? AND status='pending'
|
|
""", (boost_amount, mac, recommended_action))
|
|
|
|
if affected and affected > 0:
|
|
# NEW: Update metadata to reflect AI influence
|
|
try:
|
|
# We fetch all matching IDs to update their metadata
|
|
rows = self.db.query("""
|
|
SELECT id, metadata FROM action_queue
|
|
WHERE mac_address=? AND action_name=? AND status='pending'
|
|
""", (mac, recommended_action))
|
|
|
|
for row in rows:
|
|
meta = json.loads(row['metadata'] or '{}')
|
|
meta['decision_method'] = f"ai_boosted ({debug.get('method', 'unknown')})"
|
|
meta['decision_origin'] = "ai_boosted"
|
|
meta['decision_scope'] = "priority_boost"
|
|
meta['ai_confidence'] = confidence
|
|
meta['ai_threshold'] = threshold
|
|
meta['ai_method'] = str(debug.get('method', 'unknown'))
|
|
meta['ai_recommended_action'] = recommended_action
|
|
meta['ai_model_loaded'] = bool(getattr(self.ai_engine, "model_loaded", False))
|
|
meta['ai_reason'] = "priority_boost_applied"
|
|
meta['ai_debug'] = debug # Includes all_scores and input_vector
|
|
self.db.execute("UPDATE action_queue SET metadata=? WHERE id=?",
|
|
(json.dumps(meta), row['id']))
|
|
except Exception as meta_e:
|
|
logger.error(f"Failed to update metadata for AI boost: {meta_e}")
|
|
|
|
logger.info(
|
|
f"[AI_BOOST] action={recommended_action} mac={mac} boost={boost_amount} "
|
|
f"conf={float(confidence):.2f} thr={float(threshold):.2f} "
|
|
f"method={debug.get('method', 'unknown')}"
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error applying AI priority boost: {e}")
|
|
|
|
def stop(self):
|
|
"""Stop the scheduler."""
|
|
logger.info("Stopping ActionScheduler...")
|
|
self.running = False
|
|
self._stop_event.set()
|
|
|
|
# --------------------------------------------------------------- definitions
|
|
|
|
def _get_ai_confirm_threshold(self) -> float:
|
|
"""Return normalized AI confirmation threshold in [0.0, 1.0]."""
|
|
try:
|
|
raw = float(getattr(self.shared_data, "ai_confirm_threshold", 0.3))
|
|
except Exception:
|
|
raw = 0.3
|
|
return max(0.0, min(1.0, raw))
|
|
|
|
def _annotate_decision_metadata(
|
|
self,
|
|
metadata: Dict[str, Any],
|
|
action_name: str,
|
|
context: Dict[str, Any],
|
|
decision_scope: str,
|
|
) -> None:
|
|
"""
|
|
Fill metadata with a consistent decision trace:
|
|
decision_method/origin + AI method/confidence/threshold/reason.
|
|
"""
|
|
metadata.setdefault("decision_method", "heuristic")
|
|
metadata.setdefault("decision_origin", "heuristic")
|
|
metadata["decision_scope"] = decision_scope
|
|
|
|
threshold = self._get_ai_confirm_threshold()
|
|
metadata["ai_threshold"] = threshold
|
|
|
|
if self.shared_data.operation_mode != "AI":
|
|
metadata["ai_reason"] = "ai_mode_disabled"
|
|
return
|
|
|
|
if not self.ai_engine:
|
|
metadata["ai_reason"] = "ai_engine_unavailable"
|
|
return
|
|
|
|
try:
|
|
recommended, confidence, debug = self.ai_engine.choose_action(
|
|
host_context=context,
|
|
available_actions=[action_name],
|
|
exploration_rate=0.0,
|
|
)
|
|
|
|
ai_method = str((debug or {}).get("method", "unknown"))
|
|
confidence_f = float(confidence or 0.0)
|
|
model_loaded = bool(getattr(self.ai_engine, "model_loaded", False))
|
|
|
|
metadata["ai_method"] = ai_method
|
|
metadata["ai_confidence"] = confidence_f
|
|
metadata["ai_recommended_action"] = recommended or ""
|
|
metadata["ai_model_loaded"] = model_loaded
|
|
|
|
if recommended == action_name and confidence_f >= threshold:
|
|
metadata["decision_method"] = f"ai_confirmed ({ai_method})"
|
|
metadata["decision_origin"] = "ai_confirmed"
|
|
metadata["ai_reason"] = "recommended_above_threshold"
|
|
elif recommended != action_name:
|
|
metadata["decision_origin"] = "heuristic"
|
|
metadata["ai_reason"] = "recommended_different_action"
|
|
else:
|
|
metadata["decision_origin"] = "heuristic"
|
|
metadata["ai_reason"] = "confidence_below_threshold"
|
|
|
|
except Exception as e:
|
|
metadata["ai_reason"] = "ai_check_failed"
|
|
logger.debug(f"AI decision annotation failed for {action_name}: {e}")
|
|
|
|
def _log_queue_decision(
|
|
self,
|
|
action_name: str,
|
|
mac: str,
|
|
metadata: Dict[str, Any],
|
|
target_port: Optional[int] = None,
|
|
target_service: Optional[str] = None,
|
|
) -> None:
|
|
"""Emit a compact, explicit queue-decision log line."""
|
|
decision = str(metadata.get("decision_method", "heuristic"))
|
|
origin = str(metadata.get("decision_origin", "heuristic"))
|
|
ai_method = str(metadata.get("ai_method", "n/a"))
|
|
ai_reason = str(metadata.get("ai_reason", "n/a"))
|
|
ai_conf = metadata.get("ai_confidence")
|
|
ai_thr = metadata.get("ai_threshold")
|
|
scope = str(metadata.get("decision_scope", "unknown"))
|
|
|
|
conf_txt = f"{float(ai_conf):.2f}" if isinstance(ai_conf, (int, float)) else "n/a"
|
|
thr_txt = f"{float(ai_thr):.2f}" if isinstance(ai_thr, (int, float)) else "n/a"
|
|
model_loaded = bool(metadata.get("ai_model_loaded", False))
|
|
port_txt = "None" if target_port is None else str(target_port)
|
|
svc_txt = target_service if target_service else "None"
|
|
|
|
logger.info(
|
|
f"[QUEUE_DECISION] scope={scope} action={action_name} mac={mac} port={port_txt} service={svc_txt} "
|
|
f"decision={decision} origin={origin} ai_method={ai_method} conf={conf_txt} thr={thr_txt} "
|
|
f"model_loaded={model_loaded} reason={ai_reason}"
|
|
)
|
|
|
|
# ---------- replace this method ----------
|
|
def _refresh_cache_if_needed(self):
|
|
"""Refresh action definitions cache if expired or source flipped."""
|
|
now = time.time()
|
|
use_studio = bool(getattr(self.shared_data, "use_actions_studio", False))
|
|
|
|
# Refresh if TTL expired or the source changed (actions ↔ studio_actions)
|
|
if (now - self._last_cache_refresh > self._cache_ttl) or (self._last_source_is_studio != use_studio):
|
|
self._refresh_action_cache(use_studio=use_studio)
|
|
self._last_cache_refresh = now
|
|
self._last_source_is_studio = use_studio
|
|
|
|
|
|
# ---------- replace this method ----------
|
|
def _refresh_action_cache(self, use_studio: Optional[bool] = None):
|
|
"""Reload action definitions from database, from 'actions' or 'studio' view."""
|
|
if use_studio is None:
|
|
use_studio = bool(getattr(self.shared_data, "use_actions_studio", False))
|
|
|
|
try:
|
|
if use_studio:
|
|
# Primary: studio
|
|
actions = self.db.list_studio_actions()
|
|
source = "studio"
|
|
else:
|
|
# Primary: plain actions
|
|
actions = self.db.list_actions()
|
|
source = "actions"
|
|
|
|
# Build cache (expect same action schema: b_class, b_trigger, b_action, etc.)
|
|
self._action_definitions = {a["b_class"]: a for a in actions}
|
|
# Runtime truth: orchestrator loads from `actions`, so align b_enabled to it
|
|
# even when scheduler uses `actions_studio` as source.
|
|
self._overlay_runtime_enabled_flags()
|
|
logger.info(f"Refreshed action cache from '{source}': {len(self._action_definitions)} actions")
|
|
|
|
except AttributeError as e:
|
|
# Fallback if the chosen method isn't available on the DB adapter
|
|
if use_studio and hasattr(self.db, "list_actions"):
|
|
logger.warning(f"DB has no list_studio_actions(); falling back to list_actions(): {e}")
|
|
try:
|
|
actions = self.db.list_actions()
|
|
self._action_definitions = {a["b_class"]: a for a in actions}
|
|
logger.info(f"Refreshed action cache from 'actions' (fallback): {len(self._action_definitions)} actions")
|
|
return
|
|
except Exception as ee:
|
|
logger.error(f"Fallback list_actions() failed: {ee}")
|
|
else:
|
|
logger.error(f"Action cache refresh failed (no suitable DB method): {e}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to refresh action cache: {e}")
|
|
|
|
def _is_action_enabled(self, action_def: Dict[str, Any]) -> bool:
|
|
"""Parse b_enabled robustly across int/bool/string/null values."""
|
|
raw = action_def.get("b_enabled", 1)
|
|
if raw is None:
|
|
return True
|
|
if isinstance(raw, bool):
|
|
return raw
|
|
if isinstance(raw, (int, float)):
|
|
return int(raw) == 1
|
|
s = str(raw).strip().lower()
|
|
if s in {"1", "true", "yes", "on"}:
|
|
return True
|
|
if s in {"0", "false", "no", "off"}:
|
|
return False
|
|
try:
|
|
return int(float(s)) == 1
|
|
except Exception:
|
|
# Conservative default: keep action enabled when value is malformed.
|
|
return True
|
|
|
|
def _overlay_runtime_enabled_flags(self):
|
|
"""
|
|
Override cached `b_enabled` with runtime `actions` table values.
|
|
This keeps scheduler decisions aligned with orchestrator loaded actions.
|
|
"""
|
|
try:
|
|
runtime_rows = self.db.list_actions()
|
|
runtime_map = {r.get("b_class"): r.get("b_enabled", 1) for r in runtime_rows}
|
|
for action_name, action_def in self._action_definitions.items():
|
|
if action_name in runtime_map:
|
|
action_def["b_enabled"] = runtime_map[action_name]
|
|
except Exception as e:
|
|
logger.warning(f"Could not overlay runtime b_enabled flags: {e}")
|
|
|
|
def _cancel_queued_disabled_actions(self):
|
|
"""Cancel pending/scheduled queue entries for currently disabled actions."""
|
|
try:
|
|
disabled = [
|
|
name for name, definition in self._action_definitions.items()
|
|
if not self._is_action_enabled(definition)
|
|
]
|
|
if not disabled:
|
|
return
|
|
|
|
placeholders = ",".join("?" for _ in disabled)
|
|
affected = self.db.execute(
|
|
f"""
|
|
UPDATE action_queue
|
|
SET status='cancelled',
|
|
completed_at=CURRENT_TIMESTAMP,
|
|
error_message=COALESCE(error_message, 'disabled_by_config')
|
|
WHERE status IN ('scheduled','pending')
|
|
AND action_name IN ({placeholders})
|
|
""",
|
|
tuple(disabled),
|
|
)
|
|
if affected:
|
|
logger.info(f"Cancelled {affected} queued action(s) because b_enabled=0")
|
|
except Exception as e:
|
|
logger.error(f"Failed to cancel queued disabled actions: {e}")
|
|
|
|
|
|
# ------------------------------------------------------------------ helpers
|
|
|
|
def _ensure_db_invariants(self):
|
|
"""
|
|
Create a partial UNIQUE index that forbids more than one active entry
|
|
for the same (action_name, mac_address, COALESCE(port,0)).
|
|
"""
|
|
try:
|
|
self.db.execute(
|
|
"""
|
|
CREATE UNIQUE INDEX IF NOT EXISTS uq_action_active
|
|
ON action_queue(action_name, mac_address, COALESCE(port,0))
|
|
WHERE status IN ('scheduled','pending','running')
|
|
"""
|
|
)
|
|
except Exception as e:
|
|
# If the SQLite build does not support partial/expression indexes,
|
|
# we still have app-level guards (NOT EXISTS inserts). But this
|
|
# index is recommended to make the invariant bulletproof.
|
|
logger.warning(f"Could not create unique partial index (fallback to app-level guards): {e}")
|
|
|
|
def _promote_scheduled_to_pending(self):
|
|
"""Promote due scheduled actions to pending status."""
|
|
try:
|
|
promoted = self.db.promote_due_scheduled_to_pending()
|
|
if promoted:
|
|
logger.debug(f"Promoted {promoted} scheduled action(s) to pending")
|
|
except Exception as e:
|
|
logger.error(f"Failed to promote scheduled actions: {e}")
|
|
|
|
def _ensure_host_exists(self, mac: str):
|
|
"""Ensure host exists in database (idempotent)."""
|
|
if not mac:
|
|
return
|
|
try:
|
|
self.db.execute(
|
|
"""
|
|
INSERT INTO hosts (mac_address, alive, updated_at)
|
|
VALUES (?, 0, CURRENT_TIMESTAMP)
|
|
ON CONFLICT(mac_address) DO UPDATE SET
|
|
updated_at = CURRENT_TIMESTAMP
|
|
""",
|
|
(mac,),
|
|
)
|
|
except Exception:
|
|
pass
|
|
|
|
# ---------------------------------------------------------- interval logic
|
|
|
|
def _parse_interval_seconds(self, trigger: str) -> int:
|
|
"""Parse interval from trigger string 'on_interval:SECONDS'."""
|
|
if not trigger or not trigger.startswith("on_interval:"):
|
|
return 0
|
|
try:
|
|
return max(0, int(trigger.split(":", 1)[1] or 0))
|
|
except Exception:
|
|
return 0
|
|
|
|
def _publish_all_upcoming(self):
|
|
"""
|
|
Publish next scheduled occurrence for all interval actions.
|
|
|
|
NOTE: By design, the runtime flags do not cancel interval publishing.
|
|
"""
|
|
# Global interval actions
|
|
for action in self._action_definitions.values():
|
|
if (action.get("b_action") or "normal") != "global":
|
|
continue
|
|
if not self._is_action_enabled(action):
|
|
continue
|
|
|
|
trigger = (action.get("b_trigger") or "").strip()
|
|
interval = self._parse_interval_seconds(trigger)
|
|
if interval <= 0:
|
|
continue
|
|
|
|
self._publish_next_schedule_for_global(action, interval)
|
|
|
|
# Per-host interval actions
|
|
try:
|
|
hosts = self.db.get_all_hosts()
|
|
except Exception:
|
|
hosts = []
|
|
|
|
for host in hosts:
|
|
if not host.get("alive"):
|
|
continue
|
|
|
|
mac = host.get("mac_address") or ""
|
|
if not mac:
|
|
continue
|
|
|
|
for action in self._action_definitions.values():
|
|
if (action.get("b_action") or "normal") == "global":
|
|
continue
|
|
if not self._is_action_enabled(action):
|
|
continue
|
|
|
|
trigger = (action.get("b_trigger") or "").strip()
|
|
interval = self._parse_interval_seconds(trigger)
|
|
if interval <= 0:
|
|
continue
|
|
|
|
self._publish_next_schedule_for_host(host, action, interval)
|
|
|
|
def _publish_next_schedule_for_global(self, action_def: Dict[str, Any], interval: int):
|
|
"""Publish next scheduled occurrence for a global action."""
|
|
try:
|
|
action_name = action_def["b_class"]
|
|
mac = self.ctrl_mac
|
|
|
|
# Already active?
|
|
active = self.db.query(
|
|
"""
|
|
SELECT 1 FROM action_queue
|
|
WHERE action_name=? AND mac_address=?
|
|
AND status IN ('scheduled','pending','running')
|
|
LIMIT 1
|
|
""",
|
|
(action_name, mac),
|
|
)
|
|
if active:
|
|
return
|
|
|
|
# Next occurrence immediately after last completion, else now (UTC)
|
|
last = self._get_last_global_execution_time(action_name)
|
|
next_run = _utcnow() if not last else (last + timedelta(seconds=interval))
|
|
scheduled_for = _db_ts(next_run)
|
|
|
|
metadata = {
|
|
"interval": interval,
|
|
"is_global": True,
|
|
"decision_method": "heuristic",
|
|
"decision_origin": "heuristic",
|
|
}
|
|
self._annotate_decision_metadata(
|
|
metadata=metadata,
|
|
action_name=action_name,
|
|
context={"mac": mac, "hostname": "Bjorn-C2", "ports": []},
|
|
decision_scope="scheduled_global",
|
|
)
|
|
|
|
inserted = self.db.ensure_scheduled_occurrence(
|
|
action_name=action_name,
|
|
next_run_at=scheduled_for,
|
|
mac=mac,
|
|
ip="0.0.0.0",
|
|
priority=int(action_def.get("b_priority", 40) or 40),
|
|
trigger="scheduler",
|
|
tags=action_def.get("b_tags", []),
|
|
metadata=metadata,
|
|
max_retries=int(action_def.get("b_max_retries", 3) or 3),
|
|
)
|
|
if inserted:
|
|
logger.debug(f"Scheduled global '{action_name}' at {scheduled_for}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to publish global schedule: {e}")
|
|
|
|
def _publish_next_schedule_for_host(self, host: Dict[str, Any], action_def: Dict[str, Any], interval: int):
|
|
"""Publish next scheduled occurrence for a per-host action."""
|
|
try:
|
|
mac = host.get("mac_address") or ""
|
|
if not mac:
|
|
return
|
|
|
|
self._ensure_host_exists(mac)
|
|
action_name = action_def["b_class"]
|
|
|
|
# Already active?
|
|
active = self.db.query(
|
|
"""
|
|
SELECT 1 FROM action_queue
|
|
WHERE action_name=? AND mac_address=?
|
|
AND status IN ('scheduled','pending','running')
|
|
LIMIT 1
|
|
""",
|
|
(action_name, mac),
|
|
)
|
|
if active:
|
|
return
|
|
|
|
# Next occurrence immediately after last completion, else now (UTC)
|
|
last = self._get_last_execution_time(mac, action_name)
|
|
next_run = _utcnow() if not last else (last + timedelta(seconds=interval))
|
|
scheduled_for = _db_ts(next_run)
|
|
|
|
metadata = {
|
|
"interval": interval,
|
|
"is_global": False,
|
|
"decision_method": "heuristic",
|
|
"decision_origin": "heuristic",
|
|
}
|
|
self._annotate_decision_metadata(
|
|
metadata=metadata,
|
|
action_name=action_name,
|
|
context={
|
|
"mac": mac,
|
|
"hostname": (host.get("hostnames") or "").split(";")[0],
|
|
"ports": [int(p) for p in (host.get("ports") or "").split(";") if p.isdigit()],
|
|
},
|
|
decision_scope="scheduled_host",
|
|
)
|
|
|
|
inserted = self.db.ensure_scheduled_occurrence(
|
|
action_name=action_name,
|
|
next_run_at=scheduled_for,
|
|
mac=mac,
|
|
ip=(host.get("ips") or "").split(";")[0] if host.get("ips") else "",
|
|
priority=int(action_def.get("b_priority", 40) or 40),
|
|
trigger="scheduler",
|
|
tags=action_def.get("b_tags", []),
|
|
metadata=metadata,
|
|
max_retries=int(action_def.get("b_max_retries", 3) or 3),
|
|
)
|
|
if inserted:
|
|
logger.debug(f"Scheduled '{action_name}' for {mac} at {scheduled_for}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to publish host schedule: {e}")
|
|
|
|
# ------------------------------------------------------------ global start
|
|
|
|
def _evaluate_global_actions(self):
|
|
"""Evaluate and queue global actions with on_start trigger."""
|
|
self._globals_lock = getattr(self, "_globals_lock", threading.Lock())
|
|
|
|
with self._globals_lock:
|
|
try:
|
|
for action in self._action_definitions.values():
|
|
if (action.get("b_action") or "normal") != "global":
|
|
continue
|
|
if not self._is_action_enabled(action):
|
|
continue
|
|
|
|
trigger = (action.get("b_trigger") or "").strip()
|
|
if trigger != "on_start":
|
|
continue
|
|
|
|
action_name = action["b_class"]
|
|
|
|
# Already executed at least once?
|
|
last = self._get_last_global_execution_time(action_name)
|
|
if last is not None:
|
|
continue
|
|
|
|
# Already queued?
|
|
existing = self.db.query(
|
|
"""
|
|
SELECT 1 FROM action_queue
|
|
WHERE action_name=? AND status IN ('scheduled','pending','running')
|
|
LIMIT 1
|
|
""",
|
|
(action_name,),
|
|
)
|
|
if existing:
|
|
continue
|
|
|
|
# Queue the action
|
|
if self._queue_global_action(action):
|
|
self._last_global_runs[action_name] = time.time()
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error evaluating global actions: {e}")
|
|
|
|
def _queue_global_action(self, action_def: Dict[str, Any]) -> bool:
|
|
"""Queue a global action for execution (idempotent insert)."""
|
|
action_name = action_def["b_class"]
|
|
mac = self.ctrl_mac
|
|
ip = "0.0.0.0"
|
|
timeout = int(action_def.get("b_timeout", 300) or 300)
|
|
expires_at = _db_ts(_utcnow() + timedelta(seconds=timeout))
|
|
|
|
metadata = {
|
|
"trigger": action_def.get("b_trigger", ""),
|
|
"requirements": action_def.get("b_requires", ""),
|
|
"timeout": timeout,
|
|
"is_global": True,
|
|
"decision_method": "heuristic",
|
|
"decision_origin": "heuristic",
|
|
}
|
|
|
|
# Global context (controller itself)
|
|
context = {
|
|
"mac": mac,
|
|
"hostname": "Bjorn-C2",
|
|
"ports": [] # Global actions usually don't target specific ports on controller
|
|
}
|
|
self._annotate_decision_metadata(
|
|
metadata=metadata,
|
|
action_name=action_name,
|
|
context=context,
|
|
decision_scope="queue_global",
|
|
)
|
|
ai_conf = metadata.get("ai_confidence")
|
|
if isinstance(ai_conf, (int, float)) and metadata.get("decision_origin") == "ai_confirmed":
|
|
action_def["b_priority"] = int(action_def.get("b_priority", 50) or 50) + int(20 * float(ai_conf))
|
|
|
|
try:
|
|
self._ensure_host_exists(mac)
|
|
# Guard with NOT EXISTS to avoid races
|
|
affected = self.db.execute(
|
|
"""
|
|
INSERT INTO action_queue (
|
|
action_name, mac_address, ip, port, hostname, service,
|
|
priority, status, max_retries, expires_at,
|
|
trigger_source, tags, metadata
|
|
)
|
|
SELECT ?, ?, ?, NULL, NULL, NULL,
|
|
?, 'pending', ?, ?, ?, ?, ?
|
|
WHERE NOT EXISTS (
|
|
SELECT 1 FROM action_queue
|
|
WHERE action_name=? AND mac_address=? AND COALESCE(port,0)=0
|
|
AND status IN ('scheduled','pending','running')
|
|
)
|
|
""",
|
|
(
|
|
action_name,
|
|
mac,
|
|
ip,
|
|
int(action_def.get("b_priority", 50) or 50),
|
|
int(action_def.get("b_max_retries", 3) or 3),
|
|
expires_at,
|
|
action_def.get("b_trigger", ""),
|
|
json.dumps(action_def.get("b_tags", [])),
|
|
json.dumps(metadata),
|
|
action_name,
|
|
mac,
|
|
),
|
|
)
|
|
if affected and affected > 0:
|
|
self._log_queue_decision(action_name=action_name, mac=mac, metadata=metadata)
|
|
return True
|
|
return False
|
|
except Exception as e:
|
|
logger.error(f"Failed to queue global action {action_name}: {e}")
|
|
return False
|
|
|
|
# ------------------------------------------------------------- host path
|
|
|
|
def evaluate_all_triggers(self):
|
|
"""Evaluate triggers for all hosts."""
|
|
hosts = self.db.get_all_hosts() # include dead hosts for on_leave trigger
|
|
for host in hosts:
|
|
mac = host["mac_address"]
|
|
|
|
for action_name, action_def in self._action_definitions.items():
|
|
# Skip global actions
|
|
if (action_def.get("b_action") or "normal") == "global":
|
|
continue
|
|
|
|
# Skip disabled actions
|
|
if not self._is_action_enabled(action_def):
|
|
continue
|
|
|
|
trigger = (action_def.get("b_trigger") or "").strip()
|
|
if not trigger:
|
|
continue
|
|
|
|
# Skip interval triggers (handled elsewhere)
|
|
if trigger.startswith("on_interval:"):
|
|
continue
|
|
|
|
# Evaluate trigger
|
|
if not evaluate_trigger(trigger, host, action_def):
|
|
continue
|
|
|
|
# Evaluate requirements
|
|
requires = action_def.get("b_requires", "")
|
|
if requires and not evaluate_requirements(requires, host, action_def):
|
|
continue
|
|
|
|
# Resolve target port/service
|
|
target_port, target_service = self._resolve_target_port_service(mac, host, action_def)
|
|
|
|
# Decide if we should enqueue
|
|
if not self._should_queue_action(mac, action_name, action_def, target_port):
|
|
continue
|
|
|
|
# Queue the action
|
|
self._queue_action(host, action_def, target_port, target_service)
|
|
|
|
def _resolve_target_port_service(
|
|
self, mac: str, host: Dict[str, Any], action_def: Dict[str, Any]
|
|
) -> Tuple[Optional[int], Optional[str]]:
|
|
"""Resolve target port and service for action (service wins over port when present)."""
|
|
ports = _normalize_ports(host.get("ports"))
|
|
target_port: Optional[int] = None
|
|
target_service: Optional[str] = None
|
|
|
|
# Try b_service first
|
|
if action_def.get("b_service"):
|
|
try:
|
|
services = (
|
|
json.loads(action_def["b_service"])
|
|
if isinstance(action_def["b_service"], str)
|
|
else action_def["b_service"]
|
|
)
|
|
except Exception:
|
|
services = []
|
|
|
|
if services:
|
|
for svc in services:
|
|
row = self.db.query(
|
|
"SELECT port FROM port_services "
|
|
"WHERE mac_address=? AND state='open' AND LOWER(service)=? "
|
|
"ORDER BY last_seen DESC LIMIT 1",
|
|
(mac, str(svc).lower()),
|
|
)
|
|
if row:
|
|
target_port = int(row[0]["port"])
|
|
target_service = str(svc).lower()
|
|
break
|
|
|
|
# Fallback to b_port
|
|
if target_port is None and action_def.get("b_port"):
|
|
if str(action_def["b_port"]) in ports:
|
|
target_port = int(action_def["b_port"])
|
|
|
|
return target_port, target_service
|
|
|
|
# ----------------------------------------------------- re-queue policy core
|
|
|
|
def _get_last_status(self, mac: str, action_name: str, target_port: Optional[int]) -> Optional[str]:
|
|
"""
|
|
Return last known status for (mac, action, port), considering the
|
|
chronological fields (completed_at > started_at > scheduled_for > created_at).
|
|
"""
|
|
self_port = 0 if target_port is None else int(target_port)
|
|
row = self.db.query(
|
|
"""
|
|
SELECT status
|
|
FROM action_queue
|
|
WHERE mac_address=? AND action_name=? AND COALESCE(port,0)=?
|
|
ORDER BY datetime(COALESCE(completed_at, started_at, scheduled_for, created_at)) DESC
|
|
LIMIT 1
|
|
""",
|
|
(mac, action_name, self_port),
|
|
)
|
|
return row[0]["status"] if row else None
|
|
|
|
def _should_queue_action(
|
|
self, mac: str, action_name: str, action_def: Dict[str, Any], target_port: Optional[int]
|
|
) -> bool:
|
|
"""
|
|
Decide if we should enqueue a new job.
|
|
|
|
Evaluation order:
|
|
0) no duplicate active job
|
|
1) runtime flags (retry_success_actions / retry_failed_actions)
|
|
1-bis) do NOT enqueue if a retryable failed exists (let cleanup_queue() handle it)
|
|
2) cooldown
|
|
3) rate limit
|
|
"""
|
|
self_port = 0 if target_port is None else int(target_port)
|
|
|
|
# 0) Duplicate protection (active)
|
|
existing = self.db.query(
|
|
"""
|
|
SELECT 1 FROM action_queue
|
|
WHERE mac_address=? AND action_name=? AND COALESCE(port,0)=?
|
|
AND status IN ('scheduled','pending','running')
|
|
LIMIT 1
|
|
""",
|
|
(mac, action_name, self_port),
|
|
)
|
|
if existing:
|
|
return False
|
|
|
|
# 1) Runtime flags take precedence
|
|
allow_success = bool(getattr(self.shared_data, "retry_success_actions", False))
|
|
allow_failed = bool(getattr(self.shared_data, "retry_failed_actions", True))
|
|
last_status = self._get_last_status(mac, action_name, target_port)
|
|
|
|
if last_status == "success" and not allow_success:
|
|
return False
|
|
if last_status == "failed" and not allow_failed:
|
|
return False
|
|
|
|
# 1-bis) If a retryable failed exists, let cleanup_queue() requeue it (avoid duplicates)
|
|
if allow_failed:
|
|
retryable = self.db.query(
|
|
"""
|
|
SELECT 1 FROM action_queue
|
|
WHERE mac_address=? AND action_name=? AND COALESCE(port,0)=?
|
|
AND status='failed'
|
|
AND retry_count < max_retries
|
|
AND COALESCE(error_message,'') != 'expired'
|
|
LIMIT 1
|
|
""",
|
|
(mac, action_name, self_port),
|
|
)
|
|
if retryable:
|
|
return False
|
|
|
|
# 2) Cooldown (UTC)
|
|
cooldown = int(action_def.get("b_cooldown", 0) or 0)
|
|
if cooldown > 0:
|
|
last_exec = self._get_last_execution_time(mac, action_name)
|
|
if last_exec and (_utcnow() - last_exec).total_seconds() < cooldown:
|
|
return False
|
|
|
|
# 3) Rate limit (UTC)
|
|
rate_limit = (action_def.get("b_rate_limit") or "").strip()
|
|
if rate_limit and not self._check_rate_limit(mac, action_name, rate_limit):
|
|
return False
|
|
|
|
return True
|
|
|
|
def _queue_action(
|
|
self, host: Dict[str, Any], action_def: Dict[str, Any], target_port: Optional[int], target_service: Optional[str]
|
|
) -> bool:
|
|
"""Queue action for execution (idempotent insert with NOT EXISTS guard)."""
|
|
action_name = action_def["b_class"]
|
|
mac = host["mac_address"]
|
|
timeout = int(action_def.get("b_timeout", 300) or 300)
|
|
expires_at = _db_ts(_utcnow() + timedelta(seconds=timeout))
|
|
self_port = 0 if target_port is None else int(target_port)
|
|
|
|
metadata = {
|
|
"trigger": action_def.get("b_trigger", ""),
|
|
"requirements": action_def.get("b_requires", ""),
|
|
"is_global": False,
|
|
"timeout": timeout,
|
|
"decision_method": "heuristic",
|
|
"decision_origin": "heuristic",
|
|
"ports_snapshot": host.get("ports") or "",
|
|
}
|
|
|
|
context = {
|
|
"mac": mac,
|
|
"hostname": (host.get("hostnames") or "").split(";")[0],
|
|
"ports": [int(p) for p in (host.get("ports") or "").split(";") if p.isdigit()],
|
|
}
|
|
self._annotate_decision_metadata(
|
|
metadata=metadata,
|
|
action_name=action_name,
|
|
context=context,
|
|
decision_scope="queue_host",
|
|
)
|
|
ai_conf = metadata.get("ai_confidence")
|
|
if isinstance(ai_conf, (int, float)) and metadata.get("decision_origin") == "ai_confirmed":
|
|
# Apply small priority boost only when AI confirmed this exact action.
|
|
action_def["b_priority"] = int(action_def.get("b_priority", 50) or 50) + int(20 * float(ai_conf))
|
|
|
|
try:
|
|
affected = self.db.execute(
|
|
"""
|
|
INSERT INTO action_queue (
|
|
action_name, mac_address, ip, port, hostname, service,
|
|
priority, status, max_retries, expires_at,
|
|
trigger_source, tags, metadata
|
|
)
|
|
SELECT ?, ?, ?, ?, ?, ?,
|
|
?, 'pending', ?, ?, ?, ?, ?
|
|
WHERE NOT EXISTS (
|
|
SELECT 1 FROM action_queue
|
|
WHERE mac_address=? AND action_name=? AND COALESCE(port,0)=?
|
|
AND status IN ('scheduled','pending','running')
|
|
)
|
|
""",
|
|
(
|
|
action_name,
|
|
mac,
|
|
(host.get("ips") or "").split(";")[0] if host.get("ips") else "",
|
|
target_port,
|
|
(host.get("hostnames") or "").split(";")[0] if host.get("hostnames") else "",
|
|
target_service,
|
|
int(action_def.get("b_priority", 50) or 50),
|
|
int(action_def.get("b_max_retries", 3) or 3),
|
|
expires_at,
|
|
action_def.get("b_trigger", ""),
|
|
json.dumps(action_def.get("b_tags", [])),
|
|
json.dumps(metadata),
|
|
mac,
|
|
action_name,
|
|
self_port,
|
|
),
|
|
)
|
|
if affected and affected > 0:
|
|
self._log_queue_decision(
|
|
action_name=action_name,
|
|
mac=mac,
|
|
metadata=metadata,
|
|
target_port=target_port,
|
|
target_service=target_service,
|
|
)
|
|
return True
|
|
return False
|
|
except Exception as e:
|
|
logger.error(f"Failed to queue {action_name} for {mac}: {e}")
|
|
return False
|
|
|
|
# ------------------------------------------------------------- last times
|
|
|
|
def _get_last_execution_time(self, mac: str, action_name: str) -> Optional[datetime]:
|
|
"""Get last execution time (DB read only; naive UTC)."""
|
|
row = self.db.query(
|
|
"""
|
|
SELECT completed_at FROM action_queue
|
|
WHERE mac_address=? AND action_name=? AND status IN ('success','failed')
|
|
ORDER BY completed_at DESC
|
|
LIMIT 1
|
|
""",
|
|
(mac, action_name),
|
|
)
|
|
if row and row[0].get("completed_at"):
|
|
try:
|
|
val = row[0]["completed_at"]
|
|
if isinstance(val, str):
|
|
return datetime.fromisoformat(val)
|
|
elif isinstance(val, datetime):
|
|
return val
|
|
except Exception:
|
|
return None
|
|
return None
|
|
|
|
def _get_last_global_execution_time(self, action_name: str) -> Optional[datetime]:
|
|
"""Get last global action execution time (naive UTC)."""
|
|
row = self.db.query(
|
|
"""
|
|
SELECT completed_at FROM action_queue
|
|
WHERE action_name=? AND status IN ('success','failed')
|
|
ORDER BY completed_at DESC
|
|
LIMIT 1
|
|
""",
|
|
(action_name,),
|
|
)
|
|
if row and row[0].get("completed_at"):
|
|
try:
|
|
val = row[0]["completed_at"]
|
|
if isinstance(val, str):
|
|
return datetime.fromisoformat(val)
|
|
elif isinstance(val, datetime):
|
|
return val
|
|
except Exception:
|
|
return None
|
|
return None
|
|
|
|
# ------------------------------------------------------------- constraints
|
|
|
|
def _check_rate_limit(self, mac: str, action_name: str, rate_limit: str) -> bool:
|
|
"""
|
|
Check "X/SECONDS" rate-limit (count based on created_at).
|
|
Returns True if action is allowed to queue.
|
|
"""
|
|
try:
|
|
max_count, period = rate_limit.split("/")
|
|
max_count = int(max_count)
|
|
period = int(period)
|
|
since = _db_ts(_utcnow() - timedelta(seconds=period))
|
|
|
|
count = self.db.query(
|
|
"""
|
|
SELECT COUNT(*) AS c FROM action_queue
|
|
WHERE mac_address=? AND action_name=? AND created_at >= ?
|
|
""",
|
|
(mac, action_name, since),
|
|
)[0]["c"]
|
|
|
|
return int(count) < max_count
|
|
except Exception:
|
|
# Invalid format -> do not block
|
|
return True
|
|
|
|
# -------------------------------------------------------------- maintenance
|
|
|
|
def cleanup_queue(self):
|
|
"""Clean up queue: timeouts, retries, purge old entries."""
|
|
try:
|
|
now_iso = _utcnow_str()
|
|
|
|
# 1) Expire pending actions
|
|
self.db.execute(
|
|
"""
|
|
UPDATE action_queue
|
|
SET status='failed',
|
|
completed_at=CURRENT_TIMESTAMP,
|
|
error_message=COALESCE(error_message,'expired')
|
|
WHERE status='pending'
|
|
AND expires_at IS NOT NULL
|
|
AND expires_at < ?
|
|
""",
|
|
(now_iso,),
|
|
)
|
|
|
|
# 2) Timeout running actions
|
|
self.db.execute(
|
|
"""
|
|
UPDATE action_queue
|
|
SET status='failed',
|
|
completed_at=CURRENT_TIMESTAMP,
|
|
error_message=COALESCE(error_message,'timeout')
|
|
WHERE status='running'
|
|
AND started_at IS NOT NULL
|
|
AND datetime(started_at, '+' || COALESCE(
|
|
CAST(json_extract(metadata, '$.timeout') AS INTEGER), 900
|
|
) || ' seconds') <= datetime('now')
|
|
"""
|
|
)
|
|
|
|
# 3) Retry failed actions with exponential backoff
|
|
if bool(getattr(self.shared_data, "retry_failed_actions", True)):
|
|
# Only if no active job exists for the same (action, mac, port)
|
|
self.db.execute(
|
|
"""
|
|
UPDATE action_queue AS a
|
|
SET status='pending',
|
|
retry_count = retry_count + 1,
|
|
scheduled_for = datetime(
|
|
'now',
|
|
'+' || (
|
|
CASE
|
|
WHEN (60 * (1 << retry_count)) > 900 THEN 900
|
|
ELSE (60 * (1 << retry_count))
|
|
END
|
|
) || ' seconds'
|
|
),
|
|
error_message = NULL,
|
|
started_at = NULL,
|
|
completed_at = NULL
|
|
WHERE a.status='failed'
|
|
AND a.retry_count < a.max_retries
|
|
AND COALESCE(a.error_message,'') != 'expired'
|
|
AND NOT EXISTS (
|
|
SELECT 1 FROM action_queue b
|
|
WHERE b.mac_address=a.mac_address
|
|
AND b.action_name=a.action_name
|
|
AND COALESCE(b.port,0)=COALESCE(a.port,0)
|
|
AND b.status IN ('scheduled','pending','running')
|
|
)
|
|
"""
|
|
)
|
|
|
|
# 4) Purge old completed entries
|
|
old_date = _db_ts(_utcnow() - timedelta(days=7))
|
|
self.db.execute(
|
|
"""
|
|
DELETE FROM action_queue
|
|
WHERE status IN ('success','failed','cancelled','expired')
|
|
AND completed_at < ?
|
|
""",
|
|
(old_date,),
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to cleanup queue: {e}")
|
|
|
|
# update_priorities is defined above (line ~166); this duplicate is removed.
|
|
|
|
|
|
# =================================================================== helpers ==
|
|
|
|
def _normalize_ports(raw) -> List[str]:
|
|
"""Normalize ports to list of strings."""
|
|
if not raw:
|
|
return []
|
|
if isinstance(raw, list):
|
|
return [str(p).split("/")[0] for p in raw if p is not None and str(p) != ""]
|
|
if isinstance(raw, int):
|
|
return [str(raw)]
|
|
if isinstance(raw, str):
|
|
s = raw.strip()
|
|
if not s:
|
|
return []
|
|
if s.startswith("[") and s.endswith("]"):
|
|
try:
|
|
arr = json.loads(s)
|
|
return [str(p).split("/")[0] for p in arr]
|
|
except Exception:
|
|
pass
|
|
if ";" in s:
|
|
return [p.strip().split("/")[0] for p in s.split(";") if p.strip()]
|
|
return [s.split("/")[0]]
|
|
return [str(raw)]
|
|
|
|
|
|
def _has_open_service(mac: str, svc: str, host: Dict[str, Any]) -> bool:
|
|
"""Check if service is open for host (port_services first, then fallback list)."""
|
|
svc = (svc or "").lower().strip()
|
|
|
|
# Check port_services table first
|
|
rows = shared_data.db.query(
|
|
"SELECT 1 FROM port_services WHERE mac_address=? AND state='open' AND LOWER(service)=? LIMIT 1",
|
|
(mac, svc),
|
|
)
|
|
if rows:
|
|
return True
|
|
|
|
# Fallback to known port numbers
|
|
ports = set(_normalize_ports(host.get("ports")))
|
|
for p in SERVICE_PORTS.get(svc, []):
|
|
if p in ports:
|
|
return True
|
|
return False
|
|
|
|
|
|
def _last_presence_event_for_mac(mac: str) -> Optional[str]:
|
|
"""Get last presence event for MAC (PresenceJoin/PresenceLeave)."""
|
|
rows = shared_data.db.query(
|
|
"""
|
|
SELECT action_name
|
|
FROM action_queue
|
|
WHERE mac_address=?
|
|
AND action_name IN ('PresenceJoin','PresenceLeave')
|
|
ORDER BY datetime(COALESCE(completed_at, started_at, scheduled_for, created_at)) DESC
|
|
LIMIT 1
|
|
""",
|
|
(mac,),
|
|
)
|
|
return rows[0]["action_name"] if rows else None
|
|
|
|
|
|
# --------------------------------------------------------------- trigger eval --
|
|
|
|
def evaluate_trigger(trigger: str, host: Dict[str, Any], action_def: Dict[str, Any]) -> bool:
|
|
"""
|
|
Evaluate trigger condition for host.
|
|
|
|
Supported triggers:
|
|
- on_start, on_host_alive, on_host_dead
|
|
- on_port_change, on_new_port:PORT
|
|
- on_service:SERVICE, on_web_service
|
|
- on_success:ACTION, on_failure:ACTION
|
|
- on_cred_found:SERVICE
|
|
- on_mac_is:MAC, on_essid_is:ESSID, on_ip_is:IP
|
|
- on_has_cve[:CVE], on_has_cpe[:CPE]
|
|
- on_all:[...], on_any:[...]
|
|
"""
|
|
try:
|
|
mac = host["mac_address"]
|
|
s = (trigger or "").strip()
|
|
if not s:
|
|
return False
|
|
|
|
# Combined triggers
|
|
if s.startswith("on_all:"):
|
|
try:
|
|
arr = json.loads(s.split(":", 1)[1])
|
|
except Exception:
|
|
return False
|
|
return all(evaluate_trigger(t, host, action_def) for t in arr)
|
|
|
|
if s.startswith("on_any:"):
|
|
try:
|
|
arr = json.loads(s.split(":", 1)[1])
|
|
except Exception:
|
|
return False
|
|
return any(evaluate_trigger(t, host, action_def) for t in arr)
|
|
|
|
# Skip interval triggers
|
|
if s.startswith("on_interval:"):
|
|
return False
|
|
|
|
# Parse trigger name and parameter
|
|
if ":" in s:
|
|
name, param = s.split(":", 1)
|
|
name = name.strip()
|
|
param = (param or "").strip()
|
|
else:
|
|
name, param = s, ""
|
|
|
|
# Aliases
|
|
if name == "on_alive":
|
|
name = "on_host_alive"
|
|
if name == "on_dead":
|
|
name = "on_host_dead"
|
|
|
|
# Join/Leave events
|
|
if name == "on_join":
|
|
if not bool(host.get("alive")):
|
|
return False
|
|
last = _last_presence_event_for_mac(mac)
|
|
return last != "PresenceJoin"
|
|
|
|
if name == "on_leave":
|
|
if bool(host.get("alive")):
|
|
return False
|
|
last = _last_presence_event_for_mac(mac)
|
|
return last != "PresenceLeave"
|
|
|
|
# Basic triggers
|
|
if name == "on_start" or name == "on_new_host":
|
|
r = shared_data.db.query(
|
|
"""
|
|
SELECT 1 FROM action_queue
|
|
WHERE mac_address=? AND action_name=?
|
|
AND status IN ('success','failed')
|
|
LIMIT 1
|
|
""",
|
|
(mac, action_def["b_class"]),
|
|
)
|
|
return not bool(r)
|
|
|
|
if name == "on_host_alive":
|
|
return bool(host.get("alive"))
|
|
|
|
if name == "on_host_dead":
|
|
return not bool(host.get("alive"))
|
|
|
|
# Skip port/service triggers for dead hosts
|
|
if not bool(host.get("alive")) and name in {"on_service", "on_web_service", "on_new_port", "on_port_change"}:
|
|
return False
|
|
|
|
# Port triggers
|
|
if name == "on_port_change":
|
|
cur = set(_normalize_ports(host.get("ports")))
|
|
prev = set(_normalize_ports(host.get("previous_ports")))
|
|
return cur != prev
|
|
|
|
if name == "on_new_port":
|
|
port = str(param)
|
|
cur = set(_normalize_ports(host.get("ports")))
|
|
prev = set(_normalize_ports(host.get("previous_ports")))
|
|
return port in cur and port not in prev
|
|
|
|
# Action status triggers
|
|
if name == "on_success":
|
|
parent = param
|
|
r = shared_data.db.query(
|
|
"""
|
|
SELECT 1 FROM action_queue
|
|
WHERE mac_address=? AND action_name=? AND status='success'
|
|
ORDER BY completed_at DESC LIMIT 1
|
|
""",
|
|
(mac, parent),
|
|
)
|
|
return bool(r)
|
|
|
|
if name == "on_failure":
|
|
parent = param
|
|
r = shared_data.db.query(
|
|
"""
|
|
SELECT 1 FROM action_queue
|
|
WHERE mac_address=? AND action_name=? AND status='failed'
|
|
ORDER BY completed_at DESC LIMIT 1
|
|
""",
|
|
(mac, parent),
|
|
)
|
|
return bool(r)
|
|
|
|
# Service triggers
|
|
if name == "on_cred_found":
|
|
service = param.lower()
|
|
r = shared_data.db.query(
|
|
"SELECT 1 FROM creds WHERE mac_address=? AND LOWER(service)=? LIMIT 1",
|
|
(mac, service),
|
|
)
|
|
return bool(r)
|
|
|
|
if name == "on_service":
|
|
return _has_open_service(mac, param, host)
|
|
|
|
if name == "on_web_service":
|
|
return _has_open_service(mac, "http", host) or _has_open_service(mac, "https", host)
|
|
|
|
# Identity triggers
|
|
if name == "on_mac_is":
|
|
return str(mac).lower() == param.lower()
|
|
|
|
if name == "on_essid_is":
|
|
return (host.get("essid") or "") == param
|
|
|
|
if name == "on_ip_is":
|
|
ips = (host.get("ips") or "").split(";") if host.get("ips") else []
|
|
return param in ips
|
|
|
|
# Vulnerability triggers
|
|
if name == "on_has_cve":
|
|
if not param:
|
|
r = shared_data.db.query(
|
|
"SELECT 1 FROM vulnerabilities WHERE mac_address=? AND is_active=1 LIMIT 1",
|
|
(mac,),
|
|
)
|
|
return bool(r)
|
|
r = shared_data.db.query(
|
|
"SELECT 1 FROM vulnerabilities WHERE mac_address=? AND vuln_id=? AND is_active=1 LIMIT 1",
|
|
(mac, param),
|
|
)
|
|
return bool(r)
|
|
|
|
if name == "on_has_cpe":
|
|
if not param:
|
|
r = shared_data.db.query(
|
|
"SELECT 1 FROM detected_software WHERE mac_address=? AND is_active=1 LIMIT 1",
|
|
(mac,),
|
|
)
|
|
return bool(r)
|
|
r = shared_data.db.query(
|
|
"SELECT 1 FROM detected_software WHERE mac_address=? AND cpe=? AND is_active=1 LIMIT 1",
|
|
(mac, param),
|
|
)
|
|
return bool(r)
|
|
|
|
# Unknown trigger
|
|
logger.debug(f"Unknown trigger: {name}")
|
|
return False
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error evaluating trigger '{trigger}': {e}")
|
|
return False
|
|
|
|
|
|
# ---------------------------------------------------------- requirements eval --
|
|
|
|
def evaluate_requirements(requires: Any, host: Dict[str, Any], action_def: Dict[str, Any]) -> bool:
|
|
"""Evaluate requirements for action."""
|
|
if requires is None:
|
|
return True
|
|
|
|
# Already an object
|
|
if isinstance(requires, (dict, list)):
|
|
return evaluate_requirements_object(requires, host, action_def)
|
|
|
|
s = str(requires).strip()
|
|
if not s:
|
|
return True
|
|
|
|
# JSON string
|
|
if (s.startswith("{") and s.endswith("}")) or (s.startswith("[") and s.endswith("]")):
|
|
try:
|
|
obj = json.loads(s)
|
|
return evaluate_requirements_object(obj, host, action_def)
|
|
except Exception:
|
|
pass
|
|
|
|
# Legacy "Action:status" format
|
|
if ":" in s:
|
|
a, st = s.split(":", 1)
|
|
obj = {"action": a.strip(), "status": st.strip()}
|
|
return evaluate_requirements_object(obj, host, action_def)
|
|
|
|
return True
|
|
|
|
|
|
def evaluate_requirements_object(req: Any, host: Dict[str, Any], action_def: Dict[str, Any]) -> bool:
|
|
"""
|
|
Evaluate requirements object.
|
|
|
|
Supported:
|
|
- {"all": [...]} / {"any": [...]} / {"not": {...}}
|
|
- {"action": "ACTION", "status": "STATUS", "scope": "host|global"}
|
|
- {"has_port": PORT}
|
|
- {"has_cred": "SERVICE"}
|
|
- {"has_cve": "CVE"}
|
|
- {"has_cpe": "CPE"}
|
|
- {"mac_is": "MAC"}
|
|
- {"essid_is": "ESSID"}
|
|
- {"service_is_open": "SERVICE"}
|
|
"""
|
|
mac = host["mac_address"]
|
|
|
|
# Combinators
|
|
if isinstance(req, dict) and "all" in req:
|
|
return all(evaluate_requirements_object(x, host, action_def) for x in (req.get("all") or []))
|
|
|
|
if isinstance(req, dict) and "any" in req:
|
|
return any(evaluate_requirements_object(x, host, action_def) for x in (req.get("any") or []))
|
|
|
|
if isinstance(req, dict) and "not" in req:
|
|
return not evaluate_requirements_object(req.get("not"), host, action_def)
|
|
|
|
# Atomic requirements
|
|
if isinstance(req, dict):
|
|
if "action" in req:
|
|
action = str(req.get("action") or "").strip()
|
|
status = str(req.get("status") or "success").strip()
|
|
scope = str(req.get("scope") or "host").strip().lower()
|
|
|
|
if scope == "global":
|
|
r = shared_data.db.query(
|
|
"""
|
|
SELECT 1 FROM action_queue
|
|
WHERE action_name=? AND status=?
|
|
ORDER BY completed_at DESC LIMIT 1
|
|
""",
|
|
(action, status),
|
|
)
|
|
return bool(r)
|
|
|
|
# Host scope
|
|
r = shared_data.db.query(
|
|
"""
|
|
SELECT 1 FROM action_queue
|
|
WHERE mac_address=? AND action_name=? AND status=?
|
|
ORDER BY completed_at DESC LIMIT 1
|
|
""",
|
|
(mac, action, status),
|
|
)
|
|
return bool(r)
|
|
|
|
if "has_port" in req:
|
|
want = str(req.get("has_port"))
|
|
return want in set(_normalize_ports(host.get("ports")))
|
|
|
|
if "has_cred" in req:
|
|
svc = str(req.get("has_cred") or "").lower()
|
|
r = shared_data.db.query(
|
|
"SELECT 1 FROM creds WHERE mac_address=? AND LOWER(service)=? LIMIT 1",
|
|
(mac, svc),
|
|
)
|
|
return bool(r)
|
|
|
|
if "has_cve" in req:
|
|
cve = str(req.get("has_cve") or "")
|
|
r = shared_data.db.query(
|
|
"SELECT 1 FROM vulnerabilities WHERE mac_address=? AND vuln_id=? AND is_active=1 LIMIT 1",
|
|
(mac, cve),
|
|
)
|
|
return bool(r)
|
|
|
|
if "has_cpe" in req:
|
|
cpe = str(req.get("has_cpe") or "")
|
|
r = shared_data.db.query(
|
|
"SELECT 1 FROM detected_software WHERE mac_address=? AND cpe=? AND is_active=1 LIMIT 1",
|
|
(mac, cpe),
|
|
)
|
|
return bool(r)
|
|
|
|
if "mac_is" in req:
|
|
return str(mac).lower() == str(req.get("mac_is") or "").lower()
|
|
|
|
if "essid_is" in req:
|
|
return (host.get("essid") or "") == str(req.get("essid_is") or "")
|
|
|
|
if "service_is_open" in req:
|
|
svc = str(req.get("service_is_open") or "").lower()
|
|
return _has_open_service(mac, svc, host)
|
|
|
|
# Default: truthy
|
|
return bool(req)
|