mirror of
https://github.com/infinition/Bjorn.git
synced 2026-03-09 14:12:00 +00:00
Add RLUtils class for managing RL/AI dashboard endpoints
- Implemented methods for fetching AI stats, training history, and recent experiences. - Added functionality to set operation mode (MANUAL, AUTO, AI) with appropriate handling. - Included helper methods for querying the database and sending JSON responses. - Integrated model metadata extraction for visualization purposes.
This commit is contained in:
762
feature_logger.py
Normal file
762
feature_logger.py
Normal file
@@ -0,0 +1,762 @@
|
||||
"""
|
||||
feature_logger.py - Dynamic Feature Logging Engine for Bjorn
|
||||
═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
Purpose:
|
||||
Automatically capture ALL relevant features from action executions
|
||||
for deep learning model training. No manual feature declaration needed.
|
||||
|
||||
Architecture:
|
||||
- Automatic feature extraction from all data sources
|
||||
- Time-series aggregation
|
||||
- Network topology features
|
||||
- Action success patterns
|
||||
- Lightweight storage optimized for Pi Zero
|
||||
- Export format ready for deep learning
|
||||
|
||||
Author: Bjorn Team (Enhanced AI Version)
|
||||
Version: 2.0.0
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
import hashlib
|
||||
import random
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Any, Optional, Tuple
|
||||
from collections import defaultdict, deque
|
||||
from logger import Logger
|
||||
|
||||
logger = Logger(name="feature_logger.py", level=20)
|
||||
|
||||
|
||||
class FeatureLogger:
|
||||
"""
|
||||
Captures comprehensive features from network reconnaissance
|
||||
and action execution for deep learning.
|
||||
"""
|
||||
|
||||
def __init__(self, shared_data):
|
||||
"""Initialize feature logger with database connection"""
|
||||
self.shared_data = shared_data
|
||||
self.db = shared_data.db
|
||||
self._max_hosts_tracked = max(
|
||||
64, int(getattr(self.shared_data, "ai_feature_hosts_limit", 512))
|
||||
)
|
||||
|
||||
# Rolling windows for temporal features (memory efficient)
|
||||
self.recent_actions = deque(maxlen=100)
|
||||
self.host_history = defaultdict(lambda: deque(maxlen=50))
|
||||
|
||||
# Initialize feature tables
|
||||
self._ensure_tables_exist()
|
||||
|
||||
logger.info("FeatureLogger initialized - auto-discovery mode enabled")
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
# DATABASE SCHEMA
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
|
||||
def _ensure_tables_exist(self):
|
||||
"""Create feature logging tables if they don't exist"""
|
||||
try:
|
||||
# Main feature log table
|
||||
self.db.execute("""
|
||||
CREATE TABLE IF NOT EXISTS ml_features (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
timestamp TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
|
||||
-- Identifiers
|
||||
mac_address TEXT,
|
||||
ip_address TEXT,
|
||||
action_name TEXT,
|
||||
|
||||
-- Context features (JSON)
|
||||
host_features TEXT, -- Vendor, ports, services, etc.
|
||||
network_features TEXT, -- Topology, neighbors, subnets
|
||||
temporal_features TEXT, -- Time patterns, sequences
|
||||
action_features TEXT, -- Action-specific metadata
|
||||
|
||||
-- Outcome
|
||||
success INTEGER,
|
||||
duration_seconds REAL,
|
||||
reward REAL,
|
||||
|
||||
-- Raw event data (for replay)
|
||||
raw_event TEXT,
|
||||
|
||||
-- Consolidation status
|
||||
consolidated INTEGER DEFAULT 0,
|
||||
export_batch_id INTEGER
|
||||
)
|
||||
""")
|
||||
|
||||
# Index for fast queries
|
||||
self.db.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_ml_features_mac
|
||||
ON ml_features(mac_address, timestamp DESC)
|
||||
""")
|
||||
|
||||
self.db.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_ml_features_consolidated
|
||||
ON ml_features(consolidated, timestamp)
|
||||
""")
|
||||
|
||||
# Aggregated features table (pre-computed for efficiency)
|
||||
self.db.execute("""
|
||||
CREATE TABLE IF NOT EXISTS ml_features_aggregated (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
computed_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
|
||||
mac_address TEXT,
|
||||
time_window TEXT, -- 'hourly', 'daily', 'weekly'
|
||||
|
||||
-- Aggregated metrics
|
||||
total_actions INTEGER,
|
||||
success_rate REAL,
|
||||
avg_duration REAL,
|
||||
total_reward REAL,
|
||||
|
||||
-- Action distribution
|
||||
action_counts TEXT, -- JSON: {action_name: count}
|
||||
|
||||
-- Discovery metrics
|
||||
new_ports_found INTEGER,
|
||||
new_services_found INTEGER,
|
||||
credentials_found INTEGER,
|
||||
|
||||
-- Feature vector (for DL)
|
||||
feature_vector TEXT, -- JSON array of numerical features
|
||||
|
||||
UNIQUE(mac_address, time_window, computed_at)
|
||||
)
|
||||
""")
|
||||
|
||||
# Export batches tracking
|
||||
self.db.execute("""
|
||||
CREATE TABLE IF NOT EXISTS ml_export_batches (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
created_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
record_count INTEGER,
|
||||
file_path TEXT,
|
||||
status TEXT DEFAULT 'pending', -- pending, exported, transferred
|
||||
notes TEXT
|
||||
)
|
||||
""")
|
||||
|
||||
logger.info("ML feature tables initialized")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create ML tables: {e}")
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
# AUTOMATIC FEATURE EXTRACTION
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
|
||||
def log_action_execution(
|
||||
self,
|
||||
mac_address: str,
|
||||
ip_address: str,
|
||||
action_name: str,
|
||||
success: bool,
|
||||
duration: float,
|
||||
reward: float,
|
||||
raw_event: Dict[str, Any]
|
||||
):
|
||||
"""
|
||||
Log a complete action execution with automatically extracted features.
|
||||
|
||||
Args:
|
||||
mac_address: Target MAC address
|
||||
ip_address: Target IP address
|
||||
action_name: Name of executed action
|
||||
success: Whether action succeeded
|
||||
duration: Execution time in seconds
|
||||
reward: Calculated reward value
|
||||
raw_event: Complete event data (for replay/debugging)
|
||||
"""
|
||||
try:
|
||||
# Shield against missing MAC
|
||||
if not mac_address:
|
||||
logger.debug("Skipping ML log: missing MAC address")
|
||||
return
|
||||
|
||||
# Extract features from multiple sources
|
||||
host_features = self._extract_host_features(mac_address, ip_address)
|
||||
network_features = self._extract_network_features(mac_address)
|
||||
temporal_features = self._extract_temporal_features(mac_address, action_name)
|
||||
action_features = self._extract_action_features(action_name, raw_event)
|
||||
|
||||
# Store in database
|
||||
self.db.execute("""
|
||||
INSERT INTO ml_features (
|
||||
mac_address, ip_address, action_name,
|
||||
host_features, network_features, temporal_features, action_features,
|
||||
success, duration_seconds, reward, raw_event
|
||||
)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""", (
|
||||
mac_address, ip_address, action_name,
|
||||
json.dumps(host_features),
|
||||
json.dumps(network_features),
|
||||
json.dumps(temporal_features),
|
||||
json.dumps(action_features),
|
||||
1 if success else 0,
|
||||
duration,
|
||||
reward,
|
||||
json.dumps(raw_event)
|
||||
))
|
||||
|
||||
# Update rolling windows
|
||||
self.recent_actions.append({
|
||||
'mac': mac_address,
|
||||
'action': action_name,
|
||||
'success': success,
|
||||
'timestamp': time.time()
|
||||
})
|
||||
|
||||
self.host_history[mac_address].append({
|
||||
'action': action_name,
|
||||
'success': success,
|
||||
'timestamp': time.time()
|
||||
})
|
||||
self._prune_host_history()
|
||||
|
||||
logger.debug(
|
||||
f"Logged features for {action_name} on {mac_address} "
|
||||
f"(success={success}, features={len(host_features)}+{len(network_features)}+"
|
||||
f"{len(temporal_features)}+{len(action_features)})"
|
||||
)
|
||||
|
||||
# Prune old database records to save disk space (keep last 1000)
|
||||
if random.random() < 0.05: # 5% chance to prune to avoid overhead every hit
|
||||
self._prune_database_records()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to log action execution: {e}")
|
||||
|
||||
def _prune_host_history(self):
|
||||
"""Bound host_history keys to avoid unbounded growth over very long runtimes."""
|
||||
try:
|
||||
current_size = len(self.host_history)
|
||||
if current_size <= self._max_hosts_tracked:
|
||||
return
|
||||
|
||||
overflow = current_size - self._max_hosts_tracked
|
||||
ranked = []
|
||||
for mac, entries in self.host_history.items():
|
||||
if entries:
|
||||
ranked.append((entries[-1]['timestamp'], mac))
|
||||
else:
|
||||
ranked.append((0.0, mac))
|
||||
ranked.sort(key=lambda x: x[0]) # oldest first
|
||||
|
||||
for _, mac in ranked[:overflow]:
|
||||
self.host_history.pop(mac, None)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _prune_database_records(self, limit: int = 1000):
|
||||
"""Keep the ml_features table within a reasonable size limit."""
|
||||
try:
|
||||
self.db.execute(f"""
|
||||
DELETE FROM ml_features
|
||||
WHERE id NOT IN (
|
||||
SELECT id FROM ml_features
|
||||
ORDER BY timestamp DESC
|
||||
LIMIT {limit}
|
||||
)
|
||||
""")
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to prune ml_features: {e}")
|
||||
|
||||
def _extract_host_features(self, mac: str, ip: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Extract features about the target host.
|
||||
Auto-discovers all relevant attributes from database.
|
||||
"""
|
||||
features = {}
|
||||
|
||||
try:
|
||||
# Get host data
|
||||
host = self.db.get_host_by_mac(mac)
|
||||
if not host:
|
||||
return features
|
||||
|
||||
# Basic identifiers (hashed for privacy if needed)
|
||||
features['mac_hash'] = hashlib.md5(mac.encode()).hexdigest()[:8]
|
||||
features['vendor_oui'] = mac[:8].upper() if mac else None
|
||||
|
||||
# Vendor classification
|
||||
vendor = host.get('vendor', '')
|
||||
features['vendor'] = vendor
|
||||
features['vendor_category'] = self._categorize_vendor(vendor)
|
||||
|
||||
# Network interfaces
|
||||
ips = [p.strip() for p in (host.get('ips', '') or '').split(';') if p.strip()]
|
||||
features['ip_count'] = len(ips)
|
||||
features['has_multiple_ips'] = len(ips) > 1
|
||||
|
||||
# Subnet classification
|
||||
if ips:
|
||||
features['subnet'] = '.'.join(ips[0].split('.')[:3]) + '.0/24'
|
||||
features['is_private'] = self._is_private_ip(ips[0])
|
||||
|
||||
# Open ports
|
||||
ports_str = host.get('ports', '') or ''
|
||||
ports = [int(p) for p in ports_str.split(';') if p.strip().isdigit()]
|
||||
features['port_count'] = len(ports)
|
||||
features['ports'] = sorted(ports)[:20] # Limit to top 20
|
||||
|
||||
# Port profiles (auto-detect common patterns)
|
||||
features['port_profile'] = self._detect_port_profile(ports)
|
||||
features['has_ssh'] = 22 in ports
|
||||
features['has_http'] = 80 in ports or 8080 in ports
|
||||
features['has_https'] = 443 in ports
|
||||
features['has_smb'] = 445 in ports
|
||||
features['has_rdp'] = 3389 in ports
|
||||
features['has_database'] = any(p in ports for p in [3306, 5432, 1433, 27017])
|
||||
|
||||
# Services detected
|
||||
services = self._get_services_for_host(mac)
|
||||
features['service_count'] = len(services)
|
||||
features['services'] = services
|
||||
|
||||
# Hostnames
|
||||
hostnames = [h.strip() for h in (host.get('hostnames', '') or '').split(';') if h.strip()]
|
||||
features['hostname_count'] = len(hostnames)
|
||||
if hostnames:
|
||||
features['primary_hostname'] = hostnames[0]
|
||||
features['hostname_hints'] = self._extract_hostname_hints(hostnames[0])
|
||||
|
||||
# First/last seen
|
||||
features['first_seen'] = host.get('first_seen')
|
||||
features['last_seen'] = host.get('last_seen')
|
||||
|
||||
# Calculate age
|
||||
if host.get('first_seen'):
|
||||
ts = host['first_seen']
|
||||
if isinstance(ts, str):
|
||||
try:
|
||||
first_seen_dt = datetime.fromisoformat(ts)
|
||||
except ValueError:
|
||||
# Fallback for other formats if needed
|
||||
first_seen_dt = datetime.now()
|
||||
elif isinstance(ts, datetime):
|
||||
first_seen_dt = ts
|
||||
else:
|
||||
first_seen_dt = datetime.now()
|
||||
|
||||
age_hours = (datetime.now() - first_seen_dt).total_seconds() / 3600
|
||||
features['age_hours'] = round(age_hours, 2)
|
||||
features['is_new'] = age_hours < 24
|
||||
|
||||
# Credentials found
|
||||
creds = self._get_credentials_for_host(mac)
|
||||
features['credential_count'] = len(creds)
|
||||
features['has_credentials'] = len(creds) > 0
|
||||
|
||||
# OS fingerprinting hints
|
||||
features['os_hints'] = self._guess_os(vendor, ports, hostnames)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting host features: {e}")
|
||||
|
||||
return features
|
||||
|
||||
def _extract_network_features(self, mac: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Extract network topology and relationship features.
|
||||
Discovers patterns in the network structure.
|
||||
"""
|
||||
features = {}
|
||||
|
||||
try:
|
||||
# Get all hosts
|
||||
all_hosts = self.db.get_all_hosts()
|
||||
|
||||
# Network size
|
||||
features['total_hosts'] = len(all_hosts)
|
||||
|
||||
# Subnet distribution
|
||||
subnet_counts = defaultdict(int)
|
||||
for h in all_hosts:
|
||||
ips = [p.strip() for p in (h.get('ips', '') or '').split(';') if p.strip()]
|
||||
if ips:
|
||||
subnet = '.'.join(ips[0].split('.')[:3]) + '.0'
|
||||
subnet_counts[subnet] += 1
|
||||
|
||||
features['subnet_count'] = len(subnet_counts)
|
||||
features['largest_subnet_size'] = max(subnet_counts.values()) if subnet_counts else 0
|
||||
|
||||
# Similar hosts (same vendor)
|
||||
target_host = self.db.get_host_by_mac(mac)
|
||||
if target_host:
|
||||
vendor = target_host.get('vendor', '')
|
||||
similar = sum(1 for h in all_hosts if h.get('vendor') == vendor)
|
||||
features['similar_vendor_count'] = similar
|
||||
|
||||
# Port correlation (hosts with similar port profiles)
|
||||
target_ports = set()
|
||||
if target_host:
|
||||
ports_str = target_host.get('ports', '') or ''
|
||||
target_ports = {int(p) for p in ports_str.split(';') if p.strip().isdigit()}
|
||||
|
||||
if target_ports:
|
||||
similar_port_hosts = 0
|
||||
for h in all_hosts:
|
||||
if h.get('mac_address') == mac:
|
||||
continue
|
||||
ports_str = h.get('ports', '') or ''
|
||||
other_ports = {int(p) for p in ports_str.split(';') if p.strip().isdigit()}
|
||||
|
||||
# Calculate Jaccard similarity
|
||||
if other_ports:
|
||||
intersection = len(target_ports & other_ports)
|
||||
union = len(target_ports | other_ports)
|
||||
similarity = intersection / union if union > 0 else 0
|
||||
if similarity > 0.5: # >50% similar
|
||||
similar_port_hosts += 1
|
||||
|
||||
features['similar_port_profile_count'] = similar_port_hosts
|
||||
|
||||
# Network activity level
|
||||
recent_hosts = sum(1 for h in all_hosts
|
||||
if self._is_recently_active(h.get('last_seen')))
|
||||
features['active_host_ratio'] = round(recent_hosts / len(all_hosts), 2) if all_hosts else 0
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting network features: {e}")
|
||||
|
||||
return features
|
||||
|
||||
def _extract_temporal_features(self, mac: str, action: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Extract time-based and sequence features.
|
||||
Discovers temporal patterns in attack sequences.
|
||||
"""
|
||||
features = {}
|
||||
|
||||
try:
|
||||
# Current time features
|
||||
now = datetime.now()
|
||||
features['hour_of_day'] = now.hour
|
||||
features['day_of_week'] = now.weekday()
|
||||
features['is_weekend'] = now.weekday() >= 5
|
||||
features['is_night'] = now.hour < 6 or now.hour >= 22
|
||||
|
||||
# Action history for this host
|
||||
history = list(self.host_history.get(mac, []))
|
||||
features['previous_action_count'] = len(history)
|
||||
|
||||
if history:
|
||||
# Last action
|
||||
last = history[-1]
|
||||
features['last_action'] = last['action']
|
||||
features['last_action_success'] = last['success']
|
||||
features['seconds_since_last'] = round(time.time() - last['timestamp'], 1)
|
||||
|
||||
# Success rate history
|
||||
successes = sum(1 for h in history if h['success'])
|
||||
features['historical_success_rate'] = round(successes / len(history), 2)
|
||||
|
||||
# Action sequence
|
||||
recent_sequence = [h['action'] for h in history[-5:]]
|
||||
features['recent_action_sequence'] = recent_sequence
|
||||
|
||||
# Repeated action detection
|
||||
same_action_count = sum(1 for h in history if h['action'] == action)
|
||||
features['same_action_attempts'] = same_action_count
|
||||
features['is_retry'] = same_action_count > 0
|
||||
|
||||
# Global action patterns
|
||||
recent = list(self.recent_actions)
|
||||
if recent:
|
||||
# Action distribution in recent history
|
||||
action_counts = defaultdict(int)
|
||||
for a in recent:
|
||||
action_counts[a['action']] += 1
|
||||
|
||||
features['most_common_recent_action'] = max(
|
||||
action_counts.items(),
|
||||
key=lambda x: x[1]
|
||||
)[0] if action_counts else None
|
||||
|
||||
# Global success rate
|
||||
global_successes = sum(1 for a in recent if a['success'])
|
||||
features['global_success_rate'] = round(
|
||||
global_successes / len(recent), 2
|
||||
)
|
||||
|
||||
# Time since first seen
|
||||
host = self.db.get_host_by_mac(mac)
|
||||
if host and host.get('first_seen'):
|
||||
ts = host['first_seen']
|
||||
if isinstance(ts, str):
|
||||
try:
|
||||
first_seen = datetime.fromisoformat(ts)
|
||||
except ValueError:
|
||||
first_seen = now
|
||||
elif isinstance(ts, datetime):
|
||||
first_seen = ts
|
||||
else:
|
||||
first_seen = now
|
||||
|
||||
features['hours_since_discovery'] = round(
|
||||
(now - first_seen).total_seconds() / 3600, 1
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting temporal features: {e}")
|
||||
|
||||
return features
|
||||
|
||||
def _extract_action_features(self, action_name: str, raw_event: Dict) -> Dict[str, Any]:
|
||||
"""
|
||||
Extract action-specific features.
|
||||
Auto-discovers relevant metadata from action execution.
|
||||
"""
|
||||
features = {}
|
||||
|
||||
try:
|
||||
features['action_name'] = action_name
|
||||
|
||||
# Action type classification
|
||||
features['action_type'] = self._classify_action_type(action_name)
|
||||
|
||||
# Port-specific actions
|
||||
port = raw_event.get('port')
|
||||
if port:
|
||||
features['target_port'] = int(port)
|
||||
features['is_standard_port'] = int(port) < 1024
|
||||
|
||||
# Extract any additional metadata from raw event
|
||||
# This allows actions to add custom features
|
||||
if 'metadata' in raw_event:
|
||||
metadata = raw_event['metadata']
|
||||
if isinstance(metadata, dict):
|
||||
# Flatten metadata into features
|
||||
for key, value in metadata.items():
|
||||
if isinstance(value, (int, float, bool, str)):
|
||||
features[f'meta_{key}'] = value
|
||||
|
||||
# Execution context
|
||||
features['operation_mode'] = self.shared_data.operation_mode
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting action features: {e}")
|
||||
|
||||
return features
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
# HELPER METHODS
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
|
||||
def _categorize_vendor(self, vendor: str) -> str:
|
||||
"""Categorize vendor into high-level groups"""
|
||||
if not vendor:
|
||||
return 'unknown'
|
||||
|
||||
vendor_lower = vendor.lower()
|
||||
|
||||
categories = {
|
||||
'networking': ['cisco', 'juniper', 'ubiquiti', 'mikrotik', 'tp-link', 'netgear', 'asus', 'd-link', 'linksys'],
|
||||
'iot': ['hikvision', 'dahua', 'axis', 'hanwha', 'tuya', 'sonoff', 'shelly', 'xiaomi', 'yeelight'],
|
||||
'nas': ['synology', 'qnap', 'netapp', 'truenas', 'unraid'],
|
||||
'compute': ['raspberry', 'intel', 'apple', 'dell', 'hp', 'lenovo', 'acer'],
|
||||
'virtualization': ['vmware', 'microsoft', 'citrix', 'proxmox'],
|
||||
'mobile': ['apple', 'samsung', 'huawei', 'xiaomi', 'google', 'oneplus']
|
||||
}
|
||||
|
||||
for category, vendors in categories.items():
|
||||
if any(v in vendor_lower for v in vendors):
|
||||
return category
|
||||
|
||||
return 'other'
|
||||
|
||||
def _is_private_ip(self, ip: str) -> bool:
|
||||
"""Check if IP is in private range"""
|
||||
if not ip:
|
||||
return False
|
||||
|
||||
parts = ip.split('.')
|
||||
if len(parts) != 4:
|
||||
return False
|
||||
|
||||
try:
|
||||
first = int(parts[0])
|
||||
second = int(parts[1])
|
||||
|
||||
return (
|
||||
first == 10 or
|
||||
(first == 172 and 16 <= second <= 31) or
|
||||
(first == 192 and second == 168)
|
||||
)
|
||||
except:
|
||||
return False
|
||||
|
||||
def _detect_port_profile(self, ports: List[int]) -> str:
|
||||
"""Auto-detect device type from port signature"""
|
||||
if not ports:
|
||||
return 'unknown'
|
||||
|
||||
port_set = set(ports)
|
||||
|
||||
profiles = {
|
||||
'camera': {554, 80, 8000, 37777},
|
||||
'web_server': {80, 443, 8080, 8443},
|
||||
'nas': {5000, 5001, 548, 139, 445},
|
||||
'database': {3306, 5432, 1433, 27017, 6379},
|
||||
'linux_server': {22, 80, 443},
|
||||
'windows_server': {135, 139, 445, 3389},
|
||||
'printer': {9100, 515, 631},
|
||||
'router': {22, 23, 80, 443, 161}
|
||||
}
|
||||
|
||||
max_overlap = 0
|
||||
best_profile = 'generic'
|
||||
|
||||
for profile_name, profile_ports in profiles.items():
|
||||
overlap = len(port_set & profile_ports)
|
||||
if overlap > max_overlap:
|
||||
max_overlap = overlap
|
||||
best_profile = profile_name
|
||||
|
||||
return best_profile if max_overlap >= 2 else 'generic'
|
||||
|
||||
def _get_services_for_host(self, mac: str) -> List[str]:
|
||||
"""Get list of detected services for host"""
|
||||
try:
|
||||
results = self.db.query("""
|
||||
SELECT DISTINCT service
|
||||
FROM port_services
|
||||
WHERE mac_address=?
|
||||
""", (mac,))
|
||||
|
||||
return [r['service'] for r in results if r.get('service')]
|
||||
except:
|
||||
return []
|
||||
|
||||
def _extract_hostname_hints(self, hostname: str) -> List[str]:
|
||||
"""Extract hints from hostname"""
|
||||
if not hostname:
|
||||
return []
|
||||
|
||||
hints = []
|
||||
hostname_lower = hostname.lower()
|
||||
|
||||
keywords = {
|
||||
'nas': ['nas', 'storage', 'diskstation'],
|
||||
'camera': ['cam', 'ipc', 'nvr', 'dvr'],
|
||||
'router': ['router', 'gateway', 'gw'],
|
||||
'server': ['server', 'srv', 'host'],
|
||||
'printer': ['printer', 'print'],
|
||||
'iot': ['iot', 'sensor', 'smart']
|
||||
}
|
||||
|
||||
for hint, words in keywords.items():
|
||||
if any(word in hostname_lower for word in words):
|
||||
hints.append(hint)
|
||||
|
||||
return hints
|
||||
|
||||
def _get_credentials_for_host(self, mac: str) -> List[Dict]:
|
||||
"""Get credentials found for host"""
|
||||
try:
|
||||
return self.db.query("""
|
||||
SELECT service, user, port
|
||||
FROM creds
|
||||
WHERE mac_address=?
|
||||
""", (mac,))
|
||||
except:
|
||||
return []
|
||||
|
||||
def _guess_os(self, vendor: str, ports: List[int], hostnames: List[str]) -> str:
|
||||
"""Guess OS from available indicators"""
|
||||
if not vendor and not ports and not hostnames:
|
||||
return 'unknown'
|
||||
|
||||
vendor_lower = (vendor or '').lower()
|
||||
port_set = set(ports or [])
|
||||
hostname = hostnames[0].lower() if hostnames else ''
|
||||
|
||||
# Strong indicators
|
||||
if 'microsoft' in vendor_lower or 3389 in port_set:
|
||||
return 'windows'
|
||||
if 'apple' in vendor_lower or 'mac' in hostname:
|
||||
return 'macos'
|
||||
if 'raspberry' in vendor_lower:
|
||||
return 'linux'
|
||||
|
||||
# Port-based guessing
|
||||
if {22, 80} <= port_set:
|
||||
return 'linux'
|
||||
if {135, 139, 445} <= port_set:
|
||||
return 'windows'
|
||||
|
||||
# Hostname hints
|
||||
if any(word in hostname for word in ['ubuntu', 'debian', 'centos', 'rhel']):
|
||||
return 'linux'
|
||||
|
||||
return 'unknown'
|
||||
|
||||
def _is_recently_active(self, last_seen: Optional[str]) -> bool:
|
||||
"""Check if host was active in last 24h"""
|
||||
if not last_seen:
|
||||
return False
|
||||
|
||||
try:
|
||||
if isinstance(last_seen, str):
|
||||
last_seen_dt = datetime.fromisoformat(last_seen)
|
||||
elif isinstance(last_seen, datetime):
|
||||
last_seen_dt = last_seen
|
||||
else:
|
||||
return False
|
||||
|
||||
hours_ago = (datetime.now() - last_seen_dt).total_seconds() / 3600
|
||||
return hours_ago < 24
|
||||
except:
|
||||
return False
|
||||
|
||||
def _classify_action_type(self, action_name: str) -> str:
|
||||
"""Classify action into high-level categories"""
|
||||
action_lower = action_name.lower()
|
||||
|
||||
if 'brute' in action_lower or 'crack' in action_lower:
|
||||
return 'bruteforce'
|
||||
elif 'scan' in action_lower or 'enum' in action_lower:
|
||||
return 'enumeration'
|
||||
elif 'exploit' in action_lower:
|
||||
return 'exploitation'
|
||||
elif 'dump' in action_lower or 'extract' in action_lower:
|
||||
return 'extraction'
|
||||
else:
|
||||
return 'other'
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
# FEATURE AGGREGATION & EXPORT
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get current feature logging statistics"""
|
||||
try:
|
||||
total = self.db.query("SELECT COUNT(*) as cnt FROM ml_features")[0]['cnt']
|
||||
unconsolidated = self.db.query(
|
||||
"SELECT COUNT(*) as cnt FROM ml_features WHERE consolidated=0"
|
||||
)[0]['cnt']
|
||||
|
||||
return {
|
||||
'total_features_logged': total,
|
||||
'unconsolidated_count': unconsolidated,
|
||||
'ready_for_export': unconsolidated,
|
||||
'recent_actions_buffer': len(self.recent_actions),
|
||||
'hosts_tracked': len(self.host_history)
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting feature stats: {e}")
|
||||
return {}
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
# END OF FILE
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
Reference in New Issue
Block a user