synapse_pingora/persistence/
mod.rs1use std::fs;
10use std::io;
11use std::path::{Path, PathBuf};
12use std::sync::Arc;
13use std::time::{Duration, SystemTime, UNIX_EPOCH};
14
15use log::{debug, error, info, warn};
16use serde::{Deserialize, Serialize};
17use tokio::time;
18
19use crate::actor::ActorState;
20use crate::correlation::Campaign;
21use crate::detection::StuffingState;
22use crate::entity::EntityState;
23use crate::profiler::EndpointProfile;
24
25#[derive(Debug, Clone)]
27pub struct PersistenceConfig {
28 pub data_dir: PathBuf,
30 pub save_interval_secs: u64,
32 pub load_on_startup: bool,
34 pub enabled: bool,
36}
37
38impl Default for PersistenceConfig {
39 fn default() -> Self {
40 Self {
41 data_dir: PathBuf::from("./data"),
42 save_interval_secs: 60,
43 load_on_startup: true,
44 enabled: true,
45 }
46 }
47}
48
49const SNAPSHOT_VERSION: u32 = 1;
52
53#[derive(Debug, Clone, Serialize, Deserialize)]
57pub struct WafSnapshot {
58 pub version: u32,
60 pub saved_at: u64,
62 pub instance_id: String,
64 pub entities: Vec<EntityState>,
66 pub campaigns: Vec<Campaign>,
68 pub actors: Vec<ActorState>,
70 pub profiles: Vec<EndpointProfile>,
72 #[serde(default)]
74 pub credential_stuffing: Option<StuffingState>,
75}
76
77impl WafSnapshot {
78 pub fn new(
80 instance_id: String,
81 entities: Vec<EntityState>,
82 campaigns: Vec<Campaign>,
83 actors: Vec<ActorState>,
84 profiles: Vec<EndpointProfile>,
85 ) -> Self {
86 let saved_at = SystemTime::now()
87 .duration_since(UNIX_EPOCH)
88 .unwrap_or_default()
89 .as_millis() as u64;
90
91 Self {
92 version: SNAPSHOT_VERSION,
93 saved_at,
94 instance_id,
95 entities,
96 campaigns,
97 actors,
98 profiles,
99 credential_stuffing: None,
100 }
101 }
102
103 pub fn with_credential_stuffing(
105 instance_id: String,
106 entities: Vec<EntityState>,
107 campaigns: Vec<Campaign>,
108 actors: Vec<ActorState>,
109 profiles: Vec<EndpointProfile>,
110 credential_stuffing: StuffingState,
111 ) -> Self {
112 let saved_at = SystemTime::now()
113 .duration_since(UNIX_EPOCH)
114 .unwrap_or_default()
115 .as_millis() as u64;
116
117 Self {
118 version: SNAPSHOT_VERSION,
119 saved_at,
120 instance_id,
121 entities,
122 campaigns,
123 actors,
124 profiles,
125 credential_stuffing: Some(credential_stuffing),
126 }
127 }
128
129 pub fn is_empty(&self) -> bool {
131 self.entities.is_empty()
132 && self.campaigns.is_empty()
133 && self.actors.is_empty()
134 && self.profiles.is_empty()
135 && self.credential_stuffing.as_ref().is_none_or(|s| {
136 s.entity_metrics.is_empty()
137 && s.distributed_attacks.is_empty()
138 && s.takeover_alerts.is_empty()
139 })
140 }
141
142 pub fn stats(&self) -> SnapshotStats {
144 let (auth_entities, distributed_attacks, takeover_alerts) =
145 self.credential_stuffing.as_ref().map_or((0, 0, 0), |s| {
146 (
147 s.entity_metrics.len(),
148 s.distributed_attacks.len(),
149 s.takeover_alerts.len(),
150 )
151 });
152
153 SnapshotStats {
154 entities: self.entities.len(),
155 campaigns: self.campaigns.len(),
156 actors: self.actors.len(),
157 profiles: self.profiles.len(),
158 auth_entities,
159 distributed_attacks,
160 takeover_alerts,
161 }
162 }
163}
164
165#[derive(Debug, Clone)]
167pub struct SnapshotStats {
168 pub entities: usize,
169 pub campaigns: usize,
170 pub actors: usize,
171 pub profiles: usize,
172 pub auth_entities: usize,
174 pub distributed_attacks: usize,
176 pub takeover_alerts: usize,
178}
179
180impl std::fmt::Display for SnapshotStats {
181 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
182 write!(
183 f,
184 "{} entities, {} campaigns, {} actors, {} profiles, {} auth entities, {} attacks, {} takeovers",
185 self.entities, self.campaigns, self.actors, self.profiles,
186 self.auth_entities, self.distributed_attacks, self.takeover_alerts
187 )
188 }
189}
190
191pub struct SnapshotManager {
193 config: PersistenceConfig,
194}
195
196impl SnapshotManager {
197 pub fn new(config: PersistenceConfig) -> Self {
198 Self { config }
199 }
200
201 pub fn snapshot_path(&self) -> PathBuf {
203 self.config.data_dir.join("waf_state.json")
204 }
205
206 pub fn legacy_profiles_path(&self) -> PathBuf {
208 self.config.data_dir.join("profiles.json")
209 }
210
211 pub fn is_enabled(&self) -> bool {
213 self.config.enabled
214 }
215
216 pub fn start_background_saver<F>(self: Arc<Self>, fetch_snapshot: F) -> io::Result<()>
229 where
230 F: Fn() -> WafSnapshot + Send + Sync + 'static,
231 {
232 if !self.config.enabled {
233 info!("Persistence disabled, skipping background saver");
234 return Ok(());
235 }
236
237 let config = self.config.clone();
238 let log_interval = config.save_interval_secs;
239 let log_dir = config.data_dir.clone();
240
241 std::thread::Builder::new()
244 .name("persistence-saver".into())
245 .spawn(move || {
246 let rt = match tokio::runtime::Builder::new_current_thread()
247 .enable_all()
248 .build()
249 {
250 Ok(rt) => rt,
251 Err(e) => {
252 error!("Failed to create persistence runtime: {}", e);
253 return;
254 }
255 };
256
257 rt.block_on(async move {
258 let mut interval =
259 time::interval(Duration::from_secs(config.save_interval_secs));
260
261 if let Err(e) = tokio::fs::create_dir_all(&config.data_dir).await {
263 error!(
264 "Failed to create data directory {:?}: {}",
265 config.data_dir, e
266 );
267 return;
268 }
269
270 loop {
271 interval.tick().await;
272
273 let snapshot = fetch_snapshot();
274 if snapshot.is_empty() {
275 debug!("Snapshot empty, skipping save");
276 continue;
277 }
278
279 let path = config.data_dir.join("waf_state.json");
280 let path_clone = path.clone();
281 let stats = snapshot.stats();
282
283 let res = tokio::task::spawn_blocking(move || {
285 Self::save_snapshot(&snapshot, &path_clone)
286 })
287 .await;
288
289 match res {
290 Ok(Ok(_)) => info!("Saved WAF state to {:?} ({})", path, stats),
291 Ok(Err(e)) => error!("Failed to save WAF state: {}", e),
292 Err(e) => error!("Save task panicked: {}", e),
293 }
294 }
295 });
296 })?;
297
298 info!(
299 "Background persistence started (interval: {}s, dir: {:?})",
300 log_interval, log_dir
301 );
302
303 Ok(())
304 }
305
306 pub fn save_snapshot(snapshot: &WafSnapshot, path: &Path) -> io::Result<()> {
310 let json = serde_json::to_string_pretty(snapshot)
311 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
312
313 let tmp_path = path.with_extension("tmp");
315 fs::write(&tmp_path, json)?;
316 fs::rename(&tmp_path, path)?;
317 Ok(())
318 }
319
320 pub fn load_snapshot(path: &Path) -> io::Result<Option<WafSnapshot>> {
322 if !path.exists() {
323 return Ok(None);
324 }
325
326 let json = fs::read_to_string(path)?;
327 let snapshot: WafSnapshot = serde_json::from_str(&json)
328 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
329
330 if snapshot.version > SNAPSHOT_VERSION {
332 warn!(
333 "Snapshot version {} is newer than supported version {}",
334 snapshot.version, SNAPSHOT_VERSION
335 );
336 }
337
338 Ok(Some(snapshot))
339 }
340
341 pub fn load_on_startup(&self) -> io::Result<Option<WafSnapshot>> {
345 if !self.config.enabled || !self.config.load_on_startup {
346 return Ok(None);
347 }
348
349 let path = self.snapshot_path();
350 match Self::load_snapshot(&path) {
351 Ok(Some(snapshot)) => {
352 let age_secs = SystemTime::now()
353 .duration_since(UNIX_EPOCH)
354 .unwrap_or_default()
355 .as_millis() as u64
356 - snapshot.saved_at;
357 let age_mins = age_secs / 60_000;
358
359 info!(
360 "Loaded WAF state from {:?} ({}, age: {}m)",
361 path,
362 snapshot.stats(),
363 age_mins
364 );
365 Ok(Some(snapshot))
366 }
367 Ok(None) => {
368 info!("No existing WAF state found at {:?}", path);
369 Ok(None)
370 }
371 Err(e) => {
372 error!("Failed to load WAF state from {:?}: {}", path, e);
373 Err(e)
374 }
375 }
376 }
377
378 pub fn save_profiles(profiles: &[EndpointProfile], path: &Path) -> io::Result<()> {
383 let json = serde_json::to_string_pretty(profiles)
384 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
385
386 let tmp_path = path.with_extension("tmp");
387 fs::write(&tmp_path, json)?;
388 fs::rename(&tmp_path, path)?;
389 Ok(())
390 }
391
392 pub fn load_profiles(path: &Path) -> io::Result<Vec<EndpointProfile>> {
395 if !path.exists() {
396 return Ok(Vec::new());
397 }
398 let json = fs::read_to_string(path)?;
399 let profiles: Vec<EndpointProfile> = serde_json::from_str(&json)
400 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
401 Ok(profiles)
402 }
403}
404
405#[cfg(test)]
406mod tests {
407 use super::*;
408 use tempfile::TempDir;
409
410 #[test]
411 fn test_snapshot_roundtrip() {
412 let tmp = TempDir::new().unwrap();
413 let path = tmp.path().join("test_state.json");
414
415 let snapshot = WafSnapshot::new("test-sensor".to_string(), vec![], vec![], vec![], vec![]);
416
417 SnapshotManager::save_snapshot(&snapshot, &path).unwrap();
418 let loaded = SnapshotManager::load_snapshot(&path).unwrap().unwrap();
419
420 assert_eq!(loaded.version, SNAPSHOT_VERSION);
421 assert_eq!(loaded.instance_id, "test-sensor");
422 }
423
424 #[test]
425 fn test_empty_snapshot() {
426 let snapshot = WafSnapshot::new("test".to_string(), vec![], vec![], vec![], vec![]);
427 assert!(snapshot.is_empty());
428 }
429
430 #[test]
431 fn test_snapshot_persists_profiles() {
432 let tmp = TempDir::new().unwrap();
433 let path = tmp.path().join("test_profiles.json");
434
435 let mut profile = EndpointProfile::new("/api/users".to_string(), 1000);
436 profile.update(128, &[("name", "alice")], Some("application/json"), 2000);
437
438 let snapshot = WafSnapshot::new(
439 "test-sensor".to_string(),
440 vec![],
441 vec![],
442 vec![],
443 vec![profile.clone()],
444 );
445
446 SnapshotManager::save_snapshot(&snapshot, &path).unwrap();
447 let loaded = SnapshotManager::load_snapshot(&path).unwrap().unwrap();
448
449 assert_eq!(loaded.profiles.len(), 1);
450 assert_eq!(loaded.profiles[0].template, profile.template);
451 assert_eq!(loaded.profiles[0].sample_count, profile.sample_count);
452 }
453}