Skip to main content

wafrift_strategy/
learning_cache.rs

1//! Learning cache — persistent per-WAF, per-payload-type pipeline memory.
2//!
3//! After a successful bypass, the winning pipeline is cached to disk
4//! and re-used on subsequent scans of the same WAF + payload type.
5
6use crate::pipeline::EvasionPipeline;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::fs;
10use std::path::{Path, PathBuf};
11
12/// Cache key: WAF fingerprint + payload type.
13#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
14pub struct CacheKey {
15    pub waf_fingerprint: String,
16    pub payload_type: String,
17}
18
19impl CacheKey {
20    #[must_use]
21    pub fn new(waf: impl Into<String>, payload: impl Into<String>) -> Self {
22        Self {
23            waf_fingerprint: waf.into(),
24            payload_type: payload.into(),
25        }
26    }
27}
28
29/// A single cached entry: the winning pipeline and its success stats.
30#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
31pub struct CacheEntry {
32    pub pipeline: EvasionPipeline,
33    pub successes: u32,
34    pub attempts: u32,
35    pub last_success_epoch: u64,
36}
37
38impl CacheEntry {
39    #[must_use]
40    pub fn success_rate(&self) -> f64 {
41        if self.attempts == 0 {
42            0.0
43        } else {
44            f64::from(self.successes) / f64::from(self.attempts)
45        }
46    }
47}
48
49/// On-disk learning cache.
50///
51/// Keys are JSON-serialized [`CacheKey`] strings because JSON object keys must be strings.
52#[derive(Debug, Clone, Default, Serialize, Deserialize)]
53pub struct LearningCache {
54    #[serde(skip)]
55    path: Option<PathBuf>,
56    entries: HashMap<String, CacheEntry>,
57}
58
59fn cache_key_str(k: &CacheKey) -> String {
60    serde_json::to_string(k).unwrap_or_else(|_| {
61        format!(
62            "{{\"waf_fingerprint\":{},\"payload_type\":{}}}",
63            serde_json::to_string(&k.waf_fingerprint).unwrap_or_else(|_| "\"\"".to_string()),
64            serde_json::to_string(&k.payload_type).unwrap_or_else(|_| "\"\"".to_string()),
65        )
66    })
67}
68
69impl LearningCache {
70    /// Open the default cache at `~/.wafrift/learning_cache.json`.
71    ///
72    /// # Errors
73    ///
74    /// Returns an error if the home directory cannot be determined.
75    pub fn open_default() -> Result<Self, LearningCacheError> {
76        let home = dirs::home_dir().ok_or(LearningCacheError::NoHomeDir)?;
77        let path = home.join(".wafrift").join("learning_cache.json");
78        Self::open(path)
79    }
80
81    /// Open or create a cache at a specific path.
82    ///
83    /// A corrupted cache file (kill-9 mid-save, disk corruption, partial
84    /// flush) is moved aside to `<path>.corrupt-<epoch>` and a fresh
85    /// empty cache is returned. Crashing the whole strategy engine on
86    /// one bad JSON file would lose all subsequent learning — better to
87    /// surface the corruption via `tracing::warn` and keep going.
88    ///
89    /// # Errors
90    ///
91    /// Returns an error only if the file exists, looks fine, and the
92    /// underlying I/O still fails (permission denied, etc.).
93    pub fn open(path: impl AsRef<Path>) -> Result<Self, LearningCacheError> {
94        let path = path.as_ref();
95        if path.exists() {
96            // Audit (2026-05-10): pre-fix the cache was loaded with no
97            // size or depth limit on the JSON. A maliciously crafted
98            // ~/.wafrift/learning_cache.json could exhaust memory
99            // (multi-GB file) or stack (deeply nested arrays). Cap the
100            // file at MAX_CACHE_FILE_BYTES; the JSON parser then has
101            // a bounded heap and stack via that bound.
102            const MAX_CACHE_FILE_BYTES: u64 = 16 * 1024 * 1024;
103            let metadata = fs::metadata(path).map_err(LearningCacheError::Io)?;
104            if metadata.len() > MAX_CACHE_FILE_BYTES {
105                tracing::warn!(
106                    path = %path.display(),
107                    bytes = metadata.len(),
108                    cap = MAX_CACHE_FILE_BYTES,
109                    "learning cache file exceeds size cap; moving aside and starting fresh"
110                );
111                let backup = path.with_extension(format!("oversize-{}", current_epoch()));
112                let _ = fs::rename(path, &backup);
113                return Ok(Self {
114                    path: Some(path.to_path_buf()),
115                    entries: HashMap::new(),
116                });
117            }
118            let contents = fs::read_to_string(path).map_err(LearningCacheError::Io)?;
119            match serde_json::from_str::<LearningCache>(&contents) {
120                Ok(mut cache) => {
121                    cache.path = Some(path.to_path_buf());
122                    Ok(cache)
123                }
124                Err(e) => {
125                    let backup = path.with_extension(format!("corrupt-{}", current_epoch()));
126                    let backup_msg = match fs::rename(path, &backup) {
127                        Ok(()) => format!("moved aside to {}", backup.display()),
128                        Err(rename_err) => {
129                            format!("could not rename ({rename_err}); leaving file in place")
130                        }
131                    };
132                    tracing::warn!(
133                        path = %path.display(),
134                        error = %e,
135                        backup = %backup_msg,
136                        "learning cache file corrupted; starting fresh"
137                    );
138                    Ok(Self {
139                        path: Some(path.to_path_buf()),
140                        entries: HashMap::new(),
141                    })
142                }
143            }
144        } else {
145            Ok(Self {
146                path: Some(path.to_path_buf()),
147                entries: HashMap::new(),
148            })
149        }
150    }
151
152    /// Look up a cached pipeline.
153    #[must_use]
154    pub fn get(&self, key: &CacheKey) -> Option<&CacheEntry> {
155        self.entries.get(&cache_key_str(key))
156    }
157
158    /// Record a successful bypass.
159    pub fn record_success(&mut self, key: CacheKey, pipeline: EvasionPipeline) {
160        let now = current_epoch();
161        let entry = self
162            .entries
163            .entry(cache_key_str(&key))
164            .or_insert(CacheEntry {
165                pipeline,
166                successes: 0,
167                attempts: 0,
168                last_success_epoch: 0,
169            });
170        entry.successes = entry.successes.saturating_add(1);
171        entry.attempts = entry.attempts.saturating_add(1);
172        entry.last_success_epoch = now;
173    }
174
175    /// Record a failed attempt.
176    pub fn record_failure(&mut self, key: CacheKey, pipeline: EvasionPipeline) {
177        let entry = self
178            .entries
179            .entry(cache_key_str(&key))
180            .or_insert(CacheEntry {
181                pipeline,
182                successes: 0,
183                attempts: 0,
184                last_success_epoch: 0,
185            });
186        entry.attempts = entry.attempts.saturating_add(1);
187    }
188
189    /// Persist the cache to disk atomically.
190    ///
191    /// Writes to a sibling `<path>.tmp.<pid>.<epoch>` file, fsyncs it,
192    /// then renames over the target path. A kill-9 between `write` and
193    /// `rename` leaves the previous good cache file untouched instead
194    /// of producing the half-written JSON that was poisoning subsequent
195    /// `open` calls.
196    ///
197    /// # Errors
198    ///
199    /// Returns an error if the file cannot be written or renamed.
200    pub fn save(&self) -> Result<(), LearningCacheError> {
201        let path = self.path.as_ref().ok_or(LearningCacheError::NoPath)?;
202        if let Some(parent) = path.parent() {
203            fs::create_dir_all(parent).map_err(LearningCacheError::Io)?;
204        }
205        let json = serde_json::to_string_pretty(self).map_err(LearningCacheError::Serde)?;
206
207        // Sibling tmp file in the same directory so `rename` is atomic
208        // (cross-FS rename on /tmp would silently fall back to copy).
209        let tmp = path.with_extension(format!(
210            "tmp.{}.{}",
211            std::process::id(),
212            current_epoch()
213        ));
214        // Scope the file handle so the OS releases its descriptor before
215        // we rename — Windows would otherwise refuse the rename.
216        {
217            use std::io::Write;
218            let mut f = fs::File::create(&tmp).map_err(LearningCacheError::Io)?;
219            f.write_all(json.as_bytes()).map_err(LearningCacheError::Io)?;
220            f.sync_all().map_err(LearningCacheError::Io)?;
221        }
222        if let Err(e) = fs::rename(&tmp, path) {
223            // Clean up the orphaned tmp file before propagating.
224            let _ = fs::remove_file(&tmp);
225            return Err(LearningCacheError::Io(e));
226        }
227        Ok(())
228    }
229
230    /// All cached keys.
231    #[must_use]
232    pub fn keys(&self) -> Vec<CacheKey> {
233        self.entries
234            .keys()
235            .filter_map(|s| match serde_json::from_str(s) {
236                Ok(k) => Some(k),
237                Err(e) => {
238                    tracing::warn!(key = %s, error = %e, "learning cache key parse failed");
239                    None
240                }
241            })
242            .collect()
243    }
244}
245
246/// Errors from learning cache operations.
247#[derive(Debug, thiserror::Error)]
248pub enum LearningCacheError {
249    #[error("learning cache I/O error: {0}")]
250    Io(#[from] std::io::Error),
251    #[error("learning cache serialization error: {0}")]
252    Serde(#[from] serde_json::Error),
253    #[error("cannot determine home directory")]
254    NoHomeDir,
255    #[error("no path set for learning cache")]
256    NoPath,
257}
258
259fn current_epoch() -> u64 {
260    std::time::SystemTime::now()
261        .duration_since(std::time::UNIX_EPOCH)
262        .map_or(0, |d| d.as_secs())
263}
264
265#[cfg(test)]
266mod tests {
267    use super::*;
268    use crate::pipeline::EvasionStage;
269    use wafrift_types::Technique;
270
271    #[test]
272    fn cache_roundtrip() {
273        let tmp = std::env::temp_dir().join("wafrift_learning_cache_test.json");
274        let _ = fs::remove_file(&tmp);
275
276        let mut cache = LearningCache::open(&tmp).unwrap();
277        let pipeline = EvasionPipeline::new(
278            "test",
279            vec![EvasionStage {
280                technique: Technique::UserAgentRotation,
281                context: None,
282            }],
283            1,
284        );
285        cache.record_success(CacheKey::new("cloudflare", "sql"), pipeline);
286        cache.save().unwrap();
287
288        let cache2 = LearningCache::open(&tmp).unwrap();
289        let entry = cache2.get(&CacheKey::new("cloudflare", "sql")).unwrap();
290        assert_eq!(entry.successes, 1);
291        assert_eq!(entry.attempts, 1);
292
293        let _ = fs::remove_file(&tmp);
294    }
295
296    #[test]
297    fn cache_persists_across_process_restarts() {
298        let tmp = std::env::temp_dir().join("wafrift_learning_cache_restart.json");
299        let _ = fs::remove_file(&tmp);
300
301        // Process 1
302        {
303            let mut cache = LearningCache::open(&tmp).unwrap();
304            let pipeline = EvasionPipeline::new(
305                "win",
306                vec![EvasionStage {
307                    technique: Technique::GrammarMutation("sql".into()),
308                    context: None,
309                }],
310                2,
311            );
312            cache.record_success(CacheKey::new("aws_waf", "xss"), pipeline);
313            cache.save().unwrap();
314        }
315
316        // Process 2
317        {
318            let cache = LearningCache::open(&tmp).unwrap();
319            let entry = cache.get(&CacheKey::new("aws_waf", "xss")).unwrap();
320            assert_eq!(entry.successes, 1);
321            assert!(entry.last_success_epoch > 0);
322        }
323
324        let _ = fs::remove_file(&tmp);
325    }
326
327    #[test]
328    fn cache_failure_tracking() {
329        let tmp = std::env::temp_dir().join("wafrift_learning_cache_fail.json");
330        let _ = fs::remove_file(&tmp);
331
332        let mut cache = LearningCache::open(&tmp).unwrap();
333        let pipeline = EvasionPipeline::new("lose", vec![], 1);
334        let key = CacheKey::new("modsecurity", "cmdi");
335        cache.record_failure(key.clone(), pipeline);
336        cache.save().unwrap();
337
338        let cache2 = LearningCache::open(&tmp).unwrap();
339        let entry = cache2.get(&key).unwrap();
340        assert_eq!(entry.successes, 0);
341        assert_eq!(entry.attempts, 1);
342
343        let _ = fs::remove_file(&tmp);
344    }
345}