mirror of
https://github.com/infinition/Bjorn.git
synced 2026-03-15 08:52:00 +00:00
Add Loki and Sentinel utility classes for web API endpoints
- Implemented LokiUtils class with GET and POST endpoints for managing scripts, jobs, and payloads. - Added SentinelUtils class with GET and POST endpoints for managing events, rules, devices, and notifications. - Both classes include error handling and JSON response formatting.
This commit is contained in:
346
ai_engine.py
346
ai_engine.py
@@ -59,10 +59,28 @@ class BjornAIEngine:
|
||||
self.feature_config = None
|
||||
self.last_server_attempted = False
|
||||
self.last_server_contact_ok = None
|
||||
|
||||
|
||||
# AI-03: Model versioning & rollback
|
||||
self._previous_model = None # {weights, config, feature_config}
|
||||
self._model_history = [] # [{version, loaded_at, accuracy, avg_reward}]
|
||||
self._max_model_versions_on_disk = 3
|
||||
self._performance_window = [] # recent reward values for current model
|
||||
self._performance_check_interval = int(
|
||||
getattr(shared_data, 'ai_model_perf_check_interval', 50)
|
||||
)
|
||||
self._prev_model_avg_reward = None # avg reward of the model we replaced
|
||||
|
||||
# AI-04: Cold-start bootstrap scores
|
||||
self._bootstrap_scores = {} # {(action_name, port_profile): [total_reward, count]}
|
||||
self._bootstrap_file = self.model_dir / 'ai_bootstrap_scores.json'
|
||||
self._bootstrap_weight = float(
|
||||
getattr(shared_data, 'ai_cold_start_bootstrap_weight', 0.6)
|
||||
)
|
||||
self._load_bootstrap_scores()
|
||||
|
||||
# Try to load latest model
|
||||
self._load_latest_model()
|
||||
|
||||
|
||||
# Fallback heuristics (always available)
|
||||
self._init_heuristics()
|
||||
|
||||
@@ -79,9 +97,9 @@ class BjornAIEngine:
|
||||
"""Load the most recent model from model directory"""
|
||||
try:
|
||||
# Find all potential model configs
|
||||
all_json_files = [f for f in self.model_dir.glob("bjorn_model_*.json")
|
||||
all_json_files = [f for f in self.model_dir.glob("bjorn_model_*.json")
|
||||
if "_weights.json" not in f.name]
|
||||
|
||||
|
||||
# 1. Filter for files that have matching weights
|
||||
valid_models = []
|
||||
for f in all_json_files:
|
||||
@@ -90,50 +108,103 @@ class BjornAIEngine:
|
||||
valid_models.append(f)
|
||||
else:
|
||||
logger.debug(f"Skipping model {f.name}: Weights file missing")
|
||||
|
||||
|
||||
if not valid_models:
|
||||
logger.info(f"No complete models found in {self.model_dir}. Checking server...")
|
||||
# Try to download from server
|
||||
if self.check_for_updates():
|
||||
return
|
||||
|
||||
|
||||
logger.info_throttled(
|
||||
"No AI model available (server offline or empty). Using heuristics only.",
|
||||
key="ai_no_model_available",
|
||||
interval_s=600.0,
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
# 2. Sort by timestamp in filename (lexicographical) and pick latest
|
||||
latest_model = sorted(valid_models)[-1]
|
||||
valid_models = sorted(valid_models)
|
||||
latest_model = valid_models[-1]
|
||||
weights_file = latest_model.with_name(latest_model.stem + '_weights.json')
|
||||
|
||||
|
||||
logger.info(f"Loading model: {latest_model.name} (Weights exists!)")
|
||||
|
||||
|
||||
with open(latest_model, 'r') as f:
|
||||
model_data = json.load(f)
|
||||
|
||||
self.model_config = model_data.get('config', model_data)
|
||||
self.feature_config = model_data.get('features', {})
|
||||
|
||||
|
||||
new_config = model_data.get('config', model_data)
|
||||
new_feature_config = model_data.get('features', {})
|
||||
|
||||
# Load weights
|
||||
with open(weights_file, 'r') as f:
|
||||
weights_data = json.load(f)
|
||||
self.model_weights = {
|
||||
new_weights = {
|
||||
k: np.array(v) for k, v in weights_data.items()
|
||||
}
|
||||
del weights_data # Free raw dict — numpy arrays are the canonical form
|
||||
|
||||
|
||||
# AI-03: Save previous model for rollback
|
||||
if self.model_loaded and self.model_weights is not None:
|
||||
self._previous_model = {
|
||||
'weights': self.model_weights,
|
||||
'config': self.model_config,
|
||||
'feature_config': self.feature_config,
|
||||
}
|
||||
# Record avg reward of outgoing model for performance comparison
|
||||
if self._performance_window:
|
||||
self._prev_model_avg_reward = (
|
||||
sum(self._performance_window) / len(self._performance_window)
|
||||
)
|
||||
self._performance_window = [] # reset for new model
|
||||
|
||||
self.model_config = new_config
|
||||
self.feature_config = new_feature_config
|
||||
self.model_weights = new_weights
|
||||
self.model_loaded = True
|
||||
|
||||
# AI-03: Track model history
|
||||
from datetime import datetime as _dt
|
||||
version = self.model_config.get('version', 'unknown')
|
||||
self._model_history.append({
|
||||
'version': version,
|
||||
'loaded_at': _dt.now().isoformat(),
|
||||
'accuracy': self.model_config.get('accuracy'),
|
||||
'avg_reward': None, # filled later as decisions accumulate
|
||||
})
|
||||
# Keep history bounded
|
||||
if len(self._model_history) > 10:
|
||||
self._model_history = self._model_history[-10:]
|
||||
|
||||
logger.success(
|
||||
f"Model loaded successfully: {self.model_config.get('version', 'unknown')}"
|
||||
f"Model loaded successfully: {version}"
|
||||
)
|
||||
|
||||
|
||||
# AI-03: Prune old model versions on disk (keep N most recent)
|
||||
self._prune_old_model_files(valid_models)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load model: {e}")
|
||||
import traceback
|
||||
logger.debug(traceback.format_exc())
|
||||
self.model_loaded = False
|
||||
|
||||
def _prune_old_model_files(self, valid_models: list):
|
||||
"""AI-03: Keep only the N most recent model versions on disk."""
|
||||
try:
|
||||
keep = self._max_model_versions_on_disk
|
||||
if len(valid_models) <= keep:
|
||||
return
|
||||
to_remove = valid_models[:-keep]
|
||||
for config_path in to_remove:
|
||||
weights_path = config_path.with_name(config_path.stem + '_weights.json')
|
||||
try:
|
||||
config_path.unlink(missing_ok=True)
|
||||
weights_path.unlink(missing_ok=True)
|
||||
logger.info(f"Pruned old model: {config_path.name}")
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not prune {config_path.name}: {e}")
|
||||
except Exception as e:
|
||||
logger.debug(f"Model pruning error: {e}")
|
||||
|
||||
def reload_model(self) -> bool:
|
||||
"""Reload model from disk"""
|
||||
@@ -142,9 +213,103 @@ class BjornAIEngine:
|
||||
self.model_weights = None
|
||||
self.model_config = None
|
||||
self.feature_config = None
|
||||
|
||||
|
||||
self._load_latest_model()
|
||||
return self.model_loaded
|
||||
|
||||
def rollback_model(self) -> bool:
|
||||
"""
|
||||
AI-03: Rollback to the previous model version.
|
||||
Returns True if rollback succeeded.
|
||||
"""
|
||||
if self._previous_model is None:
|
||||
logger.warning("No previous model available for rollback")
|
||||
return False
|
||||
|
||||
logger.info("Rolling back to previous model version...")
|
||||
# Current model becomes the "next" previous (so we can undo a rollback)
|
||||
current_backup = None
|
||||
if self.model_loaded and self.model_weights is not None:
|
||||
current_backup = {
|
||||
'weights': self.model_weights,
|
||||
'config': self.model_config,
|
||||
'feature_config': self.feature_config,
|
||||
}
|
||||
|
||||
self.model_weights = self._previous_model['weights']
|
||||
self.model_config = self._previous_model['config']
|
||||
self.feature_config = self._previous_model['feature_config']
|
||||
self.model_loaded = True
|
||||
self._previous_model = current_backup
|
||||
self._performance_window = [] # reset
|
||||
|
||||
version = self.model_config.get('version', 'unknown')
|
||||
from datetime import datetime as _dt
|
||||
self._model_history.append({
|
||||
'version': f"{version}_rollback",
|
||||
'loaded_at': _dt.now().isoformat(),
|
||||
'accuracy': self.model_config.get('accuracy'),
|
||||
'avg_reward': None,
|
||||
})
|
||||
|
||||
logger.success(f"Rolled back to model version: {version}")
|
||||
return True
|
||||
|
||||
def record_reward(self, reward: float):
|
||||
"""
|
||||
AI-03: Record a reward for performance tracking.
|
||||
After N decisions, auto-rollback if performance has degraded.
|
||||
"""
|
||||
self._performance_window.append(reward)
|
||||
|
||||
# Update current history entry
|
||||
if self._model_history:
|
||||
self._model_history[-1]['avg_reward'] = round(
|
||||
sum(self._performance_window) / len(self._performance_window), 2
|
||||
)
|
||||
|
||||
# Check for auto-rollback after sufficient samples
|
||||
if len(self._performance_window) >= self._performance_check_interval:
|
||||
current_avg = sum(self._performance_window) / len(self._performance_window)
|
||||
|
||||
if (
|
||||
self._prev_model_avg_reward is not None
|
||||
and current_avg < self._prev_model_avg_reward
|
||||
and self._previous_model is not None
|
||||
):
|
||||
logger.warning(
|
||||
f"Model performance degraded: current avg={current_avg:.2f} vs "
|
||||
f"previous avg={self._prev_model_avg_reward:.2f}. Auto-rolling back."
|
||||
)
|
||||
self.rollback_model()
|
||||
else:
|
||||
logger.info(
|
||||
f"Model performance check passed: avg_reward={current_avg:.2f} "
|
||||
f"over {len(self._performance_window)} decisions"
|
||||
)
|
||||
# Reset window for next check cycle
|
||||
self._performance_window = []
|
||||
|
||||
def get_model_info(self) -> Dict[str, Any]:
|
||||
"""AI-03: Return current version, history, and performance stats."""
|
||||
current_avg = None
|
||||
if self._performance_window:
|
||||
current_avg = round(
|
||||
sum(self._performance_window) / len(self._performance_window), 2
|
||||
)
|
||||
|
||||
return {
|
||||
'current_version': self.model_config.get('version') if self.model_config else None,
|
||||
'model_loaded': self.model_loaded,
|
||||
'has_previous_model': self._previous_model is not None,
|
||||
'history': list(self._model_history),
|
||||
'performance': {
|
||||
'current_avg_reward': current_avg,
|
||||
'decisions_since_load': len(self._performance_window),
|
||||
'check_interval': self._performance_check_interval,
|
||||
'previous_model_avg_reward': self._prev_model_avg_reward,
|
||||
},
|
||||
}
|
||||
|
||||
def check_for_updates(self) -> bool:
|
||||
"""Check AI Server for new model version."""
|
||||
@@ -596,10 +761,62 @@ class BjornAIEngine:
|
||||
if 'dump' in name or 'extract' in name: return 'extraction'
|
||||
return 'other'
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
# AI-04: COLD-START BOOTSTRAP
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
|
||||
def _load_bootstrap_scores(self):
|
||||
"""Load persisted bootstrap scores from disk."""
|
||||
try:
|
||||
if self._bootstrap_file.exists():
|
||||
with open(self._bootstrap_file, 'r') as f:
|
||||
raw = json.load(f)
|
||||
# Stored as {"action|profile": [total_reward, count], ...}
|
||||
for key_str, val in raw.items():
|
||||
parts = key_str.split('|', 1)
|
||||
if len(parts) == 2 and isinstance(val, list) and len(val) == 2:
|
||||
self._bootstrap_scores[(parts[0], parts[1])] = val
|
||||
logger.info(f"Loaded {len(self._bootstrap_scores)} bootstrap score entries")
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not load bootstrap scores: {e}")
|
||||
|
||||
def _save_bootstrap_scores(self):
|
||||
"""Persist bootstrap scores to disk."""
|
||||
try:
|
||||
serializable = {
|
||||
f"{k[0]}|{k[1]}": v for k, v in self._bootstrap_scores.items()
|
||||
}
|
||||
with open(self._bootstrap_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(serializable, f)
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not save bootstrap scores: {e}")
|
||||
|
||||
def update_bootstrap(self, action_name: str, port_profile: str, reward: float):
|
||||
"""
|
||||
AI-04: Update running average reward for an (action, port_profile) pair.
|
||||
Called after each action execution to accumulate real performance data.
|
||||
"""
|
||||
key = (action_name, port_profile)
|
||||
if key not in self._bootstrap_scores:
|
||||
self._bootstrap_scores[key] = [0.0, 0]
|
||||
entry = self._bootstrap_scores[key]
|
||||
entry[0] += reward
|
||||
entry[1] += 1
|
||||
|
||||
# Persist periodically (every 5 updates to reduce disk writes)
|
||||
total_updates = sum(v[1] for v in self._bootstrap_scores.values())
|
||||
if total_updates % 5 == 0:
|
||||
self._save_bootstrap_scores()
|
||||
|
||||
logger.debug(
|
||||
f"Bootstrap updated: {action_name}+{port_profile} "
|
||||
f"avg={entry[0]/entry[1]:.1f} (n={entry[1]})"
|
||||
)
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
# HEURISTIC FALLBACK
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
def _init_heuristics(self):
|
||||
"""Initialize rule-based heuristics for cold start"""
|
||||
self.heuristics = {
|
||||
@@ -641,68 +858,99 @@ class BjornAIEngine:
|
||||
) -> Tuple[str, float, Dict[str, Any]]:
|
||||
"""
|
||||
Use rule-based heuristics for action selection.
|
||||
Provides decent performance without machine learning.
|
||||
AI-04: Blends static rules with bootstrap scores from actual execution data.
|
||||
"""
|
||||
try:
|
||||
mac = host_context.get('mac', '')
|
||||
host = self.db.get_host_by_mac(mac) if mac else {}
|
||||
|
||||
|
||||
# Get ports and services
|
||||
ports_str = host.get('ports', '') or ''
|
||||
ports = {int(p) for p in ports_str.split(';') if p.strip().isdigit()}
|
||||
services = self._get_services_for_host(mac)
|
||||
|
||||
|
||||
# Detect port profile
|
||||
port_profile = self._detect_port_profile(ports)
|
||||
|
||||
# Scoring system
|
||||
action_scores = {action: 0.0 for action in available_actions}
|
||||
|
||||
|
||||
# Static heuristic scoring
|
||||
static_scores = {action: 0.0 for action in available_actions}
|
||||
|
||||
# Score based on ports
|
||||
for port in ports:
|
||||
if port in self.heuristics['port_based']:
|
||||
for action in self.heuristics['port_based'][port]:
|
||||
if action in action_scores:
|
||||
action_scores[action] += 0.3
|
||||
|
||||
if action in static_scores:
|
||||
static_scores[action] += 0.3
|
||||
|
||||
# Score based on services
|
||||
for service in services:
|
||||
if service in self.heuristics['service_based']:
|
||||
for action in self.heuristics['service_based'][service]:
|
||||
if action in action_scores:
|
||||
action_scores[action] += 0.4
|
||||
|
||||
if action in static_scores:
|
||||
static_scores[action] += 0.4
|
||||
|
||||
# Score based on port profile
|
||||
if port_profile in self.heuristics['profile_based']:
|
||||
for action in self.heuristics['profile_based'][port_profile]:
|
||||
if action in action_scores:
|
||||
action_scores[action] += 0.3
|
||||
|
||||
if action in static_scores:
|
||||
static_scores[action] += 0.3
|
||||
|
||||
# AI-04: Blend static scores with bootstrap scores
|
||||
blended_scores = {}
|
||||
bootstrap_used = False
|
||||
for action in available_actions:
|
||||
static_score = static_scores.get(action, 0.0)
|
||||
key = (action, port_profile)
|
||||
entry = self._bootstrap_scores.get(key)
|
||||
|
||||
if entry and entry[1] > 0:
|
||||
bootstrap_used = True
|
||||
bootstrap_avg = entry[0] / entry[1]
|
||||
# Normalize bootstrap avg to 0-1 range (assume reward range ~-30 to +200)
|
||||
bootstrap_norm = max(0.0, min(1.0, (bootstrap_avg + 30) / 230))
|
||||
sample_count = entry[1]
|
||||
|
||||
# Lerp bootstrap weight from 40% to 80% over 20 samples
|
||||
base_weight = self._bootstrap_weight # default 0.6
|
||||
if sample_count < 20:
|
||||
# Interpolate: at 1 sample -> 0.4, at 20 samples -> 0.8
|
||||
t = (sample_count - 1) / 19.0
|
||||
bootstrap_w = 0.4 + t * (0.8 - 0.4)
|
||||
else:
|
||||
bootstrap_w = 0.8
|
||||
static_w = 1.0 - bootstrap_w
|
||||
|
||||
blended_scores[action] = static_w * static_score + bootstrap_w * bootstrap_norm
|
||||
else:
|
||||
blended_scores[action] = static_score
|
||||
|
||||
# Find best action
|
||||
action_scores = blended_scores
|
||||
if action_scores:
|
||||
best_action = max(action_scores, key=action_scores.get)
|
||||
best_score = action_scores[best_action]
|
||||
|
||||
|
||||
# Normalize score to 0-1
|
||||
if best_score > 0:
|
||||
best_score = min(best_score / 1.0, 1.0)
|
||||
|
||||
|
||||
debug_info = {
|
||||
'method': 'heuristics',
|
||||
'method': 'heuristics_bootstrap' if bootstrap_used else 'heuristics',
|
||||
'port_profile': port_profile,
|
||||
'ports': list(ports)[:10],
|
||||
'services': services,
|
||||
'all_scores': {k: v for k, v in action_scores.items() if v > 0}
|
||||
'bootstrap_used': bootstrap_used,
|
||||
'all_scores': {k: round(v, 4) for k, v in action_scores.items() if v > 0}
|
||||
}
|
||||
|
||||
|
||||
return best_action, best_score, debug_info
|
||||
|
||||
|
||||
# Ultimate fallback
|
||||
if available_actions:
|
||||
return available_actions[0], 0.1, {'method': 'fallback_first'}
|
||||
|
||||
|
||||
return None, 0.0, {'method': 'no_actions'}
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Heuristic prediction failed: {e}")
|
||||
if available_actions:
|
||||
@@ -824,7 +1072,7 @@ class BjornAIEngine:
|
||||
'heuristics_available': True,
|
||||
'decision_mode': 'neural_network' if self.model_loaded else 'heuristics'
|
||||
}
|
||||
|
||||
|
||||
if self.model_loaded and self.model_config:
|
||||
stats.update({
|
||||
'model_version': self.model_config.get('version'),
|
||||
@@ -832,7 +1080,13 @@ class BjornAIEngine:
|
||||
'model_accuracy': self.model_config.get('accuracy'),
|
||||
'training_samples': self.model_config.get('training_samples')
|
||||
})
|
||||
|
||||
|
||||
# AI-03: Include model versioning info
|
||||
stats['model_info'] = self.get_model_info()
|
||||
|
||||
# AI-04: Include bootstrap stats
|
||||
stats['bootstrap_entries'] = len(self._bootstrap_scores)
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user