Add RLUtils class for managing RL/AI dashboard endpoints

- Implemented methods for fetching AI stats, training history, and recent experiences.
- Added functionality to set operation mode (MANUAL, AUTO, AI) with appropriate handling.
- Included helper methods for querying the database and sending JSON responses.
- Integrated model metadata extraction for visualization purposes.
This commit is contained in:
Fabien POLLY
2026-02-18 22:36:10 +01:00
parent b8a13cc698
commit eb20b168a6
684 changed files with 53278 additions and 27977 deletions

View File

@@ -12,6 +12,9 @@ from typing import Any, Dict, List, Optional
from init_shared import shared_data
from logger import Logger
from action_scheduler import ActionScheduler
from ai_engine import get_or_create_ai_engine, invalidate_ai_engine
from feature_logger import FeatureLogger
from data_consolidator import DataConsolidator
logger = Logger(name="orchestrator.py", level=logging.DEBUG)
@@ -25,10 +28,117 @@ class Orchestrator:
self.network_scanner = None
self.scheduler = None
self.scheduler_thread = None
self._loop_error_backoff = 1.0
# ┌─────────────────────────────────────────────────────────┐
# │ AI / Feature-logging Components │
# └─────────────────────────────────────────────────────────┘
# feature_logger runs in AUTO and AI mode to collect training data
# from ALL automated executions.
# ai_engine + data_consolidator run only in AI mode.
self.ai_engine = None
self.data_consolidator = None
self.ai_enabled = bool(self.shared_data.operation_mode == "AI")
self._ai_server_failure_streak = 0
# FeatureLogger: active as long as the orchestrator runs (AUTO or AI)
self.feature_logger = None
if self.shared_data.operation_mode in ("AUTO", "AI"):
try:
self.feature_logger = FeatureLogger(self.shared_data)
logger.info("FeatureLogger initialized (data collection active)")
except Exception as e:
logger.info_throttled(
f"FeatureLogger unavailable; execution data will not be logged: {e}",
key="orch_feature_logger_init_failed",
interval_s=300.0,
)
self.feature_logger = None
if self.ai_enabled:
try:
self.ai_engine = get_or_create_ai_engine(self.shared_data)
self.data_consolidator = DataConsolidator(self.shared_data)
logger.info("AI engine + DataConsolidator initialized (AI mode)")
except Exception as e:
logger.info_throttled(
f"AI mode active but AI components unavailable; continuing heuristic-only: {e}",
key="orch_ai_init_failed",
interval_s=300.0,
)
self.ai_engine = None
self.data_consolidator = None
self.ai_enabled = False
# Load all available actions
self.load_actions()
logger.info(f"Actions loaded: {list(self.actions.keys())}")
def _is_enabled_value(self, value: Any) -> bool:
"""Robust parser for b_enabled values coming from DB."""
if value is None:
return True
if isinstance(value, bool):
return value
if isinstance(value, (int, float)):
return int(value) == 1
s = str(value).strip().lower()
if s in {"1", "true", "yes", "on"}:
return True
if s in {"0", "false", "no", "off"}:
return False
try:
return int(float(s)) == 1
except Exception:
return True
def _is_action_eligible_for_ai_learning(self, action_name: str) -> bool:
"""Exclude control-plane actions from AI training/reward."""
return str(action_name) not in {"NetworkScanner"}
def _update_ai_server_health(self, contact_events: List[bool]) -> None:
"""
Update consecutive AI server failure counter and fallback to AUTO when needed.
`contact_events` contains one bool per attempted contact in this cycle.
"""
if not contact_events:
return
contacted_ok = any(contact_events)
if contacted_ok:
if self._ai_server_failure_streak > 0:
logger.info("AI server contact recovered; reset failure streak")
self._ai_server_failure_streak = 0
return
self._ai_server_failure_streak += 1
max_failures = max(
1,
int(getattr(self.shared_data, "ai_server_max_failures_before_auto", 3)),
)
model_loaded = bool(getattr(self.ai_engine, "model_loaded", False))
if self.shared_data.operation_mode == "AI" and (not model_loaded):
remaining_cycles = max(0, max_failures - self._ai_server_failure_streak)
if remaining_cycles > 0:
logger.info_throttled(
f"AI server unreachable ({self._ai_server_failure_streak}/{max_failures}) and no local model loaded; "
f"AUTO fallback in {remaining_cycles} cycle(s) if server remains offline",
key="orch_ai_unreachable_no_model_pre_fallback",
interval_s=60.0,
)
if (
self.shared_data.operation_mode == "AI"
and self._ai_server_failure_streak >= max_failures
and (not model_loaded)
):
logger.warning(
f"AI server unreachable for {self._ai_server_failure_streak} consecutive cycles and no local AI model is loaded; "
"switching operation mode to AUTO (heuristics-only)"
)
self.shared_data.operation_mode = "AUTO"
self._disable_ai_components()
def load_actions(self):
"""Load all actions from database"""
@@ -64,9 +174,82 @@ class Orchestrator:
except Exception as e:
logger.error(f"Failed to load action {b_class}: {e}")
# ----------------------------------------------------------------- AI mode
def _ensure_feature_logger(self) -> None:
"""Init FeatureLogger if not yet running (called when entering AUTO or AI mode)."""
if self.feature_logger is not None:
return
try:
self.feature_logger = FeatureLogger(self.shared_data)
logger.info("FeatureLogger enabled")
except Exception as e:
logger.info_throttled(
f"FeatureLogger unavailable: {e}",
key="orch_feature_logger_enable_failed",
interval_s=300.0,
)
def _enable_ai_components(self) -> None:
"""Lazy-init AI-specific helpers when switching to AI mode at runtime."""
self._ensure_feature_logger()
if self.ai_engine and self.data_consolidator:
self.ai_enabled = True
return
try:
self.ai_engine = get_or_create_ai_engine(self.shared_data)
self.data_consolidator = DataConsolidator(self.shared_data)
self.ai_enabled = True
if self.ai_engine and not bool(getattr(self.ai_engine, "model_loaded", False)):
logger.warning(
"AI mode active but no local model loaded yet; "
"will fallback to AUTO if server stays unreachable"
)
logger.info("AI engine + DataConsolidator enabled")
except Exception as e:
self.ai_engine = None
self.data_consolidator = None
self.ai_enabled = False
logger.info_throttled(
f"AI components not available; staying heuristic-only: {e}",
key="orch_ai_enable_failed",
interval_s=300.0,
)
def _disable_ai_components(self) -> None:
"""Drop AI-specific helpers when leaving AI mode.
FeatureLogger is kept alive so AUTO mode still collects data."""
self.ai_enabled = False
self.ai_engine = None
self.data_consolidator = None
# Release cached AI engine singleton so model weights can be freed in AUTO mode.
try:
invalidate_ai_engine(self.shared_data)
except Exception:
pass
def _sync_ai_components(self) -> None:
"""Keep runtime AI helpers aligned with shared_data.operation_mode."""
mode = self.shared_data.operation_mode
if mode == "AI":
if not self.ai_enabled:
self._enable_ai_components()
else:
if self.ai_enabled:
self._disable_ai_components()
# Ensure feature_logger is alive in AUTO mode too
if mode == "AUTO":
self._ensure_feature_logger()
def start_scheduler(self):
"""Start the scheduler in background"""
if self.scheduler_thread and self.scheduler_thread.is_alive():
logger.info("ActionScheduler thread already running")
return
logger.info("Starting ActionScheduler in background...")
self.scheduler = ActionScheduler(self.shared_data)
self.scheduler_thread = threading.Thread(
@@ -87,24 +270,227 @@ class Orchestrator:
)
return action
def _build_host_state(self, mac_address: str) -> Dict:
"""
Build RL state dict from host data in database.
Args:
mac_address: Target MAC address
Returns:
Dict with keys: mac, ports, hostname
"""
try:
# Get host from database
host = self.shared_data.db.get_host_by_mac(mac_address)
if not host:
logger.warning(f"Host not found for MAC: {mac_address}")
return {'mac': mac_address, 'ports': [], 'hostname': ''}
# Parse ports
ports_str = host.get('ports', '')
ports = []
if ports_str:
for p in ports_str.split(';'):
p = p.strip()
if p.isdigit():
ports.append(int(p))
# Get first hostname
hostnames_str = host.get('hostnames', '')
hostname = hostnames_str.split(';')[0] if hostnames_str else ''
return {
'mac': mac_address,
'ports': ports,
'hostname': hostname
}
except Exception as e:
logger.error(f"Error building host state: {e}")
return {'mac': mac_address, 'ports': [], 'hostname': ''}
def _calculate_reward(
self,
action_name: str,
success: bool,
duration: float,
mac: str,
state_before: Dict,
state_after: Dict
) -> float:
"""
Calculate reward for RL update.
Reward structure:
- Base: +50 for success, -5 for failure
- Credentials found: +100
- New services: +20 per service
- Time bonus: +20 if <30s, -10 if >120s
- New ports discovered: +15 per port
Args:
action_name: Name of action executed
success: Did action succeed?
duration: Execution time in seconds
mac: Target MAC address
state_before: State dict before action
state_after: State dict after action
Returns:
Reward value (float)
"""
if not self._is_action_eligible_for_ai_learning(action_name):
return 0.0
# Base reward
reward = 50.0 if success else -5.0
if not success:
# Penalize time waste on failure
reward -= (duration * 0.1)
return reward
# ─────────────────────────────────────────────────────────
# Check for credentials found (high value!)
# ─────────────────────────────────────────────────────────
try:
recent_creds = self.shared_data.db.query("""
SELECT COUNT(*) as cnt FROM creds
WHERE mac_address=?
AND first_seen > datetime('now', '-1 minute')
""", (mac,))
if recent_creds and recent_creds[0]['cnt'] > 0:
creds_count = recent_creds[0]['cnt']
reward += 100 * creds_count # 100 per credential!
logger.info(f"RL: +{100*creds_count} reward for {creds_count} credentials")
except Exception as e:
logger.error(f"Error checking credentials: {e}")
# ─────────────────────────────────────────────────────────
# Check for new services discovered
# ─────────────────────────────────────────────────────────
try:
# Compare ports before/after
ports_before = set(state_before.get('ports', []))
ports_after = set(state_after.get('ports', []))
new_ports = ports_after - ports_before
if new_ports:
reward += 15 * len(new_ports)
logger.info(f"RL: +{15*len(new_ports)} reward for {len(new_ports)} new ports")
except Exception as e:
logger.error(f"Error checking new ports: {e}")
# ─────────────────────────────────────────────────────────
# Time efficiency bonus/penalty
# ─────────────────────────────────────────────────────────
if duration < 30:
reward += 20 # Fast execution bonus
elif duration > 120:
reward -= 10 # Slow execution penalty
# ─────────────────────────────────────────────────────────
# Action-specific bonuses
# ─────────────────────────────────────────────────────────
if action_name == "SSHBruteforce" and success:
# Extra bonus for SSH success (difficult action)
reward += 30
logger.debug(f"RL Reward calculated: {reward:.1f} for {action_name}")
return reward
def execute_queued_action(self, queued_action: Dict[str, Any]) -> bool:
"""Execute a single queued action"""
"""Execute a single queued action with RL integration"""
queue_id = queued_action['id']
action_name = queued_action['action_name']
mac = queued_action['mac_address']
ip = queued_action['ip']
port = queued_action['port']
logger.info(f"Executing: {action_name} for {ip}:{port}")
# Parse metadata once — used throughout this function
metadata = json.loads(queued_action.get('metadata', '{}'))
source = str(metadata.get('decision_method', 'unknown'))
source_label = f"[{source.upper()}]" if source != 'unknown' else ""
decision_origin = str(metadata.get('decision_origin', 'unknown'))
ai_confidence = metadata.get('ai_confidence')
ai_threshold = metadata.get('ai_threshold', getattr(self.shared_data, "ai_confirm_threshold", 0.3))
ai_reason = str(metadata.get('ai_reason', 'n/a'))
ai_method = metadata.get('ai_method')
if not ai_method:
ai_method = (metadata.get('ai_debug') or {}).get('method')
ai_method = str(ai_method or 'n/a')
ai_model_loaded = bool(metadata.get('ai_model_loaded', bool(getattr(self.ai_engine, "model_loaded", False)) if self.ai_engine else False))
decision_scope = str(metadata.get('decision_scope', 'unknown'))
exec_payload = {
"action": action_name,
"target": ip,
"port": port,
"decision_method": source,
"decision_origin": decision_origin,
"decision_scope": decision_scope,
"ai_method": ai_method,
"ai_confidence": ai_confidence if isinstance(ai_confidence, (int, float)) else None,
"ai_threshold": ai_threshold if isinstance(ai_threshold, (int, float)) else None,
"ai_model_loaded": ai_model_loaded,
"ai_reason": ai_reason,
}
logger.info(f"Executing {source_label}: {action_name} for {ip}:{port}")
logger.info(f"[DECISION_EXEC] {json.dumps(exec_payload)}")
# Guard rail: stale queue rows can exist for disabled or not-loaded actions.
try:
action_row = self.shared_data.db.get_action_by_class(action_name)
if action_row and not self._is_enabled_value(action_row.get("b_enabled", 1)):
self.shared_data.db.update_queue_status(
queue_id,
'cancelled',
f"Action {action_name} disabled (b_enabled=0)",
)
logger.info(f"Skipping queued disabled action: {action_name}")
return False
except Exception as e:
logger.debug(f"Could not verify b_enabled for {action_name}: {e}")
if action_name not in self.actions:
self.shared_data.db.update_queue_status(
queue_id,
'cancelled',
f"Action {action_name} not loaded",
)
logger.warning(f"Skipping queued action not loaded: {action_name}")
return False
# ┌─────────────────────────────────────────────────────────┐
# │ STEP 1: Capture state BEFORE action (all modes) │
# └─────────────────────────────────────────────────────────┘
state_before = None
if self.feature_logger:
try:
state_before = self._build_host_state(mac)
logger.debug(f"State before captured for {mac}")
except Exception as e:
logger.info_throttled(
f"State capture skipped: {e}",
key="orch_state_before_failed",
interval_s=120.0,
)
# Update status to running
self.shared_data.db.update_queue_status(queue_id, 'running')
# ┌─────────────────────────────────────────────────────────┐
# │ EXECUTE ACTION (existing code) │
# └─────────────────────────────────────────────────────────┘
start_time = time.time()
success = False
try:
# Check if action is loaded
if action_name not in self.actions:
raise Exception(f"Action {action_name} not loaded")
action = self.actions[action_name]
# Prepare row data for compatibility
@@ -115,12 +501,49 @@ class Orchestrator:
"Alive": 1
}
# Prepare status details
if ip and ip != "0.0.0.0":
port_str = str(port).strip() if port is not None else ""
has_port = bool(port_str) and port_str.lower() != "none"
target_display = f"{ip}:{port_str}" if has_port else ip
status_msg = f"{action_name} on {ip}"
details = f"Target: {target_display}"
self.shared_data.action_target_ip = target_display
else:
status_msg = f"{action_name} (Global)"
details = "Scanning network..."
self.shared_data.action_target_ip = ""
# Update shared status for display
self.shared_data.bjorn_orch_status = action_name
self.shared_data.bjorn_status_text2 = ip
self.shared_data.bjorn_status_text2 = self.shared_data.action_target_ip or ip
self.shared_data.update_status(status_msg, details)
# Check if global action
metadata = json.loads(queued_action.get('metadata', '{}'))
# --- AI Dashboard Metadata (AI mode only) ---
if (
self.ai_enabled
and self.shared_data.operation_mode == "AI"
and self._is_action_eligible_for_ai_learning(action_name)
):
decision_method = metadata.get('decision_method', 'heuristic')
self.shared_data.active_action = action_name
self.shared_data.last_decision_method = decision_method
self.shared_data.last_ai_decision = metadata.get('ai_debug', {})
ai_exec_payload = {
"action": action_name,
"method": decision_method,
"origin": decision_origin,
"target": ip,
"ai_method": ai_method,
"ai_confidence": ai_confidence if isinstance(ai_confidence, (int, float)) else None,
"ai_threshold": ai_threshold if isinstance(ai_threshold, (int, float)) else None,
"ai_model_loaded": ai_model_loaded,
"reason": ai_reason,
}
logger.info(f"[AI_EXEC] {json.dumps(ai_exec_payload)}")
# Check if global action (metadata already parsed above)
if metadata.get('is_global') and hasattr(action, 'scan'):
# Execute global scan
action.scan()
@@ -134,23 +557,92 @@ class Orchestrator:
action_name
)
# Determine success
success = (result == 'success')
# Update queue status based on result
if result == 'success':
if success:
self.shared_data.db.update_queue_status(queue_id, 'success')
logger.success(f"Action {action_name} completed successfully for {ip}")
else:
self.shared_data.db.update_queue_status(queue_id, 'failed')
logger.warning(f"Action {action_name} failed for {ip}")
return result == 'success'
except Exception as e:
logger.error(f"Error executing action {action_name}: {e}")
self.shared_data.db.update_queue_status(queue_id, 'failed', str(e))
return False
success = False
finally:
if (
self.ai_enabled
and self.shared_data.operation_mode == "AI"
and self._is_action_eligible_for_ai_learning(action_name)
):
ai_done_payload = {
"action": action_name,
"success": bool(success),
"method": source,
"origin": decision_origin,
}
logger.info(f"[AI_DONE] {json.dumps(ai_done_payload)}")
self.shared_data.active_action = None
# Clear status text
self.shared_data.bjorn_status_text2 = ""
self.shared_data.action_target_ip = ""
# Reset Status to Thinking/Idle
self.shared_data.update_status("Thinking", "Processing results...")
duration = time.time() - start_time
# ┌─────────────────────────────────────────────────────────┐
# │ STEP 2: Log execution features (AUTO + AI modes) │
# └─────────────────────────────────────────────────────────┘
if self.feature_logger and state_before and self._is_action_eligible_for_ai_learning(action_name):
try:
reward = self._calculate_reward(
action_name=action_name,
success=success,
duration=duration,
mac=mac,
state_before=state_before,
state_after=self._build_host_state(mac),
)
self.feature_logger.log_action_execution(
mac_address=mac,
ip_address=ip,
action_name=action_name,
success=success,
duration=duration,
reward=reward,
raw_event={
'port': port,
'action': action_name,
'queue_id': queue_id,
# metadata already parsed — no second json.loads
'metadata': metadata,
# Tag decision source so the training pipeline can weight
# human choices (MANUAL would be logged if orchestrator
# ever ran in that mode) vs automated ones.
'decision_source': self.shared_data.operation_mode,
'human_override': False,
},
)
logger.debug(f"Features logged for {action_name} (mode={self.shared_data.operation_mode})")
except Exception as e:
logger.info_throttled(
f"Feature logging skipped: {e}",
key="orch_feature_log_failed",
interval_s=120.0,
)
elif self.feature_logger and state_before:
logger.debug(f"Feature logging disabled for {action_name} (excluded from AI learning)")
return success
def run(self):
"""Main loop: start scheduler and consume queue"""
@@ -164,9 +656,13 @@ class Orchestrator:
# Main execution loop
idle_time = 0
consecutive_idle_logs = 0
self._last_background_task = 0
while not self.shared_data.orchestrator_should_exit:
try:
# Allow live mode switching from the UI without restarting the process.
self._sync_ai_components()
# Get next action from queue
next_action = self.get_next_action()
@@ -174,14 +670,17 @@ class Orchestrator:
# Reset idle counters
idle_time = 0
consecutive_idle_logs = 0
self._loop_error_backoff = 1.0
# Execute the action
self.execute_queued_action(next_action)
else:
# IDLE mode
idle_time += 1
self.shared_data.bjorn_orch_status = "IDLE"
self.shared_data.bjorn_status_text2 = ""
self.shared_data.action_target_ip = ""
# Log periodically (less spam)
if idle_time % 30 == 0: # Every 30 seconds
@@ -192,18 +691,96 @@ class Orchestrator:
# Event-driven wait (max 5s to check for exit signals)
self.shared_data.queue_event.wait(timeout=5)
self.shared_data.queue_event.clear()
# Periodically process background tasks (even if busy)
current_time = time.time()
sync_interval = int(getattr(self.shared_data, "ai_sync_interval", 60))
if current_time - self._last_background_task > sync_interval:
self._process_background_tasks()
self._last_background_task = current_time
except Exception as e:
logger.error(f"Error in orchestrator loop: {e}")
time.sleep(1)
time.sleep(self._loop_error_backoff)
self._loop_error_backoff = min(self._loop_error_backoff * 2.0, 10.0)
# Cleanup on exit
# Cleanup on exit (OUTSIDE while loop)
if self.scheduler:
self.scheduler.stop()
self.shared_data.queue_event.set()
if self.scheduler_thread and self.scheduler_thread.is_alive():
self.scheduler_thread.join(timeout=10.0)
if self.scheduler_thread.is_alive():
logger.warning("ActionScheduler thread did not exit cleanly")
logger.info("Orchestrator stopped")
def _process_background_tasks(self):
"""Run periodic tasks like consolidation, upload retries, and model updates (AI mode only)."""
if not (self.ai_enabled and self.shared_data.operation_mode == "AI"):
return
ai_server_contact_events: List[bool] = []
try:
# Consolidate features
batch_size = int(getattr(self.shared_data, "ai_batch_size", 100))
max_batches = max(1, int(getattr(self.shared_data, "ai_consolidation_max_batches", 2)))
stats = self.data_consolidator.consolidate_features(
batch_size=batch_size,
max_batches=max_batches,
)
if stats.get("records_processed", 0) > 0:
logger.info(f"AI Consolidation: {stats['records_processed']} records processed")
logger.debug(f"DEBUG STATS: {stats}")
# Auto-export after consolidation
max_export_records = max(100, int(getattr(self.shared_data, "ai_export_max_records", 1000)))
filepath, count = self.data_consolidator.export_for_training(
format="csv",
compress=True,
max_records=max_export_records,
)
if filepath:
logger.info(f"AI export ready: {count} records -> {filepath}")
self.data_consolidator.upload_to_server(filepath)
if getattr(self.data_consolidator, "last_server_attempted", False):
ai_server_contact_events.append(
bool(getattr(self.data_consolidator, "last_server_contact_ok", False))
)
# Always retry any pending uploads when the server comes back.
self.data_consolidator.flush_pending_uploads(max_files=2)
if getattr(self.data_consolidator, "last_server_attempted", False):
ai_server_contact_events.append(
bool(getattr(self.data_consolidator, "last_server_contact_ok", False))
)
except Exception as e:
logger.info_throttled(
f"AI background tasks skipped: {e}",
key="orch_ai_background_failed",
interval_s=120.0,
)
# Check for model updates (tolerant when server is offline)
try:
if self.ai_engine and self.ai_engine.check_for_updates():
logger.info("AI model updated from server")
if self.ai_engine and getattr(self.ai_engine, "last_server_attempted", False):
ai_server_contact_events.append(
bool(getattr(self.ai_engine, "last_server_contact_ok", False))
)
elif self.ai_engine and not bool(getattr(self.ai_engine, "model_loaded", False)):
# No model loaded and no successful server contact path this cycle.
ai_server_contact_events.append(False)
except Exception as e:
logger.debug(f"AI model update check skipped: {e}")
self._update_ai_server_health(ai_server_contact_events)
if __name__ == "__main__":
orchestrator = Orchestrator()
orchestrator.run()
orchestrator.run()