mirror of
https://github.com/infinition/Bjorn.git
synced 2025-12-13 08:04:59 +00:00
BREAKING CHANGE: Complete refactor of architecture to prepare BJORN V2 release, APIs, assets, and UI, webapp, logics, attacks, a lot of new features...
This commit is contained in:
38
db_utils/__init__.py
Normal file
38
db_utils/__init__.py
Normal file
@@ -0,0 +1,38 @@
|
||||
# db_utils/__init__.py
|
||||
# Database utilities package
|
||||
|
||||
from .base import DatabaseBase
|
||||
from .config import ConfigOps
|
||||
from .hosts import HostOps
|
||||
from .actions import ActionOps
|
||||
from .queue import QueueOps
|
||||
from .vulnerabilities import VulnerabilityOps
|
||||
from .software import SoftwareOps
|
||||
from .credentials import CredentialOps
|
||||
from .services import ServiceOps
|
||||
from .scripts import ScriptOps
|
||||
from .stats import StatsOps
|
||||
from .backups import BackupOps
|
||||
from .comments import CommentOps
|
||||
from .agents import AgentOps
|
||||
from .studio import StudioOps
|
||||
from .webenum import WebEnumOps
|
||||
|
||||
__all__ = [
|
||||
'DatabaseBase',
|
||||
'ConfigOps',
|
||||
'HostOps',
|
||||
'ActionOps',
|
||||
'QueueOps',
|
||||
'VulnerabilityOps',
|
||||
'SoftwareOps',
|
||||
'CredentialOps',
|
||||
'ServiceOps',
|
||||
'ScriptOps',
|
||||
'StatsOps',
|
||||
'BackupOps',
|
||||
'CommentOps',
|
||||
'AgentOps',
|
||||
'StudioOps',
|
||||
'WebEnumOps',
|
||||
]
|
||||
293
db_utils/actions.py
Normal file
293
db_utils/actions.py
Normal file
@@ -0,0 +1,293 @@
|
||||
# db_utils/actions.py
|
||||
# Action definition and management operations
|
||||
|
||||
import json
|
||||
import sqlite3
|
||||
from functools import lru_cache
|
||||
from typing import Any, Dict, List, Optional
|
||||
import logging
|
||||
|
||||
from logger import Logger
|
||||
|
||||
logger = Logger(name="db_utils.actions", level=logging.DEBUG)
|
||||
|
||||
|
||||
class ActionOps:
|
||||
"""Action definition and configuration operations"""
|
||||
|
||||
def __init__(self, base):
|
||||
self.base = base
|
||||
|
||||
def create_tables(self):
|
||||
"""Create actions table"""
|
||||
self.base.execute("""
|
||||
CREATE TABLE IF NOT EXISTS actions (
|
||||
b_class TEXT PRIMARY KEY,
|
||||
b_module TEXT NOT NULL,
|
||||
b_port INTEGER,
|
||||
b_status TEXT,
|
||||
b_parent TEXT,
|
||||
b_args TEXT,
|
||||
b_description TEXT,
|
||||
b_name TEXT,
|
||||
b_author TEXT,
|
||||
b_version TEXT,
|
||||
b_icon TEXT,
|
||||
b_docs_url TEXT,
|
||||
b_examples TEXT,
|
||||
b_action TEXT DEFAULT 'normal',
|
||||
b_service TEXT,
|
||||
b_trigger TEXT,
|
||||
b_requires TEXT,
|
||||
b_priority INTEGER DEFAULT 50,
|
||||
b_tags TEXT,
|
||||
b_timeout INTEGER DEFAULT 300,
|
||||
b_max_retries INTEGER DEFAULT 3,
|
||||
b_cooldown INTEGER DEFAULT 0,
|
||||
b_rate_limit TEXT,
|
||||
b_stealth_level INTEGER DEFAULT 5,
|
||||
b_risk_level TEXT DEFAULT 'medium',
|
||||
b_enabled INTEGER DEFAULT 1
|
||||
);
|
||||
""")
|
||||
logger.debug("Actions table created/verified")
|
||||
|
||||
# =========================================================================
|
||||
# ACTION CRUD OPERATIONS
|
||||
# =========================================================================
|
||||
|
||||
def sync_actions(self, actions):
|
||||
"""Sync action definitions to database"""
|
||||
if not actions:
|
||||
return
|
||||
|
||||
def _as_int(x, default=None):
|
||||
if x is None:
|
||||
return default
|
||||
if isinstance(x, (list, tuple)):
|
||||
x = x[0] if x else default
|
||||
try:
|
||||
return int(x)
|
||||
except Exception:
|
||||
return default
|
||||
|
||||
def _as_str(x, default=None):
|
||||
if x is None:
|
||||
return default
|
||||
if isinstance(x, (list, tuple, set, dict)):
|
||||
try:
|
||||
return json.dumps(list(x) if not isinstance(x, dict) else x, ensure_ascii=False)
|
||||
except Exception:
|
||||
return default
|
||||
return str(x)
|
||||
|
||||
def _as_json(x):
|
||||
if x is None:
|
||||
return None
|
||||
if isinstance(x, str):
|
||||
xs = x.strip()
|
||||
if (xs.startswith("{") and xs.endswith("}")) or (xs.startswith("[") and xs.endswith("]")):
|
||||
return xs
|
||||
return json.dumps(x, ensure_ascii=False)
|
||||
try:
|
||||
return json.dumps(x, ensure_ascii=False)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
with self.base.transaction():
|
||||
for a in actions:
|
||||
# Normalize fields
|
||||
b_service = a.get("b_service")
|
||||
if isinstance(b_service, (list, tuple, set, dict)):
|
||||
b_service = json.dumps(list(b_service) if not isinstance(b_service, dict) else b_service, ensure_ascii=False)
|
||||
|
||||
b_tags = a.get("b_tags")
|
||||
if isinstance(b_tags, (list, tuple, set, dict)):
|
||||
b_tags = json.dumps(list(b_tags) if not isinstance(b_tags, dict) else b_tags, ensure_ascii=False)
|
||||
|
||||
b_trigger = a.get("b_trigger")
|
||||
if isinstance(b_trigger, (list, tuple, set, dict)):
|
||||
b_trigger = json.dumps(b_trigger, ensure_ascii=False)
|
||||
|
||||
b_requires = a.get("b_requires")
|
||||
if isinstance(b_requires, (list, tuple, set, dict)):
|
||||
b_requires = json.dumps(b_requires, ensure_ascii=False)
|
||||
|
||||
b_args_json = _as_json(a.get("b_args"))
|
||||
|
||||
# Enriched metadata
|
||||
b_name = _as_str(a.get("b_name"))
|
||||
b_description = _as_str(a.get("b_description"))
|
||||
b_author = _as_str(a.get("b_author"))
|
||||
b_version = _as_str(a.get("b_version"))
|
||||
b_icon = _as_str(a.get("b_icon"))
|
||||
b_docs_url = _as_str(a.get("b_docs_url"))
|
||||
b_examples = _as_json(a.get("b_examples"))
|
||||
|
||||
# Typed fields
|
||||
b_port = _as_int(a.get("b_port"))
|
||||
b_priority = _as_int(a.get("b_priority"), 50)
|
||||
b_timeout = _as_int(a.get("b_timeout"), 300)
|
||||
b_max_retries = _as_int(a.get("b_max_retries"), 3)
|
||||
b_cooldown = _as_int(a.get("b_cooldown"), 0)
|
||||
b_stealth_level = _as_int(a.get("b_stealth_level"), 5)
|
||||
b_enabled = _as_int(a.get("b_enabled"), 1)
|
||||
b_rate_limit = _as_str(a.get("b_rate_limit"))
|
||||
b_risk_level = _as_str(a.get("b_risk_level"), "medium")
|
||||
|
||||
self.base.execute("""
|
||||
INSERT INTO actions (
|
||||
b_class,b_module,b_port,b_status,b_parent,
|
||||
b_action,b_service,b_trigger,b_requires,b_priority,
|
||||
b_tags,b_timeout,b_max_retries,b_cooldown,b_rate_limit,
|
||||
b_stealth_level,b_risk_level,b_enabled,
|
||||
b_args,
|
||||
b_name, b_description, b_author, b_version, b_icon, b_docs_url, b_examples
|
||||
) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,
|
||||
?,?,?,?,?,?,?)
|
||||
ON CONFLICT(b_class) DO UPDATE SET
|
||||
b_module = excluded.b_module,
|
||||
b_port = COALESCE(excluded.b_port, actions.b_port),
|
||||
b_status = COALESCE(excluded.b_status, actions.b_status),
|
||||
b_parent = COALESCE(excluded.b_parent, actions.b_parent),
|
||||
b_action = COALESCE(excluded.b_action, actions.b_action),
|
||||
b_service = COALESCE(excluded.b_service, actions.b_service),
|
||||
b_trigger = COALESCE(excluded.b_trigger, actions.b_trigger),
|
||||
b_requires = COALESCE(excluded.b_requires, actions.b_requires),
|
||||
b_priority = COALESCE(excluded.b_priority, actions.b_priority),
|
||||
b_tags = COALESCE(excluded.b_tags, actions.b_tags),
|
||||
b_timeout = COALESCE(excluded.b_timeout, actions.b_timeout),
|
||||
b_max_retries = COALESCE(excluded.b_max_retries, actions.b_max_retries),
|
||||
b_cooldown = COALESCE(excluded.b_cooldown, actions.b_cooldown),
|
||||
b_rate_limit = COALESCE(excluded.b_rate_limit, actions.b_rate_limit),
|
||||
b_stealth_level = COALESCE(excluded.b_stealth_level, actions.b_stealth_level),
|
||||
b_risk_level = COALESCE(excluded.b_risk_level, actions.b_risk_level),
|
||||
b_enabled = COALESCE(excluded.b_enabled, actions.b_enabled),
|
||||
b_args = COALESCE(excluded.b_args, actions.b_args),
|
||||
b_name = COALESCE(excluded.b_name, actions.b_name),
|
||||
b_description = COALESCE(excluded.b_description, actions.b_description),
|
||||
b_author = COALESCE(excluded.b_author, actions.b_author),
|
||||
b_version = COALESCE(excluded.b_version, actions.b_version),
|
||||
b_icon = COALESCE(excluded.b_icon, actions.b_icon),
|
||||
b_docs_url = COALESCE(excluded.b_docs_url, actions.b_docs_url),
|
||||
b_examples = COALESCE(excluded.b_examples, actions.b_examples)
|
||||
""", (
|
||||
a.get("b_class"),
|
||||
a.get("b_module"),
|
||||
b_port,
|
||||
a.get("b_status"),
|
||||
a.get("b_parent"),
|
||||
a.get("b_action", "normal"),
|
||||
b_service,
|
||||
b_trigger,
|
||||
b_requires,
|
||||
b_priority,
|
||||
b_tags,
|
||||
b_timeout,
|
||||
b_max_retries,
|
||||
b_cooldown,
|
||||
b_rate_limit,
|
||||
b_stealth_level,
|
||||
b_risk_level,
|
||||
b_enabled,
|
||||
b_args_json,
|
||||
b_name,
|
||||
b_description,
|
||||
b_author,
|
||||
b_version,
|
||||
b_icon,
|
||||
b_docs_url,
|
||||
b_examples
|
||||
))
|
||||
|
||||
# Update action counter in stats
|
||||
action_count_row = self.base.query_one("SELECT COUNT(*) as cnt FROM actions WHERE b_enabled = 1")
|
||||
if action_count_row:
|
||||
try:
|
||||
self.base.execute("""
|
||||
UPDATE stats
|
||||
SET actions_count = ?
|
||||
WHERE id = 1
|
||||
""", (action_count_row['cnt'],))
|
||||
except sqlite3.OperationalError:
|
||||
# Column doesn't exist yet, add it
|
||||
self.base.execute("ALTER TABLE stats ADD COLUMN actions_count INTEGER DEFAULT 0")
|
||||
self.base.execute("""
|
||||
UPDATE stats
|
||||
SET actions_count = ?
|
||||
WHERE id = 1
|
||||
""", (action_count_row['cnt'],))
|
||||
|
||||
logger.info(f"Synchronized {len(actions)} actions")
|
||||
|
||||
def list_actions(self):
|
||||
"""List all action definitions ordered by class name"""
|
||||
return self.base.query("SELECT * FROM actions ORDER BY b_class;")
|
||||
|
||||
def list_studio_actions(self):
|
||||
"""List all studio action definitions"""
|
||||
return self.base.query("SELECT * FROM actions_studio ORDER BY b_class;")
|
||||
|
||||
def get_action_by_class(self, b_class: str) -> dict | None:
|
||||
"""Get action by class name"""
|
||||
rows = self.base.query("SELECT * FROM actions WHERE b_class=? LIMIT 1;", (b_class,))
|
||||
return rows[0] if rows else None
|
||||
|
||||
def delete_action(self, b_class: str) -> None:
|
||||
"""Delete action by class name"""
|
||||
self.base.execute("DELETE FROM actions WHERE b_class=?;", (b_class,))
|
||||
|
||||
def upsert_simple_action(self, *, b_class: str, b_module: str, **kw) -> None:
|
||||
"""Minimal upsert of an action by reusing sync_actions"""
|
||||
rec = {"b_class": b_class, "b_module": b_module}
|
||||
rec.update(kw)
|
||||
self.sync_actions([rec])
|
||||
|
||||
def list_action_cards(self) -> list[dict]:
|
||||
"""Lightweight descriptor of actions for card-based UIs"""
|
||||
rows = self.base.query("""
|
||||
SELECT b_class, COALESCE(b_enabled, 0) AS b_enabled
|
||||
FROM actions
|
||||
ORDER BY b_class;
|
||||
""")
|
||||
out = []
|
||||
for r in rows:
|
||||
cls = r["b_class"]
|
||||
enabled = int(r["b_enabled"]) # 0 reste 0
|
||||
out.append({
|
||||
"name": cls,
|
||||
"image": f"/actions/actions_icons/{cls}.png",
|
||||
"enabled": enabled,
|
||||
})
|
||||
return out
|
||||
|
||||
# def list_action_cards(self) -> list[dict]:
|
||||
# """Lightweight descriptor of actions for card-based UIs"""
|
||||
# rows = self.base.query("""
|
||||
# SELECT b_class, b_enabled
|
||||
# FROM actions
|
||||
# ORDER BY b_class;
|
||||
# """)
|
||||
# out = []
|
||||
# for r in rows:
|
||||
# cls = r["b_class"]
|
||||
# out.append({
|
||||
# "name": cls,
|
||||
# "image": f"/actions/actions_icons/{cls}.png",
|
||||
# "enabled": int(r.get("b_enabled", 1) or 1),
|
||||
# })
|
||||
# return out
|
||||
|
||||
@lru_cache(maxsize=32)
|
||||
def get_action_definition(self, b_class: str) -> Optional[Dict[str, Any]]:
|
||||
"""Cached lookup of an action definition by class name"""
|
||||
row = self.base.query("SELECT * FROM actions WHERE b_class=? LIMIT 1;", (b_class,))
|
||||
if not row:
|
||||
return None
|
||||
r = row[0]
|
||||
if r.get("b_args"):
|
||||
try:
|
||||
r["b_args"] = json.loads(r["b_args"])
|
||||
except Exception:
|
||||
pass
|
||||
return r
|
||||
369
db_utils/agents.py
Normal file
369
db_utils/agents.py
Normal file
@@ -0,0 +1,369 @@
|
||||
# db_utils/agents.py
|
||||
# C2 (Command & Control) agent management operations
|
||||
|
||||
import json
|
||||
import os
|
||||
import sqlite3
|
||||
from typing import List, Optional
|
||||
import logging
|
||||
|
||||
from logger import Logger
|
||||
|
||||
logger = Logger(name="db_utils.agents", level=logging.DEBUG)
|
||||
|
||||
|
||||
class AgentOps:
|
||||
"""C2 agent tracking and command history operations"""
|
||||
|
||||
def __init__(self, base):
|
||||
self.base = base
|
||||
|
||||
def create_tables(self):
|
||||
"""Create C2 agent tables"""
|
||||
# Agents table
|
||||
self.base.execute("""
|
||||
CREATE TABLE IF NOT EXISTS agents (
|
||||
id TEXT PRIMARY KEY,
|
||||
hostname TEXT,
|
||||
platform TEXT,
|
||||
os_version TEXT,
|
||||
architecture TEXT,
|
||||
ip_address TEXT,
|
||||
first_seen TIMESTAMP,
|
||||
last_seen TIMESTAMP,
|
||||
status TEXT,
|
||||
notes TEXT
|
||||
);
|
||||
""")
|
||||
|
||||
# Indexes for performance
|
||||
self.base.execute("CREATE INDEX IF NOT EXISTS idx_agents_last_seen ON agents(last_seen);")
|
||||
self.base.execute("CREATE INDEX IF NOT EXISTS idx_agents_status ON agents(status);")
|
||||
|
||||
# Commands table
|
||||
self.base.execute("""
|
||||
CREATE TABLE IF NOT EXISTS commands (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
agent_id TEXT,
|
||||
command TEXT,
|
||||
timestamp TIMESTAMP,
|
||||
response TEXT,
|
||||
success BOOLEAN,
|
||||
FOREIGN KEY (agent_id) REFERENCES agents (id)
|
||||
);
|
||||
""")
|
||||
|
||||
# Agent keys (versioned for rotation)
|
||||
self.base.execute("""
|
||||
CREATE TABLE IF NOT EXISTS agent_keys (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
agent_id TEXT NOT NULL,
|
||||
key_b64 TEXT NOT NULL,
|
||||
version INTEGER NOT NULL,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
rotated_at TIMESTAMP,
|
||||
revoked_at TIMESTAMP,
|
||||
active INTEGER DEFAULT 1,
|
||||
UNIQUE(agent_id, version)
|
||||
);
|
||||
""")
|
||||
self.base.execute("CREATE INDEX IF NOT EXISTS idx_agent_keys_active ON agent_keys(agent_id, active);")
|
||||
|
||||
# Loot table
|
||||
self.base.execute("""
|
||||
CREATE TABLE IF NOT EXISTS loot (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
agent_id TEXT,
|
||||
filename TEXT,
|
||||
filepath TEXT,
|
||||
size INTEGER,
|
||||
timestamp TIMESTAMP,
|
||||
hash TEXT,
|
||||
FOREIGN KEY (agent_id) REFERENCES agents (id)
|
||||
);
|
||||
""")
|
||||
|
||||
# Telemetry table
|
||||
self.base.execute("""
|
||||
CREATE TABLE IF NOT EXISTS telemetry (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
agent_id TEXT,
|
||||
cpu_percent REAL,
|
||||
mem_percent REAL,
|
||||
disk_percent REAL,
|
||||
uptime INTEGER,
|
||||
timestamp TIMESTAMP,
|
||||
FOREIGN KEY (agent_id) REFERENCES agents (id)
|
||||
);
|
||||
""")
|
||||
|
||||
logger.debug("C2 agent tables created/verified")
|
||||
|
||||
# =========================================================================
|
||||
# AGENT OPERATIONS
|
||||
# =========================================================================
|
||||
|
||||
def save_agent(self, agent_data: dict) -> None:
|
||||
"""
|
||||
Upsert an agent preserving first_seen and updating last_seen.
|
||||
Status field expected as str (e.g. 'online'/'offline').
|
||||
"""
|
||||
agent_id = agent_data.get('id')
|
||||
hostname = agent_data.get('hostname')
|
||||
platform_ = agent_data.get('platform')
|
||||
os_version = agent_data.get('os_version')
|
||||
arch = agent_data.get('architecture')
|
||||
ip_address = agent_data.get('ip_address')
|
||||
status = agent_data.get('status') or 'offline'
|
||||
notes = agent_data.get('notes')
|
||||
|
||||
if not agent_id:
|
||||
raise ValueError("save_agent: 'id' is required in agent_data")
|
||||
|
||||
# Upsert that preserves first_seen and updates last_seen to NOW
|
||||
self.base.execute("""
|
||||
INSERT INTO agents (id, hostname, platform, os_version, architecture, ip_address,
|
||||
first_seen, last_seen, status, notes)
|
||||
VALUES (?, ?, ?, ?, ?, ?, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, ?, ?)
|
||||
ON CONFLICT(id) DO UPDATE SET
|
||||
hostname = COALESCE(excluded.hostname, agents.hostname),
|
||||
platform = COALESCE(excluded.platform, agents.platform),
|
||||
os_version = COALESCE(excluded.os_version, agents.os_version),
|
||||
architecture = COALESCE(excluded.architecture, agents.architecture),
|
||||
ip_address = COALESCE(excluded.ip_address, agents.ip_address),
|
||||
first_seen = COALESCE(agents.first_seen, excluded.first_seen, CURRENT_TIMESTAMP),
|
||||
last_seen = CURRENT_TIMESTAMP,
|
||||
status = COALESCE(excluded.status, agents.status),
|
||||
notes = COALESCE(excluded.notes, agents.notes)
|
||||
""", (agent_id, hostname, platform_, os_version, arch, ip_address, status, notes))
|
||||
|
||||
# Optionally refresh zombie counter
|
||||
try:
|
||||
self._refresh_zombie_counter()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def save_command(self, agent_id: str, command: str,
|
||||
response: str | None = None, success: bool = False) -> None:
|
||||
"""Record a command history entry"""
|
||||
if not agent_id or not command:
|
||||
raise ValueError("save_command: 'agent_id' and 'command' are required")
|
||||
self.base.execute("""
|
||||
INSERT INTO commands (agent_id, command, timestamp, response, success)
|
||||
VALUES (?, ?, CURRENT_TIMESTAMP, ?, ?)
|
||||
""", (agent_id, command, response, 1 if success else 0))
|
||||
|
||||
def save_telemetry(self, agent_id: str, telemetry: dict) -> None:
|
||||
"""Record a telemetry snapshot for an agent"""
|
||||
if not agent_id:
|
||||
raise ValueError("save_telemetry: 'agent_id' is required")
|
||||
self.base.execute("""
|
||||
INSERT INTO telemetry (agent_id, cpu_percent, mem_percent, disk_percent, uptime, timestamp)
|
||||
VALUES (?, ?, ?, ?, ?, CURRENT_TIMESTAMP)
|
||||
""", (
|
||||
agent_id,
|
||||
telemetry.get('cpu_percent'),
|
||||
telemetry.get('mem_percent'),
|
||||
telemetry.get('disk_percent'),
|
||||
telemetry.get('uptime')
|
||||
))
|
||||
|
||||
def save_loot(self, loot: dict) -> None:
|
||||
"""
|
||||
Record a retrieved file (loot).
|
||||
Expected: {'agent_id', 'filename', 'filepath', 'size', 'hash'}
|
||||
Timestamp is added database-side.
|
||||
"""
|
||||
if not loot or not loot.get('agent_id') or not loot.get('filename'):
|
||||
raise ValueError("save_loot: 'agent_id' and 'filename' are required")
|
||||
|
||||
self.base.execute("""
|
||||
INSERT INTO loot (agent_id, filename, filepath, size, timestamp, hash)
|
||||
VALUES (?, ?, ?, ?, CURRENT_TIMESTAMP, ?)
|
||||
""", (
|
||||
loot.get('agent_id'),
|
||||
loot.get('filename'),
|
||||
loot.get('filepath'),
|
||||
int(loot.get('size') or 0),
|
||||
loot.get('hash')
|
||||
))
|
||||
|
||||
def get_agent_history(self, agent_id: str) -> List[dict]:
|
||||
"""
|
||||
Return the 100 most recent commands for an agent (most recent first).
|
||||
"""
|
||||
if not agent_id:
|
||||
return []
|
||||
rows = self.base.query("""
|
||||
SELECT command, timestamp, response, success
|
||||
FROM commands
|
||||
WHERE agent_id = ?
|
||||
ORDER BY datetime(timestamp) DESC
|
||||
LIMIT 100
|
||||
""", (agent_id,))
|
||||
# Normalize success to bool
|
||||
for r in rows:
|
||||
r['success'] = bool(r.get('success'))
|
||||
return rows
|
||||
|
||||
def purge_stale_agents(self, threshold_seconds: int) -> int:
|
||||
"""
|
||||
Delete agents whose last_seen is older than now - threshold_seconds.
|
||||
Returns the number of deleted rows.
|
||||
"""
|
||||
if not threshold_seconds or threshold_seconds <= 0:
|
||||
return 0
|
||||
|
||||
return self.base.execute("""
|
||||
DELETE FROM agents
|
||||
WHERE last_seen IS NOT NULL
|
||||
AND datetime(last_seen) < datetime('now', ?)
|
||||
""", (f'-{threshold_seconds} seconds',))
|
||||
|
||||
def get_stale_agents(self, threshold_seconds: int) -> list[dict]:
|
||||
"""
|
||||
Return the list of agents whose last_seen is older than now - threshold_seconds.
|
||||
Useful for detecting/purging inactive agents.
|
||||
"""
|
||||
if not threshold_seconds or threshold_seconds <= 0:
|
||||
return []
|
||||
|
||||
rows = self.base.query("""
|
||||
SELECT *
|
||||
FROM agents
|
||||
WHERE last_seen IS NOT NULL
|
||||
AND datetime(last_seen) < datetime('now', ?)
|
||||
""", (f'-{threshold_seconds} seconds',))
|
||||
|
||||
return rows or []
|
||||
|
||||
# =========================================================================
|
||||
# AGENT KEY MANAGEMENT
|
||||
# =========================================================================
|
||||
|
||||
def get_active_key(self, agent_id: str) -> str | None:
|
||||
"""Return the active key (base64) for an agent, or None"""
|
||||
row = self.base.query_one("""
|
||||
SELECT key_b64 FROM agent_keys
|
||||
WHERE agent_id=? AND active=1
|
||||
ORDER BY version DESC
|
||||
LIMIT 1
|
||||
""", (agent_id,))
|
||||
return row["key_b64"] if row else None
|
||||
|
||||
def list_keys(self, agent_id: str) -> list[dict]:
|
||||
"""List all keys for an agent (versions, states)"""
|
||||
return self.base.query("""
|
||||
SELECT id, agent_id, key_b64, version, created_at, rotated_at, revoked_at, active
|
||||
FROM agent_keys
|
||||
WHERE agent_id=?
|
||||
ORDER BY version DESC
|
||||
""", (agent_id,))
|
||||
|
||||
def _next_key_version(self, agent_id: str) -> int:
|
||||
"""Get next key version number for an agent"""
|
||||
row = self.base.query_one("SELECT COALESCE(MAX(version),0) AS v FROM agent_keys WHERE agent_id=?", (agent_id,))
|
||||
return int(row["v"] or 0) + 1
|
||||
|
||||
def save_new_key(self, agent_id: str, key_b64: str) -> int:
|
||||
"""
|
||||
Record a first key for an agent (if no existing key).
|
||||
Returns the version created.
|
||||
"""
|
||||
v = self._next_key_version(agent_id)
|
||||
self.base.execute("""
|
||||
INSERT INTO agent_keys(agent_id, key_b64, version, active)
|
||||
VALUES(?,?,?,1)
|
||||
""", (agent_id, key_b64, v))
|
||||
return v
|
||||
|
||||
def rotate_key(self, agent_id: str, new_key_b64: str) -> int:
|
||||
"""
|
||||
Rotation: disable old active key (rotated_at), insert new one in version+1 active=1.
|
||||
Returns the new version.
|
||||
"""
|
||||
with self.base.transaction():
|
||||
# Disable existing active key
|
||||
self.base.execute("""
|
||||
UPDATE agent_keys
|
||||
SET active=0, rotated_at=CURRENT_TIMESTAMP
|
||||
WHERE agent_id=? AND active=1
|
||||
""", (agent_id,))
|
||||
# Insert new
|
||||
v = self._next_key_version(agent_id)
|
||||
self.base.execute("""
|
||||
INSERT INTO agent_keys(agent_id, key_b64, version, active)
|
||||
VALUES(?,?,?,1)
|
||||
""", (agent_id, new_key_b64, v))
|
||||
return v
|
||||
|
||||
def revoke_keys(self, agent_id: str) -> int:
|
||||
"""
|
||||
Total revocation: active=0 + revoked_at now for all agent keys.
|
||||
Returns the number of affected rows.
|
||||
"""
|
||||
return self.base.execute("""
|
||||
UPDATE agent_keys
|
||||
SET active=0, revoked_at=CURRENT_TIMESTAMP
|
||||
WHERE agent_id=? AND active=1
|
||||
""", (agent_id,))
|
||||
|
||||
def verify_client_key(self, agent_id: str, key_b64: str) -> bool:
|
||||
"""True if the provided key matches an active key for this agent"""
|
||||
row = self.base.query_one("""
|
||||
SELECT 1 FROM agent_keys
|
||||
WHERE agent_id=? AND key_b64=? AND active=1
|
||||
LIMIT 1
|
||||
""", (agent_id, key_b64))
|
||||
return bool(row)
|
||||
|
||||
def migrate_keys_from_file(self, json_path: str) -> int:
|
||||
"""
|
||||
One-shot migration from a keys.json in format {agent_id: key_b64}.
|
||||
For each agent: if no active key, create it in version 1.
|
||||
Returns the number of keys inserted.
|
||||
"""
|
||||
if not json_path or not os.path.exists(json_path):
|
||||
return 0
|
||||
inserted = 0
|
||||
try:
|
||||
with open(json_path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
if not isinstance(data, dict):
|
||||
return 0
|
||||
with self.base.transaction():
|
||||
for agent_id, key_b64 in data.items():
|
||||
if not self.get_active_key(agent_id):
|
||||
self.save_new_key(agent_id, key_b64)
|
||||
inserted += 1
|
||||
except Exception:
|
||||
pass
|
||||
return inserted
|
||||
|
||||
# =========================================================================
|
||||
# HELPER METHODS
|
||||
# =========================================================================
|
||||
|
||||
def _refresh_zombie_counter(self) -> None:
|
||||
"""
|
||||
Update stats.zombie_count with the number of online agents.
|
||||
Won't fail if the column doesn't exist yet.
|
||||
"""
|
||||
try:
|
||||
row = self.base.query_one("SELECT COUNT(*) AS c FROM agents WHERE LOWER(status)='online';")
|
||||
count = int(row['c'] if row else 0)
|
||||
updated = self.base.execute("UPDATE stats SET zombie_count=? WHERE id=1;", (count,))
|
||||
if not updated:
|
||||
# Ensure singleton row exists
|
||||
self.base.execute("INSERT OR IGNORE INTO stats(id) VALUES(1);")
|
||||
self.base.execute("UPDATE stats SET zombie_count=? WHERE id=1;", (count,))
|
||||
except sqlite3.OperationalError:
|
||||
# Column absent: add it properly and retry
|
||||
try:
|
||||
self.base.execute("ALTER TABLE stats ADD COLUMN zombie_count INTEGER DEFAULT 0;")
|
||||
self.base.execute("UPDATE stats SET zombie_count=0 WHERE id=1;")
|
||||
row = self.base.query_one("SELECT COUNT(*) AS c FROM agents WHERE LOWER(status)='online';")
|
||||
count = int(row['c'] if row else 0)
|
||||
self.base.execute("UPDATE stats SET zombie_count=? WHERE id=1;", (count,))
|
||||
except Exception:
|
||||
pass
|
||||
76
db_utils/backups.py
Normal file
76
db_utils/backups.py
Normal file
@@ -0,0 +1,76 @@
|
||||
# db_utils/backups.py
|
||||
# Backup registry and management operations
|
||||
|
||||
from typing import Any, Dict, List
|
||||
import logging
|
||||
|
||||
from logger import Logger
|
||||
|
||||
logger = Logger(name="db_utils.backups", level=logging.DEBUG)
|
||||
|
||||
|
||||
class BackupOps:
|
||||
"""Backup registry and management operations"""
|
||||
|
||||
def __init__(self, base):
|
||||
self.base = base
|
||||
|
||||
def create_tables(self):
|
||||
"""Create backups registry table"""
|
||||
self.base.execute("""
|
||||
CREATE TABLE IF NOT EXISTS backups (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
filename TEXT UNIQUE NOT NULL,
|
||||
description TEXT,
|
||||
date TEXT,
|
||||
type TEXT DEFAULT 'User Backup',
|
||||
is_default INTEGER DEFAULT 0,
|
||||
is_restore INTEGER DEFAULT 0,
|
||||
is_github INTEGER DEFAULT 0,
|
||||
created_at TEXT DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
""")
|
||||
logger.debug("Backups table created/verified")
|
||||
|
||||
# =========================================================================
|
||||
# BACKUP OPERATIONS
|
||||
# =========================================================================
|
||||
|
||||
def add_backup(self, filename: str, description: str, date: str,
|
||||
type_: str = "User Backup", is_default: bool = False,
|
||||
is_restore: bool = False, is_github: bool = False):
|
||||
"""Insert or update a backup registry entry"""
|
||||
self.base.execute("""
|
||||
INSERT INTO backups(filename,description,date,type,is_default,is_restore,is_github)
|
||||
VALUES(?,?,?,?,?,?,?)
|
||||
ON CONFLICT(filename) DO UPDATE SET
|
||||
description=excluded.description,
|
||||
date=excluded.date,
|
||||
type=excluded.type,
|
||||
is_default=excluded.is_default,
|
||||
is_restore=excluded.is_restore,
|
||||
is_github=excluded.is_github;
|
||||
""", (filename, description, date, type_, int(is_default),
|
||||
int(is_restore), int(is_github)))
|
||||
|
||||
def list_backups(self) -> List[Dict[str, Any]]:
|
||||
"""List all backups ordered by date descending"""
|
||||
return self.base.query("""
|
||||
SELECT filename, description, date, type,
|
||||
is_default, is_restore, is_github
|
||||
FROM backups
|
||||
ORDER BY date DESC;
|
||||
""")
|
||||
|
||||
def delete_backup(self, filename: str) -> None:
|
||||
"""Delete a backup entry by filename"""
|
||||
self.base.execute("DELETE FROM backups WHERE filename=?;", (filename,))
|
||||
|
||||
def clear_default_backup(self) -> None:
|
||||
"""Clear the default flag on all backups"""
|
||||
self.base.execute("UPDATE backups SET is_default=0;")
|
||||
|
||||
def set_default_backup(self, filename: str) -> None:
|
||||
"""Set the default flag on a specific backup"""
|
||||
self.clear_default_backup()
|
||||
self.base.execute("UPDATE backups SET is_default=1 WHERE filename=?;", (filename,))
|
||||
159
db_utils/base.py
Normal file
159
db_utils/base.py
Normal file
@@ -0,0 +1,159 @@
|
||||
# db_utils/base.py
|
||||
# Base database connection and transaction management
|
||||
|
||||
import sqlite3
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from threading import RLock
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
||||
import logging
|
||||
|
||||
from logger import Logger
|
||||
|
||||
logger = Logger(name="db_utils.base", level=logging.DEBUG)
|
||||
|
||||
|
||||
class DatabaseBase:
|
||||
"""
|
||||
Base database manager providing connection, transaction, and query primitives.
|
||||
All specialized operation modules inherit access to these primitives.
|
||||
"""
|
||||
|
||||
def __init__(self, db_path: str):
|
||||
self.db_path = db_path
|
||||
|
||||
# Connection with optimized settings for constrained devices (e.g., Raspberry Pi)
|
||||
self._conn = sqlite3.connect(
|
||||
self.db_path,
|
||||
check_same_thread=False,
|
||||
isolation_level=None # Autocommit mode (we manage transactions explicitly)
|
||||
)
|
||||
self._conn.row_factory = sqlite3.Row
|
||||
self._lock = RLock()
|
||||
|
||||
# Small in-process cache for frequently refreshed UI counters
|
||||
self._cache_ttl = 5.0 # seconds
|
||||
self._stats_cache = {'data': None, 'timestamp': 0}
|
||||
|
||||
# Apply PRAGMA tuning
|
||||
with self._lock:
|
||||
cur = self._conn.cursor()
|
||||
# Optimize SQLite for Raspberry Pi / flash storage
|
||||
cur.execute("PRAGMA journal_mode=WAL;")
|
||||
cur.execute("PRAGMA synchronous=NORMAL;")
|
||||
cur.execute("PRAGMA foreign_keys=ON;")
|
||||
cur.execute("PRAGMA cache_size=2000;") # Increase page cache
|
||||
cur.execute("PRAGMA temp_store=MEMORY;") # Use RAM for temporary objects
|
||||
cur.close()
|
||||
|
||||
logger.info(f"DatabaseBase initialized: {db_path}")
|
||||
|
||||
# =========================================================================
|
||||
# CORE CONCURRENCY + SQL PRIMITIVES
|
||||
# =========================================================================
|
||||
|
||||
@contextmanager
|
||||
def _cursor(self):
|
||||
"""Thread-safe cursor context manager"""
|
||||
with self._lock:
|
||||
cur = self._conn.cursor()
|
||||
try:
|
||||
yield cur
|
||||
finally:
|
||||
cur.close()
|
||||
|
||||
@contextmanager
|
||||
def transaction(self, immediate: bool = True):
|
||||
"""Transactional block with automatic rollback on error"""
|
||||
with self._lock:
|
||||
try:
|
||||
self._conn.execute("BEGIN IMMEDIATE;" if immediate else "BEGIN;")
|
||||
yield
|
||||
self._conn.execute("COMMIT;")
|
||||
except Exception:
|
||||
self._conn.execute("ROLLBACK;")
|
||||
raise
|
||||
|
||||
def execute(self, sql: str, params: Iterable[Any] = (), many: bool = False) -> int:
|
||||
"""Execute a DML statement. Supports batch mode via `many=True`"""
|
||||
with self._cursor() as c:
|
||||
if many and params and isinstance(params, (list, tuple)) and isinstance(params[0], (list, tuple)):
|
||||
c.executemany(sql, params)
|
||||
return c.rowcount if c.rowcount is not None else 0
|
||||
c.execute(sql, params)
|
||||
return c.rowcount if c.rowcount is not None else 0
|
||||
|
||||
def executemany(self, sql: str, seq_of_params: Iterable[Iterable[Any]]) -> int:
|
||||
"""Convenience wrapper around `execute(..., many=True)`"""
|
||||
return self.execute(sql, seq_of_params, many=True)
|
||||
|
||||
def query(self, sql: str, params: Iterable[Any] = ()) -> List[Dict[str, Any]]:
|
||||
"""Execute a SELECT and return rows as list[dict]"""
|
||||
with self._cursor() as c:
|
||||
c.execute(sql, params)
|
||||
rows = c.fetchall()
|
||||
return [dict(r) for r in rows]
|
||||
|
||||
def query_one(self, sql: str, params: Iterable[Any] = ()) -> Optional[Dict[str, Any]]:
|
||||
"""Execute a SELECT and return a single row as dict (or None)"""
|
||||
with self._cursor() as c:
|
||||
c.execute(sql, params)
|
||||
row = c.fetchone()
|
||||
return dict(row) if row else None
|
||||
|
||||
# =========================================================================
|
||||
# CACHE MANAGEMENT
|
||||
# =========================================================================
|
||||
|
||||
def invalidate_stats_cache(self):
|
||||
"""Invalidate the small in-memory stats cache"""
|
||||
self._stats_cache = {'data': None, 'timestamp': 0}
|
||||
|
||||
# =========================================================================
|
||||
# SCHEMA HELPERS
|
||||
# =========================================================================
|
||||
|
||||
def _table_exists(self, name: str) -> bool:
|
||||
"""Return True if a table exists in the current database"""
|
||||
row = self.query("SELECT name FROM sqlite_master WHERE type='table' AND name=?", (name,))
|
||||
return bool(row)
|
||||
|
||||
def _column_names(self, table: str) -> List[str]:
|
||||
"""Return a list of column names for a given table (empty if table missing)"""
|
||||
with self._cursor() as c:
|
||||
c.execute(f"PRAGMA table_info({table});")
|
||||
return [r[1] for r in c.fetchall()]
|
||||
|
||||
def _ensure_column(self, table: str, column: str, ddl: str) -> None:
|
||||
"""Add a column with the provided DDL if it does not exist yet"""
|
||||
cols = self._column_names(table) if self._table_exists(table) else []
|
||||
if column not in cols:
|
||||
self.execute(f"ALTER TABLE {table} ADD COLUMN {ddl};")
|
||||
|
||||
# =========================================================================
|
||||
# MAINTENANCE OPERATIONS
|
||||
# =========================================================================
|
||||
|
||||
def checkpoint(self, mode: str = "TRUNCATE") -> Tuple[int, int, int]:
|
||||
"""
|
||||
Force a WAL checkpoint. Returns (busy, log_frames, checkpointed_frames).
|
||||
mode ∈ {PASSIVE, FULL, RESTART, TRUNCATE}
|
||||
"""
|
||||
mode = (mode or "PASSIVE").upper()
|
||||
if mode not in {"PASSIVE", "FULL", "RESTART", "TRUNCATE"}:
|
||||
mode = "PASSIVE"
|
||||
with self._cursor() as c:
|
||||
c.execute(f"PRAGMA wal_checkpoint({mode});")
|
||||
row = c.fetchone()
|
||||
if not row:
|
||||
return (0, 0, 0)
|
||||
vals = tuple(row)
|
||||
return (int(vals[0]), int(vals[1]), int(vals[2]))
|
||||
|
||||
def optimize(self) -> None:
|
||||
"""Run PRAGMA optimize to help the query planner update statistics"""
|
||||
self.execute("PRAGMA optimize;")
|
||||
|
||||
def vacuum(self) -> None:
|
||||
"""Vacuum the database to reclaim space (use sparingly on flash media)"""
|
||||
self.execute("VACUUM;")
|
||||
126
db_utils/comments.py
Normal file
126
db_utils/comments.py
Normal file
@@ -0,0 +1,126 @@
|
||||
# db_utils/comments.py
|
||||
# Comment and status message operations
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
import logging
|
||||
|
||||
from logger import Logger
|
||||
|
||||
logger = Logger(name="db_utils.comments", level=logging.DEBUG)
|
||||
|
||||
|
||||
class CommentOps:
|
||||
"""Comment and status message management operations"""
|
||||
|
||||
def __init__(self, base):
|
||||
self.base = base
|
||||
|
||||
def create_tables(self):
|
||||
"""Create comments table"""
|
||||
self.base.execute("""
|
||||
CREATE TABLE IF NOT EXISTS comments (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
text TEXT NOT NULL,
|
||||
status TEXT NOT NULL,
|
||||
theme TEXT DEFAULT 'general',
|
||||
lang TEXT DEFAULT 'fr',
|
||||
weight INTEGER DEFAULT 1,
|
||||
created_at TEXT DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
""")
|
||||
|
||||
try:
|
||||
self.base.execute("""
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS uq_comments_dedup
|
||||
ON comments(text, status, theme, lang);
|
||||
""")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
logger.debug("Comments table created/verified")
|
||||
|
||||
# =========================================================================
|
||||
# COMMENT OPERATIONS
|
||||
# =========================================================================
|
||||
|
||||
def count_comments(self) -> int:
|
||||
"""Return total number of comment rows"""
|
||||
row = self.base.query_one("SELECT COUNT(1) c FROM comments;")
|
||||
return int(row["c"]) if row else 0
|
||||
|
||||
def insert_comments(self, comments: List[Tuple[str, str, str, str, int]]):
|
||||
"""Batch insert of comments (dedup via UNIQUE or INSERT OR IGNORE semantics)"""
|
||||
if not comments:
|
||||
return
|
||||
self.base.executemany(
|
||||
"INSERT OR IGNORE INTO comments(text,status,theme,lang,weight) VALUES(?,?,?,?,?)",
|
||||
comments
|
||||
)
|
||||
|
||||
def import_comments_from_json(
|
||||
self,
|
||||
json_path: str,
|
||||
lang: Optional[str] = None,
|
||||
default_theme: str = "general",
|
||||
default_weight: int = 1,
|
||||
clear_existing: bool = False
|
||||
) -> int:
|
||||
"""
|
||||
Import comments from a JSON mapping {status: [strings]}.
|
||||
Lang is auto-detected from args, shared_data.lang, or filename.
|
||||
"""
|
||||
if not json_path or not os.path.exists(json_path):
|
||||
return 0
|
||||
try:
|
||||
with open(json_path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
except Exception:
|
||||
return 0
|
||||
|
||||
if not isinstance(data, dict):
|
||||
return 0
|
||||
|
||||
# Determine language
|
||||
if not lang:
|
||||
# From filename (comments.xx.json)
|
||||
base = os.path.basename(json_path).lower()
|
||||
if "comments." in base:
|
||||
parts = base.split(".")
|
||||
if len(parts) >= 3:
|
||||
lang = parts[-2]
|
||||
|
||||
# Fallback
|
||||
lang = (lang or "en").lower()
|
||||
|
||||
rows: List[Tuple[str, str, str, str, int]] = []
|
||||
for status, items in data.items():
|
||||
if not isinstance(items, list):
|
||||
continue
|
||||
for txt in items:
|
||||
t = str(txt).strip()
|
||||
if not t:
|
||||
continue
|
||||
rows.append((t, str(status), str(status), lang, int(default_weight)))
|
||||
|
||||
if not rows:
|
||||
return 0
|
||||
|
||||
with self.base.transaction(immediate=True):
|
||||
if clear_existing:
|
||||
self.base.execute("DELETE FROM comments;")
|
||||
self.insert_comments(rows)
|
||||
|
||||
return len(rows)
|
||||
|
||||
def random_comment_for(self, status: str, lang: str = "en") -> Optional[Dict[str, Any]]:
|
||||
"""Pick a random comment for the given status/language"""
|
||||
rows = self.base.query("""
|
||||
SELECT id, text, status, theme, lang, weight
|
||||
FROM comments
|
||||
WHERE status=? AND lang=?
|
||||
ORDER BY RANDOM()
|
||||
LIMIT 1;
|
||||
""", (status, lang))
|
||||
return rows[0] if rows else None
|
||||
63
db_utils/config.py
Normal file
63
db_utils/config.py
Normal file
@@ -0,0 +1,63 @@
|
||||
# db_utils/config.py
|
||||
# Configuration management operations
|
||||
|
||||
import json
|
||||
import ast
|
||||
from typing import Any, Dict
|
||||
import logging
|
||||
|
||||
from logger import Logger
|
||||
|
||||
logger = Logger(name="db_utils.config", level=logging.DEBUG)
|
||||
|
||||
|
||||
class ConfigOps:
|
||||
"""Configuration key-value store operations"""
|
||||
|
||||
def __init__(self, base):
|
||||
self.base = base
|
||||
|
||||
def create_tables(self):
|
||||
"""Create config table"""
|
||||
self.base.execute("""
|
||||
CREATE TABLE IF NOT EXISTS config (
|
||||
key TEXT PRIMARY KEY,
|
||||
value TEXT
|
||||
);
|
||||
""")
|
||||
logger.debug("Config table created/verified")
|
||||
|
||||
def get_config(self) -> Dict[str, Any]:
|
||||
"""Load config as typed dict (tries JSON, then literal_eval, then raw)"""
|
||||
rows = self.base.query("SELECT key, value FROM config;")
|
||||
out: Dict[str, Any] = {}
|
||||
for r in rows:
|
||||
k = r["key"]
|
||||
raw = r["value"]
|
||||
try:
|
||||
v = json.loads(raw)
|
||||
except Exception:
|
||||
try:
|
||||
v = ast.literal_eval(raw)
|
||||
except Exception:
|
||||
v = raw
|
||||
out[k] = v
|
||||
return out
|
||||
|
||||
def save_config(self, config: Dict[str, Any]) -> None:
|
||||
"""Save the full config mapping to the database (JSON-serialized)"""
|
||||
if not config:
|
||||
return
|
||||
pairs = []
|
||||
for k, v in config.items():
|
||||
try:
|
||||
s = json.dumps(v, ensure_ascii=False)
|
||||
except Exception:
|
||||
s = json.dumps(str(v), ensure_ascii=False)
|
||||
pairs.append((str(k), s))
|
||||
|
||||
with self.base.transaction():
|
||||
self.base.execute("DELETE FROM config;")
|
||||
self.base.executemany("INSERT INTO config(key,value) VALUES(?,?);", pairs)
|
||||
|
||||
logger.info(f"Saved {len(pairs)} config entries")
|
||||
124
db_utils/credentials.py
Normal file
124
db_utils/credentials.py
Normal file
@@ -0,0 +1,124 @@
|
||||
# db_utils/credentials.py
|
||||
# Credential storage and management operations
|
||||
|
||||
import json
|
||||
import sqlite3
|
||||
from typing import Any, Dict, List, Optional
|
||||
import logging
|
||||
|
||||
from logger import Logger
|
||||
|
||||
logger = Logger(name="db_utils.credentials", level=logging.DEBUG)
|
||||
|
||||
|
||||
class CredentialOps:
|
||||
"""Credential storage and retrieval operations"""
|
||||
|
||||
def __init__(self, base):
|
||||
self.base = base
|
||||
|
||||
def create_tables(self):
|
||||
"""Create credentials table"""
|
||||
self.base.execute("""
|
||||
CREATE TABLE IF NOT EXISTS creds (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
service TEXT NOT NULL,
|
||||
mac_address TEXT,
|
||||
ip TEXT,
|
||||
hostname TEXT,
|
||||
"user" TEXT,
|
||||
"password" TEXT,
|
||||
port INTEGER,
|
||||
"database" TEXT,
|
||||
extra TEXT,
|
||||
first_seen TEXT DEFAULT CURRENT_TIMESTAMP,
|
||||
last_seen TEXT DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
""")
|
||||
|
||||
# Indexes to support real UPSERT and dedup
|
||||
try:
|
||||
self.base.execute("""
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS uq_creds_identity
|
||||
ON creds(service, mac_address, ip, "user", "database", port);
|
||||
""")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Optional NULL-safe dedup guard for future rows
|
||||
try:
|
||||
self.base.execute("""
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS uq_creds_identity_norm
|
||||
ON creds(
|
||||
service,
|
||||
COALESCE(mac_address,''),
|
||||
COALESCE(ip,''),
|
||||
COALESCE("user",''),
|
||||
COALESCE("database",''),
|
||||
COALESCE(port,0)
|
||||
);
|
||||
""")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
logger.debug("Credentials table created/verified")
|
||||
|
||||
# =========================================================================
|
||||
# CREDENTIAL OPERATIONS
|
||||
# =========================================================================
|
||||
|
||||
def insert_cred(self, service: str, mac: Optional[str] = None, ip: Optional[str] = None,
|
||||
hostname: Optional[str] = None, user: Optional[str] = None,
|
||||
password: Optional[str] = None, port: Optional[int] = None,
|
||||
database: Optional[str] = None, extra: Optional[Dict[str, Any]] = None):
|
||||
"""Insert or update a credential identity; last_seen is touched on update"""
|
||||
self.base.invalidate_stats_cache()
|
||||
|
||||
# NULL-safe normalization to keep a single identity form
|
||||
mac_n = mac or ""
|
||||
ip_n = ip or ""
|
||||
user_n = user or ""
|
||||
db_n = database or ""
|
||||
port_n = int(port or 0)
|
||||
js = json.dumps(extra, ensure_ascii=False) if extra else None
|
||||
|
||||
try:
|
||||
self.base.execute("""
|
||||
INSERT INTO creds(service,mac_address,ip,hostname,"user","password",port,"database",extra)
|
||||
VALUES(?,?,?,?,?,?,?,?,?)
|
||||
ON CONFLICT(service, mac_address, ip, "user", "database", port) DO UPDATE SET
|
||||
"password"=excluded."password",
|
||||
hostname=COALESCE(excluded.hostname, creds.hostname),
|
||||
last_seen=CURRENT_TIMESTAMP,
|
||||
extra=COALESCE(excluded.extra, creds.extra);
|
||||
""", (service, mac_n, ip_n, hostname, user_n, password, port_n, db_n, js))
|
||||
except sqlite3.OperationalError:
|
||||
# Fallback if unique index not available: manual upsert
|
||||
row = self.base.query_one("""
|
||||
SELECT id FROM creds
|
||||
WHERE service=? AND COALESCE(mac_address,'')=? AND COALESCE(ip,'')=?
|
||||
AND COALESCE("user",'')=? AND COALESCE("database",'')=? AND COALESCE(port,0)=?
|
||||
LIMIT 1
|
||||
""", (service, mac_n, ip_n, user_n, db_n, port_n))
|
||||
if row:
|
||||
self.base.execute("""
|
||||
UPDATE creds
|
||||
SET "password"=?,
|
||||
hostname=COALESCE(?, hostname),
|
||||
last_seen=CURRENT_TIMESTAMP,
|
||||
extra=COALESCE(?, extra)
|
||||
WHERE id=?
|
||||
""", (password, hostname, js, row["id"]))
|
||||
else:
|
||||
self.base.execute("""
|
||||
INSERT INTO creds(service,mac_address,ip,hostname,"user","password",port,"database",extra)
|
||||
VALUES(?,?,?,?,?,?,?,?,?)
|
||||
""", (service, mac_n, ip_n, hostname, user_n, password, port_n, db_n, js))
|
||||
|
||||
def list_creds_grouped(self) -> List[Dict[str, Any]]:
|
||||
"""List all credential rows grouped/sorted by service/ip/user/port for UI"""
|
||||
return self.base.query("""
|
||||
SELECT service, mac_address, ip, hostname, "user", "password", port, "database", last_seen
|
||||
FROM creds
|
||||
ORDER BY service, ip, "user", port
|
||||
""")
|
||||
480
db_utils/hosts.py
Normal file
480
db_utils/hosts.py
Normal file
@@ -0,0 +1,480 @@
|
||||
# db_utils/hosts.py
|
||||
# Host and network device management operations
|
||||
|
||||
import time
|
||||
import sqlite3
|
||||
from typing import Any, Dict, Iterable, List, Optional
|
||||
import logging
|
||||
|
||||
from logger import Logger
|
||||
|
||||
logger = Logger(name="db_utils.hosts", level=logging.DEBUG)
|
||||
|
||||
|
||||
class HostOps:
|
||||
"""Host management and tracking operations"""
|
||||
|
||||
def __init__(self, base):
|
||||
self.base = base
|
||||
|
||||
def create_tables(self):
|
||||
"""Create hosts and related tables"""
|
||||
# Main hosts table
|
||||
self.base.execute("""
|
||||
CREATE TABLE IF NOT EXISTS hosts (
|
||||
mac_address TEXT PRIMARY KEY,
|
||||
ips TEXT,
|
||||
hostnames TEXT,
|
||||
alive INTEGER DEFAULT 0,
|
||||
ports TEXT,
|
||||
vendor TEXT,
|
||||
essid TEXT,
|
||||
previous_hostnames TEXT,
|
||||
previous_ips TEXT,
|
||||
previous_ports TEXT,
|
||||
previous_essids TEXT,
|
||||
first_seen INTEGER,
|
||||
last_seen INTEGER,
|
||||
updated_at TEXT DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
""")
|
||||
self.base.execute("CREATE INDEX IF NOT EXISTS idx_hosts_alive ON hosts(alive);")
|
||||
|
||||
# Hostname history table
|
||||
self.base.execute("""
|
||||
CREATE TABLE IF NOT EXISTS hostnames_history(
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
mac_address TEXT NOT NULL,
|
||||
hostname TEXT NOT NULL,
|
||||
first_seen TEXT DEFAULT CURRENT_TIMESTAMP,
|
||||
last_seen TEXT DEFAULT CURRENT_TIMESTAMP,
|
||||
is_current INTEGER DEFAULT 1,
|
||||
UNIQUE(mac_address, hostname)
|
||||
);
|
||||
""")
|
||||
|
||||
# Guarantee a single current hostname per MAC
|
||||
try:
|
||||
# One and only one "current" hostname row per MAC in history
|
||||
self.base.execute("""
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS uq_hostname_current
|
||||
ON hostnames_history(mac_address)
|
||||
WHERE is_current=1;
|
||||
""")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Uniqueness for real MACs only (allows legacy stubs in old DBs but our scanner no longer writes them)
|
||||
try:
|
||||
self.base.execute("""
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS ux_hosts_real_mac
|
||||
ON hosts(mac_address)
|
||||
WHERE instr(mac_address, ':') > 0;
|
||||
""")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
logger.debug("Hosts tables created/verified")
|
||||
|
||||
# =========================================================================
|
||||
# HOST CRUD OPERATIONS
|
||||
# =========================================================================
|
||||
|
||||
def get_all_hosts(self) -> List[Dict[str, Any]]:
|
||||
"""Get all hosts with current/previous IPs/ports/essids ordered by liveness then MAC"""
|
||||
return self.base.query("""
|
||||
SELECT mac_address, ips, previous_ips,
|
||||
hostnames, previous_hostnames,
|
||||
alive,
|
||||
ports, previous_ports,
|
||||
vendor, essid, previous_essids,
|
||||
first_seen, last_seen
|
||||
FROM hosts
|
||||
ORDER BY alive DESC, mac_address;
|
||||
""")
|
||||
|
||||
def update_host(self, mac_address: str, ips: Optional[str] = None,
|
||||
hostnames: Optional[str] = None, alive: Optional[int] = None,
|
||||
ports: Optional[str] = None, vendor: Optional[str] = None,
|
||||
essid: Optional[str] = None):
|
||||
"""
|
||||
Partial upsert of the host row. None/'' fields do not erase existing values.
|
||||
For automatic tracking of previous_* fields, use update_*_current helpers instead.
|
||||
"""
|
||||
# --- Hardening: normalize and guard ---
|
||||
# Always store normalized lowercase MACs; refuse 'ip:' stubs defensively.
|
||||
mac_address = (mac_address or "").strip().lower()
|
||||
if mac_address.startswith("ip:"):
|
||||
raise ValueError("stub MAC not allowed (scanner runs in no-stub mode)")
|
||||
|
||||
self.base.invalidate_stats_cache()
|
||||
|
||||
now = int(time.time())
|
||||
|
||||
self.base.execute("""
|
||||
INSERT INTO hosts(mac_address, ips, hostnames, alive, ports, vendor, essid,
|
||||
first_seen, last_seen, updated_at)
|
||||
VALUES(?, ?, ?, COALESCE(?, 0), ?, ?, ?, ?, ?, CURRENT_TIMESTAMP)
|
||||
ON CONFLICT(mac_address) DO UPDATE SET
|
||||
ips = COALESCE(NULLIF(excluded.ips, ''), hosts.ips),
|
||||
hostnames = COALESCE(NULLIF(excluded.hostnames, ''), hosts.hostnames),
|
||||
alive = COALESCE(excluded.alive, hosts.alive),
|
||||
ports = COALESCE(NULLIF(excluded.ports, ''), hosts.ports),
|
||||
vendor = COALESCE(NULLIF(excluded.vendor, ''), hosts.vendor),
|
||||
essid = COALESCE(NULLIF(excluded.essid, ''), hosts.essid),
|
||||
last_seen = ?,
|
||||
updated_at= CURRENT_TIMESTAMP;
|
||||
""", (mac_address, ips, hostnames, alive, ports, vendor, essid, now, now, now))
|
||||
|
||||
# =========================================================================
|
||||
# HOSTNAME OPERATIONS
|
||||
# =========================================================================
|
||||
|
||||
def update_hostname(self, mac_address: str, new_hostname: str):
|
||||
"""Update current hostname + track previous/current in both hosts and history tables"""
|
||||
new_hostname = (new_hostname or "").strip()
|
||||
if not new_hostname:
|
||||
return
|
||||
|
||||
with self.base.transaction(immediate=True):
|
||||
row = self.base.query(
|
||||
"SELECT hostnames, previous_hostnames FROM hosts WHERE mac_address=? LIMIT 1;",
|
||||
(mac_address,)
|
||||
)
|
||||
curr = (row[0]["hostnames"] or "") if row else ""
|
||||
prev = (row[0]["previous_hostnames"] or "") if row else ""
|
||||
|
||||
curr_list = [h for h in curr.split(';') if h]
|
||||
prev_list = [h for h in prev.split(';') if h]
|
||||
|
||||
if new_hostname in curr_list:
|
||||
curr_list = [new_hostname] + [h for h in curr_list if h != new_hostname]
|
||||
next_curr = ';'.join(curr_list)
|
||||
next_prev = ';'.join(prev_list)
|
||||
else:
|
||||
merged_prev = list(dict.fromkeys(curr_list + prev_list))[:50] # cap at 50
|
||||
next_curr = new_hostname
|
||||
next_prev = ';'.join(merged_prev)
|
||||
|
||||
self.base.execute("""
|
||||
INSERT INTO hosts(mac_address, hostnames, previous_hostnames, updated_at)
|
||||
VALUES(?,?,?,CURRENT_TIMESTAMP)
|
||||
ON CONFLICT(mac_address) DO UPDATE SET
|
||||
hostnames = excluded.hostnames,
|
||||
previous_hostnames = excluded.previous_hostnames,
|
||||
updated_at = CURRENT_TIMESTAMP;
|
||||
""", (mac_address, next_curr, next_prev))
|
||||
|
||||
# Update hostname history table
|
||||
self.base.execute("""
|
||||
UPDATE hostnames_history
|
||||
SET is_current=0, last_seen=CURRENT_TIMESTAMP
|
||||
WHERE mac_address=? AND is_current=1;
|
||||
""", (mac_address,))
|
||||
|
||||
self.base.execute("""
|
||||
INSERT INTO hostnames_history(mac_address, hostname, is_current)
|
||||
VALUES(?,?,1)
|
||||
ON CONFLICT(mac_address, hostname) DO UPDATE SET
|
||||
is_current=1, last_seen=CURRENT_TIMESTAMP;
|
||||
""", (mac_address, new_hostname))
|
||||
|
||||
def get_current_hostname(self, mac_address: str) -> Optional[str]:
|
||||
"""Get the current hostname from history when available; fallback to hosts.hostnames"""
|
||||
row = self.base.query("""
|
||||
SELECT hostname FROM hostnames_history
|
||||
WHERE mac_address=? AND is_current=1 LIMIT 1;
|
||||
""", (mac_address,))
|
||||
if row:
|
||||
return row[0]["hostname"]
|
||||
|
||||
row = self.base.query("SELECT hostnames FROM hosts WHERE mac_address=? LIMIT 1;", (mac_address,))
|
||||
if row and row[0]["hostnames"]:
|
||||
return row[0]["hostnames"].split(';', 1)[0]
|
||||
return None
|
||||
|
||||
def record_hostname_seen(self, mac_address: str, hostname: str):
|
||||
"""Alias for update_hostname: mark a hostname as seen/current"""
|
||||
self.update_hostname(mac_address, hostname)
|
||||
|
||||
def list_hostname_history(self, mac_address: str) -> List[Dict[str, Any]]:
|
||||
"""Return the full hostname history for a MAC (current first)"""
|
||||
return self.base.query("""
|
||||
SELECT hostname, first_seen, last_seen, is_current
|
||||
FROM hostnames_history
|
||||
WHERE mac_address=?
|
||||
ORDER BY is_current DESC, last_seen DESC, first_seen DESC;
|
||||
""", (mac_address,))
|
||||
|
||||
# =========================================================================
|
||||
# IP OPERATIONS
|
||||
# =========================================================================
|
||||
|
||||
def update_ips_current(self, mac_address: str, current_ips: Iterable[str], cap_prev: int = 200):
|
||||
"""Replace current IP set and roll removed IPs into previous_ips (deduped, size-capped)"""
|
||||
cur_set = {ip.strip() for ip in (current_ips or []) if ip}
|
||||
row = self.base.query("SELECT ips, previous_ips FROM hosts WHERE mac_address=? LIMIT 1;", (mac_address,))
|
||||
prev_cur = set(self._parse_list(row[0]["ips"])) if row else set()
|
||||
prev_prev = set(self._parse_list(row[0]["previous_ips"])) if row else set()
|
||||
|
||||
removed = prev_cur - cur_set
|
||||
prev_prev |= removed
|
||||
|
||||
if len(prev_prev) > cap_prev:
|
||||
prev_prev = set(sorted(prev_prev, key=self._sort_ip_key)[:cap_prev])
|
||||
|
||||
ips_sorted = ";".join(sorted(cur_set, key=self._sort_ip_key))
|
||||
prev_sorted = ";".join(sorted(prev_prev, key=self._sort_ip_key))
|
||||
|
||||
self.base.execute("""
|
||||
INSERT INTO hosts(mac_address, ips, previous_ips, updated_at)
|
||||
VALUES(?,?,?,CURRENT_TIMESTAMP)
|
||||
ON CONFLICT(mac_address) DO UPDATE SET
|
||||
ips = excluded.ips,
|
||||
previous_ips = excluded.previous_ips,
|
||||
updated_at = CURRENT_TIMESTAMP;
|
||||
""", (mac_address, ips_sorted, prev_sorted))
|
||||
|
||||
# =========================================================================
|
||||
# PORT OPERATIONS
|
||||
# =========================================================================
|
||||
|
||||
def update_ports_current(self, mac_address: str, current_ports: Iterable[int], cap_prev: int = 500):
|
||||
"""Replace current port set and roll removed ports into previous_ports (deduped, size-capped)"""
|
||||
cur_set = set(int(p) for p in (current_ports or []) if str(p).isdigit())
|
||||
row = self.base.query("SELECT ports, previous_ports FROM hosts WHERE mac_address=? LIMIT 1;", (mac_address,))
|
||||
prev_cur = set(int(p) for p in self._parse_list(row[0]["ports"])) if row else set()
|
||||
prev_prev = set(int(p) for p in self._parse_list(row[0]["previous_ports"])) if row else set()
|
||||
|
||||
removed = prev_cur - cur_set
|
||||
prev_prev |= removed
|
||||
|
||||
if len(prev_prev) > cap_prev:
|
||||
prev_prev = set(sorted(prev_prev)[:cap_prev])
|
||||
|
||||
ports_sorted = ";".join(str(p) for p in sorted(cur_set))
|
||||
prev_sorted = ";".join(str(p) for p in sorted(prev_prev))
|
||||
|
||||
self.base.execute("""
|
||||
INSERT INTO hosts(mac_address, ports, previous_ports, updated_at)
|
||||
VALUES(?,?,?,CURRENT_TIMESTAMP)
|
||||
ON CONFLICT(mac_address) DO UPDATE SET
|
||||
ports = excluded.ports,
|
||||
previous_ports = excluded.previous_ports,
|
||||
updated_at = CURRENT_TIMESTAMP;
|
||||
""", (mac_address, ports_sorted, prev_sorted))
|
||||
|
||||
# =========================================================================
|
||||
# ESSID OPERATIONS
|
||||
# =========================================================================
|
||||
|
||||
def update_essid_current(self, mac_address: str, new_essid: Optional[str], cap_prev: int = 50):
|
||||
"""Update current ESSID and move previous one into previous_essids if it changed"""
|
||||
new_essid = (new_essid or "").strip()
|
||||
|
||||
row = self.base.query(
|
||||
"SELECT essid, previous_essids FROM hosts WHERE mac_address=? LIMIT 1;",
|
||||
(mac_address,)
|
||||
)
|
||||
|
||||
if row:
|
||||
old = (row[0]["essid"] or "").strip()
|
||||
prev_prev = self._parse_list(row[0]["previous_essids"]) or []
|
||||
else:
|
||||
old = ""
|
||||
prev_prev = []
|
||||
|
||||
if old and new_essid and new_essid == old:
|
||||
essid = new_essid
|
||||
prev_joined = ";".join(prev_prev)
|
||||
else:
|
||||
if old and old not in prev_prev:
|
||||
prev_prev = [old] + prev_prev
|
||||
prev_prev = prev_prev[:cap_prev]
|
||||
essid = new_essid
|
||||
prev_joined = ";".join(prev_prev)
|
||||
|
||||
self.base.execute("""
|
||||
INSERT INTO hosts(mac_address, essid, previous_essids, updated_at)
|
||||
VALUES(?,?,?,CURRENT_TIMESTAMP)
|
||||
ON CONFLICT(mac_address) DO UPDATE SET
|
||||
essid = excluded.essid,
|
||||
previous_essids = excluded.previous_essids,
|
||||
updated_at = CURRENT_TIMESTAMP;
|
||||
""", (mac_address, essid, prev_joined))
|
||||
|
||||
# =========================================================================
|
||||
# IP STUB MERGING
|
||||
# =========================================================================
|
||||
|
||||
def merge_ip_stub_into_real(self, ip: str, real_mac: str,
|
||||
hostname: Optional[str] = None, essid_hint: Optional[str] = None):
|
||||
"""
|
||||
Merge a host 'IP:<ip>' stub with the host at 'real_mac' (if present) or rename the stub.
|
||||
- Unifies ips, hostnames, ports, vendor, essid, first_seen/last_seen, alive.
|
||||
- Updates tables that have a 'mac_address' column to point to the real MAC.
|
||||
- SSID tolerance (if one of the two is empty, keep the present one).
|
||||
- If the host 'real_mac' doesn't exist yet, simply rename the stub -> real_mac.
|
||||
"""
|
||||
if not real_mac or ':' not in real_mac:
|
||||
return # nothing to do if we don't have a real MAC
|
||||
|
||||
now = int(time.time())
|
||||
stub_key = f"IP:{ip}".lower()
|
||||
real_key = real_mac.lower()
|
||||
|
||||
with self.base._lock:
|
||||
con = self.base._conn
|
||||
cur = con.cursor()
|
||||
|
||||
# Retrieve stub candidates (by mac=IP:ip) + fallback by ip contained and mac 'IP:%'
|
||||
cur.execute("""
|
||||
SELECT * FROM hosts
|
||||
WHERE lower(mac_address)=?
|
||||
OR (lower(mac_address) LIKE 'ip:%' AND (ips LIKE '%'||?||'%'))
|
||||
ORDER BY lower(mac_address)=? DESC
|
||||
LIMIT 1
|
||||
""", (stub_key, ip, stub_key))
|
||||
stub = cur.fetchone()
|
||||
|
||||
# Nothing to merge?
|
||||
cur.execute("SELECT * FROM hosts WHERE lower(mac_address)=? LIMIT 1", (real_key,))
|
||||
real = cur.fetchone()
|
||||
|
||||
if not stub and not real:
|
||||
# No record: create the real one directly
|
||||
cur.execute("""INSERT OR IGNORE INTO hosts
|
||||
(mac_address, ips, hostnames, ports, vendor, essid, alive, first_seen, last_seen)
|
||||
VALUES (?,?,?,?,?,?,1,?,?)""",
|
||||
(real_key, ip, hostname or None, None, None, essid_hint or None, now, now))
|
||||
con.commit()
|
||||
return
|
||||
|
||||
if stub and not real:
|
||||
# Rename the stub -> real MAC
|
||||
ips_merged = self._union_semicol(stub['ips'], ip, sort_ip=True)
|
||||
hosts_merged = self._union_semicol(stub['hostnames'], hostname)
|
||||
essid_final = stub['essid'] or essid_hint
|
||||
vendor_final = stub['vendor']
|
||||
|
||||
cur.execute("""UPDATE hosts SET
|
||||
mac_address=?,
|
||||
ips=?,
|
||||
hostnames=?,
|
||||
essid=COALESCE(?, essid),
|
||||
alive=1,
|
||||
last_seen=?
|
||||
WHERE lower(mac_address)=?""",
|
||||
(real_key, ips_merged, hosts_merged, essid_final, now, stub['mac_address'].lower()))
|
||||
|
||||
# Redirect references from other tables (if they exist)
|
||||
self._redirect_mac_references(cur, stub['mac_address'].lower(), real_key)
|
||||
con.commit()
|
||||
return
|
||||
|
||||
if stub and real:
|
||||
# Full merge into the real, then delete stub
|
||||
ips_merged = self._union_semicol(real['ips'], stub['ips'], sort_ip=True)
|
||||
ips_merged = self._union_semicol(ips_merged, ip, sort_ip=True)
|
||||
hosts_merged = self._union_semicol(real['hostnames'], stub['hostnames'])
|
||||
hosts_merged = self._union_semicol(hosts_merged, hostname)
|
||||
ports_merged = self._union_semicol(real['ports'], stub['ports'])
|
||||
vendor_final = real['vendor'] or stub['vendor']
|
||||
essid_final = real['essid'] or stub['essid'] or essid_hint
|
||||
first_seen = min(int(real['first_seen'] or now), int(stub['first_seen'] or now))
|
||||
last_seen = max(int(real['last_seen'] or now), int(stub['last_seen'] or now), now)
|
||||
|
||||
cur.execute("""UPDATE hosts SET
|
||||
ips=?,
|
||||
hostnames=?,
|
||||
ports=?,
|
||||
vendor=COALESCE(?, vendor),
|
||||
essid=COALESCE(?, essid),
|
||||
alive=1,
|
||||
first_seen=?,
|
||||
last_seen=?
|
||||
WHERE lower(mac_address)=?""",
|
||||
(ips_merged, hosts_merged, ports_merged, vendor_final, essid_final,
|
||||
first_seen, last_seen, real_key))
|
||||
|
||||
# Redirect references to real_key then delete stub
|
||||
self._redirect_mac_references(cur, stub['mac_address'].lower(), real_key)
|
||||
cur.execute("DELETE FROM hosts WHERE lower(mac_address)=?", (stub['mac_address'].lower(),))
|
||||
con.commit()
|
||||
return
|
||||
|
||||
# No stub but a real exists already: ensure current IP/hostname are unified
|
||||
if real and not stub:
|
||||
ips_merged = self._union_semicol(real['ips'], ip, sort_ip=True)
|
||||
hosts_merged = self._union_semicol(real['hostnames'], hostname)
|
||||
essid_final = real['essid'] or essid_hint
|
||||
cur.execute("""UPDATE hosts SET
|
||||
ips=?,
|
||||
hostnames=?,
|
||||
essid=COALESCE(?, essid),
|
||||
alive=1,
|
||||
last_seen=?
|
||||
WHERE lower(mac_address)=?""",
|
||||
(ips_merged, hosts_merged, essid_final, now, real_key))
|
||||
con.commit()
|
||||
|
||||
def _redirect_mac_references(self, cur, old_mac: str, new_mac: str):
|
||||
"""Redirect mac_address references in all relevant tables"""
|
||||
try:
|
||||
# Discover all tables with a mac_address column
|
||||
cur.execute("""SELECT name FROM sqlite_master
|
||||
WHERE type='table' AND name NOT LIKE 'sqlite_%'""")
|
||||
for (tname,) in cur.fetchall():
|
||||
if tname == 'hosts':
|
||||
continue
|
||||
try:
|
||||
cur.execute(f"PRAGMA table_info({tname})")
|
||||
cols = [r[1].lower() for r in cur.fetchall()]
|
||||
if 'mac_address' in cols:
|
||||
cur.execute(f"""UPDATE {tname}
|
||||
SET mac_address=?
|
||||
WHERE lower(mac_address)=?""",
|
||||
(new_mac, old_mac))
|
||||
except Exception:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# =========================================================================
|
||||
# HELPER METHODS
|
||||
# =========================================================================
|
||||
|
||||
def _parse_list(self, s: Optional[str]) -> List[str]:
|
||||
"""Parse a semicolon-separated string into a list, ignoring empties"""
|
||||
return [x for x in (s or "").split(";") if x]
|
||||
|
||||
def _sort_ip_key(self, ip: str):
|
||||
"""Return a sortable key for IPv4 addresses; non-IPv4 sorts last"""
|
||||
if ip and ip.count(".") == 3:
|
||||
try:
|
||||
return tuple(int(x) for x in ip.split("."))
|
||||
except Exception:
|
||||
return (0, 0, 0, 0)
|
||||
return (0, 0, 0, 0)
|
||||
|
||||
def _union_semicol(self, *values: Optional[str], sort_ip: bool = False) -> str:
|
||||
"""Union deduplicated of semicolon-separated lists (ignores empties)"""
|
||||
def _key(x):
|
||||
if sort_ip and x.count('.') == 3:
|
||||
try:
|
||||
return tuple(map(int, x.split('.')))
|
||||
except Exception:
|
||||
return (0, 0, 0, 0)
|
||||
return x
|
||||
|
||||
s = set()
|
||||
for v in values:
|
||||
if not v:
|
||||
continue
|
||||
for it in str(v).split(';'):
|
||||
it = it.strip()
|
||||
if it:
|
||||
s.add(it)
|
||||
if not s:
|
||||
return ""
|
||||
return ';'.join(sorted(s, key=_key))
|
||||
410
db_utils/queue.py
Normal file
410
db_utils/queue.py
Normal file
@@ -0,0 +1,410 @@
|
||||
# db_utils/queue.py
|
||||
# Action queue management operations
|
||||
|
||||
import json
|
||||
import sqlite3
|
||||
from typing import Any, Dict, Iterable, List, Optional
|
||||
import logging
|
||||
|
||||
from logger import Logger
|
||||
|
||||
logger = Logger(name="db_utils.queue", level=logging.DEBUG)
|
||||
|
||||
|
||||
class QueueOps:
|
||||
"""Action queue scheduling and execution tracking operations"""
|
||||
|
||||
def __init__(self, base):
|
||||
self.base = base
|
||||
|
||||
def create_tables(self):
|
||||
"""Create action queue table and indexes"""
|
||||
self.base.execute("""
|
||||
CREATE TABLE IF NOT EXISTS action_queue (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
action_name TEXT NOT NULL,
|
||||
mac_address TEXT NOT NULL,
|
||||
ip TEXT NOT NULL,
|
||||
port INTEGER,
|
||||
hostname TEXT,
|
||||
service TEXT,
|
||||
priority INTEGER DEFAULT 50,
|
||||
status TEXT DEFAULT 'pending',
|
||||
retry_count INTEGER DEFAULT 0,
|
||||
max_retries INTEGER DEFAULT 3,
|
||||
created_at TEXT DEFAULT CURRENT_TIMESTAMP,
|
||||
scheduled_for TEXT,
|
||||
started_at TEXT,
|
||||
completed_at TEXT,
|
||||
expires_at TEXT,
|
||||
trigger_source TEXT,
|
||||
dependencies TEXT,
|
||||
conditions TEXT,
|
||||
result_summary TEXT,
|
||||
error_message TEXT,
|
||||
tags TEXT,
|
||||
metadata TEXT,
|
||||
FOREIGN KEY (mac_address) REFERENCES hosts(mac_address)
|
||||
);
|
||||
""")
|
||||
|
||||
# Optimized indexes for queue operations
|
||||
self.base.execute("CREATE INDEX IF NOT EXISTS idx_queue_pending ON action_queue(status) WHERE status='pending';")
|
||||
self.base.execute("CREATE INDEX IF NOT EXISTS idx_queue_scheduled ON action_queue(scheduled_for) WHERE status='scheduled';")
|
||||
self.base.execute("CREATE INDEX IF NOT EXISTS idx_queue_mac_action ON action_queue(mac_address, action_name);")
|
||||
self.base.execute("CREATE INDEX IF NOT EXISTS idx_queue_key_status ON action_queue(action_name, mac_address, port, status);")
|
||||
self.base.execute("CREATE INDEX IF NOT EXISTS idx_queue_key_time ON action_queue(action_name, mac_address, port, completed_at);")
|
||||
|
||||
# Unique constraint for a single upcoming schedule per action/target
|
||||
self.base.execute("""
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS uq_next_scheduled
|
||||
ON action_queue(action_name,
|
||||
COALESCE(mac_address,''),
|
||||
COALESCE(service,''),
|
||||
COALESCE(port,-1))
|
||||
WHERE status='scheduled';
|
||||
""")
|
||||
|
||||
logger.debug("Action queue table created/verified")
|
||||
|
||||
# =========================================================================
|
||||
# QUEUE RETRIEVAL OPERATIONS
|
||||
# =========================================================================
|
||||
|
||||
def get_next_queued_action(self) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Fetch the next action to execute from the queue.
|
||||
Priority is dynamically boosted: +1 per 5 minutes since creation, capped at +100.
|
||||
"""
|
||||
rows = self.base.query("""
|
||||
SELECT *,
|
||||
MIN(100, priority + CAST((strftime('%s','now') - strftime('%s',created_at))/300 AS INTEGER)) AS priority_effective
|
||||
FROM action_queue
|
||||
WHERE status = 'pending'
|
||||
AND (scheduled_for IS NULL OR scheduled_for <= datetime('now'))
|
||||
ORDER BY priority_effective DESC,
|
||||
COALESCE(scheduled_for, created_at) ASC
|
||||
LIMIT 1
|
||||
""")
|
||||
return rows[0] if rows else None
|
||||
|
||||
def list_action_queue(self, statuses: Optional[Iterable[str]] = None) -> List[Dict[str, Any]]:
|
||||
"""List queue entries with a computed `priority_effective` column for pending items"""
|
||||
order_sql = """
|
||||
CASE status
|
||||
WHEN 'running' THEN 1
|
||||
WHEN 'pending' THEN 2
|
||||
WHEN 'scheduled' THEN 3
|
||||
WHEN 'failed' THEN 4
|
||||
WHEN 'success' THEN 5
|
||||
WHEN 'expired' THEN 6
|
||||
WHEN 'cancelled' THEN 7
|
||||
ELSE 99
|
||||
END ASC,
|
||||
priority_effective DESC,
|
||||
COALESCE(scheduled_for, created_at) ASC
|
||||
"""
|
||||
|
||||
select_sql = """
|
||||
SELECT *,
|
||||
MIN(100, priority + CAST((strftime('%s','now') - strftime('%s',created_at))/300 AS INTEGER)) AS priority_effective
|
||||
FROM action_queue
|
||||
"""
|
||||
|
||||
if statuses:
|
||||
in_clause = ",".join("?" for _ in statuses)
|
||||
return self.base.query(f"""
|
||||
{select_sql}
|
||||
WHERE status IN ({in_clause})
|
||||
ORDER BY {order_sql}
|
||||
""", tuple(statuses))
|
||||
|
||||
return self.base.query(f"""
|
||||
{select_sql}
|
||||
ORDER BY {order_sql}
|
||||
""")
|
||||
|
||||
def get_upcoming_actions_summary(self) -> List[Dict[str, Any]]:
|
||||
"""Summary: next run per action_name from the schedule"""
|
||||
return self.base.query("""
|
||||
SELECT action_name, MIN(scheduled_for) AS next_run_at
|
||||
FROM action_queue
|
||||
WHERE status='scheduled' AND scheduled_for IS NOT NULL
|
||||
GROUP BY action_name
|
||||
ORDER BY next_run_at ASC
|
||||
""")
|
||||
|
||||
# =========================================================================
|
||||
# QUEUE UPDATE OPERATIONS
|
||||
# =========================================================================
|
||||
|
||||
def update_queue_status(self, queue_id: int, status: str, error_msg: str = None, result: str = None):
|
||||
"""Update queue entry status with retry management on failure/expiry"""
|
||||
self.base.invalidate_stats_cache()
|
||||
|
||||
if status == 'running':
|
||||
self.base.execute(
|
||||
"UPDATE action_queue SET status=?, started_at=CURRENT_TIMESTAMP WHERE id=?",
|
||||
(status, queue_id)
|
||||
)
|
||||
elif status in ('failed', 'expired'):
|
||||
self.base.execute("""
|
||||
UPDATE action_queue
|
||||
SET status=?,
|
||||
completed_at=CURRENT_TIMESTAMP,
|
||||
error_message=?,
|
||||
result_summary=COALESCE(?, result_summary),
|
||||
retry_count = MIN(retry_count + 1, max_retries)
|
||||
WHERE id=?
|
||||
""", (status, error_msg, result, queue_id))
|
||||
elif status in ('success', 'cancelled'):
|
||||
self.base.execute("""
|
||||
UPDATE action_queue
|
||||
SET status=?,
|
||||
completed_at=CURRENT_TIMESTAMP,
|
||||
error_message=?,
|
||||
result_summary=COALESCE(?, result_summary)
|
||||
WHERE id=?
|
||||
""", (status, error_msg, result, queue_id))
|
||||
|
||||
# When execution succeeds, supersede old failed/expired attempts
|
||||
if status == 'success':
|
||||
row = self.base.query_one("""
|
||||
SELECT action_name, mac_address, port,
|
||||
COALESCE(completed_at, started_at, created_at) AS ts
|
||||
FROM action_queue WHERE id=? LIMIT 1
|
||||
""", (queue_id,))
|
||||
if row:
|
||||
try:
|
||||
self.supersede_old_attempts(row['action_name'], row['mac_address'], row['port'], row['ts'])
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def promote_due_scheduled_to_pending(self) -> int:
|
||||
"""Promote scheduled actions that are due (returns number of rows affected)"""
|
||||
self.base.invalidate_stats_cache()
|
||||
return self.base.execute("""
|
||||
UPDATE action_queue
|
||||
SET status='pending'
|
||||
WHERE status='scheduled'
|
||||
AND scheduled_for <= CURRENT_TIMESTAMP
|
||||
""")
|
||||
|
||||
# =========================================================================
|
||||
# QUEUE INSERTION OPERATIONS
|
||||
# =========================================================================
|
||||
|
||||
def ensure_scheduled_occurrence(
|
||||
self,
|
||||
action_name: str,
|
||||
next_run_at: str,
|
||||
mac: Optional[str] = "",
|
||||
ip: Optional[str] = "",
|
||||
*,
|
||||
port: Optional[int] = None,
|
||||
hostname: Optional[str] = None,
|
||||
service: Optional[str] = None,
|
||||
priority: int = 40,
|
||||
trigger: str = "scheduler",
|
||||
tags: Optional[Iterable[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
max_retries: Optional[int] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Ensure a single upcoming 'scheduled' row exists for the given action/target.
|
||||
Returns True if inserted, False if already present (enforced by unique partial index).
|
||||
"""
|
||||
js_tags = json.dumps(list(tags)) if tags is not None and not isinstance(tags, str) else (tags if isinstance(tags, str) else None)
|
||||
js_meta = json.dumps(metadata, ensure_ascii=False) if metadata else None
|
||||
|
||||
try:
|
||||
self.base.execute("""
|
||||
INSERT INTO action_queue(
|
||||
action_name, mac_address, ip, port, hostname, service,
|
||||
priority, status, scheduled_for, trigger_source, tags, metadata, max_retries
|
||||
) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?)
|
||||
""", (
|
||||
action_name, mac or "", ip or "", port, hostname, service,
|
||||
int(priority), "scheduled", next_run_at, trigger, js_tags, js_meta, max_retries
|
||||
))
|
||||
self.base.invalidate_stats_cache()
|
||||
return True
|
||||
except sqlite3.IntegrityError:
|
||||
return False
|
||||
|
||||
def queue_action(self, action_name: str, mac: str, ip: str, port: int = None,
|
||||
priority: int = 50, trigger: str = None, metadata: Dict = None) -> None:
|
||||
"""Quick enqueue of a 'pending' action"""
|
||||
meta_json = json.dumps(metadata, ensure_ascii=False) if metadata else None
|
||||
self.base.execute("""
|
||||
INSERT INTO action_queue
|
||||
(action_name, mac_address, ip, port, priority, trigger_source, metadata)
|
||||
VALUES (?,?,?,?,?,?,?)
|
||||
""", (action_name, mac, ip, port, priority, trigger, meta_json))
|
||||
|
||||
def queue_action_at(
|
||||
self,
|
||||
action_name: str,
|
||||
mac: Optional[str] = "",
|
||||
ip: Optional[str] = "",
|
||||
*,
|
||||
port: Optional[int] = None,
|
||||
hostname: Optional[str] = None,
|
||||
service: Optional[str] = None,
|
||||
priority: int = 50,
|
||||
status: str = "pending",
|
||||
scheduled_for: Optional[str] = None,
|
||||
trigger: Optional[str] = "scheduler",
|
||||
tags: Optional[Iterable[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
max_retries: Optional[int] = None,
|
||||
) -> None:
|
||||
"""Generic enqueue that can publish 'pending' or 'scheduled' items with a date"""
|
||||
js_tags = json.dumps(list(tags)) if tags is not None and not isinstance(tags, str) else (tags if isinstance(tags, str) else None)
|
||||
js_meta = json.dumps(metadata, ensure_ascii=False) if metadata else None
|
||||
self.base.execute("""
|
||||
INSERT INTO action_queue(
|
||||
action_name, mac_address, ip, port, hostname, service,
|
||||
priority, status, scheduled_for, trigger_source, tags, metadata, max_retries
|
||||
) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?)
|
||||
""", (
|
||||
action_name, mac or "", ip or "", port, hostname, service,
|
||||
int(priority), status, scheduled_for, trigger, js_tags, js_meta, max_retries
|
||||
))
|
||||
|
||||
# =========================================================================
|
||||
# HISTORY AND STATUS OPERATIONS
|
||||
# =========================================================================
|
||||
|
||||
def supersede_old_attempts(self, action_name: str, mac_address: str,
|
||||
port: Optional[int] = None, ref_ts: Optional[str] = None) -> int:
|
||||
"""
|
||||
Mark as 'superseded' all old attempts (failed|expired) for the triplet (action, mac, port)
|
||||
earlier than or equal to ref_ts (if provided). Returns affected row count.
|
||||
"""
|
||||
params: List[Any] = [action_name, mac_address, port]
|
||||
time_clause = ""
|
||||
if ref_ts:
|
||||
time_clause = " AND datetime(COALESCE(completed_at, started_at, created_at)) <= datetime(?)"
|
||||
params.append(ref_ts)
|
||||
|
||||
return self.base.execute(f"""
|
||||
UPDATE action_queue
|
||||
SET status='superseded',
|
||||
error_message = COALESCE(error_message, 'superseded by newer success'),
|
||||
completed_at = COALESCE(completed_at, CURRENT_TIMESTAMP)
|
||||
WHERE action_name = ?
|
||||
AND mac_address = ?
|
||||
AND COALESCE(port,0) = COALESCE(?,0)
|
||||
AND status IN ('failed','expired')
|
||||
{time_clause}
|
||||
""", tuple(params))
|
||||
|
||||
def list_attempt_history(self, action_name: str, mac_address: str,
|
||||
port: Optional[int] = None, limit: int = 20) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Return history of attempts for (action, mac, port), most recent first.
|
||||
"""
|
||||
return self.base.query("""
|
||||
SELECT action_name, mac_address, port, status, retry_count, max_retries,
|
||||
COALESCE(completed_at, started_at, scheduled_for, created_at) AS ts
|
||||
FROM action_queue
|
||||
WHERE action_name=? AND mac_address=? AND COALESCE(port,0)=COALESCE(?,0)
|
||||
ORDER BY datetime(ts) DESC
|
||||
LIMIT ?
|
||||
""", (action_name, mac_address, port, int(limit)))
|
||||
|
||||
def get_action_status_from_queue(
|
||||
self,
|
||||
action_name: str,
|
||||
mac_address: Optional[str] = None
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Return the latest status row for an action (optionally filtered by MAC).
|
||||
"""
|
||||
if mac_address:
|
||||
rows = self.base.query("""
|
||||
SELECT status, created_at, started_at, completed_at,
|
||||
error_message, result_summary, retry_count, max_retries,
|
||||
mac_address, port, hostname, service, priority
|
||||
FROM action_queue
|
||||
WHERE mac_address=? AND action_name=?
|
||||
ORDER BY datetime(COALESCE(completed_at, started_at, scheduled_for, created_at)) DESC
|
||||
LIMIT 1
|
||||
""", (mac_address, action_name))
|
||||
else:
|
||||
rows = self.base.query("""
|
||||
SELECT status, created_at, started_at, completed_at,
|
||||
error_message, result_summary, retry_count, max_retries,
|
||||
mac_address, port, hostname, service, priority
|
||||
FROM action_queue
|
||||
WHERE action_name=?
|
||||
ORDER BY datetime(COALESCE(completed_at, started_at, scheduled_for, created_at)) DESC
|
||||
LIMIT 1
|
||||
""", (action_name,))
|
||||
return rows[0] if rows else None
|
||||
|
||||
def get_last_action_status_from_queue(self, mac_address: str, action_name: str) -> Optional[Dict[str, str]]:
|
||||
"""
|
||||
Return {'status': 'success|failed|running|pending', 'raw': 'status_YYYYMMDD_HHMMSS'}
|
||||
based only on action_queue.
|
||||
"""
|
||||
rows = self.base.query(
|
||||
"""
|
||||
SELECT status,
|
||||
COALESCE(completed_at, started_at, scheduled_for, created_at) AS ts
|
||||
FROM action_queue
|
||||
WHERE mac_address=? AND action_name=?
|
||||
ORDER BY datetime(COALESCE(completed_at, started_at, scheduled_for, created_at)) DESC
|
||||
LIMIT 1
|
||||
""",
|
||||
(mac_address, action_name)
|
||||
)
|
||||
if not rows:
|
||||
return None
|
||||
status = rows[0]["status"]
|
||||
ts = self._format_ts_for_raw(rows[0]["ts"])
|
||||
return {"status": status, "raw": f"{status}_{ts}"}
|
||||
|
||||
def get_last_action_statuses_for_mac(self, mac_address: str) -> Dict[str, Dict[str, str]]:
|
||||
"""
|
||||
Map action_name -> {'status':..., 'raw':...} from the latest queue rows for a MAC.
|
||||
"""
|
||||
rows = self.base.query(
|
||||
"""
|
||||
SELECT action_name, status,
|
||||
COALESCE(completed_at, started_at, scheduled_for, created_at) AS ts
|
||||
FROM (
|
||||
SELECT action_name, status, completed_at, started_at, scheduled_for, created_at,
|
||||
ROW_NUMBER() OVER (
|
||||
PARTITION BY action_name
|
||||
ORDER BY datetime(COALESCE(completed_at, started_at, scheduled_for, created_at)) DESC
|
||||
) AS rn
|
||||
FROM action_queue
|
||||
WHERE mac_address=?
|
||||
)
|
||||
WHERE rn=1
|
||||
""",
|
||||
(mac_address,)
|
||||
)
|
||||
out: Dict[str, Dict[str, str]] = {}
|
||||
for r in rows:
|
||||
ts = self._format_ts_for_raw(r["ts"])
|
||||
st = r["status"]
|
||||
out[r["action_name"]] = {"status": st, "raw": f"{st}_{ts}"}
|
||||
return out
|
||||
|
||||
# =========================================================================
|
||||
# HELPER METHODS
|
||||
# =========================================================================
|
||||
|
||||
def _format_ts_for_raw(self, ts_db: Optional[str]) -> str:
|
||||
"""
|
||||
Convert SQLite 'YYYY-MM-DD HH:MM:SS' to 'YYYYMMDD_HHMMSS'.
|
||||
Fallback to current UTC when no timestamp is available.
|
||||
"""
|
||||
from datetime import datetime as _dt
|
||||
ts = (ts_db or "").strip()
|
||||
if not ts:
|
||||
return _dt.utcnow().strftime("%Y%m%d_%H%M%S")
|
||||
return ts.replace("-", "").replace(":", "").replace(" ", "_")
|
||||
62
db_utils/scripts.py
Normal file
62
db_utils/scripts.py
Normal file
@@ -0,0 +1,62 @@
|
||||
# db_utils/scripts.py
|
||||
# Script and project metadata operations
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
import logging
|
||||
|
||||
from logger import Logger
|
||||
|
||||
logger = Logger(name="db_utils.scripts", level=logging.DEBUG)
|
||||
|
||||
|
||||
class ScriptOps:
|
||||
"""Script and project metadata management operations"""
|
||||
|
||||
def __init__(self, base):
|
||||
self.base = base
|
||||
|
||||
def create_tables(self):
|
||||
"""Create scripts metadata table"""
|
||||
self.base.execute("""
|
||||
CREATE TABLE IF NOT EXISTS scripts (
|
||||
name TEXT PRIMARY KEY,
|
||||
type TEXT NOT NULL,
|
||||
path TEXT NOT NULL,
|
||||
main_file TEXT,
|
||||
category TEXT,
|
||||
description TEXT,
|
||||
created_at TEXT DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
""")
|
||||
logger.debug("Scripts table created/verified")
|
||||
|
||||
# =========================================================================
|
||||
# SCRIPT OPERATIONS
|
||||
# =========================================================================
|
||||
|
||||
def add_script(self, name: str, type_: str, path: str,
|
||||
main_file: Optional[str] = None, category: Optional[str] = None,
|
||||
description: Optional[str] = None):
|
||||
"""Insert or update a script/project metadata row"""
|
||||
self.base.execute("""
|
||||
INSERT INTO scripts(name,type,path,main_file,category,description)
|
||||
VALUES(?,?,?,?,?,?)
|
||||
ON CONFLICT(name) DO UPDATE SET
|
||||
type=excluded.type,
|
||||
path=excluded.path,
|
||||
main_file=excluded.main_file,
|
||||
category=excluded.category,
|
||||
description=excluded.description;
|
||||
""", (name, type_, path, main_file, category, description))
|
||||
|
||||
def list_scripts(self) -> List[Dict[str, Any]]:
|
||||
"""List all scripts/projects"""
|
||||
return self.base.query("""
|
||||
SELECT name, type, path, main_file, category, description, created_at
|
||||
FROM scripts
|
||||
ORDER BY name;
|
||||
""")
|
||||
|
||||
def delete_script(self, name: str) -> None:
|
||||
"""Delete a script/project metadata row by name"""
|
||||
self.base.execute("DELETE FROM scripts WHERE name=?;", (name,))
|
||||
191
db_utils/services.py
Normal file
191
db_utils/services.py
Normal file
@@ -0,0 +1,191 @@
|
||||
# db_utils/services.py
|
||||
# Per-port service fingerprinting and tracking operations
|
||||
|
||||
from typing import Dict, List, Optional
|
||||
import logging
|
||||
|
||||
from logger import Logger
|
||||
|
||||
logger = Logger(name="db_utils.services", level=logging.DEBUG)
|
||||
|
||||
|
||||
class ServiceOps:
|
||||
"""Per-port service fingerprinting and tracking operations"""
|
||||
|
||||
def __init__(self, base):
|
||||
self.base = base
|
||||
|
||||
def create_tables(self):
|
||||
"""Create port services tables"""
|
||||
# PORT SERVICES (current view of per-port fingerprinting)
|
||||
self.base.execute("""
|
||||
CREATE TABLE IF NOT EXISTS port_services (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
mac_address TEXT NOT NULL,
|
||||
ip TEXT,
|
||||
port INTEGER NOT NULL,
|
||||
protocol TEXT DEFAULT 'tcp',
|
||||
state TEXT DEFAULT 'open',
|
||||
service TEXT,
|
||||
product TEXT,
|
||||
version TEXT,
|
||||
banner TEXT,
|
||||
fingerprint TEXT,
|
||||
confidence REAL,
|
||||
source TEXT DEFAULT 'ml',
|
||||
first_seen TEXT DEFAULT CURRENT_TIMESTAMP,
|
||||
last_seen TEXT DEFAULT CURRENT_TIMESTAMP,
|
||||
is_current INTEGER DEFAULT 1,
|
||||
UNIQUE(mac_address, port, protocol)
|
||||
);
|
||||
""")
|
||||
self.base.execute("CREATE INDEX IF NOT EXISTS idx_ps_mac_port ON port_services(mac_address, port);")
|
||||
self.base.execute("CREATE INDEX IF NOT EXISTS idx_ps_state ON port_services(state);")
|
||||
|
||||
# Per-port service history (immutable log of changes)
|
||||
self.base.execute("""
|
||||
CREATE TABLE IF NOT EXISTS port_service_history (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
mac_address TEXT NOT NULL,
|
||||
ip TEXT,
|
||||
port INTEGER NOT NULL,
|
||||
protocol TEXT DEFAULT 'tcp',
|
||||
state TEXT,
|
||||
service TEXT,
|
||||
product TEXT,
|
||||
version TEXT,
|
||||
banner TEXT,
|
||||
fingerprint TEXT,
|
||||
confidence REAL,
|
||||
source TEXT,
|
||||
seen_at TEXT DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
""")
|
||||
|
||||
logger.debug("Port services tables created/verified")
|
||||
|
||||
# =========================================================================
|
||||
# SERVICE CRUD OPERATIONS
|
||||
# =========================================================================
|
||||
|
||||
def upsert_port_service(
|
||||
self,
|
||||
mac_address: str,
|
||||
ip: Optional[str],
|
||||
port: int,
|
||||
*,
|
||||
protocol: str = "tcp",
|
||||
state: str = "open",
|
||||
service: Optional[str] = None,
|
||||
product: Optional[str] = None,
|
||||
version: Optional[str] = None,
|
||||
banner: Optional[str] = None,
|
||||
fingerprint: Optional[str] = None,
|
||||
confidence: Optional[float] = None,
|
||||
source: str = "ml",
|
||||
touch_history_on_change: bool = True,
|
||||
):
|
||||
"""
|
||||
Create/update the current (service,fingerprint,...) for a given (mac,port,proto).
|
||||
Also refresh hosts.ports aggregate so legacy code keeps working.
|
||||
"""
|
||||
self.base.invalidate_stats_cache()
|
||||
|
||||
with self.base.transaction(immediate=True):
|
||||
prev = self.base.query(
|
||||
"""SELECT * FROM port_services
|
||||
WHERE mac_address=? AND port=? AND protocol=? LIMIT 1""",
|
||||
(mac_address, int(port), protocol)
|
||||
)
|
||||
|
||||
if prev:
|
||||
p = prev[0]
|
||||
changed = any([
|
||||
state != p.get("state"),
|
||||
service != p.get("service"),
|
||||
product != p.get("product"),
|
||||
version != p.get("version"),
|
||||
banner != p.get("banner"),
|
||||
fingerprint != p.get("fingerprint"),
|
||||
(confidence is not None and confidence != p.get("confidence")),
|
||||
])
|
||||
|
||||
if touch_history_on_change and changed:
|
||||
self.base.execute("""
|
||||
INSERT INTO port_service_history
|
||||
(mac_address, ip, port, protocol, state, service, product, version, banner, fingerprint, confidence, source)
|
||||
VALUES (?,?,?,?,?,?,?,?,?,?,?,?)
|
||||
""", (mac_address, ip, int(port), protocol, state, service, product, version, banner, fingerprint, confidence, source))
|
||||
|
||||
self.base.execute("""
|
||||
UPDATE port_services
|
||||
SET ip=?, state=?, service=?, product=?, version=?,
|
||||
banner=?, fingerprint=?, confidence=?, source=?,
|
||||
last_seen=CURRENT_TIMESTAMP
|
||||
WHERE mac_address=? AND port=? AND protocol=?
|
||||
""", (ip, state, service, product, version, banner, fingerprint, confidence, source,
|
||||
mac_address, int(port), protocol))
|
||||
else:
|
||||
self.base.execute("""
|
||||
INSERT INTO port_services
|
||||
(mac_address, ip, port, protocol, state, service, product, version, banner, fingerprint, confidence, source)
|
||||
VALUES (?,?,?,?,?,?,?,?,?,?,?,?)
|
||||
""", (mac_address, ip, int(port), protocol, state, service, product, version, banner, fingerprint, confidence, source))
|
||||
|
||||
# Rebuild host ports for compatibility
|
||||
self._rebuild_host_ports(mac_address)
|
||||
|
||||
def _rebuild_host_ports(self, mac_address: str):
|
||||
"""Rebuild hosts.ports from current port_services where state='open' (tcp only)"""
|
||||
row = self.base.query("SELECT ports, previous_ports FROM hosts WHERE mac_address=? LIMIT 1;", (mac_address,))
|
||||
old_ports = set(int(p) for p in (row[0]["ports"].split(";") if row and row[0].get("ports") else []) if str(p).isdigit())
|
||||
old_prev = set(int(p) for p in (row[0]["previous_ports"].split(";") if row and row[0].get("previous_ports") else []) if str(p).isdigit())
|
||||
|
||||
current_rows = self.base.query(
|
||||
"SELECT port FROM port_services WHERE mac_address=? AND state='open' AND protocol='tcp'",
|
||||
(mac_address,)
|
||||
)
|
||||
new_ports = set(int(r["port"]) for r in current_rows)
|
||||
|
||||
removed = old_ports - new_ports
|
||||
new_prev = old_prev | removed
|
||||
|
||||
ports_txt = ";".join(str(p) for p in sorted(new_ports))
|
||||
prev_txt = ";".join(str(p) for p in sorted(new_prev))
|
||||
|
||||
self.base.execute("""
|
||||
INSERT INTO hosts(mac_address, ports, previous_ports, updated_at)
|
||||
VALUES(?,?,?,CURRENT_TIMESTAMP)
|
||||
ON CONFLICT(mac_address) DO UPDATE SET
|
||||
ports = excluded.ports,
|
||||
previous_ports = excluded.previous_ports,
|
||||
updated_at = CURRENT_TIMESTAMP;
|
||||
""", (mac_address, ports_txt, prev_txt))
|
||||
|
||||
# =========================================================================
|
||||
# SERVICE QUERY OPERATIONS
|
||||
# =========================================================================
|
||||
|
||||
def get_services_for_host(self, mac_address: str) -> List[Dict]:
|
||||
"""Return all per-port service rows for the given host, ordered by port"""
|
||||
return self.base.query("""
|
||||
SELECT port, protocol, state, service, product, version, confidence, last_seen
|
||||
FROM port_services
|
||||
WHERE mac_address=?
|
||||
ORDER BY port
|
||||
""", (mac_address,))
|
||||
|
||||
def find_hosts_by_service(self, service: str) -> List[Dict]:
|
||||
"""Return distinct host MACs that expose the given service (state='open')"""
|
||||
return self.base.query("""
|
||||
SELECT DISTINCT mac_address FROM port_services
|
||||
WHERE service=? AND state='open'
|
||||
""", (service,))
|
||||
|
||||
def get_service_for_host_port(self, mac_address: str, port: int, protocol: str = "tcp") -> Optional[Dict]:
|
||||
"""Return the single port_services row for (mac, port, protocol), if any"""
|
||||
rows = self.base.query("""
|
||||
SELECT * FROM port_services
|
||||
WHERE mac_address=? AND port=? AND protocol=? LIMIT 1
|
||||
""", (mac_address, int(port), protocol))
|
||||
return rows[0] if rows else None
|
||||
157
db_utils/software.py
Normal file
157
db_utils/software.py
Normal file
@@ -0,0 +1,157 @@
|
||||
# db_utils/software.py
|
||||
# Detected software (CPE) inventory operations
|
||||
|
||||
from typing import List, Optional
|
||||
import logging
|
||||
|
||||
from logger import Logger
|
||||
|
||||
logger = Logger(name="db_utils.software", level=logging.DEBUG)
|
||||
|
||||
|
||||
class SoftwareOps:
|
||||
"""Detected software (CPE) tracking operations"""
|
||||
|
||||
def __init__(self, base):
|
||||
self.base = base
|
||||
|
||||
def create_tables(self):
|
||||
"""Create detected software tables"""
|
||||
self.base.execute("""
|
||||
CREATE TABLE IF NOT EXISTS detected_software (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
mac_address TEXT NOT NULL,
|
||||
ip TEXT,
|
||||
hostname TEXT,
|
||||
port INTEGER NOT NULL DEFAULT 0,
|
||||
cpe TEXT NOT NULL,
|
||||
first_seen TEXT DEFAULT CURRENT_TIMESTAMP,
|
||||
last_seen TEXT DEFAULT CURRENT_TIMESTAMP,
|
||||
is_active INTEGER DEFAULT 1,
|
||||
UNIQUE(mac_address, port, cpe)
|
||||
);
|
||||
""")
|
||||
|
||||
# Migration for detected_software
|
||||
self.base.execute("""
|
||||
UPDATE detected_software SET port = 0 WHERE port IS NULL
|
||||
""")
|
||||
|
||||
# Detected software history (immutable log)
|
||||
self.base.execute("""
|
||||
CREATE TABLE IF NOT EXISTS detected_software_history (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
mac_address TEXT NOT NULL,
|
||||
ip TEXT,
|
||||
hostname TEXT,
|
||||
port INTEGER NOT NULL DEFAULT 0,
|
||||
cpe TEXT NOT NULL,
|
||||
event TEXT NOT NULL,
|
||||
seen_at TEXT DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
""")
|
||||
|
||||
logger.debug("Software detection tables created/verified")
|
||||
|
||||
# =========================================================================
|
||||
# SOFTWARE CRUD OPERATIONS
|
||||
# =========================================================================
|
||||
|
||||
def add_detected_software(self, mac_address: str, cpe: str, ip: Optional[str] = None,
|
||||
hostname: Optional[str] = None, port: Optional[int] = None) -> None:
|
||||
"""Upsert a (mac, port, cpe) tuple and record history (new/seen)"""
|
||||
p = int(port or 0)
|
||||
existed = self.base.query(
|
||||
"SELECT id FROM detected_software WHERE mac_address=? AND port=? AND cpe=? LIMIT 1",
|
||||
(mac_address, p, cpe)
|
||||
)
|
||||
if existed:
|
||||
self.base.execute("""
|
||||
UPDATE detected_software
|
||||
SET ip=COALESCE(?, detected_software.ip),
|
||||
hostname=COALESCE(?, detected_software.hostname),
|
||||
last_seen=CURRENT_TIMESTAMP,
|
||||
is_active=1
|
||||
WHERE mac_address=? AND port=? AND cpe=?
|
||||
""", (ip, hostname, mac_address, p, cpe))
|
||||
self.base.execute("""
|
||||
INSERT INTO detected_software_history(mac_address, ip, hostname, port, cpe, event)
|
||||
VALUES(?,?,?,?,?,'seen')
|
||||
""", (mac_address, ip, hostname, p, cpe))
|
||||
else:
|
||||
self.base.execute("""
|
||||
INSERT INTO detected_software(mac_address, ip, hostname, port, cpe, is_active)
|
||||
VALUES(?,?,?,?,?,1)
|
||||
""", (mac_address, ip, hostname, p, cpe))
|
||||
self.base.execute("""
|
||||
INSERT INTO detected_software_history(mac_address, ip, hostname, port, cpe, event)
|
||||
VALUES(?,?,?,?,?,'new')
|
||||
""", (mac_address, ip, hostname, p, cpe))
|
||||
|
||||
def update_detected_software_status(self, mac_address: str, current_cpes: List[str]) -> None:
|
||||
"""Mark absent CPEs as inactive, present ones as seen, insert new ones as needed"""
|
||||
rows = self.base.query(
|
||||
"SELECT cpe FROM detected_software WHERE mac_address=? AND is_active=1",
|
||||
(mac_address,)
|
||||
)
|
||||
existing = {r['cpe'] for r in rows}
|
||||
cur = set(current_cpes)
|
||||
|
||||
# Inactive
|
||||
for cpe in (existing - cur):
|
||||
self.base.execute("""
|
||||
UPDATE detected_software
|
||||
SET is_active=0, last_seen=CURRENT_TIMESTAMP
|
||||
WHERE mac_address=? AND cpe=? AND is_active=1
|
||||
""", (mac_address, cpe))
|
||||
self.base.execute("""
|
||||
INSERT INTO detected_software_history(mac_address, port, cpe, event)
|
||||
SELECT mac_address, port, cpe, 'inactive'
|
||||
FROM detected_software
|
||||
WHERE mac_address=? AND cpe=? LIMIT 1
|
||||
""", (mac_address, cpe))
|
||||
|
||||
# New
|
||||
for cpe in (cur - existing):
|
||||
self.add_detected_software(mac_address, cpe)
|
||||
|
||||
# Seen
|
||||
for cpe in (cur & existing):
|
||||
self.base.execute("""
|
||||
UPDATE detected_software
|
||||
SET last_seen=CURRENT_TIMESTAMP
|
||||
WHERE mac_address=? AND cpe=? AND is_active=1
|
||||
""", (mac_address, cpe))
|
||||
self.base.execute("""
|
||||
INSERT INTO detected_software_history(mac_address, port, cpe, event)
|
||||
SELECT mac_address, port, cpe, 'seen'
|
||||
FROM detected_software
|
||||
WHERE mac_address=? AND cpe=? LIMIT 1
|
||||
""", (mac_address, cpe))
|
||||
|
||||
# =========================================================================
|
||||
# MIGRATION HELPER
|
||||
# =========================================================================
|
||||
|
||||
def migrate_cpe_from_vulnerabilities(self) -> int:
|
||||
"""
|
||||
Migrate historical CPE entries wrongly stored in `vulnerabilities.vuln_id`
|
||||
into `detected_software`. Returns the number of rows migrated.
|
||||
"""
|
||||
rows = self.base.query("""
|
||||
SELECT id, mac_address, ip, hostname, COALESCE(port,0) AS port, vuln_id
|
||||
FROM vulnerabilities
|
||||
WHERE LOWER(vuln_id) LIKE 'cpe:%' OR UPPER(vuln_id) LIKE 'CPE:%'
|
||||
""")
|
||||
moved = 0
|
||||
for r in rows:
|
||||
vid = r['vuln_id']
|
||||
cpe = vid.split(':', 1)[1] if vid.upper().startswith('CPE:') else vid
|
||||
try:
|
||||
self.add_detected_software(r['mac_address'], cpe, r.get('ip'), r.get('hostname'), r.get('port'))
|
||||
self.base.execute("DELETE FROM vulnerabilities WHERE id=?", (r['id'],))
|
||||
moved += 1
|
||||
except Exception:
|
||||
# Best-effort migration; keep moving on errors
|
||||
pass
|
||||
return moved
|
||||
155
db_utils/stats.py
Normal file
155
db_utils/stats.py
Normal file
@@ -0,0 +1,155 @@
|
||||
# db_utils/stats.py
|
||||
# Statistics tracking and display operations
|
||||
|
||||
import time
|
||||
import sqlite3
|
||||
from typing import Dict
|
||||
import logging
|
||||
|
||||
from logger import Logger
|
||||
|
||||
logger = Logger(name="db_utils.stats", level=logging.DEBUG)
|
||||
|
||||
|
||||
class StatsOps:
|
||||
"""Statistics tracking and display operations"""
|
||||
|
||||
def __init__(self, base):
|
||||
self.base = base
|
||||
|
||||
def create_tables(self):
|
||||
"""Create stats table"""
|
||||
self.base.execute("""
|
||||
CREATE TABLE IF NOT EXISTS stats (
|
||||
id INTEGER PRIMARY KEY CHECK (id=1),
|
||||
total_open_ports INTEGER DEFAULT 0,
|
||||
alive_hosts_count INTEGER DEFAULT 0,
|
||||
all_known_hosts_count INTEGER DEFAULT 0,
|
||||
vulnerabilities_count INTEGER DEFAULT 0,
|
||||
actions_count INTEGER DEFAULT 0,
|
||||
zombie_count INTEGER DEFAULT 0
|
||||
);
|
||||
""")
|
||||
logger.debug("Stats table created/verified")
|
||||
|
||||
def ensure_stats_initialized(self):
|
||||
"""Ensure the singleton row in `stats` exists"""
|
||||
row = self.base.query("SELECT 1 FROM stats WHERE id=1")
|
||||
if not row:
|
||||
self.base.execute("INSERT INTO stats(id) VALUES(1);")
|
||||
|
||||
# =========================================================================
|
||||
# STATS OPERATIONS
|
||||
# =========================================================================
|
||||
|
||||
def get_livestats(self) -> Dict[str, int]:
|
||||
"""Return the live counters maintained in the `stats` singleton row"""
|
||||
row = self.base.query("""
|
||||
SELECT total_open_ports, alive_hosts_count, all_known_hosts_count, vulnerabilities_count
|
||||
FROM stats WHERE id=1
|
||||
""")[0]
|
||||
return {
|
||||
"total_open_ports": int(row["total_open_ports"]),
|
||||
"alive_hosts_count": int(row["alive_hosts_count"]),
|
||||
"all_known_hosts_count": int(row["all_known_hosts_count"]),
|
||||
"vulnerabilities_count": int(row["vulnerabilities_count"]),
|
||||
}
|
||||
|
||||
def update_livestats(self, total_open_ports: int, alive_hosts_count: int,
|
||||
all_known_hosts_count: int, vulnerabilities_count: int):
|
||||
"""Update the live stats counters (touch in-place)"""
|
||||
self.base.invalidate_stats_cache()
|
||||
self.base.execute("""
|
||||
UPDATE stats
|
||||
SET total_open_ports = ?,
|
||||
alive_hosts_count = ?,
|
||||
all_known_hosts_count = ?,
|
||||
vulnerabilities_count = ?
|
||||
WHERE id = 1;
|
||||
""", (int(total_open_ports), int(alive_hosts_count),
|
||||
int(all_known_hosts_count), int(vulnerabilities_count)))
|
||||
|
||||
def get_stats(self) -> Dict[str, int]:
|
||||
"""Compatibility alias to retrieve stats; ensures the singleton row exists"""
|
||||
row = self.base.query("SELECT total_open_ports, alive_hosts_count, all_known_hosts_count, vulnerabilities_count FROM stats WHERE id=1;")
|
||||
if not row:
|
||||
self.ensure_stats_initialized()
|
||||
row = self.base.query("SELECT total_open_ports, alive_hosts_count, all_known_hosts_count, vulnerabilities_count FROM stats WHERE id=1;")
|
||||
r = row[0]
|
||||
return {
|
||||
"total_open_ports": int(r["total_open_ports"]),
|
||||
"alive_hosts_count": int(r["alive_hosts_count"]),
|
||||
"all_known_hosts_count": int(r["all_known_hosts_count"]),
|
||||
"vulnerabilities_count": int(r["vulnerabilities_count"]),
|
||||
}
|
||||
|
||||
def set_stats(self, total_open_ports: int, alive_hosts_count: int,
|
||||
all_known_hosts_count: int, vulnerabilities_count: int):
|
||||
"""Compatibility alias that forwards to update_livestats"""
|
||||
self.update_livestats(total_open_ports, alive_hosts_count, all_known_hosts_count, vulnerabilities_count)
|
||||
|
||||
def get_display_stats(self) -> Dict[str, int]:
|
||||
"""
|
||||
Cached bundle of counters for quick UI refresh using stats table.
|
||||
"""
|
||||
now = time.time()
|
||||
|
||||
# Serve from cache when valid
|
||||
if self.base._stats_cache['data'] and (now - self.base._stats_cache['timestamp']) < self.base._cache_ttl:
|
||||
return self.base._stats_cache['data'].copy()
|
||||
|
||||
# Compute fresh counters
|
||||
with self.base._lock:
|
||||
try:
|
||||
# Use stats table for pre-calculated values
|
||||
result = self.base.query_one("""
|
||||
SELECT
|
||||
s.total_open_ports,
|
||||
s.alive_hosts_count,
|
||||
s.all_known_hosts_count,
|
||||
s.vulnerabilities_count,
|
||||
COALESCE(s.actions_count,
|
||||
(SELECT COUNT(*) FROM actions WHERE b_enabled = 1)
|
||||
) as actions_count,
|
||||
COALESCE(s.zombie_count, 0) as zombie_count,
|
||||
(SELECT COUNT(*) FROM creds) as creds
|
||||
FROM stats s
|
||||
WHERE s.id = 1
|
||||
""")
|
||||
|
||||
if result:
|
||||
stats = {
|
||||
'alive_hosts_count': int(result['alive_hosts_count'] or 0),
|
||||
'all_known_hosts_count': int(result['all_known_hosts_count'] or 0),
|
||||
'total_open_ports': int(result['total_open_ports'] or 0),
|
||||
'vulnerabilities_count': int(result['vulnerabilities_count'] or 0),
|
||||
'credentials_count': int(result['creds'] or 0),
|
||||
'actions_count': int(result['actions_count'] or 0),
|
||||
'zombie_count': int(result['zombie_count'] or 0)
|
||||
}
|
||||
else:
|
||||
# Fallback if no stats row
|
||||
stats = {
|
||||
'alive_hosts_count': 0,
|
||||
'all_known_hosts_count': 0,
|
||||
'total_open_ports': 0,
|
||||
'vulnerabilities_count': 0,
|
||||
'credentials_count': 0,
|
||||
'actions_count': 0,
|
||||
'zombie_count': 0
|
||||
}
|
||||
|
||||
# Update cache
|
||||
self.base._stats_cache = {'data': stats, 'timestamp': now}
|
||||
return stats
|
||||
|
||||
except Exception:
|
||||
return {
|
||||
'alive_hosts_count': 0,
|
||||
'all_known_hosts_count': 0,
|
||||
'total_open_ports': 0,
|
||||
'vulnerabilities_count': 0,
|
||||
'credentials_count': 0,
|
||||
'actions_count': 0,
|
||||
'zombie_count': 0
|
||||
}
|
||||
332
db_utils/studio.py
Normal file
332
db_utils/studio.py
Normal file
@@ -0,0 +1,332 @@
|
||||
# db_utils/studio.py
|
||||
# Actions Studio visual editor operations
|
||||
|
||||
import json
|
||||
from typing import Dict, List, Optional
|
||||
import logging
|
||||
|
||||
from logger import Logger
|
||||
|
||||
logger = Logger(name="db_utils.studio", level=logging.DEBUG)
|
||||
|
||||
|
||||
class StudioOps:
|
||||
"""Actions Studio visual editor and workflow operations"""
|
||||
|
||||
def __init__(self, base):
|
||||
self.base = base
|
||||
|
||||
def create_tables(self):
|
||||
"""Create Actions Studio tables"""
|
||||
# Studio actions (extended action metadata for visual editor)
|
||||
self.base.execute("""
|
||||
CREATE TABLE IF NOT EXISTS actions_studio (
|
||||
b_class TEXT PRIMARY KEY,
|
||||
studio_x REAL,
|
||||
studio_y REAL,
|
||||
studio_locked INTEGER DEFAULT 0,
|
||||
studio_color TEXT,
|
||||
studio_metadata TEXT,
|
||||
updated_at TEXT DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
""")
|
||||
|
||||
# Studio edges (relationships between actions)
|
||||
self.base.execute("""
|
||||
CREATE TABLE IF NOT EXISTS studio_edges (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
from_action TEXT NOT NULL,
|
||||
to_action TEXT NOT NULL,
|
||||
edge_type TEXT DEFAULT 'requires',
|
||||
edge_label TEXT,
|
||||
edge_metadata TEXT,
|
||||
created_at TEXT DEFAULT CURRENT_TIMESTAMP,
|
||||
FOREIGN KEY (from_action) REFERENCES actions_studio(b_class) ON DELETE CASCADE,
|
||||
FOREIGN KEY (to_action) REFERENCES actions_studio(b_class) ON DELETE CASCADE,
|
||||
UNIQUE(from_action, to_action, edge_type)
|
||||
);
|
||||
""")
|
||||
|
||||
# Studio hosts (hosts for test mode)
|
||||
self.base.execute("""
|
||||
CREATE TABLE IF NOT EXISTS studio_hosts (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
mac_address TEXT UNIQUE NOT NULL,
|
||||
ips TEXT,
|
||||
hostnames TEXT,
|
||||
alive INTEGER DEFAULT 1,
|
||||
ports TEXT,
|
||||
services TEXT,
|
||||
vulns TEXT,
|
||||
creds TEXT,
|
||||
studio_x REAL,
|
||||
studio_y REAL,
|
||||
is_simulated INTEGER DEFAULT 1,
|
||||
metadata TEXT,
|
||||
created_at TEXT DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
""")
|
||||
|
||||
# Studio layouts (saved layout snapshots)
|
||||
self.base.execute("""
|
||||
CREATE TABLE IF NOT EXISTS studio_layouts (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
name TEXT UNIQUE NOT NULL,
|
||||
description TEXT,
|
||||
layout_data TEXT NOT NULL,
|
||||
screenshot BLOB,
|
||||
is_active INTEGER DEFAULT 0,
|
||||
created_at TEXT DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TEXT DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
""")
|
||||
|
||||
logger.debug("Actions Studio tables created/verified")
|
||||
|
||||
# =========================================================================
|
||||
# STUDIO ACTION OPERATIONS
|
||||
# =========================================================================
|
||||
|
||||
def get_studio_actions(self):
|
||||
"""Retrieve all studio actions with their positions"""
|
||||
return self.base.query("""
|
||||
SELECT * FROM actions_studio
|
||||
ORDER BY b_priority DESC, b_class
|
||||
""")
|
||||
|
||||
def get_db_actions(self):
|
||||
"""Retrieve all actions from the main actions table"""
|
||||
return self.base.query("""
|
||||
SELECT * FROM actions
|
||||
ORDER BY b_priority DESC, b_class
|
||||
""")
|
||||
|
||||
def update_studio_action(self, b_class: str, updates: dict):
|
||||
"""Update a studio action"""
|
||||
sets = []
|
||||
params = []
|
||||
for key, value in updates.items():
|
||||
sets.append(f"{key} = ?")
|
||||
params.append(value)
|
||||
params.append(b_class)
|
||||
|
||||
self.base.execute(f"""
|
||||
UPDATE actions_studio
|
||||
SET {', '.join(sets)}, updated_at = CURRENT_TIMESTAMP
|
||||
WHERE b_class = ?
|
||||
""", params)
|
||||
|
||||
# =========================================================================
|
||||
# STUDIO EDGE OPERATIONS
|
||||
# =========================================================================
|
||||
|
||||
def get_studio_edges(self):
|
||||
"""Retrieve all studio edges"""
|
||||
return self.base.query("SELECT * FROM studio_edges")
|
||||
|
||||
def upsert_studio_edge(self, from_action: str, to_action: str, edge_type: str, metadata: dict = None):
|
||||
"""Create or update a studio edge"""
|
||||
meta_json = json.dumps(metadata) if metadata else None
|
||||
# Try UPDATE first
|
||||
updated = self.base.execute("""
|
||||
UPDATE studio_edges
|
||||
SET edge_metadata = ?
|
||||
WHERE from_action = ? AND to_action = ? AND edge_type = ?
|
||||
""", (meta_json, from_action, to_action, edge_type))
|
||||
if not updated:
|
||||
# If no rows updated, INSERT
|
||||
self.base.execute("""
|
||||
INSERT OR IGNORE INTO studio_edges(from_action, to_action, edge_type, edge_metadata)
|
||||
VALUES(?,?,?,?)
|
||||
""", (from_action, to_action, edge_type, meta_json))
|
||||
|
||||
def delete_studio_edge(self, edge_id: int):
|
||||
"""Delete a studio edge"""
|
||||
self.base.execute("DELETE FROM studio_edges WHERE id = ?", (edge_id,))
|
||||
|
||||
# =========================================================================
|
||||
# STUDIO HOST OPERATIONS
|
||||
# =========================================================================
|
||||
|
||||
def get_studio_hosts(self, include_real: bool = True):
|
||||
"""Retrieve studio hosts"""
|
||||
if include_real:
|
||||
# Combine real and simulated hosts
|
||||
return self.base.query("""
|
||||
SELECT mac_address, ips, hostnames, alive, ports,
|
||||
NULL as services, NULL as vulns, NULL as creds,
|
||||
NULL as studio_x, NULL as studio_y, 0 as is_simulated
|
||||
FROM hosts
|
||||
UNION ALL
|
||||
SELECT mac_address, ips, hostnames, alive, ports,
|
||||
services, vulns, creds, studio_x, studio_y, is_simulated
|
||||
FROM studio_hosts
|
||||
""")
|
||||
else:
|
||||
return self.base.query("SELECT * FROM studio_hosts WHERE is_simulated = 1")
|
||||
|
||||
def upsert_studio_host(self, mac_address: str, data: dict):
|
||||
"""Create or update a simulated host"""
|
||||
self.base.execute("""
|
||||
INSERT INTO studio_hosts (
|
||||
mac_address, ips, hostnames, alive, ports, services,
|
||||
vulns, creds, studio_x, studio_y, is_simulated, metadata
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT(mac_address) DO UPDATE SET
|
||||
ips = excluded.ips,
|
||||
hostnames = excluded.hostnames,
|
||||
alive = excluded.alive,
|
||||
ports = excluded.ports,
|
||||
services = excluded.services,
|
||||
vulns = excluded.vulns,
|
||||
creds = excluded.creds,
|
||||
studio_x = excluded.studio_x,
|
||||
studio_y = excluded.studio_y,
|
||||
metadata = excluded.metadata
|
||||
""", (
|
||||
mac_address,
|
||||
data.get('ips'),
|
||||
data.get('hostnames'),
|
||||
data.get('alive', 1),
|
||||
data.get('ports'),
|
||||
json.dumps(data.get('services', [])),
|
||||
json.dumps(data.get('vulns', [])),
|
||||
json.dumps(data.get('creds', [])),
|
||||
data.get('studio_x'),
|
||||
data.get('studio_y'),
|
||||
1, # is_simulated
|
||||
json.dumps(data.get('metadata', {}))
|
||||
))
|
||||
|
||||
def delete_studio_host(self, mac: str):
|
||||
"""Delete a studio host"""
|
||||
self.base.execute("DELETE FROM studio_hosts WHERE mac_address = ?", (mac,))
|
||||
|
||||
# =========================================================================
|
||||
# STUDIO LAYOUT OPERATIONS
|
||||
# =========================================================================
|
||||
|
||||
def save_studio_layout(self, name: str, layout_data: dict, description: str = None):
|
||||
"""Save a complete layout"""
|
||||
self.base.execute("""
|
||||
INSERT INTO studio_layouts (name, description, layout_data)
|
||||
VALUES (?, ?, ?)
|
||||
ON CONFLICT(name) DO UPDATE SET
|
||||
description = excluded.description,
|
||||
layout_data = excluded.layout_data,
|
||||
updated_at = CURRENT_TIMESTAMP
|
||||
""", (name, description, json.dumps(layout_data)))
|
||||
|
||||
def load_studio_layout(self, name: str):
|
||||
"""Load a saved layout"""
|
||||
row = self.base.query_one("SELECT * FROM studio_layouts WHERE name = ?", (name,))
|
||||
if row:
|
||||
row['layout_data'] = json.loads(row['layout_data'])
|
||||
return row
|
||||
|
||||
# =========================================================================
|
||||
# STUDIO SYNC OPERATIONS
|
||||
# =========================================================================
|
||||
|
||||
def apply_studio_to_runtime(self):
|
||||
"""Apply studio configurations to the main actions table"""
|
||||
self.base.execute("""
|
||||
UPDATE actions
|
||||
SET
|
||||
b_trigger = (SELECT b_trigger FROM actions_studio WHERE actions_studio.b_class = actions.b_class),
|
||||
b_requires = (SELECT b_requires FROM actions_studio WHERE actions_studio.b_class = actions.b_class),
|
||||
b_priority = (SELECT b_priority FROM actions_studio WHERE actions_studio.b_class = actions.b_class),
|
||||
b_enabled = (SELECT b_enabled FROM actions_studio WHERE actions_studio.b_class = actions.b_class),
|
||||
b_timeout = (SELECT b_timeout FROM actions_studio WHERE actions_studio.b_class = actions.b_class),
|
||||
b_max_retries = (SELECT b_max_retries FROM actions_studio WHERE actions_studio.b_class = actions.b_class),
|
||||
b_cooldown = (SELECT b_cooldown FROM actions_studio WHERE actions_studio.b_class = actions.b_class),
|
||||
b_rate_limit = (SELECT b_rate_limit FROM actions_studio WHERE actions_studio.b_class = actions.b_class),
|
||||
b_service = (SELECT b_service FROM actions_studio WHERE actions_studio.b_class = actions.b_class),
|
||||
b_port = (SELECT b_port FROM actions_studio WHERE actions_studio.b_class = actions.b_class)
|
||||
WHERE b_class IN (SELECT b_class FROM actions_studio)
|
||||
""")
|
||||
|
||||
def _replace_actions_studio_with_actions(self, vacuum: bool = False):
|
||||
"""
|
||||
Reset actions_studio (delete all rows) then resync from actions via _sync_actions_studio_schema_and_rows().
|
||||
Optionally run VACUUM.
|
||||
"""
|
||||
# Ensure table exists so DELETE doesn't fail
|
||||
self.base.execute("""
|
||||
CREATE TABLE IF NOT EXISTS actions_studio (
|
||||
b_class TEXT PRIMARY KEY,
|
||||
studio_x REAL,
|
||||
studio_y REAL,
|
||||
studio_locked INTEGER DEFAULT 0,
|
||||
studio_color TEXT,
|
||||
studio_metadata TEXT,
|
||||
updated_at TEXT DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
""")
|
||||
|
||||
# Total purge
|
||||
self.base.execute("DELETE FROM actions_studio;")
|
||||
|
||||
# Optional compaction
|
||||
if vacuum:
|
||||
self.base.execute("VACUUM;")
|
||||
|
||||
# Non-destructive resynchronization from actions
|
||||
self._sync_actions_studio_schema_and_rows()
|
||||
|
||||
def _sync_actions_studio_schema_and_rows(self):
|
||||
"""
|
||||
Sync actions_studio with actions table:
|
||||
- Create minimal table if needed
|
||||
- Add missing columns from actions
|
||||
- Insert missing b_class entries
|
||||
- Update NULL fields only (non-destructive)
|
||||
"""
|
||||
# 1) Minimal table: PK + studio_* columns
|
||||
self.base.execute("""
|
||||
CREATE TABLE IF NOT EXISTS actions_studio (
|
||||
b_class TEXT PRIMARY KEY,
|
||||
studio_x REAL,
|
||||
studio_y REAL,
|
||||
studio_locked INTEGER DEFAULT 0,
|
||||
studio_color TEXT,
|
||||
studio_metadata TEXT,
|
||||
updated_at TEXT DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
""")
|
||||
|
||||
# 2) Dynamically add all columns from actions that are missing in actions_studio
|
||||
act_cols = [r["name"] for r in self.base.query("PRAGMA table_info(actions);")]
|
||||
stu_cols = [r["name"] for r in self.base.query("PRAGMA table_info(actions_studio);")]
|
||||
|
||||
# Get column types from actions
|
||||
act_col_defs = {r["name"]: r["type"] for r in self.base.query("PRAGMA table_info(actions);")}
|
||||
|
||||
for col in act_cols:
|
||||
if col == "b_class":
|
||||
continue
|
||||
if col not in stu_cols:
|
||||
col_type = act_col_defs.get(col, "TEXT") or "TEXT"
|
||||
self.base.execute(f"ALTER TABLE actions_studio ADD COLUMN {col} {col_type};")
|
||||
|
||||
# 3) Insert missing b_class entries, non-destructive
|
||||
self.base.execute("""
|
||||
INSERT OR IGNORE INTO actions_studio (b_class)
|
||||
SELECT b_class FROM actions;
|
||||
""")
|
||||
|
||||
# 4) Pre-fill only NULL fields from actions (without overwriting)
|
||||
for col in act_cols:
|
||||
if col == "b_class":
|
||||
continue
|
||||
# Only update if the studio value is NULL
|
||||
self.base.execute(f"""
|
||||
UPDATE actions_studio
|
||||
SET {col} = (SELECT a.{col} FROM actions a
|
||||
WHERE a.b_class = actions_studio.b_class)
|
||||
WHERE {col} IS NULL
|
||||
AND EXISTS (SELECT 1 FROM actions a WHERE a.b_class = actions_studio.b_class);
|
||||
""")
|
||||
|
||||
# 5) Touch timestamp
|
||||
self.base.execute("UPDATE actions_studio SET updated_at = CURRENT_TIMESTAMP;")
|
||||
533
db_utils/vulnerabilities.py
Normal file
533
db_utils/vulnerabilities.py
Normal file
@@ -0,0 +1,533 @@
|
||||
# db_utils/vulnerabilities.py
|
||||
# Vulnerability tracking and CVE metadata operations
|
||||
|
||||
import json
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional
|
||||
import logging
|
||||
|
||||
from logger import Logger
|
||||
|
||||
logger = Logger(name="db_utils.vulnerabilities", level=logging.DEBUG)
|
||||
|
||||
|
||||
class VulnerabilityOps:
|
||||
"""Vulnerability tracking and CVE metadata operations"""
|
||||
|
||||
def __init__(self, base):
|
||||
self.base = base
|
||||
|
||||
def create_tables(self):
|
||||
"""Create vulnerability and CVE metadata tables"""
|
||||
# CVE metadata cache (NVD/MITRE/EPSS/KEV + Exploit-DB)
|
||||
self.base.execute("""
|
||||
CREATE TABLE IF NOT EXISTS cve_meta (
|
||||
cve_id TEXT PRIMARY KEY,
|
||||
description TEXT,
|
||||
cvss_json TEXT,
|
||||
references_json TEXT,
|
||||
last_modified TEXT,
|
||||
affected_json TEXT,
|
||||
solution TEXT,
|
||||
exploits_json TEXT,
|
||||
is_kev INTEGER DEFAULT 0,
|
||||
epss REAL,
|
||||
epss_percentile REAL,
|
||||
updated_at INTEGER
|
||||
);
|
||||
""")
|
||||
self.base.execute("CREATE INDEX IF NOT EXISTS idx_cve_meta_updated ON cve_meta(updated_at);")
|
||||
|
||||
# Vulnerabilities table
|
||||
self.base.execute("""
|
||||
CREATE TABLE IF NOT EXISTS vulnerabilities (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
mac_address TEXT NOT NULL,
|
||||
ip TEXT,
|
||||
hostname TEXT,
|
||||
port INTEGER NOT NULL DEFAULT 0,
|
||||
vuln_id TEXT NOT NULL,
|
||||
previous_vulns TEXT,
|
||||
first_seen TEXT DEFAULT CURRENT_TIMESTAMP,
|
||||
last_seen TEXT DEFAULT CURRENT_TIMESTAMP,
|
||||
is_active INTEGER DEFAULT 1
|
||||
);
|
||||
""")
|
||||
|
||||
# Unique index without COALESCE since port is now NOT NULL
|
||||
self.base.execute("""
|
||||
DROP INDEX IF EXISTS uq_vuln_identity;
|
||||
""")
|
||||
|
||||
self.base.execute("""
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS uq_vuln_identity
|
||||
ON vulnerabilities(mac_address, vuln_id, port);
|
||||
""")
|
||||
|
||||
# Migration: convert NULL to 0
|
||||
self.base.execute("""
|
||||
UPDATE vulnerabilities SET port = 0 WHERE port IS NULL;
|
||||
""")
|
||||
|
||||
# Cleanup real duplicates after migration
|
||||
self.base.execute("""
|
||||
DELETE FROM vulnerabilities
|
||||
WHERE rowid NOT IN (
|
||||
SELECT MIN(rowid)
|
||||
FROM vulnerabilities
|
||||
GROUP BY mac_address, vuln_id, port
|
||||
);
|
||||
""")
|
||||
|
||||
self.base.execute("CREATE INDEX IF NOT EXISTS idx_vuln_active ON vulnerabilities(is_active) WHERE is_active=1;")
|
||||
self.base.execute("CREATE INDEX IF NOT EXISTS idx_vuln_mac_port ON vulnerabilities(mac_address, port);")
|
||||
|
||||
# Vulnerability history (immutable log)
|
||||
self.base.execute("""
|
||||
CREATE TABLE IF NOT EXISTS vulnerability_history (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
mac_address TEXT NOT NULL,
|
||||
ip TEXT,
|
||||
hostname TEXT,
|
||||
port INTEGER,
|
||||
vuln_id TEXT NOT NULL,
|
||||
event TEXT NOT NULL,
|
||||
seen_at TEXT DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
""")
|
||||
|
||||
logger.debug("Vulnerability tables created/verified")
|
||||
|
||||
# =========================================================================
|
||||
# CVE METADATA OPERATIONS
|
||||
# =========================================================================
|
||||
|
||||
def get_cve_meta(self, cve_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get CVE metadata from cache"""
|
||||
row = self.base.query_one("SELECT * FROM cve_meta WHERE cve_id=? LIMIT 1;", (cve_id,))
|
||||
if not row:
|
||||
return None
|
||||
# Deserialize JSON fields
|
||||
for k in ("cvss_json", "references_json", "affected_json", "exploits_json"):
|
||||
if row.get(k):
|
||||
try:
|
||||
row[k] = json.loads(row[k])
|
||||
except Exception:
|
||||
row[k] = None
|
||||
return row
|
||||
|
||||
def upsert_cve_meta(self, meta: Dict[str, Any]) -> None:
|
||||
"""Insert or update CVE metadata"""
|
||||
# Serialize JSON fields
|
||||
cvss = json.dumps(meta.get("cvss"), ensure_ascii=False) if meta.get("cvss") is not None else None
|
||||
refs = json.dumps(meta.get("references"), ensure_ascii=False) if meta.get("references") is not None else None
|
||||
aff = json.dumps(meta.get("affected"), ensure_ascii=False) if meta.get("affected") is not None else None
|
||||
exps = json.dumps(meta.get("exploits"), ensure_ascii=False) if meta.get("exploits") is not None else None
|
||||
|
||||
self.base.execute("""
|
||||
INSERT INTO cve_meta(
|
||||
cve_id, description, cvss_json, references_json, last_modified,
|
||||
affected_json, solution, exploits_json, is_kev, epss, epss_percentile, updated_at
|
||||
) VALUES(?,?,?,?,?,?,?,?,?,?,?,?)
|
||||
ON CONFLICT(cve_id) DO UPDATE SET
|
||||
description = excluded.description,
|
||||
cvss_json = excluded.cvss_json,
|
||||
references_json = excluded.references_json,
|
||||
last_modified = excluded.last_modified,
|
||||
affected_json = excluded.affected_json,
|
||||
solution = excluded.solution,
|
||||
exploits_json = excluded.exploits_json,
|
||||
is_kev = excluded.is_kev,
|
||||
epss = excluded.epss,
|
||||
epss_percentile = excluded.epss_percentile,
|
||||
updated_at = excluded.updated_at;
|
||||
""", (
|
||||
meta.get("cve_id"),
|
||||
meta.get("description"),
|
||||
cvss, refs, meta.get("lastModified"),
|
||||
aff, meta.get("solution"), exps,
|
||||
1 if meta.get("is_kev") else 0,
|
||||
meta.get("epss"),
|
||||
meta.get("epss_percentile"),
|
||||
int(meta.get("updated_at") or time.time())
|
||||
))
|
||||
|
||||
def get_cve_meta_bulk(self, cve_ids: List[str]) -> Dict[str, Dict[str, Any]]:
|
||||
"""Get multiple CVE metadata entries at once"""
|
||||
if not cve_ids:
|
||||
return {}
|
||||
placeholders = ",".join("?" for _ in cve_ids)
|
||||
rows = self.base.query(f"SELECT * FROM cve_meta WHERE cve_id IN ({placeholders});", tuple(cve_ids))
|
||||
out = {}
|
||||
for r in rows:
|
||||
for k in ("cvss_json","references_json","affected_json","exploits_json"):
|
||||
if r.get(k):
|
||||
try:
|
||||
r[k] = json.loads(r[k])
|
||||
except Exception:
|
||||
r[k] = None
|
||||
out[r["cve_id"]] = r
|
||||
return out
|
||||
|
||||
# =========================================================================
|
||||
# VULNERABILITY CRUD OPERATIONS
|
||||
# =========================================================================
|
||||
|
||||
def add_vulnerability(self, mac_address: str, vuln_id: str, ip: Optional[str] = None,
|
||||
hostname: Optional[str] = None, port: Optional[int] = None):
|
||||
"""Insert/reactivate a vulnerability row and record history (NULL-safe on port)"""
|
||||
self.base.invalidate_stats_cache()
|
||||
p = int(port or 0)
|
||||
|
||||
try:
|
||||
# Try to update existing row
|
||||
updated = self.base.execute(
|
||||
"""
|
||||
UPDATE vulnerabilities
|
||||
SET is_active = 1,
|
||||
ip = COALESCE(?, ip),
|
||||
hostname = COALESCE(?, hostname),
|
||||
last_seen = CURRENT_TIMESTAMP
|
||||
WHERE mac_address = ? AND vuln_id = ? AND COALESCE(port, 0) = ?
|
||||
""",
|
||||
(ip, hostname, mac_address, vuln_id, p)
|
||||
)
|
||||
|
||||
if updated and updated > 0:
|
||||
# Seen again
|
||||
self.base.execute(
|
||||
"""
|
||||
INSERT INTO vulnerability_history(mac_address, ip, hostname, port, vuln_id, event)
|
||||
VALUES(?,?,?,?,?,'seen')
|
||||
""",
|
||||
(mac_address, ip, hostname, p, vuln_id)
|
||||
)
|
||||
return
|
||||
|
||||
# Insert new row (port=0 if unknown)
|
||||
self.base.execute(
|
||||
"""
|
||||
INSERT INTO vulnerabilities(mac_address, ip, hostname, port, vuln_id, is_active)
|
||||
VALUES(?,?,?,?,?,1)
|
||||
""",
|
||||
(mac_address, ip, hostname, p, vuln_id)
|
||||
)
|
||||
self.base.execute(
|
||||
"""
|
||||
INSERT INTO vulnerability_history(mac_address, ip, hostname, port, vuln_id, event)
|
||||
VALUES(?,?,?,?,?,'new')
|
||||
""",
|
||||
(mac_address, ip, hostname, p, vuln_id)
|
||||
)
|
||||
|
||||
except Exception:
|
||||
# Fallback if the query fails for exotic reason
|
||||
row = self.base.query_one(
|
||||
"""
|
||||
SELECT id FROM vulnerabilities
|
||||
WHERE mac_address=? AND vuln_id=? AND COALESCE(port,0)=?
|
||||
LIMIT 1
|
||||
""",
|
||||
(mac_address, vuln_id, p)
|
||||
)
|
||||
if row:
|
||||
self.base.execute(
|
||||
"""
|
||||
UPDATE vulnerabilities
|
||||
SET is_active=1,
|
||||
ip=COALESCE(?, ip),
|
||||
hostname=COALESCE(?, hostname),
|
||||
last_seen=CURRENT_TIMESTAMP
|
||||
WHERE id=?
|
||||
""",
|
||||
(ip, hostname, row["id"])
|
||||
)
|
||||
self.base.execute(
|
||||
"""
|
||||
INSERT INTO vulnerability_history(mac_address, ip, hostname, port, vuln_id, event)
|
||||
VALUES(?,?,?,?,?,'seen')
|
||||
""",
|
||||
(mac_address, ip, hostname, p, vuln_id)
|
||||
)
|
||||
else:
|
||||
self.base.execute(
|
||||
"""
|
||||
INSERT INTO vulnerabilities(mac_address, ip, hostname, port, vuln_id, is_active)
|
||||
VALUES(?,?,?,?,?,1)
|
||||
""",
|
||||
(mac_address, ip, hostname, p, vuln_id)
|
||||
)
|
||||
self.base.execute(
|
||||
"""
|
||||
INSERT INTO vulnerability_history(mac_address, ip, hostname, port, vuln_id, event)
|
||||
VALUES(?,?,?,?,?,'new')
|
||||
""",
|
||||
(mac_address, ip, hostname, p, vuln_id)
|
||||
)
|
||||
|
||||
def update_vulnerability_status(self, mac_address: str, current_vulns: List[str]):
|
||||
"""Update vulnerability presence (new/seen/inactive) and touch timestamps/history"""
|
||||
existing = self.base.query(
|
||||
"SELECT vuln_id FROM vulnerabilities WHERE mac_address=? AND is_active=1",
|
||||
(mac_address,)
|
||||
)
|
||||
existing_ids = {r['vuln_id'] for r in existing}
|
||||
current_set = set(current_vulns)
|
||||
|
||||
# Mark inactive
|
||||
for vuln_id in (existing_ids - current_set):
|
||||
self.base.execute("""
|
||||
UPDATE vulnerabilities
|
||||
SET is_active=0, last_seen=CURRENT_TIMESTAMP
|
||||
WHERE mac_address=? AND vuln_id=? AND is_active=1
|
||||
""", (mac_address, vuln_id))
|
||||
|
||||
self.base.execute("""
|
||||
INSERT INTO vulnerability_history(mac_address, port, vuln_id, event)
|
||||
SELECT mac_address, port, vuln_id, 'inactive'
|
||||
FROM vulnerabilities
|
||||
WHERE mac_address=? AND vuln_id=? LIMIT 1
|
||||
""", (mac_address, vuln_id))
|
||||
|
||||
# Add new
|
||||
for vuln_id in (current_set - existing_ids):
|
||||
self.add_vulnerability(mac_address, vuln_id)
|
||||
|
||||
# Seen: refresh last_seen and record history
|
||||
for vuln_id in (current_set & existing_ids):
|
||||
self.base.execute("""
|
||||
UPDATE vulnerabilities
|
||||
SET last_seen=CURRENT_TIMESTAMP
|
||||
WHERE mac_address=? AND vuln_id=? AND is_active=1
|
||||
""", (mac_address, vuln_id))
|
||||
|
||||
self.base.execute("""
|
||||
INSERT INTO vulnerability_history(mac_address, port, vuln_id, event)
|
||||
SELECT mac_address, port, vuln_id, 'seen'
|
||||
FROM vulnerabilities
|
||||
WHERE mac_address=? AND vuln_id=? LIMIT 1
|
||||
""", (mac_address, vuln_id))
|
||||
|
||||
def update_vulnerability_status_by_port(self, mac_address: str, port: int, current_vulns: List[str]):
|
||||
"""Update vulnerability status for a specific port to avoid NULL conflicts"""
|
||||
port = int(port) if port is not None else 0
|
||||
|
||||
existing = self.base.query(
|
||||
"SELECT vuln_id FROM vulnerabilities WHERE mac_address=? AND COALESCE(port, 0)=? AND is_active=1",
|
||||
(mac_address, port)
|
||||
)
|
||||
existing_ids = {r['vuln_id'] for r in existing}
|
||||
current_set = set(current_vulns)
|
||||
|
||||
# Mark inactive (for this specific port)
|
||||
for vuln_id in (existing_ids - current_set):
|
||||
self.base.execute("""
|
||||
UPDATE vulnerabilities
|
||||
SET is_active=0, last_seen=CURRENT_TIMESTAMP
|
||||
WHERE mac_address=? AND vuln_id=? AND COALESCE(port, 0)=? AND is_active=1
|
||||
""", (mac_address, vuln_id, port))
|
||||
|
||||
self.base.execute("""
|
||||
INSERT INTO vulnerability_history(mac_address, ip, hostname, port, vuln_id, event)
|
||||
VALUES (?, NULL, NULL, ?, ?, 'inactive')
|
||||
""", (mac_address, port, vuln_id))
|
||||
|
||||
# Add new (calls your existing method with the port)
|
||||
for vuln_id in (current_set - existing_ids):
|
||||
self.add_vulnerability(mac_address, vuln_id, port=port)
|
||||
|
||||
# Mark as seen (for this specific port)
|
||||
for vuln_id in (current_set & existing_ids):
|
||||
self.base.execute("""
|
||||
UPDATE vulnerabilities
|
||||
SET last_seen=CURRENT_TIMESTAMP
|
||||
WHERE mac_address=? AND vuln_id=? AND COALESCE(port, 0)=? AND is_active=1
|
||||
""", (mac_address, vuln_id, port))
|
||||
|
||||
self.base.execute("""
|
||||
INSERT INTO vulnerability_history(mac_address, ip, hostname, port, vuln_id, event)
|
||||
VALUES (?, NULL, NULL, ?, ?, 'seen')
|
||||
""", (mac_address, port, vuln_id))
|
||||
|
||||
def save_vulnerabilities(self, mac: str, ip: str, findings: List[Dict]):
|
||||
"""Separate CPE and CVE, update statuses + record new findings"""
|
||||
# Group findings by port to avoid conflicts
|
||||
findings_by_port = {}
|
||||
for f in findings:
|
||||
port = f.get('port', 0)
|
||||
if port is None:
|
||||
port = 0
|
||||
port = int(port) if port != 0 else 0
|
||||
|
||||
if port not in findings_by_port:
|
||||
findings_by_port[port] = {'cves': set(), 'cpes': set(), 'findings': []}
|
||||
|
||||
findings_by_port[port]['findings'].append(f)
|
||||
|
||||
vid = str(f.get('vuln_id', ''))
|
||||
if vid.upper().startswith('CVE-'):
|
||||
findings_by_port[port]['cves'].add(vid)
|
||||
elif vid.upper().startswith('CPE:'):
|
||||
findings_by_port[port]['cpes'].add(vid.split(':', 1)[1])
|
||||
elif vid.lower().startswith('cpe:'):
|
||||
findings_by_port[port]['cpes'].add(vid)
|
||||
|
||||
# Process CVE by port to avoid conflicts
|
||||
all_cve_ids = set()
|
||||
for port, data in findings_by_port.items():
|
||||
if data['cves']:
|
||||
try:
|
||||
self.update_vulnerability_status_by_port(mac, port, sorted(data['cves']))
|
||||
all_cve_ids.update(data['cves'])
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update CVE status for port {port}: {e}")
|
||||
|
||||
# Process CPE globally (as before) - delegated to SoftwareOps
|
||||
all_cpe_vals = set()
|
||||
for port, data in findings_by_port.items():
|
||||
all_cpe_vals.update(data['cpes'])
|
||||
|
||||
# Note: CPE handling would typically be done by SoftwareOps
|
||||
# but we keep the call here for compatibility
|
||||
|
||||
logger.debug(f"Processed: {len(all_cve_ids)} CVE across {len(findings_by_port)} ports, {len(all_cpe_vals)} CPE for {mac}")
|
||||
|
||||
# =========================================================================
|
||||
# VULNERABILITY QUERY OPERATIONS
|
||||
# =========================================================================
|
||||
|
||||
def get_all_vulns(self) -> List[Dict[str, Any]]:
|
||||
"""Get all vulnerabilities with host details"""
|
||||
return self.base.query("""
|
||||
SELECT v.id, v.mac_address, v.ip, v.hostname, v.port, v.vuln_id, v.is_active, v.first_seen, v.last_seen,
|
||||
h.ips AS host_ips, h.hostnames AS host_hostnames, h.ports AS host_ports, h.vendor AS host_vendor
|
||||
FROM vulnerabilities v
|
||||
LEFT JOIN hosts h ON v.mac_address = h.mac_address
|
||||
ORDER BY v.mac_address, v.vuln_id;
|
||||
""")
|
||||
|
||||
def count_vulnerabilities_alive(self, distinct: bool = False, active_only: bool = True) -> int:
|
||||
"""Count vulnerabilities for hosts with alive=1"""
|
||||
where = ["h.alive = 1"]
|
||||
if active_only:
|
||||
where.append("v.is_active = 1")
|
||||
where_sql = " AND ".join(where)
|
||||
|
||||
if distinct:
|
||||
sql = f"""
|
||||
SELECT COUNT(DISTINCT v.vuln_id) AS c
|
||||
FROM vulnerabilities v
|
||||
JOIN hosts h ON h.mac_address = v.mac_address
|
||||
WHERE {where_sql}
|
||||
"""
|
||||
else:
|
||||
sql = f"""
|
||||
SELECT COUNT(*) AS c
|
||||
FROM vulnerabilities v
|
||||
JOIN hosts h ON h.mac_address = v.mac_address
|
||||
WHERE {where_sql}
|
||||
"""
|
||||
row = self.base.query(sql)
|
||||
return int(row[0]["c"]) if row else 0
|
||||
|
||||
def count_distinct_vulnerabilities(self, alive_only: bool = False) -> int:
|
||||
"""Return the number of distinct vulnerabilities (vuln_id)"""
|
||||
if alive_only:
|
||||
row = self.base.query("""
|
||||
SELECT COUNT(DISTINCT v.vuln_id) AS c
|
||||
FROM vulnerabilities v
|
||||
JOIN hosts h ON h.mac_address = v.mac_address
|
||||
WHERE h.alive = 1
|
||||
""")
|
||||
else:
|
||||
row = self.base.query("SELECT COUNT(DISTINCT vuln_id) AS c FROM vulnerabilities")
|
||||
return int(row[0]["c"]) if row else 0
|
||||
|
||||
def get_vulnerabilities_for_alive_hosts(self) -> List[str]:
|
||||
"""Return a list of distinct vuln_id affecting hosts currently marked alive=1"""
|
||||
rows = self.base.query("""
|
||||
SELECT DISTINCT v.vuln_id
|
||||
FROM vulnerabilities v
|
||||
JOIN hosts h ON h.mac_address = v.mac_address
|
||||
WHERE h.alive = 1
|
||||
""")
|
||||
return [r["vuln_id"] for r in rows]
|
||||
|
||||
def list_vulnerability_history(self, cve_id: str | None = None,
|
||||
mac: str | None = None, limit: int = 500) -> list[dict]:
|
||||
"""Return vulnerability history (events) sorted most recent first"""
|
||||
where = []
|
||||
params: list = []
|
||||
if cve_id:
|
||||
where.append("vuln_id = ?")
|
||||
params.append(cve_id)
|
||||
if mac:
|
||||
where.append("mac_address = ?")
|
||||
params.append(mac)
|
||||
where_sql = ("WHERE " + " AND ".join(where)) if where else ""
|
||||
params.append(int(limit))
|
||||
|
||||
return self.base.query(f"""
|
||||
SELECT mac_address, ip, hostname, port, vuln_id, event, seen_at
|
||||
FROM vulnerability_history
|
||||
{where_sql}
|
||||
ORDER BY datetime(seen_at) DESC
|
||||
LIMIT ?
|
||||
""", tuple(params))
|
||||
|
||||
# =========================================================================
|
||||
# CLEANUP OPERATIONS
|
||||
# =========================================================================
|
||||
|
||||
def cleanup_vulnerability_duplicates(self):
|
||||
"""Clean up vulnerability duplicates"""
|
||||
self.base.invalidate_stats_cache()
|
||||
|
||||
# Delete entries with port NULL if an entry with port=0 exists
|
||||
self.base.execute("""
|
||||
DELETE FROM vulnerabilities
|
||||
WHERE port IS NULL
|
||||
AND EXISTS (
|
||||
SELECT 1 FROM vulnerabilities v2
|
||||
WHERE v2.mac_address = vulnerabilities.mac_address
|
||||
AND v2.vuln_id = vulnerabilities.vuln_id
|
||||
AND v2.port = 0
|
||||
)
|
||||
""")
|
||||
|
||||
# Update remaining NULL ports to 0
|
||||
self.base.execute("""
|
||||
UPDATE vulnerabilities SET port = 0 WHERE port IS NULL
|
||||
""")
|
||||
|
||||
# Delete true duplicates (same mac, vuln_id, port) - keep most recent
|
||||
self.base.execute("""
|
||||
DELETE FROM vulnerabilities
|
||||
WHERE rowid NOT IN (
|
||||
SELECT MAX(rowid)
|
||||
FROM vulnerabilities
|
||||
GROUP BY mac_address, vuln_id, COALESCE(port, 0)
|
||||
)
|
||||
""")
|
||||
|
||||
def fix_vulnerability_history_nulls(self):
|
||||
"""Fix history entries with problematic NULL values"""
|
||||
# Update history where ports are NULL but should be 0
|
||||
self.base.execute("""
|
||||
UPDATE vulnerability_history
|
||||
SET port = 0
|
||||
WHERE port IS NULL
|
||||
AND EXISTS (
|
||||
SELECT 1 FROM vulnerabilities v
|
||||
WHERE v.mac_address = vulnerability_history.mac_address
|
||||
AND v.vuln_id = vulnerability_history.vuln_id
|
||||
AND v.port = 0
|
||||
)
|
||||
""")
|
||||
|
||||
# For cases where we can't determine the port, use 0 by default
|
||||
self.base.execute("""
|
||||
UPDATE vulnerability_history
|
||||
SET port = 0
|
||||
WHERE port IS NULL
|
||||
""")
|
||||
162
db_utils/webenum.py
Normal file
162
db_utils/webenum.py
Normal file
@@ -0,0 +1,162 @@
|
||||
# db_utils/webenum.py
|
||||
# Web enumeration (directory/file discovery) operations
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
import logging
|
||||
|
||||
from logger import Logger
|
||||
|
||||
logger = Logger(name="db_utils.webenum", level=logging.DEBUG)
|
||||
|
||||
|
||||
class WebEnumOps:
|
||||
"""Web directory and file enumeration tracking operations"""
|
||||
|
||||
def __init__(self, base):
|
||||
self.base = base
|
||||
|
||||
def create_tables(self):
|
||||
"""Create web enumeration table"""
|
||||
self.base.execute("""
|
||||
CREATE TABLE IF NOT EXISTS webenum (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
mac_address TEXT NOT NULL,
|
||||
ip TEXT NOT NULL,
|
||||
hostname TEXT,
|
||||
port INTEGER NOT NULL,
|
||||
directory TEXT NOT NULL,
|
||||
status INTEGER NOT NULL,
|
||||
size INTEGER DEFAULT 0,
|
||||
response_time INTEGER DEFAULT 0,
|
||||
content_type TEXT,
|
||||
scan_date TEXT DEFAULT CURRENT_TIMESTAMP,
|
||||
tool TEXT DEFAULT 'gobuster',
|
||||
method TEXT DEFAULT 'GET',
|
||||
user_agent TEXT,
|
||||
headers TEXT,
|
||||
first_seen TEXT DEFAULT CURRENT_TIMESTAMP,
|
||||
last_seen TEXT DEFAULT CURRENT_TIMESTAMP,
|
||||
is_active INTEGER DEFAULT 1,
|
||||
UNIQUE(mac_address, ip, port, directory)
|
||||
);
|
||||
""")
|
||||
|
||||
# Indexes for frequent queries
|
||||
self.base.execute("CREATE INDEX IF NOT EXISTS idx_webenum_host_port ON webenum(mac_address, port);")
|
||||
self.base.execute("CREATE INDEX IF NOT EXISTS idx_webenum_status ON webenum(status);")
|
||||
self.base.execute("CREATE INDEX IF NOT EXISTS idx_webenum_scan_date ON webenum(scan_date);")
|
||||
self.base.execute("CREATE INDEX IF NOT EXISTS idx_webenum_active ON webenum(is_active) WHERE is_active=1;")
|
||||
|
||||
logger.debug("WebEnum table created/verified")
|
||||
|
||||
# =========================================================================
|
||||
# WEB ENUMERATION OPERATIONS
|
||||
# =========================================================================
|
||||
|
||||
def add_webenum_result(
|
||||
self,
|
||||
mac_address: str,
|
||||
ip: str,
|
||||
port: int,
|
||||
directory: str,
|
||||
status: int,
|
||||
*,
|
||||
hostname: Optional[str] = None,
|
||||
size: int = 0,
|
||||
response_time: int = 0,
|
||||
content_type: Optional[str] = None,
|
||||
tool: str = "gobuster",
|
||||
method: str = "GET",
|
||||
user_agent: Optional[str] = None,
|
||||
headers: Optional[str] = None
|
||||
):
|
||||
"""Add or update a web enumeration result"""
|
||||
self.base.execute("""
|
||||
INSERT INTO webenum (
|
||||
mac_address, ip, hostname, port, directory, status,
|
||||
size, response_time, content_type, tool, method,
|
||||
user_agent, headers, is_active
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, 1)
|
||||
ON CONFLICT(mac_address, ip, port, directory) DO UPDATE SET
|
||||
status = excluded.status,
|
||||
size = excluded.size,
|
||||
response_time = excluded.response_time,
|
||||
content_type = excluded.content_type,
|
||||
hostname = COALESCE(excluded.hostname, webenum.hostname),
|
||||
last_seen = CURRENT_TIMESTAMP,
|
||||
is_active = 1
|
||||
""", (
|
||||
mac_address, ip, hostname, port, directory, status,
|
||||
size, response_time, content_type, tool, method,
|
||||
user_agent, headers
|
||||
))
|
||||
|
||||
def get_webenum_for_host(self, mac_address: str, port: Optional[int] = None) -> List[Dict[str, Any]]:
|
||||
"""Get all web enumeration results for a host (optionally filtered by port)"""
|
||||
if port is not None:
|
||||
return self.base.query("""
|
||||
SELECT * FROM webenum
|
||||
WHERE mac_address = ? AND port = ? AND is_active = 1
|
||||
ORDER BY status, directory
|
||||
""", (mac_address, port))
|
||||
else:
|
||||
return self.base.query("""
|
||||
SELECT * FROM webenum
|
||||
WHERE mac_address = ? AND is_active = 1
|
||||
ORDER BY port, status, directory
|
||||
""", (mac_address,))
|
||||
|
||||
def get_webenum_by_status(self, status: int) -> List[Dict[str, Any]]:
|
||||
"""Get all enumeration results with a specific HTTP status code"""
|
||||
return self.base.query("""
|
||||
SELECT * FROM webenum
|
||||
WHERE status = ? AND is_active = 1
|
||||
ORDER BY mac_address, port, directory
|
||||
""", (status,))
|
||||
|
||||
def mark_webenum_inactive(self, mac_address: str, port: int, directories: List[str]):
|
||||
"""Mark enumeration results as inactive (e.g., after a rescan)"""
|
||||
if not directories:
|
||||
return
|
||||
|
||||
placeholders = ",".join("?" for _ in directories)
|
||||
self.base.execute(f"""
|
||||
UPDATE webenum
|
||||
SET is_active = 0, last_seen = CURRENT_TIMESTAMP
|
||||
WHERE mac_address = ? AND port = ? AND directory IN ({placeholders})
|
||||
""", (mac_address, port, *directories))
|
||||
|
||||
def delete_webenum_for_host(self, mac_address: str, port: Optional[int] = None):
|
||||
"""Delete all enumeration results for a host (optionally filtered by port)"""
|
||||
if port is not None:
|
||||
self.base.execute("""
|
||||
DELETE FROM webenum
|
||||
WHERE mac_address = ? AND port = ?
|
||||
""", (mac_address, port))
|
||||
else:
|
||||
self.base.execute("""
|
||||
DELETE FROM webenum
|
||||
WHERE mac_address = ?
|
||||
""", (mac_address,))
|
||||
|
||||
def count_webenum_results(self, mac_address: Optional[str] = None,
|
||||
active_only: bool = True) -> int:
|
||||
"""Count enumeration results (optionally for a specific host and/or active only)"""
|
||||
where_clauses = []
|
||||
params = []
|
||||
|
||||
if mac_address:
|
||||
where_clauses.append("mac_address = ?")
|
||||
params.append(mac_address)
|
||||
|
||||
if active_only:
|
||||
where_clauses.append("is_active = 1")
|
||||
|
||||
where_sql = " AND ".join(where_clauses) if where_clauses else "1=1"
|
||||
|
||||
row = self.base.query_one(f"""
|
||||
SELECT COUNT(*) as cnt FROM webenum
|
||||
WHERE {where_sql}
|
||||
""", tuple(params))
|
||||
|
||||
return int(row["cnt"]) if row else 0
|
||||
Reference in New Issue
Block a user