Files
Bjorn/ai_engine.py
Fabien POLLY eb20b168a6 Add RLUtils class for managing RL/AI dashboard endpoints
- Implemented methods for fetching AI stats, training history, and recent experiences.
- Added functionality to set operation mode (MANUAL, AUTO, AI) with appropriate handling.
- Included helper methods for querying the database and sending JSON responses.
- Integrated model metadata extraction for visualization purposes.
2026-02-18 22:36:10 +01:00

868 lines
37 KiB
Python

"""
ai_engine.py - Dynamic AI Decision Engine for Bjorn
═══════════════════════════════════════════════════════════════════════════
Purpose:
Lightweight AI decision engine for Raspberry Pi Zero.
Works in tandem with deep learning model trained on external PC.
Architecture:
- Lightweight inference engine (no TensorFlow/PyTorch on Pi)
- Loads pre-trained model weights from PC
- Real-time action selection
- Automatic feature extraction
- Fallback to heuristics when model unavailable
Model Pipeline:
1. Pi: Collect data → Export → Transfer to PC
2. PC: Train deep neural network → Export lightweight model
3. Pi: Load model → Use for decision making
4. Repeat: Continuous learning cycle
Author: Bjorn Team
Version: 2.0.0
"""
import json
import numpy as np
from typing import Dict, List, Any, Optional, Tuple
from pathlib import Path
from logger import Logger
logger = Logger(name="ai_engine.py", level=20)
class BjornAIEngine:
"""
Dynamic AI engine for action selection and prioritization.
Uses pre-trained model from external PC or falls back to heuristics.
"""
def __init__(self, shared_data, model_dir: str = None):
"""
Initialize AI engine
"""
self.shared_data = shared_data
self.db = shared_data.db
if model_dir is None:
self.model_dir = Path(getattr(shared_data, 'ai_models_dir', '/home/bjorn/ai_models'))
else:
self.model_dir = Path(model_dir)
self.model_dir.mkdir(parents=True, exist_ok=True)
# Model state
self.model_loaded = False
self.model_weights = None
self.model_config = None
self.feature_config = None
self.last_server_attempted = False
self.last_server_contact_ok = None
# Try to load latest model
self._load_latest_model()
# Fallback heuristics (always available)
self._init_heuristics()
logger.info(
f"AI Engine initialized (model_loaded={self.model_loaded}, "
f"heuristics_available=True)"
)
# ═══════════════════════════════════════════════════════════════════════
# MODEL LOADING
# ═══════════════════════════════════════════════════════════════════════
def _load_latest_model(self):
"""Load the most recent model from model directory"""
try:
# Find all potential model configs
all_json_files = [f for f in self.model_dir.glob("bjorn_model_*.json")
if "_weights.json" not in f.name]
# 1. Filter for files that have matching weights
valid_models = []
for f in all_json_files:
weights_path = f.with_name(f.stem + '_weights.json')
if weights_path.exists():
valid_models.append(f)
else:
logger.debug(f"Skipping model {f.name}: Weights file missing")
if not valid_models:
logger.info(f"No complete models found in {self.model_dir}. Checking server...")
# Try to download from server
if self.check_for_updates():
return
logger.info_throttled(
"No AI model available (server offline or empty). Using heuristics only.",
key="ai_no_model_available",
interval_s=600.0,
)
return
# 2. Sort by timestamp in filename (lexicographical) and pick latest
latest_model = sorted(valid_models)[-1]
weights_file = latest_model.with_name(latest_model.stem + '_weights.json')
logger.info(f"Loading model: {latest_model.name} (Weights exists!)")
with open(latest_model, 'r') as f:
model_data = json.load(f)
self.model_config = model_data.get('config', model_data)
self.feature_config = model_data.get('features', {})
# Load weights
with open(weights_file, 'r') as f:
weights_data = json.load(f)
self.model_weights = {
k: np.array(v) for k, v in weights_data.items()
}
del weights_data # Free raw dict — numpy arrays are the canonical form
self.model_loaded = True
logger.success(
f"Model loaded successfully: {self.model_config.get('version', 'unknown')}"
)
except Exception as e:
logger.error(f"Failed to load model: {e}")
import traceback
logger.debug(traceback.format_exc())
self.model_loaded = False
def reload_model(self) -> bool:
"""Reload model from disk"""
logger.info("Reloading AI model...")
self.model_loaded = False
self.model_weights = None
self.model_config = None
self.feature_config = None
self._load_latest_model()
return self.model_loaded
def check_for_updates(self) -> bool:
"""Check AI Server for new model version."""
self.last_server_attempted = False
self.last_server_contact_ok = None
try:
import requests
import os
except ImportError:
return False
url = self.shared_data.config.get("ai_server_url")
if not url:
return False
try:
logger.debug(f"Checking AI Server for updates at {url}/model/latest")
from ai_utils import get_system_mac
params = {'mac_addr': get_system_mac()}
self.last_server_attempted = True
resp = requests.get(f"{url}/model/latest", params=params, timeout=5)
# Any HTTP response means server is reachable.
self.last_server_contact_ok = True
if resp.status_code != 200:
return False
remote_config = resp.json()
remote_version = str(remote_config.get("version", "")).strip()
if not remote_version:
return False
current_version = str(self.model_config.get("version", "0")).strip() if self.model_config else "0"
if remote_version > current_version:
logger.info(f"New model available: {remote_version} (Local: {current_version})")
# Download config (stream to avoid loading the whole file into RAM)
r_conf = requests.get(
f"{url}/model/download/bjorn_model_{remote_version}.json",
stream=True, timeout=15,
)
if r_conf.status_code == 200:
conf_path = self.model_dir / f"bjorn_model_{remote_version}.json"
with open(conf_path, 'wb') as f:
for chunk in r_conf.iter_content(chunk_size=8192):
if chunk:
f.write(chunk)
f.flush()
os.fsync(f.fileno())
else:
logger.info_throttled(
f"AI model download skipped (config HTTP {r_conf.status_code})",
key=f"ai_model_dl_conf_{r_conf.status_code}",
interval_s=300.0,
)
return False
# Download weights (stream to avoid loading the whole file into RAM)
r_weights = requests.get(
f"{url}/model/download/bjorn_model_{remote_version}_weights.json",
stream=True, timeout=30,
)
if r_weights.status_code == 200:
weights_path = self.model_dir / f"bjorn_model_{remote_version}_weights.json"
with open(weights_path, 'wb') as f:
for chunk in r_weights.iter_content(chunk_size=8192):
if chunk:
f.write(chunk)
f.flush()
os.fsync(f.fileno())
logger.success(f"Downloaded model {remote_version} files to Pi.")
else:
logger.info_throttled(
f"AI model download skipped (weights HTTP {r_weights.status_code})",
key=f"ai_model_dl_weights_{r_weights.status_code}",
interval_s=300.0,
)
return False
# Reload explicitly
return self.reload_model()
logger.debug(f"Server model ({remote_version}) is not newer than local ({current_version})")
return False
except Exception as e:
self.last_server_attempted = True
self.last_server_contact_ok = False
# Server may be offline; avoid spamming errors in AI mode.
logger.info_throttled(
f"AI server unavailable for model update check: {e}",
key="ai_model_update_check_failed",
interval_s=300.0,
)
return False
# ═══════════════════════════════════════════════════════════════════════
# DECISION MAKING
# ═══════════════════════════════════════════════════════════════════════
def choose_action(
self,
host_context: Dict[str, Any],
available_actions: List[str],
exploration_rate: float = None
) -> Tuple[str, float, Dict[str, Any]]:
"""
Choose the best action for a given host.
Args:
host_context: Dict with host information (mac, ports, hostname, etc.)
available_actions: List of action names that can be executed
exploration_rate: Probability of random exploration (0.0-1.0)
Returns:
Tuple of (action_name, confidence_score, debug_info)
"""
if exploration_rate is None:
exploration_rate = float(getattr(self.shared_data, "ai_exploration_rate", 0.1))
try:
# Exploration: random action
if exploration_rate > 0 and np.random.random() < exploration_rate:
import random
action = random.choice(available_actions)
return action, 0.0, {'method': 'exploration', 'exploration_rate': exploration_rate}
# If model is loaded, use it for prediction
if self.model_loaded and self.model_weights:
return self._predict_with_model(host_context, available_actions)
# Fallback to heuristics
return self._predict_with_heuristics(host_context, available_actions)
except Exception as e:
logger.error(f"Error choosing action: {e}")
# Ultimate fallback: first available action
if available_actions:
return available_actions[0], 0.0, {'method': 'fallback_error', 'error': str(e)}
return None, 0.0, {'method': 'no_actions', 'error': 'No available actions'}
def _predict_with_model(
self,
host_context: Dict[str, Any],
available_actions: List[str]
) -> Tuple[str, float, Dict[str, Any]]:
"""
Use loaded neural network model for prediction.
Dynamically maps extracted features to model manifest.
"""
try:
from ai_utils import extract_neural_features_dict
# 1. Get model feature manifest
manifest = self.model_config.get('architecture', {}).get('feature_names', [])
if not manifest:
# Legacy fallback
return self._predict_with_model_legacy(host_context, available_actions)
# 2. Extract host-level features
mac = host_context.get('mac', '')
host = self.db.get_host_by_mac(mac) if mac else {}
host_data = self._get_host_context_from_db(mac, host)
net_data = self._get_network_context()
temp_data_base = self._get_temporal_context(mac) # MAC-level temporal, called once
best_action = None
best_score = -1.0
all_scores = {}
# 3. Score each action
for action in available_actions:
action_data = self._get_action_context(action, host, mac)
# Merge action-level temporal overrides into temporal context copy
temp_data = dict(temp_data_base)
temp_data['same_action_attempts'] = action_data.pop('same_action_attempts', 0)
temp_data['is_retry'] = action_data.pop('is_retry', False)
# Extract all known features into a dict
features_dict = extract_neural_features_dict(
host_features=host_data,
network_features=net_data,
temporal_features=temp_data,
action_features=action_data
)
# Dynamic mapping: Pull features requested by model manifest
# Defaults to 0.0 if the Pi doesn't know this feature yet
input_vector = np.array([float(features_dict.get(name, 0.0)) for name in manifest], dtype=float)
# Neural inference (supports variable hidden depth from exported model).
z_out = self._forward_network(input_vector)
z_out = np.array(z_out).reshape(-1)
if z_out.size == 1:
# Binary classifier exported with 1-neuron sigmoid output.
score = float(self._sigmoid(z_out[0]))
else:
probs = self._softmax(z_out)
score = float(probs[1] if len(probs) > 1 else probs[0])
all_scores[action] = score
if score > best_score:
best_score = score
best_action = action
if best_action is None:
return self._predict_with_heuristics(host_context, available_actions)
# Capture the last input vector (for visualization)
# Since we iterate, we'll just take the one from the best_action or the last one.
# Usually input_vector is almost the same for all actions except action-specific bits.
debug_info = {
'method': 'neural_network_v3',
'model_version': self.model_config.get('version'),
'feature_count': len(manifest),
'all_scores': all_scores,
# Convert numpy ndarray → plain Python list so debug_info is
# always JSON-serialisable (scheduler stores it in action_queue metadata).
'input_vector': input_vector.tolist(),
}
return best_action, float(best_score), debug_info
except Exception as e:
logger.error(f"Dynamic model prediction failed: {e}")
import traceback
logger.debug(traceback.format_exc())
return self._predict_with_heuristics(host_context, available_actions)
def _predict_with_model_legacy(self, host_context: Dict[str, Any], available_actions: List[str]) -> Tuple[str, float, Dict[str, Any]]:
"""Fallback for models without feature_names manifest (fixed length 56)"""
# ... very similar to previous v2 but using hardcoded list ...
return self._predict_with_heuristics(host_context, available_actions)
def _get_host_context_from_db(self, mac: str, host: Dict) -> Dict:
"""Helper to collect host features from DB"""
ports_str = host.get('ports', '') or ''
ports = [int(p) for p in ports_str.split(';') if p.strip().isdigit()]
vendor = host.get('vendor', '')
# Calculate age
age_hours = 0.0
if host.get('first_seen'):
from datetime import datetime
try:
ts = host['first_seen']
first_seen = datetime.fromisoformat(ts) if isinstance(ts, str) else ts
age_hours = (datetime.now() - first_seen).total_seconds() / 3600
except: pass
creds = self._get_credentials_for_host(mac)
return {
'port_count': len(ports),
'service_count': len(self._get_services_for_host(mac)),
'ip_count': len((host.get('ips') or '').split(';')),
'credential_count': len(creds),
'age_hours': round(age_hours, 2),
'has_ssh': 22 in ports,
'has_http': 80 in ports or 8080 in ports,
'has_https': 443 in ports,
'has_smb': 445 in ports,
'has_rdp': 3389 in ports,
'has_database': any(p in ports for p in [3306, 5432, 1433]),
'has_credentials': len(creds) > 0,
'is_new': age_hours < 24,
'is_private': True, # Simple assumption for now
'has_multiple_ips': len((host.get('ips') or '').split(';')) > 1,
'vendor_category': self._categorize_vendor(vendor),
'port_profile': self._detect_port_profile(ports)
}
def _get_network_context(self) -> Dict:
"""Collect real network-wide stats from DB (called once per choose_action)."""
try:
all_hosts = self.db.get_all_hosts()
total = len(all_hosts)
# Subnet diversity
subnets = set()
active = 0
for h in all_hosts:
ips = (h.get('ips') or '').split(';')
for ip in ips:
ip = ip.strip()
if ip:
subnets.add('.'.join(ip.split('.')[:3]))
break
if h.get('alive'):
active += 1
return {
'total_hosts': total,
'subnet_count': len(subnets),
'similar_vendor_count': 0, # filled by caller if needed
'similar_port_profile_count': 0, # filled by caller if needed
'active_host_ratio': round(active / total, 2) if total else 0.0,
}
except Exception as e:
logger.error(f"Error collecting network context: {e}")
return {
'total_hosts': 0, 'subnet_count': 1,
'similar_vendor_count': 0, 'similar_port_profile_count': 0,
'active_host_ratio': 1.0,
}
def _get_temporal_context(self, mac: str) -> Dict:
"""
Collect real temporal features for a MAC from DB.
same_action_attempts / is_retry are action-specific — they are NOT
included here; instead they are merged from _get_action_context()
inside the per-action loop in _predict_with_model().
"""
from datetime import datetime
now = datetime.now()
ctx = {
'hour_of_day': now.hour,
'day_of_week': now.weekday(),
'is_weekend': now.weekday() >= 5,
'is_night': now.hour < 6 or now.hour >= 22,
'previous_action_count': 0,
'seconds_since_last': 0,
'historical_success_rate': 0.0,
'same_action_attempts': 0, # placeholder; overwritten per-action
'is_retry': False, # placeholder; overwritten per-action
'global_success_rate': 0.0,
'hours_since_discovery': 0,
}
try:
# Per-host stats from ml_features (persistent training log)
rows = self.db.query(
"""
SELECT
COUNT(*) AS cnt,
AVG(CAST(success AS REAL)) AS success_rate,
MAX(timestamp) AS last_ts
FROM ml_features
WHERE mac_address = ?
""",
(mac,),
)
if rows and rows[0]['cnt']:
ctx['previous_action_count'] = int(rows[0]['cnt'])
ctx['historical_success_rate'] = round(float(rows[0]['success_rate'] or 0.0), 2)
if rows[0]['last_ts']:
try:
last_dt = datetime.fromisoformat(str(rows[0]['last_ts']))
ctx['seconds_since_last'] = round(
(now - last_dt).total_seconds(), 1
)
except Exception:
pass
# Global success rate (all hosts)
g = self.db.query(
"SELECT AVG(CAST(success AS REAL)) AS gsr FROM ml_features"
)
if g and g[0]['gsr'] is not None:
ctx['global_success_rate'] = round(float(g[0]['gsr']), 2)
# Hours since host first seen
host = self.db.get_host_by_mac(mac)
if host and host.get('first_seen'):
try:
ts = host['first_seen']
first_seen = datetime.fromisoformat(ts) if isinstance(ts, str) else ts
ctx['hours_since_discovery'] = round(
(now - first_seen).total_seconds() / 3600, 1
)
except Exception:
pass
except Exception as e:
logger.error(f"Error collecting temporal context for {mac}: {e}")
return ctx
# Action-specific temporal fields populated by _get_action_context
_ACTION_PORTS = {
'SSHBruteforce': 22, 'SSHEnumeration': 22, 'StealFilesSSH': 22,
'WebEnumeration': 80, 'WebVulnScan': 80, 'WebLoginProfiler': 80,
'WebSurfaceMapper': 80,
'SMBBruteforce': 445, 'StealFilesSMB': 445,
'FTPBruteforce': 21, 'StealFilesFTP': 21,
'TelnetBruteforce': 23, 'StealFilesTelnet': 23,
'SQLBruteforce': 3306, 'StealDataSQL': 3306,
'NmapVulnScanner': 0, 'NetworkScanner': 0,
'RDPBruteforce': 3389,
}
def _get_action_context(self, action_name: str, host: Dict, mac: str = '') -> Dict:
"""
Collect action-specific features including per-action attempt history.
Merges action-type + target-port info with action-level temporal stats.
"""
action_type = self._classify_action_type(action_name)
target_port = self._ACTION_PORTS.get(action_name, 0)
# If port not in lookup, try to infer from action name
if target_port == 0:
name_lower = action_name.lower()
for svc, port in [('ssh', 22), ('http', 80), ('smb', 445), ('ftp', 21),
('telnet', 23), ('sql', 3306), ('rdp', 3389)]:
if svc in name_lower:
target_port = port
break
ctx = {
'action_type': action_type,
'target_port': target_port,
'is_standard_port': 0 < target_port < 1024,
# Action-level temporal (overrides placeholder in temporal_context)
'same_action_attempts': 0,
'is_retry': False,
}
if mac:
try:
r = self.db.query(
"""
SELECT COUNT(*) AS cnt
FROM ml_features
WHERE mac_address = ? AND action_name = ?
""",
(mac, action_name),
)
attempts = int(r[0]['cnt']) if r else 0
ctx['same_action_attempts'] = attempts
ctx['is_retry'] = attempts > 0
except Exception as e:
logger.debug(f"Action context DB query failed for {action_name}: {e}")
return ctx
def _classify_action_type(self, action_name: str) -> str:
"""Classify action name into a type"""
name = action_name.lower()
if 'brute' in name: return 'bruteforce'
if 'enum' in name or 'scan' in name: return 'enumeration'
if 'exploit' in name: return 'exploitation'
if 'dump' in name or 'extract' in name: return 'extraction'
return 'other'
# ═══════════════════════════════════════════════════════════════════════
# HEURISTIC FALLBACK
# ═══════════════════════════════════════════════════════════════════════
def _init_heuristics(self):
"""Initialize rule-based heuristics for cold start"""
self.heuristics = {
'port_based': {
22: ['SSHBruteforce', 'SSHEnumeration'],
80: ['WebEnumeration', 'WebVulnScan'],
443: ['WebEnumeration', 'SSLScan'],
445: ['SMBBruteforce', 'SMBEnumeration'],
3389: ['RDPBruteforce'],
21: ['FTPBruteforce', 'FTPEnumeration'],
23: ['TelnetBruteforce'],
3306: ['MySQLBruteforce'],
5432: ['PostgresBruteforce'],
1433: ['MSSQLBruteforce']
},
'service_based': {
'ssh': ['SSHBruteforce', 'SSHEnumeration'],
'http': ['WebEnumeration', 'WebVulnScan'],
'https': ['WebEnumeration', 'SSLScan'],
'smb': ['SMBBruteforce', 'SMBEnumeration'],
'ftp': ['FTPBruteforce', 'FTPEnumeration'],
'mysql': ['MySQLBruteforce'],
'postgres': ['PostgresBruteforce']
},
'profile_based': {
'camera': ['WebEnumeration', 'DefaultCredCheck', 'RTSPBruteforce'],
'nas': ['SMBBruteforce', 'WebEnumeration', 'SSHBruteforce'],
'web_server': ['WebEnumeration', 'WebVulnScan'],
'database': ['MySQLBruteforce', 'PostgresBruteforce'],
'linux_server': ['SSHBruteforce', 'WebEnumeration'],
'windows_server': ['SMBBruteforce', 'RDPBruteforce']
}
}
def _predict_with_heuristics(
self,
host_context: Dict[str, Any],
available_actions: List[str]
) -> Tuple[str, float, Dict[str, Any]]:
"""
Use rule-based heuristics for action selection.
Provides decent performance without machine learning.
"""
try:
mac = host_context.get('mac', '')
host = self.db.get_host_by_mac(mac) if mac else {}
# Get ports and services
ports_str = host.get('ports', '') or ''
ports = {int(p) for p in ports_str.split(';') if p.strip().isdigit()}
services = self._get_services_for_host(mac)
# Detect port profile
port_profile = self._detect_port_profile(ports)
# Scoring system
action_scores = {action: 0.0 for action in available_actions}
# Score based on ports
for port in ports:
if port in self.heuristics['port_based']:
for action in self.heuristics['port_based'][port]:
if action in action_scores:
action_scores[action] += 0.3
# Score based on services
for service in services:
if service in self.heuristics['service_based']:
for action in self.heuristics['service_based'][service]:
if action in action_scores:
action_scores[action] += 0.4
# Score based on port profile
if port_profile in self.heuristics['profile_based']:
for action in self.heuristics['profile_based'][port_profile]:
if action in action_scores:
action_scores[action] += 0.3
# Find best action
if action_scores:
best_action = max(action_scores, key=action_scores.get)
best_score = action_scores[best_action]
# Normalize score to 0-1
if best_score > 0:
best_score = min(best_score / 1.0, 1.0)
debug_info = {
'method': 'heuristics',
'port_profile': port_profile,
'ports': list(ports)[:10],
'services': services,
'all_scores': {k: v for k, v in action_scores.items() if v > 0}
}
return best_action, best_score, debug_info
# Ultimate fallback
if available_actions:
return available_actions[0], 0.1, {'method': 'fallback_first'}
return None, 0.0, {'method': 'no_actions'}
except Exception as e:
logger.error(f"Heuristic prediction failed: {e}")
if available_actions:
return available_actions[0], 0.0, {'method': 'fallback_error', 'error': str(e)}
return None, 0.0, {'method': 'error', 'error': str(e)}
# ═══════════════════════════════════════════════════════════════════════
# HELPER METHODS
# ═══════════════════════════════════════════════════════════════════════
@staticmethod
def _relu(x):
"""ReLU activation function"""
return np.maximum(0, x)
@staticmethod
def _sigmoid(x):
"""Sigmoid activation function"""
return 1.0 / (1.0 + np.exp(-x))
@staticmethod
def _softmax(x):
"""Softmax activation function"""
exp_x = np.exp(x - np.max(x)) # Numerical stability
return exp_x / exp_x.sum()
def _forward_network(self, input_vector: np.ndarray) -> np.ndarray:
"""
Forward pass through exported dense network with dynamic hidden depth.
Expected keys: w1/b1, w2/b2, ..., w_out/b_out
"""
a = input_vector
layer_idx = 1
while f'w{layer_idx}' in self.model_weights:
w = self.model_weights[f'w{layer_idx}']
b = self.model_weights[f'b{layer_idx}']
a = self._relu(np.dot(a, w) + b)
layer_idx += 1
return np.dot(a, self.model_weights['w_out']) + self.model_weights['b_out']
def _get_services_for_host(self, mac: str) -> List[str]:
"""Get detected services for host"""
try:
results = self.db.query("""
SELECT DISTINCT service
FROM port_services
WHERE mac_address=?
""", (mac,))
return [r['service'] for r in results if r.get('service')]
except:
return []
def _get_credentials_for_host(self, mac: str) -> List[Dict]:
"""Get credentials found for host"""
try:
return self.db.query("""
SELECT service, user, port
FROM creds
WHERE mac_address=?
""", (mac,))
except:
return []
def _categorize_vendor(self, vendor: str) -> str:
"""Categorize vendor (same as feature_logger)"""
if not vendor:
return 'unknown'
vendor_lower = vendor.lower()
categories = {
'networking': ['cisco', 'juniper', 'ubiquiti', 'mikrotik', 'tp-link'],
'iot': ['hikvision', 'dahua', 'axis'],
'nas': ['synology', 'qnap'],
'compute': ['raspberry', 'intel', 'apple', 'dell', 'hp'],
'virtualization': ['vmware', 'microsoft'],
'mobile': ['apple', 'samsung', 'huawei']
}
for category, vendors in categories.items():
if any(v in vendor_lower for v in vendors):
return category
return 'other'
def _detect_port_profile(self, ports) -> str:
"""Detect device profile from ports (same as feature_logger)"""
port_set = set(ports)
profiles = {
'camera': {554, 80, 8000},
'web_server': {80, 443, 8080},
'nas': {5000, 5001, 548, 139, 445},
'database': {3306, 5432, 1433, 27017},
'linux_server': {22, 80, 443},
'windows_server': {135, 139, 445, 3389},
'printer': {9100, 515, 631},
'router': {22, 23, 80, 443, 161}
}
max_overlap = 0
best_profile = 'generic'
for profile_name, profile_ports in profiles.items():
overlap = len(port_set & profile_ports)
if overlap > max_overlap:
max_overlap = overlap
best_profile = profile_name
return best_profile if max_overlap >= 2 else 'generic'
# ═══════════════════════════════════════════════════════════════════════
# STATISTICS
# ═══════════════════════════════════════════════════════════════════════
def get_stats(self) -> Dict[str, Any]:
"""Get AI engine statistics"""
stats = {
'model_loaded': self.model_loaded,
'heuristics_available': True,
'decision_mode': 'neural_network' if self.model_loaded else 'heuristics'
}
if self.model_loaded and self.model_config:
stats.update({
'model_version': self.model_config.get('version'),
'model_trained_at': self.model_config.get('trained_at'),
'model_accuracy': self.model_config.get('accuracy'),
'training_samples': self.model_config.get('training_samples')
})
return stats
# ═══════════════════════════════════════════════════════════════════════════
# SINGLETON FACTORY
# ═══════════════════════════════════════════════════════════════════════════
def get_or_create_ai_engine(shared_data) -> Optional['BjornAIEngine']:
"""
Return the single BjornAIEngine instance attached to shared_data.
Creates it on first call; subsequent calls return the cached instance.
Use this instead of BjornAIEngine(shared_data) to avoid loading model
weights multiple times (orchestrator + scheduler + web each need AI).
"""
if getattr(shared_data, '_ai_engine_singleton', None) is None:
try:
shared_data._ai_engine_singleton = BjornAIEngine(shared_data)
except Exception as e:
logger.error(f"Failed to create BjornAIEngine singleton: {e}")
shared_data._ai_engine_singleton = None
return shared_data._ai_engine_singleton
def invalidate_ai_engine(shared_data) -> None:
"""Drop the cached singleton (e.g. after a mode reset or model update)."""
shared_data._ai_engine_singleton = None
# ═══════════════════════════════════════════════════════════════════════════
# END OF FILE
# ═══════════════════════════════════════════════════════════════════════════