Skip to main content

synapse_pingora/persistence/
mod.rs

1//! Persistence module for saving and loading WAF state.
2//!
3//! Handles periodic snapshots of all WAF state:
4//! - Learned Endpoint Profiles (ProfileStore)
5//! - IP Reputation Data (EntityStore)
6//! - Campaign Correlations (CampaignManager)
7//! - Actor States (ActorManager)
8
9use std::fs;
10use std::io;
11use std::path::{Path, PathBuf};
12use std::sync::Arc;
13use std::time::{Duration, SystemTime, UNIX_EPOCH};
14
15use log::{debug, error, info, warn};
16use serde::{Deserialize, Serialize};
17use tokio::time;
18
19use crate::actor::ActorState;
20use crate::correlation::Campaign;
21use crate::detection::StuffingState;
22use crate::entity::EntityState;
23use crate::profiler::EndpointProfile;
24
25/// Configuration for persistence.
26#[derive(Debug, Clone)]
27pub struct PersistenceConfig {
28    /// Directory to store snapshots
29    pub data_dir: PathBuf,
30    /// Interval for saving snapshots (seconds)
31    pub save_interval_secs: u64,
32    /// Whether to load on startup
33    pub load_on_startup: bool,
34    /// Whether persistence is enabled
35    pub enabled: bool,
36}
37
38impl Default for PersistenceConfig {
39    fn default() -> Self {
40        Self {
41            data_dir: PathBuf::from("./data"),
42            save_interval_secs: 60,
43            load_on_startup: true,
44            enabled: true,
45        }
46    }
47}
48
49/// Current snapshot format version.
50/// Increment when making breaking changes to the snapshot structure.
51const SNAPSHOT_VERSION: u32 = 1;
52
53/// Unified WAF state snapshot for atomic persistence.
54///
55/// All state is saved together to ensure consistency across restarts.
56#[derive(Debug, Clone, Serialize, Deserialize)]
57pub struct WafSnapshot {
58    /// Snapshot format version for forward compatibility
59    pub version: u32,
60    /// Timestamp when snapshot was created (ms since epoch)
61    pub saved_at: u64,
62    /// Sensor instance ID
63    pub instance_id: String,
64    /// Entity states (IP reputation)
65    pub entities: Vec<EntityState>,
66    /// Campaign correlations
67    pub campaigns: Vec<Campaign>,
68    /// Actor states
69    pub actors: Vec<ActorState>,
70    /// Learned endpoint profiles
71    pub profiles: Vec<EndpointProfile>,
72    /// Credential stuffing detector state
73    #[serde(default)]
74    pub credential_stuffing: Option<StuffingState>,
75}
76
77impl WafSnapshot {
78    /// Create a new snapshot with the given state.
79    pub fn new(
80        instance_id: String,
81        entities: Vec<EntityState>,
82        campaigns: Vec<Campaign>,
83        actors: Vec<ActorState>,
84        profiles: Vec<EndpointProfile>,
85    ) -> Self {
86        let saved_at = SystemTime::now()
87            .duration_since(UNIX_EPOCH)
88            .unwrap_or_default()
89            .as_millis() as u64;
90
91        Self {
92            version: SNAPSHOT_VERSION,
93            saved_at,
94            instance_id,
95            entities,
96            campaigns,
97            actors,
98            profiles,
99            credential_stuffing: None,
100        }
101    }
102
103    /// Create a new snapshot with credential stuffing state.
104    pub fn with_credential_stuffing(
105        instance_id: String,
106        entities: Vec<EntityState>,
107        campaigns: Vec<Campaign>,
108        actors: Vec<ActorState>,
109        profiles: Vec<EndpointProfile>,
110        credential_stuffing: StuffingState,
111    ) -> Self {
112        let saved_at = SystemTime::now()
113            .duration_since(UNIX_EPOCH)
114            .unwrap_or_default()
115            .as_millis() as u64;
116
117        Self {
118            version: SNAPSHOT_VERSION,
119            saved_at,
120            instance_id,
121            entities,
122            campaigns,
123            actors,
124            profiles,
125            credential_stuffing: Some(credential_stuffing),
126        }
127    }
128
129    /// Check if this snapshot is empty (no state to persist).
130    pub fn is_empty(&self) -> bool {
131        self.entities.is_empty()
132            && self.campaigns.is_empty()
133            && self.actors.is_empty()
134            && self.profiles.is_empty()
135            && self.credential_stuffing.as_ref().is_none_or(|s| {
136                s.entity_metrics.is_empty()
137                    && s.distributed_attacks.is_empty()
138                    && s.takeover_alerts.is_empty()
139            })
140    }
141
142    /// Get summary stats for logging.
143    pub fn stats(&self) -> SnapshotStats {
144        let (auth_entities, distributed_attacks, takeover_alerts) =
145            self.credential_stuffing.as_ref().map_or((0, 0, 0), |s| {
146                (
147                    s.entity_metrics.len(),
148                    s.distributed_attacks.len(),
149                    s.takeover_alerts.len(),
150                )
151            });
152
153        SnapshotStats {
154            entities: self.entities.len(),
155            campaigns: self.campaigns.len(),
156            actors: self.actors.len(),
157            profiles: self.profiles.len(),
158            auth_entities,
159            distributed_attacks,
160            takeover_alerts,
161        }
162    }
163}
164
165/// Summary statistics for a snapshot.
166#[derive(Debug, Clone)]
167pub struct SnapshotStats {
168    pub entities: usize,
169    pub campaigns: usize,
170    pub actors: usize,
171    pub profiles: usize,
172    /// Credential stuffing: auth entity metrics
173    pub auth_entities: usize,
174    /// Credential stuffing: distributed attacks being tracked
175    pub distributed_attacks: usize,
176    /// Credential stuffing: takeover alerts
177    pub takeover_alerts: usize,
178}
179
180impl std::fmt::Display for SnapshotStats {
181    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
182        write!(
183            f,
184            "{} entities, {} campaigns, {} actors, {} profiles, {} auth entities, {} attacks, {} takeovers",
185            self.entities, self.campaigns, self.actors, self.profiles,
186            self.auth_entities, self.distributed_attacks, self.takeover_alerts
187        )
188    }
189}
190
191/// Manager for handling state snapshots.
192pub struct SnapshotManager {
193    config: PersistenceConfig,
194}
195
196impl SnapshotManager {
197    pub fn new(config: PersistenceConfig) -> Self {
198        Self { config }
199    }
200
201    /// Get the path to the snapshot file.
202    pub fn snapshot_path(&self) -> PathBuf {
203        self.config.data_dir.join("waf_state.json")
204    }
205
206    /// Get the path to the legacy profiles file (for migration).
207    pub fn legacy_profiles_path(&self) -> PathBuf {
208        self.config.data_dir.join("profiles.json")
209    }
210
211    /// Check if persistence is enabled.
212    pub fn is_enabled(&self) -> bool {
213        self.config.enabled
214    }
215
216    /// Start the background saver task with unified snapshot.
217    ///
218    /// # Arguments
219    /// * `fetch_snapshot` - A closure that returns the current WAF state snapshot.
220    ///
221    /// # Returns
222    /// `Ok(())` if the background saver started successfully, or an error if:
223    /// - Persistence is disabled (returns Ok with early return)
224    /// - Thread spawning failed
225    ///
226    /// # Errors
227    /// Returns `io::Error` if the background thread cannot be spawned.
228    pub fn start_background_saver<F>(self: Arc<Self>, fetch_snapshot: F) -> io::Result<()>
229    where
230        F: Fn() -> WafSnapshot + Send + Sync + 'static,
231    {
232        if !self.config.enabled {
233            info!("Persistence disabled, skipping background saver");
234            return Ok(());
235        }
236
237        let config = self.config.clone();
238        let log_interval = config.save_interval_secs;
239        let log_dir = config.data_dir.clone();
240
241        // Spawn a dedicated thread with its own tokio runtime
242        // This avoids requiring a pre-existing runtime context
243        std::thread::Builder::new()
244            .name("persistence-saver".into())
245            .spawn(move || {
246                let rt = match tokio::runtime::Builder::new_current_thread()
247                    .enable_all()
248                    .build()
249                {
250                    Ok(rt) => rt,
251                    Err(e) => {
252                        error!("Failed to create persistence runtime: {}", e);
253                        return;
254                    }
255                };
256
257                rt.block_on(async move {
258                    let mut interval =
259                        time::interval(Duration::from_secs(config.save_interval_secs));
260
261                    // Ensure data directory exists
262                    if let Err(e) = tokio::fs::create_dir_all(&config.data_dir).await {
263                        error!(
264                            "Failed to create data directory {:?}: {}",
265                            config.data_dir, e
266                        );
267                        return;
268                    }
269
270                    loop {
271                        interval.tick().await;
272
273                        let snapshot = fetch_snapshot();
274                        if snapshot.is_empty() {
275                            debug!("Snapshot empty, skipping save");
276                            continue;
277                        }
278
279                        let path = config.data_dir.join("waf_state.json");
280                        let path_clone = path.clone();
281                        let stats = snapshot.stats();
282
283                        // Offload CPU-intensive serialization and blocking I/O to a worker thread
284                        let res = tokio::task::spawn_blocking(move || {
285                            Self::save_snapshot(&snapshot, &path_clone)
286                        })
287                        .await;
288
289                        match res {
290                            Ok(Ok(_)) => info!("Saved WAF state to {:?} ({})", path, stats),
291                            Ok(Err(e)) => error!("Failed to save WAF state: {}", e),
292                            Err(e) => error!("Save task panicked: {}", e),
293                        }
294                    }
295                });
296            })?;
297
298        info!(
299            "Background persistence started (interval: {}s, dir: {:?})",
300            log_interval, log_dir
301        );
302
303        Ok(())
304    }
305
306    /// Save a unified snapshot to disk (Synchronous/Blocking).
307    ///
308    /// Uses atomic write (temp file + rename) to prevent corruption.
309    pub fn save_snapshot(snapshot: &WafSnapshot, path: &Path) -> io::Result<()> {
310        let json = serde_json::to_string_pretty(snapshot)
311            .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
312
313        // Write to temp file then rename for atomic write
314        let tmp_path = path.with_extension("tmp");
315        fs::write(&tmp_path, json)?;
316        fs::rename(&tmp_path, path)?;
317        Ok(())
318    }
319
320    /// Load a unified snapshot from disk (Synchronous/Blocking).
321    pub fn load_snapshot(path: &Path) -> io::Result<Option<WafSnapshot>> {
322        if !path.exists() {
323            return Ok(None);
324        }
325
326        let json = fs::read_to_string(path)?;
327        let snapshot: WafSnapshot = serde_json::from_str(&json)
328            .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
329
330        // Version check for future compatibility
331        if snapshot.version > SNAPSHOT_VERSION {
332            warn!(
333                "Snapshot version {} is newer than supported version {}",
334                snapshot.version, SNAPSHOT_VERSION
335            );
336        }
337
338        Ok(Some(snapshot))
339    }
340
341    /// Load snapshot on startup if configured.
342    ///
343    /// Returns the loaded snapshot or None if loading is disabled or file doesn't exist.
344    pub fn load_on_startup(&self) -> io::Result<Option<WafSnapshot>> {
345        if !self.config.enabled || !self.config.load_on_startup {
346            return Ok(None);
347        }
348
349        let path = self.snapshot_path();
350        match Self::load_snapshot(&path) {
351            Ok(Some(snapshot)) => {
352                let age_secs = SystemTime::now()
353                    .duration_since(UNIX_EPOCH)
354                    .unwrap_or_default()
355                    .as_millis() as u64
356                    - snapshot.saved_at;
357                let age_mins = age_secs / 60_000;
358
359                info!(
360                    "Loaded WAF state from {:?} ({}, age: {}m)",
361                    path,
362                    snapshot.stats(),
363                    age_mins
364                );
365                Ok(Some(snapshot))
366            }
367            Ok(None) => {
368                info!("No existing WAF state found at {:?}", path);
369                Ok(None)
370            }
371            Err(e) => {
372                error!("Failed to load WAF state from {:?}: {}", path, e);
373                Err(e)
374            }
375        }
376    }
377
378    // ========== Legacy Methods (for backwards compatibility) ==========
379
380    /// Save profiles to disk (Synchronous/Blocking).
381    /// @deprecated Use save_snapshot instead.
382    pub fn save_profiles(profiles: &[EndpointProfile], path: &Path) -> io::Result<()> {
383        let json = serde_json::to_string_pretty(profiles)
384            .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
385
386        let tmp_path = path.with_extension("tmp");
387        fs::write(&tmp_path, json)?;
388        fs::rename(&tmp_path, path)?;
389        Ok(())
390    }
391
392    /// Load profiles from disk (Synchronous/Blocking).
393    /// @deprecated Use load_snapshot instead.
394    pub fn load_profiles(path: &Path) -> io::Result<Vec<EndpointProfile>> {
395        if !path.exists() {
396            return Ok(Vec::new());
397        }
398        let json = fs::read_to_string(path)?;
399        let profiles: Vec<EndpointProfile> = serde_json::from_str(&json)
400            .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
401        Ok(profiles)
402    }
403}
404
405#[cfg(test)]
406mod tests {
407    use super::*;
408    use tempfile::TempDir;
409
410    #[test]
411    fn test_snapshot_roundtrip() {
412        let tmp = TempDir::new().unwrap();
413        let path = tmp.path().join("test_state.json");
414
415        let snapshot = WafSnapshot::new("test-sensor".to_string(), vec![], vec![], vec![], vec![]);
416
417        SnapshotManager::save_snapshot(&snapshot, &path).unwrap();
418        let loaded = SnapshotManager::load_snapshot(&path).unwrap().unwrap();
419
420        assert_eq!(loaded.version, SNAPSHOT_VERSION);
421        assert_eq!(loaded.instance_id, "test-sensor");
422    }
423
424    #[test]
425    fn test_empty_snapshot() {
426        let snapshot = WafSnapshot::new("test".to_string(), vec![], vec![], vec![], vec![]);
427        assert!(snapshot.is_empty());
428    }
429
430    #[test]
431    fn test_snapshot_persists_profiles() {
432        let tmp = TempDir::new().unwrap();
433        let path = tmp.path().join("test_profiles.json");
434
435        let mut profile = EndpointProfile::new("/api/users".to_string(), 1000);
436        profile.update(128, &[("name", "alice")], Some("application/json"), 2000);
437
438        let snapshot = WafSnapshot::new(
439            "test-sensor".to_string(),
440            vec![],
441            vec![],
442            vec![],
443            vec![profile.clone()],
444        );
445
446        SnapshotManager::save_snapshot(&snapshot, &path).unwrap();
447        let loaded = SnapshotManager::load_snapshot(&path).unwrap().unwrap();
448
449        assert_eq!(loaded.profiles.len(), 1);
450        assert_eq!(loaded.profiles[0].template, profile.template);
451        assert_eq!(loaded.profiles[0].sample_count, profile.sample_count);
452    }
453}