1use crate::pipeline::EvasionPipeline;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::fs;
10use std::path::{Path, PathBuf};
11
12#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
14pub struct CacheKey {
15 pub waf_fingerprint: String,
16 pub payload_type: String,
17}
18
19impl CacheKey {
20 #[must_use]
21 pub fn new(waf: impl Into<String>, payload: impl Into<String>) -> Self {
22 Self {
23 waf_fingerprint: waf.into(),
24 payload_type: payload.into(),
25 }
26 }
27}
28
29#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
31pub struct CacheEntry {
32 pub pipeline: EvasionPipeline,
33 pub successes: u32,
34 pub attempts: u32,
35 pub last_success_epoch: u64,
36}
37
38impl CacheEntry {
39 #[must_use]
40 pub fn success_rate(&self) -> f64 {
41 if self.attempts == 0 {
42 0.0
43 } else {
44 f64::from(self.successes) / f64::from(self.attempts)
45 }
46 }
47}
48
49#[derive(Debug, Clone, Default, Serialize, Deserialize)]
53pub struct LearningCache {
54 #[serde(skip)]
55 path: Option<PathBuf>,
56 entries: HashMap<String, CacheEntry>,
57}
58
59fn cache_key_str(k: &CacheKey) -> String {
60 serde_json::to_string(k).unwrap_or_else(|_| {
61 format!(
62 "{{\"waf_fingerprint\":{},\"payload_type\":{}}}",
63 serde_json::to_string(&k.waf_fingerprint).unwrap_or_else(|_| "\"\"".to_string()),
64 serde_json::to_string(&k.payload_type).unwrap_or_else(|_| "\"\"".to_string()),
65 )
66 })
67}
68
69impl LearningCache {
70 pub fn open_default() -> Result<Self, LearningCacheError> {
76 let home = dirs::home_dir().ok_or(LearningCacheError::NoHomeDir)?;
77 let path = home.join(".wafrift").join("learning_cache.json");
78 Self::open(path)
79 }
80
81 pub fn open(path: impl AsRef<Path>) -> Result<Self, LearningCacheError> {
94 let path = path.as_ref();
95 if path.exists() {
96 const MAX_CACHE_FILE_BYTES: u64 = 16 * 1024 * 1024;
103 let metadata = fs::metadata(path).map_err(LearningCacheError::Io)?;
104 if metadata.len() > MAX_CACHE_FILE_BYTES {
105 tracing::warn!(
106 path = %path.display(),
107 bytes = metadata.len(),
108 cap = MAX_CACHE_FILE_BYTES,
109 "learning cache file exceeds size cap; moving aside and starting fresh"
110 );
111 let backup = path.with_extension(format!("oversize-{}", current_epoch()));
112 let _ = fs::rename(path, &backup);
113 return Ok(Self {
114 path: Some(path.to_path_buf()),
115 entries: HashMap::new(),
116 });
117 }
118 let contents = fs::read_to_string(path).map_err(LearningCacheError::Io)?;
119 match serde_json::from_str::<LearningCache>(&contents) {
120 Ok(mut cache) => {
121 cache.path = Some(path.to_path_buf());
122 Ok(cache)
123 }
124 Err(e) => {
125 let backup = path.with_extension(format!("corrupt-{}", current_epoch()));
126 let backup_msg = match fs::rename(path, &backup) {
127 Ok(()) => format!("moved aside to {}", backup.display()),
128 Err(rename_err) => {
129 format!("could not rename ({rename_err}); leaving file in place")
130 }
131 };
132 tracing::warn!(
133 path = %path.display(),
134 error = %e,
135 backup = %backup_msg,
136 "learning cache file corrupted; starting fresh"
137 );
138 Ok(Self {
139 path: Some(path.to_path_buf()),
140 entries: HashMap::new(),
141 })
142 }
143 }
144 } else {
145 Ok(Self {
146 path: Some(path.to_path_buf()),
147 entries: HashMap::new(),
148 })
149 }
150 }
151
152 #[must_use]
154 pub fn get(&self, key: &CacheKey) -> Option<&CacheEntry> {
155 self.entries.get(&cache_key_str(key))
156 }
157
158 pub fn record_success(&mut self, key: CacheKey, pipeline: EvasionPipeline) {
160 let now = current_epoch();
161 let entry = self
162 .entries
163 .entry(cache_key_str(&key))
164 .or_insert(CacheEntry {
165 pipeline,
166 successes: 0,
167 attempts: 0,
168 last_success_epoch: 0,
169 });
170 entry.successes = entry.successes.saturating_add(1);
171 entry.attempts = entry.attempts.saturating_add(1);
172 entry.last_success_epoch = now;
173 }
174
175 pub fn record_failure(&mut self, key: CacheKey, pipeline: EvasionPipeline) {
177 let entry = self
178 .entries
179 .entry(cache_key_str(&key))
180 .or_insert(CacheEntry {
181 pipeline,
182 successes: 0,
183 attempts: 0,
184 last_success_epoch: 0,
185 });
186 entry.attempts = entry.attempts.saturating_add(1);
187 }
188
189 pub fn save(&self) -> Result<(), LearningCacheError> {
201 let path = self.path.as_ref().ok_or(LearningCacheError::NoPath)?;
202 if let Some(parent) = path.parent() {
203 fs::create_dir_all(parent).map_err(LearningCacheError::Io)?;
204 }
205 let json = serde_json::to_string_pretty(self).map_err(LearningCacheError::Serde)?;
206
207 let tmp = path.with_extension(format!(
210 "tmp.{}.{}",
211 std::process::id(),
212 current_epoch()
213 ));
214 {
217 use std::io::Write;
218 let mut f = fs::File::create(&tmp).map_err(LearningCacheError::Io)?;
219 f.write_all(json.as_bytes()).map_err(LearningCacheError::Io)?;
220 f.sync_all().map_err(LearningCacheError::Io)?;
221 }
222 if let Err(e) = fs::rename(&tmp, path) {
223 let _ = fs::remove_file(&tmp);
225 return Err(LearningCacheError::Io(e));
226 }
227 Ok(())
228 }
229
230 #[must_use]
232 pub fn keys(&self) -> Vec<CacheKey> {
233 self.entries
234 .keys()
235 .filter_map(|s| match serde_json::from_str(s) {
236 Ok(k) => Some(k),
237 Err(e) => {
238 tracing::warn!(key = %s, error = %e, "learning cache key parse failed");
239 None
240 }
241 })
242 .collect()
243 }
244}
245
246#[derive(Debug, thiserror::Error)]
248pub enum LearningCacheError {
249 #[error("learning cache I/O error: {0}")]
250 Io(#[from] std::io::Error),
251 #[error("learning cache serialization error: {0}")]
252 Serde(#[from] serde_json::Error),
253 #[error("cannot determine home directory")]
254 NoHomeDir,
255 #[error("no path set for learning cache")]
256 NoPath,
257}
258
259fn current_epoch() -> u64 {
260 std::time::SystemTime::now()
261 .duration_since(std::time::UNIX_EPOCH)
262 .map_or(0, |d| d.as_secs())
263}
264
265#[cfg(test)]
266mod tests {
267 use super::*;
268 use crate::pipeline::EvasionStage;
269 use wafrift_types::Technique;
270
271 #[test]
272 fn cache_roundtrip() {
273 let tmp = std::env::temp_dir().join("wafrift_learning_cache_test.json");
274 let _ = fs::remove_file(&tmp);
275
276 let mut cache = LearningCache::open(&tmp).unwrap();
277 let pipeline = EvasionPipeline::new(
278 "test",
279 vec![EvasionStage {
280 technique: Technique::UserAgentRotation,
281 context: None,
282 }],
283 1,
284 );
285 cache.record_success(CacheKey::new("cloudflare", "sql"), pipeline);
286 cache.save().unwrap();
287
288 let cache2 = LearningCache::open(&tmp).unwrap();
289 let entry = cache2.get(&CacheKey::new("cloudflare", "sql")).unwrap();
290 assert_eq!(entry.successes, 1);
291 assert_eq!(entry.attempts, 1);
292
293 let _ = fs::remove_file(&tmp);
294 }
295
296 #[test]
297 fn cache_persists_across_process_restarts() {
298 let tmp = std::env::temp_dir().join("wafrift_learning_cache_restart.json");
299 let _ = fs::remove_file(&tmp);
300
301 {
303 let mut cache = LearningCache::open(&tmp).unwrap();
304 let pipeline = EvasionPipeline::new(
305 "win",
306 vec![EvasionStage {
307 technique: Technique::GrammarMutation("sql".into()),
308 context: None,
309 }],
310 2,
311 );
312 cache.record_success(CacheKey::new("aws_waf", "xss"), pipeline);
313 cache.save().unwrap();
314 }
315
316 {
318 let cache = LearningCache::open(&tmp).unwrap();
319 let entry = cache.get(&CacheKey::new("aws_waf", "xss")).unwrap();
320 assert_eq!(entry.successes, 1);
321 assert!(entry.last_success_epoch > 0);
322 }
323
324 let _ = fs::remove_file(&tmp);
325 }
326
327 #[test]
328 fn cache_failure_tracking() {
329 let tmp = std::env::temp_dir().join("wafrift_learning_cache_fail.json");
330 let _ = fs::remove_file(&tmp);
331
332 let mut cache = LearningCache::open(&tmp).unwrap();
333 let pipeline = EvasionPipeline::new("lose", vec![], 1);
334 let key = CacheKey::new("modsecurity", "cmdi");
335 cache.record_failure(key.clone(), pipeline);
336 cache.save().unwrap();
337
338 let cache2 = LearningCache::open(&tmp).unwrap();
339 let entry = cache2.get(&key).unwrap();
340 assert_eq!(entry.successes, 0);
341 assert_eq!(entry.attempts, 1);
342
343 let _ = fs::remove_file(&tmp);
344 }
345}