Add RLUtils class for managing RL/AI dashboard endpoints

- Implemented methods for fetching AI stats, training history, and recent experiences.
- Added functionality to set operation mode (MANUAL, AUTO, AI) with appropriate handling.
- Included helper methods for querying the database and sending JSON responses.
- Integrated model metadata extraction for visualization purposes.
This commit is contained in:
Fabien POLLY
2026-02-18 22:36:10 +01:00
parent b8a13cc698
commit eb20b168a6
684 changed files with 53278 additions and 27977 deletions

View File

@@ -1,4 +1,4 @@
# action_scheduler.py
# action_scheduler.py testsdd
# Smart Action Scheduler for Bjorn - queue-only implementation
# Handles trigger evaluation, requirements checking, and queue management.
#
@@ -24,6 +24,7 @@ from typing import Any, Dict, List, Optional, Tuple
from init_shared import shared_data
from logger import Logger
from ai_engine import get_or_create_ai_engine
logger = Logger(name="action_scheduler.py")
@@ -73,6 +74,8 @@ class ActionScheduler:
# Runtime flags
self.running = True
self.check_interval = 5 # seconds between iterations
self._stop_event = threading.Event()
self._error_backoff = 1.0
# Action definition cache
self._action_definitions: Dict[str, Dict[str, Any]] = {}
@@ -85,6 +88,22 @@ class ActionScheduler:
self._last_source_is_studio: Optional[bool] = None
# Enforce DB invariants (idempotent)
self._ensure_db_invariants()
# Throttling for priorities
self._last_priority_update = 0.0
self._priority_update_interval = 60.0 # seconds
# Initialize AI engine for recommendations ONLY in AI mode.
# Uses singleton so model weights are loaded only once across the process.
self.ai_engine = None
if self.shared_data.operation_mode == "AI":
self.ai_engine = get_or_create_ai_engine(self.shared_data)
if self.ai_engine is None:
logger.info_throttled(
"AI engine unavailable in scheduler; continuing heuristic-only",
key="scheduler_ai_init_failed",
interval_s=300.0,
)
logger.info("ActionScheduler initialized")
@@ -95,8 +114,24 @@ class ActionScheduler:
logger.info("ActionScheduler starting main loop")
while self.running and not self.shared_data.orchestrator_should_exit:
try:
# If the user toggles AI mode at runtime, enable/disable AI engine without restart.
if self.shared_data.operation_mode == "AI" and self.ai_engine is None:
self.ai_engine = get_or_create_ai_engine(self.shared_data)
if self.ai_engine:
logger.info("Scheduler: AI engine enabled (singleton)")
else:
logger.info_throttled(
"Scheduler: AI engine unavailable; staying heuristic-only",
key="scheduler_ai_enable_failed",
interval_s=300.0,
)
elif self.shared_data.operation_mode != "AI" and self.ai_engine is not None:
self.ai_engine = None
# Refresh action cache if needed
self._refresh_cache_if_needed()
# Keep queue consistent with current enable/disable flags.
self._cancel_queued_disabled_actions()
# 1) Promote scheduled actions that are due
self._promote_scheduled_to_pending()
@@ -114,21 +149,260 @@ class ActionScheduler:
self.cleanup_queue()
self.update_priorities()
time.sleep(self.check_interval)
self._error_backoff = 1.0
if self._stop_event.wait(self.check_interval):
break
except Exception as e:
logger.error(f"Error in scheduler loop: {e}")
time.sleep(self.check_interval)
if self._stop_event.wait(self._error_backoff):
break
self._error_backoff = min(self._error_backoff * 2.0, 15.0)
logger.info("ActionScheduler stopped")
# ----------------------------------------------------------------- priorities
def update_priorities(self):
"""
Update priorities of pending actions.
1. Increase priority over time (starvation prevention) with MIN(100) cap.
2. [AI Mode] Boost priority of actions recommended by AI engine.
"""
now = time.time()
if now - self._last_priority_update < self._priority_update_interval:
return
try:
# 1. Anti-starvation aging: +1 per minute for actions waiting >1 hour.
# julianday is portable across all SQLite builds.
# MIN(100) cap prevents unbounded priority inflation.
affected = self.db.execute(
"""
UPDATE action_queue
SET priority = MIN(100, priority + 1)
WHERE status='pending'
AND julianday('now') - julianday(created_at) > 0.0417
"""
)
self._last_priority_update = now
if affected and affected > 0:
logger.debug(f"Aged {affected} pending actions in queue")
# 2. AI Recommendation Boost
if self.shared_data.operation_mode == "AI" and self.ai_engine:
self._apply_ai_priority_boost()
elif self.shared_data.operation_mode == "AI" and not self.ai_engine:
logger.warning("Operation mode is AI, but ai_engine is not initialized!")
except Exception as e:
logger.error(f"Failed to update priorities: {e}")
def _apply_ai_priority_boost(self):
"""Boost priority of actions recommended by AI engine."""
try:
if not self.ai_engine:
logger.warning("AI Boost skipped: ai_engine is None")
return
# Get list of unique hosts with pending actions
hosts = self.db.query("""
SELECT DISTINCT mac_address FROM action_queue
WHERE status='pending'
""")
if not hosts:
return
for row in hosts:
mac = row['mac_address']
if not mac:
continue
# Get available actions for this host
available = [
r['action_name'] for r in self.db.query("""
SELECT DISTINCT action_name FROM action_queue
WHERE mac_address=? AND status='pending'
""", (mac,))
]
if not available:
continue
# Get host context
host_data = self.db.get_host_by_mac(mac)
if not host_data:
continue
context = {
'mac': mac,
'hostname': (host_data.get('hostnames') or '').split(';')[0],
'ports': [
int(p) for p in (host_data.get('ports') or '').split(';')
if p.isdigit()
]
}
# Ask AI for recommendation
recommended_action, confidence, debug = self.ai_engine.choose_action(
host_context=context,
available_actions=available,
exploration_rate=0.0 # No exploration in scheduler
)
if not isinstance(debug, dict):
debug = {}
threshold = self._get_ai_confirm_threshold()
if recommended_action and confidence >= threshold: # Only boost if confident
# Boost recommended action
boost_amount = int(20 * confidence) # Scale boost by confidence
affected = self.db.execute("""
UPDATE action_queue
SET priority = priority + ?
WHERE mac_address=? AND action_name=? AND status='pending'
""", (boost_amount, mac, recommended_action))
if affected and affected > 0:
# NEW: Update metadata to reflect AI influence
try:
# We fetch all matching IDs to update their metadata
rows = self.db.query("""
SELECT id, metadata FROM action_queue
WHERE mac_address=? AND action_name=? AND status='pending'
""", (mac, recommended_action))
for row in rows:
meta = json.loads(row['metadata'] or '{}')
meta['decision_method'] = f"ai_boosted ({debug.get('method', 'unknown')})"
meta['decision_origin'] = "ai_boosted"
meta['decision_scope'] = "priority_boost"
meta['ai_confidence'] = confidence
meta['ai_threshold'] = threshold
meta['ai_method'] = str(debug.get('method', 'unknown'))
meta['ai_recommended_action'] = recommended_action
meta['ai_model_loaded'] = bool(getattr(self.ai_engine, "model_loaded", False))
meta['ai_reason'] = "priority_boost_applied"
meta['ai_debug'] = debug # Includes all_scores and input_vector
self.db.execute("UPDATE action_queue SET metadata=? WHERE id=?",
(json.dumps(meta), row['id']))
except Exception as meta_e:
logger.error(f"Failed to update metadata for AI boost: {meta_e}")
logger.info(
f"[AI_BOOST] action={recommended_action} mac={mac} boost={boost_amount} "
f"conf={float(confidence):.2f} thr={float(threshold):.2f} "
f"method={debug.get('method', 'unknown')}"
)
except Exception as e:
logger.error(f"Error applying AI priority boost: {e}")
def stop(self):
"""Stop the scheduler."""
logger.info("Stopping ActionScheduler...")
self.running = False
self._stop_event.set()
# --------------------------------------------------------------- definitions
def _get_ai_confirm_threshold(self) -> float:
"""Return normalized AI confirmation threshold in [0.0, 1.0]."""
try:
raw = float(getattr(self.shared_data, "ai_confirm_threshold", 0.3))
except Exception:
raw = 0.3
return max(0.0, min(1.0, raw))
def _annotate_decision_metadata(
self,
metadata: Dict[str, Any],
action_name: str,
context: Dict[str, Any],
decision_scope: str,
) -> None:
"""
Fill metadata with a consistent decision trace:
decision_method/origin + AI method/confidence/threshold/reason.
"""
metadata.setdefault("decision_method", "heuristic")
metadata.setdefault("decision_origin", "heuristic")
metadata["decision_scope"] = decision_scope
threshold = self._get_ai_confirm_threshold()
metadata["ai_threshold"] = threshold
if self.shared_data.operation_mode != "AI":
metadata["ai_reason"] = "ai_mode_disabled"
return
if not self.ai_engine:
metadata["ai_reason"] = "ai_engine_unavailable"
return
try:
recommended, confidence, debug = self.ai_engine.choose_action(
host_context=context,
available_actions=[action_name],
exploration_rate=0.0,
)
ai_method = str((debug or {}).get("method", "unknown"))
confidence_f = float(confidence or 0.0)
model_loaded = bool(getattr(self.ai_engine, "model_loaded", False))
metadata["ai_method"] = ai_method
metadata["ai_confidence"] = confidence_f
metadata["ai_recommended_action"] = recommended or ""
metadata["ai_model_loaded"] = model_loaded
if recommended == action_name and confidence_f >= threshold:
metadata["decision_method"] = f"ai_confirmed ({ai_method})"
metadata["decision_origin"] = "ai_confirmed"
metadata["ai_reason"] = "recommended_above_threshold"
elif recommended != action_name:
metadata["decision_origin"] = "heuristic"
metadata["ai_reason"] = "recommended_different_action"
else:
metadata["decision_origin"] = "heuristic"
metadata["ai_reason"] = "confidence_below_threshold"
except Exception as e:
metadata["ai_reason"] = "ai_check_failed"
logger.debug(f"AI decision annotation failed for {action_name}: {e}")
def _log_queue_decision(
self,
action_name: str,
mac: str,
metadata: Dict[str, Any],
target_port: Optional[int] = None,
target_service: Optional[str] = None,
) -> None:
"""Emit a compact, explicit queue-decision log line."""
decision = str(metadata.get("decision_method", "heuristic"))
origin = str(metadata.get("decision_origin", "heuristic"))
ai_method = str(metadata.get("ai_method", "n/a"))
ai_reason = str(metadata.get("ai_reason", "n/a"))
ai_conf = metadata.get("ai_confidence")
ai_thr = metadata.get("ai_threshold")
scope = str(metadata.get("decision_scope", "unknown"))
conf_txt = f"{float(ai_conf):.2f}" if isinstance(ai_conf, (int, float)) else "n/a"
thr_txt = f"{float(ai_thr):.2f}" if isinstance(ai_thr, (int, float)) else "n/a"
model_loaded = bool(metadata.get("ai_model_loaded", False))
port_txt = "None" if target_port is None else str(target_port)
svc_txt = target_service if target_service else "None"
logger.info(
f"[QUEUE_DECISION] scope={scope} action={action_name} mac={mac} port={port_txt} service={svc_txt} "
f"decision={decision} origin={origin} ai_method={ai_method} conf={conf_txt} thr={thr_txt} "
f"model_loaded={model_loaded} reason={ai_reason}"
)
# ---------- replace this method ----------
def _refresh_cache_if_needed(self):
"""Refresh action definitions cache if expired or source flipped."""
@@ -160,6 +434,9 @@ class ActionScheduler:
# Build cache (expect same action schema: b_class, b_trigger, b_action, etc.)
self._action_definitions = {a["b_class"]: a for a in actions}
# Runtime truth: orchestrator loads from `actions`, so align b_enabled to it
# even when scheduler uses `actions_studio` as source.
self._overlay_runtime_enabled_flags()
logger.info(f"Refreshed action cache from '{source}': {len(self._action_definitions)} actions")
except AttributeError as e:
@@ -179,6 +456,67 @@ class ActionScheduler:
except Exception as e:
logger.error(f"Failed to refresh action cache: {e}")
def _is_action_enabled(self, action_def: Dict[str, Any]) -> bool:
"""Parse b_enabled robustly across int/bool/string/null values."""
raw = action_def.get("b_enabled", 1)
if raw is None:
return True
if isinstance(raw, bool):
return raw
if isinstance(raw, (int, float)):
return int(raw) == 1
s = str(raw).strip().lower()
if s in {"1", "true", "yes", "on"}:
return True
if s in {"0", "false", "no", "off"}:
return False
try:
return int(float(s)) == 1
except Exception:
# Conservative default: keep action enabled when value is malformed.
return True
def _overlay_runtime_enabled_flags(self):
"""
Override cached `b_enabled` with runtime `actions` table values.
This keeps scheduler decisions aligned with orchestrator loaded actions.
"""
try:
runtime_rows = self.db.list_actions()
runtime_map = {r.get("b_class"): r.get("b_enabled", 1) for r in runtime_rows}
for action_name, action_def in self._action_definitions.items():
if action_name in runtime_map:
action_def["b_enabled"] = runtime_map[action_name]
except Exception as e:
logger.warning(f"Could not overlay runtime b_enabled flags: {e}")
def _cancel_queued_disabled_actions(self):
"""Cancel pending/scheduled queue entries for currently disabled actions."""
try:
disabled = [
name for name, definition in self._action_definitions.items()
if not self._is_action_enabled(definition)
]
if not disabled:
return
placeholders = ",".join("?" for _ in disabled)
affected = self.db.execute(
f"""
UPDATE action_queue
SET status='cancelled',
completed_at=CURRENT_TIMESTAMP,
error_message=COALESCE(error_message, 'disabled_by_config')
WHERE status IN ('scheduled','pending')
AND action_name IN ({placeholders})
""",
tuple(disabled),
)
if affected:
logger.info(f"Cancelled {affected} queued action(s) because b_enabled=0")
except Exception as e:
logger.error(f"Failed to cancel queued disabled actions: {e}")
# ------------------------------------------------------------------ helpers
@@ -248,7 +586,7 @@ class ActionScheduler:
for action in self._action_definitions.values():
if (action.get("b_action") or "normal") != "global":
continue
if int(action.get("b_enabled", 1) or 1) != 1:
if not self._is_action_enabled(action):
continue
trigger = (action.get("b_trigger") or "").strip()
@@ -275,7 +613,7 @@ class ActionScheduler:
for action in self._action_definitions.values():
if (action.get("b_action") or "normal") == "global":
continue
if int(action.get("b_enabled", 1) or 1) != 1:
if not self._is_action_enabled(action):
continue
trigger = (action.get("b_trigger") or "").strip()
@@ -309,6 +647,19 @@ class ActionScheduler:
next_run = _utcnow() if not last else (last + timedelta(seconds=interval))
scheduled_for = _db_ts(next_run)
metadata = {
"interval": interval,
"is_global": True,
"decision_method": "heuristic",
"decision_origin": "heuristic",
}
self._annotate_decision_metadata(
metadata=metadata,
action_name=action_name,
context={"mac": mac, "hostname": "Bjorn-C2", "ports": []},
decision_scope="scheduled_global",
)
inserted = self.db.ensure_scheduled_occurrence(
action_name=action_name,
next_run_at=scheduled_for,
@@ -317,7 +668,7 @@ class ActionScheduler:
priority=int(action_def.get("b_priority", 40) or 40),
trigger="scheduler",
tags=action_def.get("b_tags", []),
metadata={"interval": interval, "is_global": True},
metadata=metadata,
max_retries=int(action_def.get("b_max_retries", 3) or 3),
)
if inserted:
@@ -354,6 +705,23 @@ class ActionScheduler:
next_run = _utcnow() if not last else (last + timedelta(seconds=interval))
scheduled_for = _db_ts(next_run)
metadata = {
"interval": interval,
"is_global": False,
"decision_method": "heuristic",
"decision_origin": "heuristic",
}
self._annotate_decision_metadata(
metadata=metadata,
action_name=action_name,
context={
"mac": mac,
"hostname": (host.get("hostnames") or "").split(";")[0],
"ports": [int(p) for p in (host.get("ports") or "").split(";") if p.isdigit()],
},
decision_scope="scheduled_host",
)
inserted = self.db.ensure_scheduled_occurrence(
action_name=action_name,
next_run_at=scheduled_for,
@@ -362,7 +730,7 @@ class ActionScheduler:
priority=int(action_def.get("b_priority", 40) or 40),
trigger="scheduler",
tags=action_def.get("b_tags", []),
metadata={"interval": interval, "is_global": False},
metadata=metadata,
max_retries=int(action_def.get("b_max_retries", 3) or 3),
)
if inserted:
@@ -382,7 +750,7 @@ class ActionScheduler:
for action in self._action_definitions.values():
if (action.get("b_action") or "normal") != "global":
continue
if int(action.get("b_enabled", 1)) != 1:
if not self._is_action_enabled(action):
continue
trigger = (action.get("b_trigger") or "").strip()
@@ -409,14 +777,13 @@ class ActionScheduler:
continue
# Queue the action
self._queue_global_action(action)
self._last_global_runs[action_name] = time.time()
logger.info(f"Queued global action {action_name}")
if self._queue_global_action(action):
self._last_global_runs[action_name] = time.time()
except Exception as e:
logger.error(f"Error evaluating global actions: {e}")
def _queue_global_action(self, action_def: Dict[str, Any]):
def _queue_global_action(self, action_def: Dict[str, Any]) -> bool:
"""Queue a global action for execution (idempotent insert)."""
action_name = action_def["b_class"]
mac = self.ctrl_mac
@@ -429,12 +796,30 @@ class ActionScheduler:
"requirements": action_def.get("b_requires", ""),
"timeout": timeout,
"is_global": True,
"decision_method": "heuristic",
"decision_origin": "heuristic",
}
# Global context (controller itself)
context = {
"mac": mac,
"hostname": "Bjorn-C2",
"ports": [] # Global actions usually don't target specific ports on controller
}
self._annotate_decision_metadata(
metadata=metadata,
action_name=action_name,
context=context,
decision_scope="queue_global",
)
ai_conf = metadata.get("ai_confidence")
if isinstance(ai_conf, (int, float)) and metadata.get("decision_origin") == "ai_confirmed":
action_def["b_priority"] = int(action_def.get("b_priority", 50) or 50) + int(20 * float(ai_conf))
try:
self._ensure_host_exists(mac)
# Guard with NOT EXISTS to avoid races
self.db.execute(
affected = self.db.execute(
"""
INSERT INTO action_queue (
action_name, mac_address, ip, port, hostname, service,
@@ -463,8 +848,13 @@ class ActionScheduler:
mac,
),
)
if affected and affected > 0:
self._log_queue_decision(action_name=action_name, mac=mac, metadata=metadata)
return True
return False
except Exception as e:
logger.error(f"Failed to queue global action {action_name}: {e}")
return False
# ------------------------------------------------------------- host path
@@ -480,7 +870,7 @@ class ActionScheduler:
continue
# Skip disabled actions
if not int(action_def.get("b_enabled", 1)):
if not self._is_action_enabled(action_def):
continue
trigger = (action_def.get("b_trigger") or "").strip()
@@ -509,7 +899,6 @@ class ActionScheduler:
# Queue the action
self._queue_action(host, action_def, target_port, target_service)
logger.info(f"Queued {action_name} for {mac} (port={target_port}, service={target_service})")
def _resolve_target_port_service(
self, mac: str, host: Dict[str, Any], action_def: Dict[str, Any]
@@ -640,7 +1029,7 @@ class ActionScheduler:
def _queue_action(
self, host: Dict[str, Any], action_def: Dict[str, Any], target_port: Optional[int], target_service: Optional[str]
):
) -> bool:
"""Queue action for execution (idempotent insert with NOT EXISTS guard)."""
action_name = action_def["b_class"]
mac = host["mac_address"]
@@ -653,11 +1042,29 @@ class ActionScheduler:
"requirements": action_def.get("b_requires", ""),
"is_global": False,
"timeout": timeout,
"decision_method": "heuristic",
"decision_origin": "heuristic",
"ports_snapshot": host.get("ports") or "",
}
context = {
"mac": mac,
"hostname": (host.get("hostnames") or "").split(";")[0],
"ports": [int(p) for p in (host.get("ports") or "").split(";") if p.isdigit()],
}
self._annotate_decision_metadata(
metadata=metadata,
action_name=action_name,
context=context,
decision_scope="queue_host",
)
ai_conf = metadata.get("ai_confidence")
if isinstance(ai_conf, (int, float)) and metadata.get("decision_origin") == "ai_confirmed":
# Apply small priority boost only when AI confirmed this exact action.
action_def["b_priority"] = int(action_def.get("b_priority", 50) or 50) + int(20 * float(ai_conf))
try:
self.db.execute(
affected = self.db.execute(
"""
INSERT INTO action_queue (
action_name, mac_address, ip, port, hostname, service,
@@ -690,8 +1097,19 @@ class ActionScheduler:
self_port,
),
)
if affected and affected > 0:
self._log_queue_decision(
action_name=action_name,
mac=mac,
metadata=metadata,
target_port=target_port,
target_service=target_service,
)
return True
return False
except Exception as e:
logger.error(f"Failed to queue {action_name} for {mac}: {e}")
return False
# ------------------------------------------------------------- last times
@@ -708,7 +1126,11 @@ class ActionScheduler:
)
if row and row[0].get("completed_at"):
try:
return datetime.fromisoformat(row[0]["completed_at"])
val = row[0]["completed_at"]
if isinstance(val, str):
return datetime.fromisoformat(val)
elif isinstance(val, datetime):
return val
except Exception:
return None
return None
@@ -726,7 +1148,11 @@ class ActionScheduler:
)
if row and row[0].get("completed_at"):
try:
return datetime.fromisoformat(row[0]["completed_at"])
val = row[0]["completed_at"]
if isinstance(val, str):
return datetime.fromisoformat(val)
elif isinstance(val, datetime):
return val
except Exception:
return None
return None
@@ -840,19 +1266,7 @@ class ActionScheduler:
except Exception as e:
logger.error(f"Failed to cleanup queue: {e}")
def update_priorities(self):
"""Boost priority for actions waiting too long (anti-starvation)."""
try:
self.db.execute(
"""
UPDATE action_queue
SET priority = MIN(100, priority + 1)
WHERE status='pending'
AND julianday('now') - julianday(created_at) > 0.0417
"""
)
except Exception as e:
logger.error(f"Failed to update priorities: {e}")
# update_priorities is defined above (line ~166); this duplicate is removed.
# =================================================================== helpers ==