Files
Bjorn/ai_engine.py
infinition aac77a3e76 Add Loki and Sentinel utility classes for web API endpoints
- Implemented LokiUtils class with GET and POST endpoints for managing scripts, jobs, and payloads.
- Added SentinelUtils class with GET and POST endpoints for managing events, rules, devices, and notifications.
- Both classes include error handling and JSON response formatting.
2026-03-14 22:33:10 +01:00

1122 lines
48 KiB
Python

"""
ai_engine.py - Dynamic AI Decision Engine for Bjorn
═══════════════════════════════════════════════════════════════════════════
Purpose:
Lightweight AI decision engine for Raspberry Pi Zero.
Works in tandem with deep learning model trained on external PC.
Architecture:
- Lightweight inference engine (no TensorFlow/PyTorch on Pi)
- Loads pre-trained model weights from PC
- Real-time action selection
- Automatic feature extraction
- Fallback to heuristics when model unavailable
Model Pipeline:
1. Pi: Collect data → Export → Transfer to PC
2. PC: Train deep neural network → Export lightweight model
3. Pi: Load model → Use for decision making
4. Repeat: Continuous learning cycle
Author: Bjorn Team
Version: 2.0.0
"""
import json
import numpy as np
from typing import Dict, List, Any, Optional, Tuple
from pathlib import Path
from logger import Logger
logger = Logger(name="ai_engine.py", level=20)
class BjornAIEngine:
"""
Dynamic AI engine for action selection and prioritization.
Uses pre-trained model from external PC or falls back to heuristics.
"""
def __init__(self, shared_data, model_dir: str = None):
"""
Initialize AI engine
"""
self.shared_data = shared_data
self.db = shared_data.db
if model_dir is None:
self.model_dir = Path(getattr(shared_data, 'ai_models_dir', '/home/bjorn/ai_models'))
else:
self.model_dir = Path(model_dir)
self.model_dir.mkdir(parents=True, exist_ok=True)
# Model state
self.model_loaded = False
self.model_weights = None
self.model_config = None
self.feature_config = None
self.last_server_attempted = False
self.last_server_contact_ok = None
# AI-03: Model versioning & rollback
self._previous_model = None # {weights, config, feature_config}
self._model_history = [] # [{version, loaded_at, accuracy, avg_reward}]
self._max_model_versions_on_disk = 3
self._performance_window = [] # recent reward values for current model
self._performance_check_interval = int(
getattr(shared_data, 'ai_model_perf_check_interval', 50)
)
self._prev_model_avg_reward = None # avg reward of the model we replaced
# AI-04: Cold-start bootstrap scores
self._bootstrap_scores = {} # {(action_name, port_profile): [total_reward, count]}
self._bootstrap_file = self.model_dir / 'ai_bootstrap_scores.json'
self._bootstrap_weight = float(
getattr(shared_data, 'ai_cold_start_bootstrap_weight', 0.6)
)
self._load_bootstrap_scores()
# Try to load latest model
self._load_latest_model()
# Fallback heuristics (always available)
self._init_heuristics()
logger.info(
f"AI Engine initialized (model_loaded={self.model_loaded}, "
f"heuristics_available=True)"
)
# ═══════════════════════════════════════════════════════════════════════
# MODEL LOADING
# ═══════════════════════════════════════════════════════════════════════
def _load_latest_model(self):
"""Load the most recent model from model directory"""
try:
# Find all potential model configs
all_json_files = [f for f in self.model_dir.glob("bjorn_model_*.json")
if "_weights.json" not in f.name]
# 1. Filter for files that have matching weights
valid_models = []
for f in all_json_files:
weights_path = f.with_name(f.stem + '_weights.json')
if weights_path.exists():
valid_models.append(f)
else:
logger.debug(f"Skipping model {f.name}: Weights file missing")
if not valid_models:
logger.info(f"No complete models found in {self.model_dir}. Checking server...")
# Try to download from server
if self.check_for_updates():
return
logger.info_throttled(
"No AI model available (server offline or empty). Using heuristics only.",
key="ai_no_model_available",
interval_s=600.0,
)
return
# 2. Sort by timestamp in filename (lexicographical) and pick latest
valid_models = sorted(valid_models)
latest_model = valid_models[-1]
weights_file = latest_model.with_name(latest_model.stem + '_weights.json')
logger.info(f"Loading model: {latest_model.name} (Weights exists!)")
with open(latest_model, 'r') as f:
model_data = json.load(f)
new_config = model_data.get('config', model_data)
new_feature_config = model_data.get('features', {})
# Load weights
with open(weights_file, 'r') as f:
weights_data = json.load(f)
new_weights = {
k: np.array(v) for k, v in weights_data.items()
}
del weights_data # Free raw dict — numpy arrays are the canonical form
# AI-03: Save previous model for rollback
if self.model_loaded and self.model_weights is not None:
self._previous_model = {
'weights': self.model_weights,
'config': self.model_config,
'feature_config': self.feature_config,
}
# Record avg reward of outgoing model for performance comparison
if self._performance_window:
self._prev_model_avg_reward = (
sum(self._performance_window) / len(self._performance_window)
)
self._performance_window = [] # reset for new model
self.model_config = new_config
self.feature_config = new_feature_config
self.model_weights = new_weights
self.model_loaded = True
# AI-03: Track model history
from datetime import datetime as _dt
version = self.model_config.get('version', 'unknown')
self._model_history.append({
'version': version,
'loaded_at': _dt.now().isoformat(),
'accuracy': self.model_config.get('accuracy'),
'avg_reward': None, # filled later as decisions accumulate
})
# Keep history bounded
if len(self._model_history) > 10:
self._model_history = self._model_history[-10:]
logger.success(
f"Model loaded successfully: {version}"
)
# AI-03: Prune old model versions on disk (keep N most recent)
self._prune_old_model_files(valid_models)
except Exception as e:
logger.error(f"Failed to load model: {e}")
import traceback
logger.debug(traceback.format_exc())
self.model_loaded = False
def _prune_old_model_files(self, valid_models: list):
"""AI-03: Keep only the N most recent model versions on disk."""
try:
keep = self._max_model_versions_on_disk
if len(valid_models) <= keep:
return
to_remove = valid_models[:-keep]
for config_path in to_remove:
weights_path = config_path.with_name(config_path.stem + '_weights.json')
try:
config_path.unlink(missing_ok=True)
weights_path.unlink(missing_ok=True)
logger.info(f"Pruned old model: {config_path.name}")
except Exception as e:
logger.debug(f"Could not prune {config_path.name}: {e}")
except Exception as e:
logger.debug(f"Model pruning error: {e}")
def reload_model(self) -> bool:
"""Reload model from disk"""
logger.info("Reloading AI model...")
self.model_loaded = False
self.model_weights = None
self.model_config = None
self.feature_config = None
self._load_latest_model()
return self.model_loaded
def rollback_model(self) -> bool:
"""
AI-03: Rollback to the previous model version.
Returns True if rollback succeeded.
"""
if self._previous_model is None:
logger.warning("No previous model available for rollback")
return False
logger.info("Rolling back to previous model version...")
# Current model becomes the "next" previous (so we can undo a rollback)
current_backup = None
if self.model_loaded and self.model_weights is not None:
current_backup = {
'weights': self.model_weights,
'config': self.model_config,
'feature_config': self.feature_config,
}
self.model_weights = self._previous_model['weights']
self.model_config = self._previous_model['config']
self.feature_config = self._previous_model['feature_config']
self.model_loaded = True
self._previous_model = current_backup
self._performance_window = [] # reset
version = self.model_config.get('version', 'unknown')
from datetime import datetime as _dt
self._model_history.append({
'version': f"{version}_rollback",
'loaded_at': _dt.now().isoformat(),
'accuracy': self.model_config.get('accuracy'),
'avg_reward': None,
})
logger.success(f"Rolled back to model version: {version}")
return True
def record_reward(self, reward: float):
"""
AI-03: Record a reward for performance tracking.
After N decisions, auto-rollback if performance has degraded.
"""
self._performance_window.append(reward)
# Update current history entry
if self._model_history:
self._model_history[-1]['avg_reward'] = round(
sum(self._performance_window) / len(self._performance_window), 2
)
# Check for auto-rollback after sufficient samples
if len(self._performance_window) >= self._performance_check_interval:
current_avg = sum(self._performance_window) / len(self._performance_window)
if (
self._prev_model_avg_reward is not None
and current_avg < self._prev_model_avg_reward
and self._previous_model is not None
):
logger.warning(
f"Model performance degraded: current avg={current_avg:.2f} vs "
f"previous avg={self._prev_model_avg_reward:.2f}. Auto-rolling back."
)
self.rollback_model()
else:
logger.info(
f"Model performance check passed: avg_reward={current_avg:.2f} "
f"over {len(self._performance_window)} decisions"
)
# Reset window for next check cycle
self._performance_window = []
def get_model_info(self) -> Dict[str, Any]:
"""AI-03: Return current version, history, and performance stats."""
current_avg = None
if self._performance_window:
current_avg = round(
sum(self._performance_window) / len(self._performance_window), 2
)
return {
'current_version': self.model_config.get('version') if self.model_config else None,
'model_loaded': self.model_loaded,
'has_previous_model': self._previous_model is not None,
'history': list(self._model_history),
'performance': {
'current_avg_reward': current_avg,
'decisions_since_load': len(self._performance_window),
'check_interval': self._performance_check_interval,
'previous_model_avg_reward': self._prev_model_avg_reward,
},
}
def check_for_updates(self) -> bool:
"""Check AI Server for new model version."""
self.last_server_attempted = False
self.last_server_contact_ok = None
try:
import requests
import os
except ImportError:
return False
url = self.shared_data.config.get("ai_server_url")
if not url:
return False
try:
logger.debug(f"Checking AI Server for updates at {url}/model/latest")
from ai_utils import get_system_mac
params = {'mac_addr': get_system_mac()}
self.last_server_attempted = True
resp = requests.get(f"{url}/model/latest", params=params, timeout=5)
# Any HTTP response means server is reachable.
self.last_server_contact_ok = True
if resp.status_code != 200:
return False
remote_config = resp.json()
remote_version = str(remote_config.get("version", "")).strip()
if not remote_version:
return False
current_version = str(self.model_config.get("version", "0")).strip() if self.model_config else "0"
if remote_version > current_version:
logger.info(f"New model available: {remote_version} (Local: {current_version})")
# Download config (stream to avoid loading the whole file into RAM)
r_conf = requests.get(
f"{url}/model/download/bjorn_model_{remote_version}.json",
stream=True, timeout=15,
)
if r_conf.status_code == 200:
conf_path = self.model_dir / f"bjorn_model_{remote_version}.json"
with open(conf_path, 'wb') as f:
for chunk in r_conf.iter_content(chunk_size=8192):
if chunk:
f.write(chunk)
f.flush()
os.fsync(f.fileno())
else:
logger.info_throttled(
f"AI model download skipped (config HTTP {r_conf.status_code})",
key=f"ai_model_dl_conf_{r_conf.status_code}",
interval_s=300.0,
)
return False
# Download weights (stream to avoid loading the whole file into RAM)
r_weights = requests.get(
f"{url}/model/download/bjorn_model_{remote_version}_weights.json",
stream=True, timeout=30,
)
if r_weights.status_code == 200:
weights_path = self.model_dir / f"bjorn_model_{remote_version}_weights.json"
with open(weights_path, 'wb') as f:
for chunk in r_weights.iter_content(chunk_size=8192):
if chunk:
f.write(chunk)
f.flush()
os.fsync(f.fileno())
logger.success(f"Downloaded model {remote_version} files to Pi.")
else:
logger.info_throttled(
f"AI model download skipped (weights HTTP {r_weights.status_code})",
key=f"ai_model_dl_weights_{r_weights.status_code}",
interval_s=300.0,
)
return False
# Reload explicitly
return self.reload_model()
logger.debug(f"Server model ({remote_version}) is not newer than local ({current_version})")
return False
except Exception as e:
self.last_server_attempted = True
self.last_server_contact_ok = False
# Server may be offline; avoid spamming errors in AI mode.
logger.info_throttled(
f"AI server unavailable for model update check: {e}",
key="ai_model_update_check_failed",
interval_s=300.0,
)
return False
# ═══════════════════════════════════════════════════════════════════════
# DECISION MAKING
# ═══════════════════════════════════════════════════════════════════════
def choose_action(
self,
host_context: Dict[str, Any],
available_actions: List[str],
exploration_rate: float = None
) -> Tuple[str, float, Dict[str, Any]]:
"""
Choose the best action for a given host.
Args:
host_context: Dict with host information (mac, ports, hostname, etc.)
available_actions: List of action names that can be executed
exploration_rate: Probability of random exploration (0.0-1.0)
Returns:
Tuple of (action_name, confidence_score, debug_info)
"""
if exploration_rate is None:
exploration_rate = float(getattr(self.shared_data, "ai_exploration_rate", 0.1))
try:
# Exploration: random action
if exploration_rate > 0 and np.random.random() < exploration_rate:
import random
action = random.choice(available_actions)
return action, 0.0, {'method': 'exploration', 'exploration_rate': exploration_rate}
# If model is loaded, use it for prediction
if self.model_loaded and self.model_weights:
return self._predict_with_model(host_context, available_actions)
# Fallback to heuristics
return self._predict_with_heuristics(host_context, available_actions)
except Exception as e:
logger.error(f"Error choosing action: {e}")
# Ultimate fallback: first available action
if available_actions:
return available_actions[0], 0.0, {'method': 'fallback_error', 'error': str(e)}
return None, 0.0, {'method': 'no_actions', 'error': 'No available actions'}
def _predict_with_model(
self,
host_context: Dict[str, Any],
available_actions: List[str]
) -> Tuple[str, float, Dict[str, Any]]:
"""
Use loaded neural network model for prediction.
Dynamically maps extracted features to model manifest.
"""
try:
from ai_utils import extract_neural_features_dict
# 1. Get model feature manifest
manifest = self.model_config.get('architecture', {}).get('feature_names', [])
if not manifest:
# Legacy fallback
return self._predict_with_model_legacy(host_context, available_actions)
# 2. Extract host-level features
mac = host_context.get('mac', '')
host = self.db.get_host_by_mac(mac) if mac else {}
host_data = self._get_host_context_from_db(mac, host)
net_data = self._get_network_context()
temp_data_base = self._get_temporal_context(mac) # MAC-level temporal, called once
best_action = None
best_score = -1.0
all_scores = {}
# 3. Score each action
for action in available_actions:
action_data = self._get_action_context(action, host, mac)
# Merge action-level temporal overrides into temporal context copy
temp_data = dict(temp_data_base)
temp_data['same_action_attempts'] = action_data.pop('same_action_attempts', 0)
temp_data['is_retry'] = action_data.pop('is_retry', False)
# Extract all known features into a dict
features_dict = extract_neural_features_dict(
host_features=host_data,
network_features=net_data,
temporal_features=temp_data,
action_features=action_data
)
# Dynamic mapping: Pull features requested by model manifest
# Defaults to 0.0 if the Pi doesn't know this feature yet
input_vector = np.array([float(features_dict.get(name, 0.0)) for name in manifest], dtype=float)
# Neural inference (supports variable hidden depth from exported model).
z_out = self._forward_network(input_vector)
z_out = np.array(z_out).reshape(-1)
if z_out.size == 1:
# Binary classifier exported with 1-neuron sigmoid output.
score = float(self._sigmoid(z_out[0]))
else:
probs = self._softmax(z_out)
score = float(probs[1] if len(probs) > 1 else probs[0])
all_scores[action] = score
if score > best_score:
best_score = score
best_action = action
if best_action is None:
return self._predict_with_heuristics(host_context, available_actions)
# Capture the last input vector (for visualization)
# Since we iterate, we'll just take the one from the best_action or the last one.
# Usually input_vector is almost the same for all actions except action-specific bits.
debug_info = {
'method': 'neural_network_v3',
'model_version': self.model_config.get('version'),
'feature_count': len(manifest),
'all_scores': all_scores,
# Convert numpy ndarray → plain Python list so debug_info is
# always JSON-serialisable (scheduler stores it in action_queue metadata).
'input_vector': input_vector.tolist(),
}
return best_action, float(best_score), debug_info
except Exception as e:
logger.error(f"Dynamic model prediction failed: {e}")
import traceback
logger.debug(traceback.format_exc())
return self._predict_with_heuristics(host_context, available_actions)
def _predict_with_model_legacy(self, host_context: Dict[str, Any], available_actions: List[str]) -> Tuple[str, float, Dict[str, Any]]:
"""Fallback for models without feature_names manifest (fixed length 56)"""
# ... very similar to previous v2 but using hardcoded list ...
return self._predict_with_heuristics(host_context, available_actions)
def _get_host_context_from_db(self, mac: str, host: Dict) -> Dict:
"""Helper to collect host features from DB"""
ports_str = host.get('ports', '') or ''
ports = [int(p) for p in ports_str.split(';') if p.strip().isdigit()]
vendor = host.get('vendor', '')
# Calculate age
age_hours = 0.0
if host.get('first_seen'):
from datetime import datetime
try:
ts = host['first_seen']
first_seen = datetime.fromisoformat(ts) if isinstance(ts, str) else ts
age_hours = (datetime.now() - first_seen).total_seconds() / 3600
except: pass
creds = self._get_credentials_for_host(mac)
return {
'port_count': len(ports),
'service_count': len(self._get_services_for_host(mac)),
'ip_count': len((host.get('ips') or '').split(';')),
'credential_count': len(creds),
'age_hours': round(age_hours, 2),
'has_ssh': 22 in ports,
'has_http': 80 in ports or 8080 in ports,
'has_https': 443 in ports,
'has_smb': 445 in ports,
'has_rdp': 3389 in ports,
'has_database': any(p in ports for p in [3306, 5432, 1433]),
'has_credentials': len(creds) > 0,
'is_new': age_hours < 24,
'is_private': True, # Simple assumption for now
'has_multiple_ips': len((host.get('ips') or '').split(';')) > 1,
'vendor_category': self._categorize_vendor(vendor),
'port_profile': self._detect_port_profile(ports)
}
def _get_network_context(self) -> Dict:
"""Collect real network-wide stats from DB (called once per choose_action)."""
try:
all_hosts = self.db.get_all_hosts()
total = len(all_hosts)
# Subnet diversity
subnets = set()
active = 0
for h in all_hosts:
ips = (h.get('ips') or '').split(';')
for ip in ips:
ip = ip.strip()
if ip:
subnets.add('.'.join(ip.split('.')[:3]))
break
if h.get('alive'):
active += 1
return {
'total_hosts': total,
'subnet_count': len(subnets),
'similar_vendor_count': 0, # filled by caller if needed
'similar_port_profile_count': 0, # filled by caller if needed
'active_host_ratio': round(active / total, 2) if total else 0.0,
}
except Exception as e:
logger.error(f"Error collecting network context: {e}")
return {
'total_hosts': 0, 'subnet_count': 1,
'similar_vendor_count': 0, 'similar_port_profile_count': 0,
'active_host_ratio': 1.0,
}
def _get_temporal_context(self, mac: str) -> Dict:
"""
Collect real temporal features for a MAC from DB.
same_action_attempts / is_retry are action-specific — they are NOT
included here; instead they are merged from _get_action_context()
inside the per-action loop in _predict_with_model().
"""
from datetime import datetime
now = datetime.now()
ctx = {
'hour_of_day': now.hour,
'day_of_week': now.weekday(),
'is_weekend': now.weekday() >= 5,
'is_night': now.hour < 6 or now.hour >= 22,
'previous_action_count': 0,
'seconds_since_last': 0,
'historical_success_rate': 0.0,
'same_action_attempts': 0, # placeholder; overwritten per-action
'is_retry': False, # placeholder; overwritten per-action
'global_success_rate': 0.0,
'hours_since_discovery': 0,
}
try:
# Per-host stats from ml_features (persistent training log)
rows = self.db.query(
"""
SELECT
COUNT(*) AS cnt,
AVG(CAST(success AS REAL)) AS success_rate,
MAX(timestamp) AS last_ts
FROM ml_features
WHERE mac_address = ?
""",
(mac,),
)
if rows and rows[0]['cnt']:
ctx['previous_action_count'] = int(rows[0]['cnt'])
ctx['historical_success_rate'] = round(float(rows[0]['success_rate'] or 0.0), 2)
if rows[0]['last_ts']:
try:
last_dt = datetime.fromisoformat(str(rows[0]['last_ts']))
ctx['seconds_since_last'] = round(
(now - last_dt).total_seconds(), 1
)
except Exception:
pass
# Global success rate (all hosts)
g = self.db.query(
"SELECT AVG(CAST(success AS REAL)) AS gsr FROM ml_features"
)
if g and g[0]['gsr'] is not None:
ctx['global_success_rate'] = round(float(g[0]['gsr']), 2)
# Hours since host first seen
host = self.db.get_host_by_mac(mac)
if host and host.get('first_seen'):
try:
ts = host['first_seen']
first_seen = datetime.fromisoformat(ts) if isinstance(ts, str) else ts
ctx['hours_since_discovery'] = round(
(now - first_seen).total_seconds() / 3600, 1
)
except Exception:
pass
except Exception as e:
logger.error(f"Error collecting temporal context for {mac}: {e}")
return ctx
# Action-specific temporal fields populated by _get_action_context
_ACTION_PORTS = {
'SSHBruteforce': 22, 'SSHEnumeration': 22, 'StealFilesSSH': 22,
'WebEnumeration': 80, 'WebVulnScan': 80, 'WebLoginProfiler': 80,
'WebSurfaceMapper': 80,
'SMBBruteforce': 445, 'StealFilesSMB': 445,
'FTPBruteforce': 21, 'StealFilesFTP': 21,
'TelnetBruteforce': 23, 'StealFilesTelnet': 23,
'SQLBruteforce': 3306, 'StealDataSQL': 3306,
'NmapVulnScanner': 0, 'NetworkScanner': 0,
'RDPBruteforce': 3389,
}
def _get_action_context(self, action_name: str, host: Dict, mac: str = '') -> Dict:
"""
Collect action-specific features including per-action attempt history.
Merges action-type + target-port info with action-level temporal stats.
"""
action_type = self._classify_action_type(action_name)
target_port = self._ACTION_PORTS.get(action_name, 0)
# If port not in lookup, try to infer from action name
if target_port == 0:
name_lower = action_name.lower()
for svc, port in [('ssh', 22), ('http', 80), ('smb', 445), ('ftp', 21),
('telnet', 23), ('sql', 3306), ('rdp', 3389)]:
if svc in name_lower:
target_port = port
break
ctx = {
'action_type': action_type,
'target_port': target_port,
'is_standard_port': 0 < target_port < 1024,
# Action-level temporal (overrides placeholder in temporal_context)
'same_action_attempts': 0,
'is_retry': False,
}
if mac:
try:
r = self.db.query(
"""
SELECT COUNT(*) AS cnt
FROM ml_features
WHERE mac_address = ? AND action_name = ?
""",
(mac, action_name),
)
attempts = int(r[0]['cnt']) if r else 0
ctx['same_action_attempts'] = attempts
ctx['is_retry'] = attempts > 0
except Exception as e:
logger.debug(f"Action context DB query failed for {action_name}: {e}")
return ctx
def _classify_action_type(self, action_name: str) -> str:
"""Classify action name into a type"""
name = action_name.lower()
if 'brute' in name: return 'bruteforce'
if 'enum' in name or 'scan' in name: return 'enumeration'
if 'exploit' in name: return 'exploitation'
if 'dump' in name or 'extract' in name: return 'extraction'
return 'other'
# ═══════════════════════════════════════════════════════════════════════
# AI-04: COLD-START BOOTSTRAP
# ═══════════════════════════════════════════════════════════════════════
def _load_bootstrap_scores(self):
"""Load persisted bootstrap scores from disk."""
try:
if self._bootstrap_file.exists():
with open(self._bootstrap_file, 'r') as f:
raw = json.load(f)
# Stored as {"action|profile": [total_reward, count], ...}
for key_str, val in raw.items():
parts = key_str.split('|', 1)
if len(parts) == 2 and isinstance(val, list) and len(val) == 2:
self._bootstrap_scores[(parts[0], parts[1])] = val
logger.info(f"Loaded {len(self._bootstrap_scores)} bootstrap score entries")
except Exception as e:
logger.debug(f"Could not load bootstrap scores: {e}")
def _save_bootstrap_scores(self):
"""Persist bootstrap scores to disk."""
try:
serializable = {
f"{k[0]}|{k[1]}": v for k, v in self._bootstrap_scores.items()
}
with open(self._bootstrap_file, 'w', encoding='utf-8') as f:
json.dump(serializable, f)
except Exception as e:
logger.debug(f"Could not save bootstrap scores: {e}")
def update_bootstrap(self, action_name: str, port_profile: str, reward: float):
"""
AI-04: Update running average reward for an (action, port_profile) pair.
Called after each action execution to accumulate real performance data.
"""
key = (action_name, port_profile)
if key not in self._bootstrap_scores:
self._bootstrap_scores[key] = [0.0, 0]
entry = self._bootstrap_scores[key]
entry[0] += reward
entry[1] += 1
# Persist periodically (every 5 updates to reduce disk writes)
total_updates = sum(v[1] for v in self._bootstrap_scores.values())
if total_updates % 5 == 0:
self._save_bootstrap_scores()
logger.debug(
f"Bootstrap updated: {action_name}+{port_profile} "
f"avg={entry[0]/entry[1]:.1f} (n={entry[1]})"
)
# ═══════════════════════════════════════════════════════════════════════
# HEURISTIC FALLBACK
# ═══════════════════════════════════════════════════════════════════════
def _init_heuristics(self):
"""Initialize rule-based heuristics for cold start"""
self.heuristics = {
'port_based': {
22: ['SSHBruteforce', 'SSHEnumeration'],
80: ['WebEnumeration', 'WebVulnScan'],
443: ['WebEnumeration', 'SSLScan'],
445: ['SMBBruteforce', 'SMBEnumeration'],
3389: ['RDPBruteforce'],
21: ['FTPBruteforce', 'FTPEnumeration'],
23: ['TelnetBruteforce'],
3306: ['MySQLBruteforce'],
5432: ['PostgresBruteforce'],
1433: ['MSSQLBruteforce']
},
'service_based': {
'ssh': ['SSHBruteforce', 'SSHEnumeration'],
'http': ['WebEnumeration', 'WebVulnScan'],
'https': ['WebEnumeration', 'SSLScan'],
'smb': ['SMBBruteforce', 'SMBEnumeration'],
'ftp': ['FTPBruteforce', 'FTPEnumeration'],
'mysql': ['MySQLBruteforce'],
'postgres': ['PostgresBruteforce']
},
'profile_based': {
'camera': ['WebEnumeration', 'DefaultCredCheck', 'RTSPBruteforce'],
'nas': ['SMBBruteforce', 'WebEnumeration', 'SSHBruteforce'],
'web_server': ['WebEnumeration', 'WebVulnScan'],
'database': ['MySQLBruteforce', 'PostgresBruteforce'],
'linux_server': ['SSHBruteforce', 'WebEnumeration'],
'windows_server': ['SMBBruteforce', 'RDPBruteforce']
}
}
def _predict_with_heuristics(
self,
host_context: Dict[str, Any],
available_actions: List[str]
) -> Tuple[str, float, Dict[str, Any]]:
"""
Use rule-based heuristics for action selection.
AI-04: Blends static rules with bootstrap scores from actual execution data.
"""
try:
mac = host_context.get('mac', '')
host = self.db.get_host_by_mac(mac) if mac else {}
# Get ports and services
ports_str = host.get('ports', '') or ''
ports = {int(p) for p in ports_str.split(';') if p.strip().isdigit()}
services = self._get_services_for_host(mac)
# Detect port profile
port_profile = self._detect_port_profile(ports)
# Static heuristic scoring
static_scores = {action: 0.0 for action in available_actions}
# Score based on ports
for port in ports:
if port in self.heuristics['port_based']:
for action in self.heuristics['port_based'][port]:
if action in static_scores:
static_scores[action] += 0.3
# Score based on services
for service in services:
if service in self.heuristics['service_based']:
for action in self.heuristics['service_based'][service]:
if action in static_scores:
static_scores[action] += 0.4
# Score based on port profile
if port_profile in self.heuristics['profile_based']:
for action in self.heuristics['profile_based'][port_profile]:
if action in static_scores:
static_scores[action] += 0.3
# AI-04: Blend static scores with bootstrap scores
blended_scores = {}
bootstrap_used = False
for action in available_actions:
static_score = static_scores.get(action, 0.0)
key = (action, port_profile)
entry = self._bootstrap_scores.get(key)
if entry and entry[1] > 0:
bootstrap_used = True
bootstrap_avg = entry[0] / entry[1]
# Normalize bootstrap avg to 0-1 range (assume reward range ~-30 to +200)
bootstrap_norm = max(0.0, min(1.0, (bootstrap_avg + 30) / 230))
sample_count = entry[1]
# Lerp bootstrap weight from 40% to 80% over 20 samples
base_weight = self._bootstrap_weight # default 0.6
if sample_count < 20:
# Interpolate: at 1 sample -> 0.4, at 20 samples -> 0.8
t = (sample_count - 1) / 19.0
bootstrap_w = 0.4 + t * (0.8 - 0.4)
else:
bootstrap_w = 0.8
static_w = 1.0 - bootstrap_w
blended_scores[action] = static_w * static_score + bootstrap_w * bootstrap_norm
else:
blended_scores[action] = static_score
# Find best action
action_scores = blended_scores
if action_scores:
best_action = max(action_scores, key=action_scores.get)
best_score = action_scores[best_action]
# Normalize score to 0-1
if best_score > 0:
best_score = min(best_score / 1.0, 1.0)
debug_info = {
'method': 'heuristics_bootstrap' if bootstrap_used else 'heuristics',
'port_profile': port_profile,
'ports': list(ports)[:10],
'services': services,
'bootstrap_used': bootstrap_used,
'all_scores': {k: round(v, 4) for k, v in action_scores.items() if v > 0}
}
return best_action, best_score, debug_info
# Ultimate fallback
if available_actions:
return available_actions[0], 0.1, {'method': 'fallback_first'}
return None, 0.0, {'method': 'no_actions'}
except Exception as e:
logger.error(f"Heuristic prediction failed: {e}")
if available_actions:
return available_actions[0], 0.0, {'method': 'fallback_error', 'error': str(e)}
return None, 0.0, {'method': 'error', 'error': str(e)}
# ═══════════════════════════════════════════════════════════════════════
# HELPER METHODS
# ═══════════════════════════════════════════════════════════════════════
@staticmethod
def _relu(x):
"""ReLU activation function"""
return np.maximum(0, x)
@staticmethod
def _sigmoid(x):
"""Sigmoid activation function"""
return 1.0 / (1.0 + np.exp(-x))
@staticmethod
def _softmax(x):
"""Softmax activation function"""
exp_x = np.exp(x - np.max(x)) # Numerical stability
return exp_x / exp_x.sum()
def _forward_network(self, input_vector: np.ndarray) -> np.ndarray:
"""
Forward pass through exported dense network with dynamic hidden depth.
Expected keys: w1/b1, w2/b2, ..., w_out/b_out
"""
a = input_vector
layer_idx = 1
while f'w{layer_idx}' in self.model_weights:
w = self.model_weights[f'w{layer_idx}']
b = self.model_weights[f'b{layer_idx}']
a = self._relu(np.dot(a, w) + b)
layer_idx += 1
return np.dot(a, self.model_weights['w_out']) + self.model_weights['b_out']
def _get_services_for_host(self, mac: str) -> List[str]:
"""Get detected services for host"""
try:
results = self.db.query("""
SELECT DISTINCT service
FROM port_services
WHERE mac_address=?
""", (mac,))
return [r['service'] for r in results if r.get('service')]
except:
return []
def _get_credentials_for_host(self, mac: str) -> List[Dict]:
"""Get credentials found for host"""
try:
return self.db.query("""
SELECT service, user, port
FROM creds
WHERE mac_address=?
""", (mac,))
except:
return []
def _categorize_vendor(self, vendor: str) -> str:
"""Categorize vendor (same as feature_logger)"""
if not vendor:
return 'unknown'
vendor_lower = vendor.lower()
categories = {
'networking': ['cisco', 'juniper', 'ubiquiti', 'mikrotik', 'tp-link'],
'iot': ['hikvision', 'dahua', 'axis'],
'nas': ['synology', 'qnap'],
'compute': ['raspberry', 'intel', 'apple', 'dell', 'hp'],
'virtualization': ['vmware', 'microsoft'],
'mobile': ['apple', 'samsung', 'huawei']
}
for category, vendors in categories.items():
if any(v in vendor_lower for v in vendors):
return category
return 'other'
def _detect_port_profile(self, ports) -> str:
"""Detect device profile from ports (same as feature_logger)"""
port_set = set(ports)
profiles = {
'camera': {554, 80, 8000},
'web_server': {80, 443, 8080},
'nas': {5000, 5001, 548, 139, 445},
'database': {3306, 5432, 1433, 27017},
'linux_server': {22, 80, 443},
'windows_server': {135, 139, 445, 3389},
'printer': {9100, 515, 631},
'router': {22, 23, 80, 443, 161}
}
max_overlap = 0
best_profile = 'generic'
for profile_name, profile_ports in profiles.items():
overlap = len(port_set & profile_ports)
if overlap > max_overlap:
max_overlap = overlap
best_profile = profile_name
return best_profile if max_overlap >= 2 else 'generic'
# ═══════════════════════════════════════════════════════════════════════
# STATISTICS
# ═══════════════════════════════════════════════════════════════════════
def get_stats(self) -> Dict[str, Any]:
"""Get AI engine statistics"""
stats = {
'model_loaded': self.model_loaded,
'heuristics_available': True,
'decision_mode': 'neural_network' if self.model_loaded else 'heuristics'
}
if self.model_loaded and self.model_config:
stats.update({
'model_version': self.model_config.get('version'),
'model_trained_at': self.model_config.get('trained_at'),
'model_accuracy': self.model_config.get('accuracy'),
'training_samples': self.model_config.get('training_samples')
})
# AI-03: Include model versioning info
stats['model_info'] = self.get_model_info()
# AI-04: Include bootstrap stats
stats['bootstrap_entries'] = len(self._bootstrap_scores)
return stats
# ═══════════════════════════════════════════════════════════════════════════
# SINGLETON FACTORY
# ═══════════════════════════════════════════════════════════════════════════
def get_or_create_ai_engine(shared_data) -> Optional['BjornAIEngine']:
"""
Return the single BjornAIEngine instance attached to shared_data.
Creates it on first call; subsequent calls return the cached instance.
Use this instead of BjornAIEngine(shared_data) to avoid loading model
weights multiple times (orchestrator + scheduler + web each need AI).
"""
if getattr(shared_data, '_ai_engine_singleton', None) is None:
try:
shared_data._ai_engine_singleton = BjornAIEngine(shared_data)
except Exception as e:
logger.error(f"Failed to create BjornAIEngine singleton: {e}")
shared_data._ai_engine_singleton = None
return shared_data._ai_engine_singleton
def invalidate_ai_engine(shared_data) -> None:
"""Drop the cached singleton (e.g. after a mode reset or model update)."""
shared_data._ai_engine_singleton = None
# ═══════════════════════════════════════════════════════════════════════════
# END OF FILE
# ═══════════════════════════════════════════════════════════════════════════