mirror of
https://github.com/infinition/Bjorn.git
synced 2026-03-09 06:01:59 +00:00
Add RLUtils class for managing RL/AI dashboard endpoints
- Implemented methods for fetching AI stats, training history, and recent experiences. - Added functionality to set operation mode (MANUAL, AUTO, AI) with appropriate handling. - Included helper methods for querying the database and sending JSON responses. - Integrated model metadata extraction for visualization purposes.
This commit is contained in:
867
ai_engine.py
Normal file
867
ai_engine.py
Normal file
@@ -0,0 +1,867 @@
|
||||
"""
|
||||
ai_engine.py - Dynamic AI Decision Engine for Bjorn
|
||||
═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
Purpose:
|
||||
Lightweight AI decision engine for Raspberry Pi Zero.
|
||||
Works in tandem with deep learning model trained on external PC.
|
||||
|
||||
Architecture:
|
||||
- Lightweight inference engine (no TensorFlow/PyTorch on Pi)
|
||||
- Loads pre-trained model weights from PC
|
||||
- Real-time action selection
|
||||
- Automatic feature extraction
|
||||
- Fallback to heuristics when model unavailable
|
||||
|
||||
Model Pipeline:
|
||||
1. Pi: Collect data → Export → Transfer to PC
|
||||
2. PC: Train deep neural network → Export lightweight model
|
||||
3. Pi: Load model → Use for decision making
|
||||
4. Repeat: Continuous learning cycle
|
||||
|
||||
Author: Bjorn Team
|
||||
Version: 2.0.0
|
||||
"""
|
||||
|
||||
import json
|
||||
import numpy as np
|
||||
from typing import Dict, List, Any, Optional, Tuple
|
||||
from pathlib import Path
|
||||
from logger import Logger
|
||||
|
||||
logger = Logger(name="ai_engine.py", level=20)
|
||||
|
||||
|
||||
class BjornAIEngine:
|
||||
"""
|
||||
Dynamic AI engine for action selection and prioritization.
|
||||
Uses pre-trained model from external PC or falls back to heuristics.
|
||||
"""
|
||||
|
||||
def __init__(self, shared_data, model_dir: str = None):
|
||||
"""
|
||||
Initialize AI engine
|
||||
"""
|
||||
self.shared_data = shared_data
|
||||
self.db = shared_data.db
|
||||
|
||||
if model_dir is None:
|
||||
self.model_dir = Path(getattr(shared_data, 'ai_models_dir', '/home/bjorn/ai_models'))
|
||||
else:
|
||||
self.model_dir = Path(model_dir)
|
||||
|
||||
self.model_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Model state
|
||||
self.model_loaded = False
|
||||
self.model_weights = None
|
||||
self.model_config = None
|
||||
self.feature_config = None
|
||||
self.last_server_attempted = False
|
||||
self.last_server_contact_ok = None
|
||||
|
||||
# Try to load latest model
|
||||
self._load_latest_model()
|
||||
|
||||
# Fallback heuristics (always available)
|
||||
self._init_heuristics()
|
||||
|
||||
logger.info(
|
||||
f"AI Engine initialized (model_loaded={self.model_loaded}, "
|
||||
f"heuristics_available=True)"
|
||||
)
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
# MODEL LOADING
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
|
||||
def _load_latest_model(self):
|
||||
"""Load the most recent model from model directory"""
|
||||
try:
|
||||
# Find all potential model configs
|
||||
all_json_files = [f for f in self.model_dir.glob("bjorn_model_*.json")
|
||||
if "_weights.json" not in f.name]
|
||||
|
||||
# 1. Filter for files that have matching weights
|
||||
valid_models = []
|
||||
for f in all_json_files:
|
||||
weights_path = f.with_name(f.stem + '_weights.json')
|
||||
if weights_path.exists():
|
||||
valid_models.append(f)
|
||||
else:
|
||||
logger.debug(f"Skipping model {f.name}: Weights file missing")
|
||||
|
||||
if not valid_models:
|
||||
logger.info(f"No complete models found in {self.model_dir}. Checking server...")
|
||||
# Try to download from server
|
||||
if self.check_for_updates():
|
||||
return
|
||||
|
||||
logger.info_throttled(
|
||||
"No AI model available (server offline or empty). Using heuristics only.",
|
||||
key="ai_no_model_available",
|
||||
interval_s=600.0,
|
||||
)
|
||||
return
|
||||
|
||||
# 2. Sort by timestamp in filename (lexicographical) and pick latest
|
||||
latest_model = sorted(valid_models)[-1]
|
||||
weights_file = latest_model.with_name(latest_model.stem + '_weights.json')
|
||||
|
||||
logger.info(f"Loading model: {latest_model.name} (Weights exists!)")
|
||||
|
||||
with open(latest_model, 'r') as f:
|
||||
model_data = json.load(f)
|
||||
|
||||
self.model_config = model_data.get('config', model_data)
|
||||
self.feature_config = model_data.get('features', {})
|
||||
|
||||
# Load weights
|
||||
with open(weights_file, 'r') as f:
|
||||
weights_data = json.load(f)
|
||||
self.model_weights = {
|
||||
k: np.array(v) for k, v in weights_data.items()
|
||||
}
|
||||
del weights_data # Free raw dict — numpy arrays are the canonical form
|
||||
|
||||
self.model_loaded = True
|
||||
logger.success(
|
||||
f"Model loaded successfully: {self.model_config.get('version', 'unknown')}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load model: {e}")
|
||||
import traceback
|
||||
logger.debug(traceback.format_exc())
|
||||
self.model_loaded = False
|
||||
|
||||
def reload_model(self) -> bool:
|
||||
"""Reload model from disk"""
|
||||
logger.info("Reloading AI model...")
|
||||
self.model_loaded = False
|
||||
self.model_weights = None
|
||||
self.model_config = None
|
||||
self.feature_config = None
|
||||
|
||||
self._load_latest_model()
|
||||
return self.model_loaded
|
||||
|
||||
def check_for_updates(self) -> bool:
|
||||
"""Check AI Server for new model version."""
|
||||
self.last_server_attempted = False
|
||||
self.last_server_contact_ok = None
|
||||
try:
|
||||
import requests
|
||||
import os
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
url = self.shared_data.config.get("ai_server_url")
|
||||
if not url:
|
||||
return False
|
||||
|
||||
try:
|
||||
logger.debug(f"Checking AI Server for updates at {url}/model/latest")
|
||||
from ai_utils import get_system_mac
|
||||
params = {'mac_addr': get_system_mac()}
|
||||
self.last_server_attempted = True
|
||||
resp = requests.get(f"{url}/model/latest", params=params, timeout=5)
|
||||
# Any HTTP response means server is reachable.
|
||||
self.last_server_contact_ok = True
|
||||
|
||||
if resp.status_code != 200:
|
||||
return False
|
||||
|
||||
remote_config = resp.json()
|
||||
remote_version = str(remote_config.get("version", "")).strip()
|
||||
|
||||
if not remote_version:
|
||||
return False
|
||||
|
||||
current_version = str(self.model_config.get("version", "0")).strip() if self.model_config else "0"
|
||||
|
||||
if remote_version > current_version:
|
||||
logger.info(f"New model available: {remote_version} (Local: {current_version})")
|
||||
|
||||
# Download config (stream to avoid loading the whole file into RAM)
|
||||
r_conf = requests.get(
|
||||
f"{url}/model/download/bjorn_model_{remote_version}.json",
|
||||
stream=True, timeout=15,
|
||||
)
|
||||
if r_conf.status_code == 200:
|
||||
conf_path = self.model_dir / f"bjorn_model_{remote_version}.json"
|
||||
with open(conf_path, 'wb') as f:
|
||||
for chunk in r_conf.iter_content(chunk_size=8192):
|
||||
if chunk:
|
||||
f.write(chunk)
|
||||
f.flush()
|
||||
os.fsync(f.fileno())
|
||||
else:
|
||||
logger.info_throttled(
|
||||
f"AI model download skipped (config HTTP {r_conf.status_code})",
|
||||
key=f"ai_model_dl_conf_{r_conf.status_code}",
|
||||
interval_s=300.0,
|
||||
)
|
||||
return False
|
||||
|
||||
# Download weights (stream to avoid loading the whole file into RAM)
|
||||
r_weights = requests.get(
|
||||
f"{url}/model/download/bjorn_model_{remote_version}_weights.json",
|
||||
stream=True, timeout=30,
|
||||
)
|
||||
if r_weights.status_code == 200:
|
||||
weights_path = self.model_dir / f"bjorn_model_{remote_version}_weights.json"
|
||||
with open(weights_path, 'wb') as f:
|
||||
for chunk in r_weights.iter_content(chunk_size=8192):
|
||||
if chunk:
|
||||
f.write(chunk)
|
||||
f.flush()
|
||||
os.fsync(f.fileno())
|
||||
|
||||
logger.success(f"Downloaded model {remote_version} files to Pi.")
|
||||
else:
|
||||
logger.info_throttled(
|
||||
f"AI model download skipped (weights HTTP {r_weights.status_code})",
|
||||
key=f"ai_model_dl_weights_{r_weights.status_code}",
|
||||
interval_s=300.0,
|
||||
)
|
||||
return False
|
||||
|
||||
# Reload explicitly
|
||||
return self.reload_model()
|
||||
|
||||
logger.debug(f"Server model ({remote_version}) is not newer than local ({current_version})")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
self.last_server_attempted = True
|
||||
self.last_server_contact_ok = False
|
||||
# Server may be offline; avoid spamming errors in AI mode.
|
||||
logger.info_throttled(
|
||||
f"AI server unavailable for model update check: {e}",
|
||||
key="ai_model_update_check_failed",
|
||||
interval_s=300.0,
|
||||
)
|
||||
return False
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
# DECISION MAKING
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
|
||||
def choose_action(
|
||||
self,
|
||||
host_context: Dict[str, Any],
|
||||
available_actions: List[str],
|
||||
exploration_rate: float = None
|
||||
) -> Tuple[str, float, Dict[str, Any]]:
|
||||
"""
|
||||
Choose the best action for a given host.
|
||||
|
||||
Args:
|
||||
host_context: Dict with host information (mac, ports, hostname, etc.)
|
||||
available_actions: List of action names that can be executed
|
||||
exploration_rate: Probability of random exploration (0.0-1.0)
|
||||
|
||||
Returns:
|
||||
Tuple of (action_name, confidence_score, debug_info)
|
||||
"""
|
||||
if exploration_rate is None:
|
||||
exploration_rate = float(getattr(self.shared_data, "ai_exploration_rate", 0.1))
|
||||
|
||||
try:
|
||||
# Exploration: random action
|
||||
if exploration_rate > 0 and np.random.random() < exploration_rate:
|
||||
import random
|
||||
action = random.choice(available_actions)
|
||||
return action, 0.0, {'method': 'exploration', 'exploration_rate': exploration_rate}
|
||||
|
||||
# If model is loaded, use it for prediction
|
||||
if self.model_loaded and self.model_weights:
|
||||
return self._predict_with_model(host_context, available_actions)
|
||||
|
||||
# Fallback to heuristics
|
||||
return self._predict_with_heuristics(host_context, available_actions)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error choosing action: {e}")
|
||||
# Ultimate fallback: first available action
|
||||
if available_actions:
|
||||
return available_actions[0], 0.0, {'method': 'fallback_error', 'error': str(e)}
|
||||
return None, 0.0, {'method': 'no_actions', 'error': 'No available actions'}
|
||||
|
||||
def _predict_with_model(
|
||||
self,
|
||||
host_context: Dict[str, Any],
|
||||
available_actions: List[str]
|
||||
) -> Tuple[str, float, Dict[str, Any]]:
|
||||
"""
|
||||
Use loaded neural network model for prediction.
|
||||
Dynamically maps extracted features to model manifest.
|
||||
"""
|
||||
try:
|
||||
from ai_utils import extract_neural_features_dict
|
||||
|
||||
# 1. Get model feature manifest
|
||||
manifest = self.model_config.get('architecture', {}).get('feature_names', [])
|
||||
if not manifest:
|
||||
# Legacy fallback
|
||||
return self._predict_with_model_legacy(host_context, available_actions)
|
||||
|
||||
# 2. Extract host-level features
|
||||
mac = host_context.get('mac', '')
|
||||
host = self.db.get_host_by_mac(mac) if mac else {}
|
||||
|
||||
host_data = self._get_host_context_from_db(mac, host)
|
||||
net_data = self._get_network_context()
|
||||
temp_data_base = self._get_temporal_context(mac) # MAC-level temporal, called once
|
||||
|
||||
best_action = None
|
||||
best_score = -1.0
|
||||
all_scores = {}
|
||||
|
||||
# 3. Score each action
|
||||
for action in available_actions:
|
||||
action_data = self._get_action_context(action, host, mac)
|
||||
|
||||
# Merge action-level temporal overrides into temporal context copy
|
||||
temp_data = dict(temp_data_base)
|
||||
temp_data['same_action_attempts'] = action_data.pop('same_action_attempts', 0)
|
||||
temp_data['is_retry'] = action_data.pop('is_retry', False)
|
||||
|
||||
# Extract all known features into a dict
|
||||
features_dict = extract_neural_features_dict(
|
||||
host_features=host_data,
|
||||
network_features=net_data,
|
||||
temporal_features=temp_data,
|
||||
action_features=action_data
|
||||
)
|
||||
|
||||
# Dynamic mapping: Pull features requested by model manifest
|
||||
# Defaults to 0.0 if the Pi doesn't know this feature yet
|
||||
input_vector = np.array([float(features_dict.get(name, 0.0)) for name in manifest], dtype=float)
|
||||
|
||||
# Neural inference (supports variable hidden depth from exported model).
|
||||
z_out = self._forward_network(input_vector)
|
||||
z_out = np.array(z_out).reshape(-1)
|
||||
if z_out.size == 1:
|
||||
# Binary classifier exported with 1-neuron sigmoid output.
|
||||
score = float(self._sigmoid(z_out[0]))
|
||||
else:
|
||||
probs = self._softmax(z_out)
|
||||
score = float(probs[1] if len(probs) > 1 else probs[0])
|
||||
|
||||
all_scores[action] = score
|
||||
if score > best_score:
|
||||
best_score = score
|
||||
best_action = action
|
||||
|
||||
if best_action is None:
|
||||
return self._predict_with_heuristics(host_context, available_actions)
|
||||
|
||||
# Capture the last input vector (for visualization)
|
||||
# Since we iterate, we'll just take the one from the best_action or the last one.
|
||||
# Usually input_vector is almost the same for all actions except action-specific bits.
|
||||
|
||||
debug_info = {
|
||||
'method': 'neural_network_v3',
|
||||
'model_version': self.model_config.get('version'),
|
||||
'feature_count': len(manifest),
|
||||
'all_scores': all_scores,
|
||||
# Convert numpy ndarray → plain Python list so debug_info is
|
||||
# always JSON-serialisable (scheduler stores it in action_queue metadata).
|
||||
'input_vector': input_vector.tolist(),
|
||||
}
|
||||
|
||||
return best_action, float(best_score), debug_info
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Dynamic model prediction failed: {e}")
|
||||
import traceback
|
||||
logger.debug(traceback.format_exc())
|
||||
return self._predict_with_heuristics(host_context, available_actions)
|
||||
|
||||
def _predict_with_model_legacy(self, host_context: Dict[str, Any], available_actions: List[str]) -> Tuple[str, float, Dict[str, Any]]:
|
||||
"""Fallback for models without feature_names manifest (fixed length 56)"""
|
||||
# ... very similar to previous v2 but using hardcoded list ...
|
||||
return self._predict_with_heuristics(host_context, available_actions)
|
||||
|
||||
def _get_host_context_from_db(self, mac: str, host: Dict) -> Dict:
|
||||
"""Helper to collect host features from DB"""
|
||||
ports_str = host.get('ports', '') or ''
|
||||
ports = [int(p) for p in ports_str.split(';') if p.strip().isdigit()]
|
||||
vendor = host.get('vendor', '')
|
||||
|
||||
# Calculate age
|
||||
age_hours = 0.0
|
||||
if host.get('first_seen'):
|
||||
from datetime import datetime
|
||||
try:
|
||||
ts = host['first_seen']
|
||||
first_seen = datetime.fromisoformat(ts) if isinstance(ts, str) else ts
|
||||
age_hours = (datetime.now() - first_seen).total_seconds() / 3600
|
||||
except: pass
|
||||
|
||||
creds = self._get_credentials_for_host(mac)
|
||||
|
||||
return {
|
||||
'port_count': len(ports),
|
||||
'service_count': len(self._get_services_for_host(mac)),
|
||||
'ip_count': len((host.get('ips') or '').split(';')),
|
||||
'credential_count': len(creds),
|
||||
'age_hours': round(age_hours, 2),
|
||||
'has_ssh': 22 in ports,
|
||||
'has_http': 80 in ports or 8080 in ports,
|
||||
'has_https': 443 in ports,
|
||||
'has_smb': 445 in ports,
|
||||
'has_rdp': 3389 in ports,
|
||||
'has_database': any(p in ports for p in [3306, 5432, 1433]),
|
||||
'has_credentials': len(creds) > 0,
|
||||
'is_new': age_hours < 24,
|
||||
'is_private': True, # Simple assumption for now
|
||||
'has_multiple_ips': len((host.get('ips') or '').split(';')) > 1,
|
||||
'vendor_category': self._categorize_vendor(vendor),
|
||||
'port_profile': self._detect_port_profile(ports)
|
||||
}
|
||||
|
||||
def _get_network_context(self) -> Dict:
|
||||
"""Collect real network-wide stats from DB (called once per choose_action)."""
|
||||
try:
|
||||
all_hosts = self.db.get_all_hosts()
|
||||
total = len(all_hosts)
|
||||
|
||||
# Subnet diversity
|
||||
subnets = set()
|
||||
active = 0
|
||||
for h in all_hosts:
|
||||
ips = (h.get('ips') or '').split(';')
|
||||
for ip in ips:
|
||||
ip = ip.strip()
|
||||
if ip:
|
||||
subnets.add('.'.join(ip.split('.')[:3]))
|
||||
break
|
||||
if h.get('alive'):
|
||||
active += 1
|
||||
|
||||
return {
|
||||
'total_hosts': total,
|
||||
'subnet_count': len(subnets),
|
||||
'similar_vendor_count': 0, # filled by caller if needed
|
||||
'similar_port_profile_count': 0, # filled by caller if needed
|
||||
'active_host_ratio': round(active / total, 2) if total else 0.0,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error collecting network context: {e}")
|
||||
return {
|
||||
'total_hosts': 0, 'subnet_count': 1,
|
||||
'similar_vendor_count': 0, 'similar_port_profile_count': 0,
|
||||
'active_host_ratio': 1.0,
|
||||
}
|
||||
|
||||
def _get_temporal_context(self, mac: str) -> Dict:
|
||||
"""
|
||||
Collect real temporal features for a MAC from DB.
|
||||
same_action_attempts / is_retry are action-specific — they are NOT
|
||||
included here; instead they are merged from _get_action_context()
|
||||
inside the per-action loop in _predict_with_model().
|
||||
"""
|
||||
from datetime import datetime
|
||||
now = datetime.now()
|
||||
|
||||
ctx = {
|
||||
'hour_of_day': now.hour,
|
||||
'day_of_week': now.weekday(),
|
||||
'is_weekend': now.weekday() >= 5,
|
||||
'is_night': now.hour < 6 or now.hour >= 22,
|
||||
'previous_action_count': 0,
|
||||
'seconds_since_last': 0,
|
||||
'historical_success_rate': 0.0,
|
||||
'same_action_attempts': 0, # placeholder; overwritten per-action
|
||||
'is_retry': False, # placeholder; overwritten per-action
|
||||
'global_success_rate': 0.0,
|
||||
'hours_since_discovery': 0,
|
||||
}
|
||||
|
||||
try:
|
||||
# Per-host stats from ml_features (persistent training log)
|
||||
rows = self.db.query(
|
||||
"""
|
||||
SELECT
|
||||
COUNT(*) AS cnt,
|
||||
AVG(CAST(success AS REAL)) AS success_rate,
|
||||
MAX(timestamp) AS last_ts
|
||||
FROM ml_features
|
||||
WHERE mac_address = ?
|
||||
""",
|
||||
(mac,),
|
||||
)
|
||||
if rows and rows[0]['cnt']:
|
||||
ctx['previous_action_count'] = int(rows[0]['cnt'])
|
||||
ctx['historical_success_rate'] = round(float(rows[0]['success_rate'] or 0.0), 2)
|
||||
if rows[0]['last_ts']:
|
||||
try:
|
||||
last_dt = datetime.fromisoformat(str(rows[0]['last_ts']))
|
||||
ctx['seconds_since_last'] = round(
|
||||
(now - last_dt).total_seconds(), 1
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Global success rate (all hosts)
|
||||
g = self.db.query(
|
||||
"SELECT AVG(CAST(success AS REAL)) AS gsr FROM ml_features"
|
||||
)
|
||||
if g and g[0]['gsr'] is not None:
|
||||
ctx['global_success_rate'] = round(float(g[0]['gsr']), 2)
|
||||
|
||||
# Hours since host first seen
|
||||
host = self.db.get_host_by_mac(mac)
|
||||
if host and host.get('first_seen'):
|
||||
try:
|
||||
ts = host['first_seen']
|
||||
first_seen = datetime.fromisoformat(ts) if isinstance(ts, str) else ts
|
||||
ctx['hours_since_discovery'] = round(
|
||||
(now - first_seen).total_seconds() / 3600, 1
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error collecting temporal context for {mac}: {e}")
|
||||
|
||||
return ctx
|
||||
|
||||
# Action-specific temporal fields populated by _get_action_context
|
||||
_ACTION_PORTS = {
|
||||
'SSHBruteforce': 22, 'SSHEnumeration': 22, 'StealFilesSSH': 22,
|
||||
'WebEnumeration': 80, 'WebVulnScan': 80, 'WebLoginProfiler': 80,
|
||||
'WebSurfaceMapper': 80,
|
||||
'SMBBruteforce': 445, 'StealFilesSMB': 445,
|
||||
'FTPBruteforce': 21, 'StealFilesFTP': 21,
|
||||
'TelnetBruteforce': 23, 'StealFilesTelnet': 23,
|
||||
'SQLBruteforce': 3306, 'StealDataSQL': 3306,
|
||||
'NmapVulnScanner': 0, 'NetworkScanner': 0,
|
||||
'RDPBruteforce': 3389,
|
||||
}
|
||||
|
||||
def _get_action_context(self, action_name: str, host: Dict, mac: str = '') -> Dict:
|
||||
"""
|
||||
Collect action-specific features including per-action attempt history.
|
||||
Merges action-type + target-port info with action-level temporal stats.
|
||||
"""
|
||||
action_type = self._classify_action_type(action_name)
|
||||
target_port = self._ACTION_PORTS.get(action_name, 0)
|
||||
|
||||
# If port not in lookup, try to infer from action name
|
||||
if target_port == 0:
|
||||
name_lower = action_name.lower()
|
||||
for svc, port in [('ssh', 22), ('http', 80), ('smb', 445), ('ftp', 21),
|
||||
('telnet', 23), ('sql', 3306), ('rdp', 3389)]:
|
||||
if svc in name_lower:
|
||||
target_port = port
|
||||
break
|
||||
|
||||
ctx = {
|
||||
'action_type': action_type,
|
||||
'target_port': target_port,
|
||||
'is_standard_port': 0 < target_port < 1024,
|
||||
# Action-level temporal (overrides placeholder in temporal_context)
|
||||
'same_action_attempts': 0,
|
||||
'is_retry': False,
|
||||
}
|
||||
|
||||
if mac:
|
||||
try:
|
||||
r = self.db.query(
|
||||
"""
|
||||
SELECT COUNT(*) AS cnt
|
||||
FROM ml_features
|
||||
WHERE mac_address = ? AND action_name = ?
|
||||
""",
|
||||
(mac, action_name),
|
||||
)
|
||||
attempts = int(r[0]['cnt']) if r else 0
|
||||
ctx['same_action_attempts'] = attempts
|
||||
ctx['is_retry'] = attempts > 0
|
||||
except Exception as e:
|
||||
logger.debug(f"Action context DB query failed for {action_name}: {e}")
|
||||
|
||||
return ctx
|
||||
|
||||
def _classify_action_type(self, action_name: str) -> str:
|
||||
"""Classify action name into a type"""
|
||||
name = action_name.lower()
|
||||
if 'brute' in name: return 'bruteforce'
|
||||
if 'enum' in name or 'scan' in name: return 'enumeration'
|
||||
if 'exploit' in name: return 'exploitation'
|
||||
if 'dump' in name or 'extract' in name: return 'extraction'
|
||||
return 'other'
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
# HEURISTIC FALLBACK
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
|
||||
def _init_heuristics(self):
|
||||
"""Initialize rule-based heuristics for cold start"""
|
||||
self.heuristics = {
|
||||
'port_based': {
|
||||
22: ['SSHBruteforce', 'SSHEnumeration'],
|
||||
80: ['WebEnumeration', 'WebVulnScan'],
|
||||
443: ['WebEnumeration', 'SSLScan'],
|
||||
445: ['SMBBruteforce', 'SMBEnumeration'],
|
||||
3389: ['RDPBruteforce'],
|
||||
21: ['FTPBruteforce', 'FTPEnumeration'],
|
||||
23: ['TelnetBruteforce'],
|
||||
3306: ['MySQLBruteforce'],
|
||||
5432: ['PostgresBruteforce'],
|
||||
1433: ['MSSQLBruteforce']
|
||||
},
|
||||
'service_based': {
|
||||
'ssh': ['SSHBruteforce', 'SSHEnumeration'],
|
||||
'http': ['WebEnumeration', 'WebVulnScan'],
|
||||
'https': ['WebEnumeration', 'SSLScan'],
|
||||
'smb': ['SMBBruteforce', 'SMBEnumeration'],
|
||||
'ftp': ['FTPBruteforce', 'FTPEnumeration'],
|
||||
'mysql': ['MySQLBruteforce'],
|
||||
'postgres': ['PostgresBruteforce']
|
||||
},
|
||||
'profile_based': {
|
||||
'camera': ['WebEnumeration', 'DefaultCredCheck', 'RTSPBruteforce'],
|
||||
'nas': ['SMBBruteforce', 'WebEnumeration', 'SSHBruteforce'],
|
||||
'web_server': ['WebEnumeration', 'WebVulnScan'],
|
||||
'database': ['MySQLBruteforce', 'PostgresBruteforce'],
|
||||
'linux_server': ['SSHBruteforce', 'WebEnumeration'],
|
||||
'windows_server': ['SMBBruteforce', 'RDPBruteforce']
|
||||
}
|
||||
}
|
||||
|
||||
def _predict_with_heuristics(
|
||||
self,
|
||||
host_context: Dict[str, Any],
|
||||
available_actions: List[str]
|
||||
) -> Tuple[str, float, Dict[str, Any]]:
|
||||
"""
|
||||
Use rule-based heuristics for action selection.
|
||||
Provides decent performance without machine learning.
|
||||
"""
|
||||
try:
|
||||
mac = host_context.get('mac', '')
|
||||
host = self.db.get_host_by_mac(mac) if mac else {}
|
||||
|
||||
# Get ports and services
|
||||
ports_str = host.get('ports', '') or ''
|
||||
ports = {int(p) for p in ports_str.split(';') if p.strip().isdigit()}
|
||||
services = self._get_services_for_host(mac)
|
||||
|
||||
# Detect port profile
|
||||
port_profile = self._detect_port_profile(ports)
|
||||
|
||||
# Scoring system
|
||||
action_scores = {action: 0.0 for action in available_actions}
|
||||
|
||||
# Score based on ports
|
||||
for port in ports:
|
||||
if port in self.heuristics['port_based']:
|
||||
for action in self.heuristics['port_based'][port]:
|
||||
if action in action_scores:
|
||||
action_scores[action] += 0.3
|
||||
|
||||
# Score based on services
|
||||
for service in services:
|
||||
if service in self.heuristics['service_based']:
|
||||
for action in self.heuristics['service_based'][service]:
|
||||
if action in action_scores:
|
||||
action_scores[action] += 0.4
|
||||
|
||||
# Score based on port profile
|
||||
if port_profile in self.heuristics['profile_based']:
|
||||
for action in self.heuristics['profile_based'][port_profile]:
|
||||
if action in action_scores:
|
||||
action_scores[action] += 0.3
|
||||
|
||||
# Find best action
|
||||
if action_scores:
|
||||
best_action = max(action_scores, key=action_scores.get)
|
||||
best_score = action_scores[best_action]
|
||||
|
||||
# Normalize score to 0-1
|
||||
if best_score > 0:
|
||||
best_score = min(best_score / 1.0, 1.0)
|
||||
|
||||
debug_info = {
|
||||
'method': 'heuristics',
|
||||
'port_profile': port_profile,
|
||||
'ports': list(ports)[:10],
|
||||
'services': services,
|
||||
'all_scores': {k: v for k, v in action_scores.items() if v > 0}
|
||||
}
|
||||
|
||||
return best_action, best_score, debug_info
|
||||
|
||||
# Ultimate fallback
|
||||
if available_actions:
|
||||
return available_actions[0], 0.1, {'method': 'fallback_first'}
|
||||
|
||||
return None, 0.0, {'method': 'no_actions'}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Heuristic prediction failed: {e}")
|
||||
if available_actions:
|
||||
return available_actions[0], 0.0, {'method': 'fallback_error', 'error': str(e)}
|
||||
return None, 0.0, {'method': 'error', 'error': str(e)}
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
# HELPER METHODS
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
|
||||
@staticmethod
|
||||
def _relu(x):
|
||||
"""ReLU activation function"""
|
||||
return np.maximum(0, x)
|
||||
|
||||
@staticmethod
|
||||
def _sigmoid(x):
|
||||
"""Sigmoid activation function"""
|
||||
return 1.0 / (1.0 + np.exp(-x))
|
||||
|
||||
@staticmethod
|
||||
def _softmax(x):
|
||||
"""Softmax activation function"""
|
||||
exp_x = np.exp(x - np.max(x)) # Numerical stability
|
||||
return exp_x / exp_x.sum()
|
||||
|
||||
def _forward_network(self, input_vector: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Forward pass through exported dense network with dynamic hidden depth.
|
||||
Expected keys: w1/b1, w2/b2, ..., w_out/b_out
|
||||
"""
|
||||
a = input_vector
|
||||
layer_idx = 1
|
||||
while f'w{layer_idx}' in self.model_weights:
|
||||
w = self.model_weights[f'w{layer_idx}']
|
||||
b = self.model_weights[f'b{layer_idx}']
|
||||
a = self._relu(np.dot(a, w) + b)
|
||||
layer_idx += 1
|
||||
return np.dot(a, self.model_weights['w_out']) + self.model_weights['b_out']
|
||||
|
||||
def _get_services_for_host(self, mac: str) -> List[str]:
|
||||
"""Get detected services for host"""
|
||||
try:
|
||||
results = self.db.query("""
|
||||
SELECT DISTINCT service
|
||||
FROM port_services
|
||||
WHERE mac_address=?
|
||||
""", (mac,))
|
||||
return [r['service'] for r in results if r.get('service')]
|
||||
except:
|
||||
return []
|
||||
|
||||
def _get_credentials_for_host(self, mac: str) -> List[Dict]:
|
||||
"""Get credentials found for host"""
|
||||
try:
|
||||
return self.db.query("""
|
||||
SELECT service, user, port
|
||||
FROM creds
|
||||
WHERE mac_address=?
|
||||
""", (mac,))
|
||||
except:
|
||||
return []
|
||||
|
||||
def _categorize_vendor(self, vendor: str) -> str:
|
||||
"""Categorize vendor (same as feature_logger)"""
|
||||
if not vendor:
|
||||
return 'unknown'
|
||||
|
||||
vendor_lower = vendor.lower()
|
||||
categories = {
|
||||
'networking': ['cisco', 'juniper', 'ubiquiti', 'mikrotik', 'tp-link'],
|
||||
'iot': ['hikvision', 'dahua', 'axis'],
|
||||
'nas': ['synology', 'qnap'],
|
||||
'compute': ['raspberry', 'intel', 'apple', 'dell', 'hp'],
|
||||
'virtualization': ['vmware', 'microsoft'],
|
||||
'mobile': ['apple', 'samsung', 'huawei']
|
||||
}
|
||||
|
||||
for category, vendors in categories.items():
|
||||
if any(v in vendor_lower for v in vendors):
|
||||
return category
|
||||
|
||||
return 'other'
|
||||
|
||||
def _detect_port_profile(self, ports) -> str:
|
||||
"""Detect device profile from ports (same as feature_logger)"""
|
||||
port_set = set(ports)
|
||||
|
||||
profiles = {
|
||||
'camera': {554, 80, 8000},
|
||||
'web_server': {80, 443, 8080},
|
||||
'nas': {5000, 5001, 548, 139, 445},
|
||||
'database': {3306, 5432, 1433, 27017},
|
||||
'linux_server': {22, 80, 443},
|
||||
'windows_server': {135, 139, 445, 3389},
|
||||
'printer': {9100, 515, 631},
|
||||
'router': {22, 23, 80, 443, 161}
|
||||
}
|
||||
|
||||
max_overlap = 0
|
||||
best_profile = 'generic'
|
||||
|
||||
for profile_name, profile_ports in profiles.items():
|
||||
overlap = len(port_set & profile_ports)
|
||||
if overlap > max_overlap:
|
||||
max_overlap = overlap
|
||||
best_profile = profile_name
|
||||
|
||||
return best_profile if max_overlap >= 2 else 'generic'
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
# STATISTICS
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get AI engine statistics"""
|
||||
stats = {
|
||||
'model_loaded': self.model_loaded,
|
||||
'heuristics_available': True,
|
||||
'decision_mode': 'neural_network' if self.model_loaded else 'heuristics'
|
||||
}
|
||||
|
||||
if self.model_loaded and self.model_config:
|
||||
stats.update({
|
||||
'model_version': self.model_config.get('version'),
|
||||
'model_trained_at': self.model_config.get('trained_at'),
|
||||
'model_accuracy': self.model_config.get('accuracy'),
|
||||
'training_samples': self.model_config.get('training_samples')
|
||||
})
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
# SINGLETON FACTORY
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
def get_or_create_ai_engine(shared_data) -> Optional['BjornAIEngine']:
|
||||
"""
|
||||
Return the single BjornAIEngine instance attached to shared_data.
|
||||
Creates it on first call; subsequent calls return the cached instance.
|
||||
|
||||
Use this instead of BjornAIEngine(shared_data) to avoid loading model
|
||||
weights multiple times (orchestrator + scheduler + web each need AI).
|
||||
"""
|
||||
if getattr(shared_data, '_ai_engine_singleton', None) is None:
|
||||
try:
|
||||
shared_data._ai_engine_singleton = BjornAIEngine(shared_data)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create BjornAIEngine singleton: {e}")
|
||||
shared_data._ai_engine_singleton = None
|
||||
return shared_data._ai_engine_singleton
|
||||
|
||||
|
||||
def invalidate_ai_engine(shared_data) -> None:
|
||||
"""Drop the cached singleton (e.g. after a mode reset or model update)."""
|
||||
shared_data._ai_engine_singleton = None
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
# END OF FILE
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
Reference in New Issue
Block a user