1use once_cell::sync::Lazy;
27use regex::{Regex, RegexSet};
28use serde::Deserialize;
29use std::collections::HashMap;
30use std::sync::RwLock;
31
32const MAX_REGEX_PATTERN_LEN: usize = 4096;
36
37const MAX_BODY_REGEX_PATTERNS: usize = 2000;
40
41const BODY_ONLY_MIN_CONFIDENCE: f64 = 0.5;
46
47static RULE_DB: Lazy<RwLock<RuleEngine>> = Lazy::new(|| {
49 let engine = RuleEngine::load_embedded().unwrap_or_else(|e| {
50 tracing::warn!("Failed to load embedded WAF rules: {e}");
51 RuleEngine::default()
52 });
53 RwLock::new(engine)
54});
55
56#[derive(Debug, Default, Clone)]
61pub struct RuleEngine {
62 pub rules: HashMap<String, CompiledWafRule>,
64 pub names: Vec<String>,
66
67 body_regex_set: Option<RegexSet>,
70
71 body_pattern_map: Vec<BodyPatternRef>,
77
78 body_regexes: Vec<Regex>,
82}
83
84#[derive(Debug, Clone)]
86struct BodyPatternRef {
87 waf_name: String,
89 #[allow(dead_code)]
91 sig_index: usize,
92 weight: f64,
94}
95
96#[derive(Debug, Clone)]
98pub struct CompiledWafRule {
99 pub name: String,
100 pub vendor: String,
101 pub confidence_threshold: f64,
102 pub evasions: Vec<String>,
103 pub source: String,
104 pub signatures: Vec<CompiledSignature>,
105}
106
107#[derive(Debug, Clone)]
113pub struct CompiledSignature {
114 pub header_name: Option<String>,
115 pub header_regex: Option<Regex>,
116 pub cookie_regex: Option<Regex>,
117 pub body_regex: Option<Regex>,
120 pub status_code: Option<u16>,
121 pub weight: f64,
122}
123
124#[derive(Debug, Clone, Deserialize)]
126struct RawRuleDb {
127 #[serde(default)]
128 waf: Vec<RawWafRule>,
129}
130
131#[derive(Debug, Clone, Deserialize)]
133struct RawWafRule {
134 name: String,
135 vendor: String,
136 #[serde(default = "default_threshold")]
137 confidence_threshold: f64,
138 #[serde(default)]
139 evasions: Vec<String>,
140 #[serde(default)]
141 source: String,
142 #[serde(default)]
143 signature: Vec<RawSignature>,
144}
145
146#[derive(Debug, Clone, Deserialize)]
148struct RawSignature {
149 header_name: Option<String>,
150 header_regex: Option<String>,
151 cookie_regex: Option<String>,
152 body_regex: Option<String>,
153 status_code: Option<u16>,
154 #[serde(default = "default_weight")]
155 weight: f64,
156}
157
158fn default_threshold() -> f64 {
159 0.3
160}
161
162fn default_weight() -> f64 {
163 0.4
164}
165
166const EMBEDDED_RULES_TOML: &str =
172 include_str!(concat!(env!("OUT_DIR"), "/embedded_detect_rules.toml"));
173
174impl RuleEngine {
175 pub fn load_embedded() -> Result<Self, DetectRulesError> {
186 let mut engine = RuleEngine {
187 rules: HashMap::new(),
188 names: Vec::new(),
189 body_regex_set: None,
190 body_pattern_map: Vec::new(),
191 body_regexes: Vec::new(),
192 };
193
194 let embedded_ok =
196 engine.load_from_str(EMBEDDED_RULES_TOML).is_ok() && !engine.rules.is_empty();
197
198 if !embedded_ok {
200 let candidates = [
201 std::path::PathBuf::from("rules/detect"),
202 std::path::PathBuf::from("../rules/detect"),
203 std::path::PathBuf::from("../../rules/detect"),
204 ];
205
206 let mut loaded = false;
207 for dir in &candidates {
208 if dir.is_dir() {
209 engine.load_directory(dir)?;
210 loaded = true;
211 break;
212 }
213 }
214
215 if !loaded {
216 return Err(DetectRulesError::Io(std::io::Error::new(
217 std::io::ErrorKind::NotFound,
218 "rules/detect directory not found and no embedded rules available",
219 )));
220 }
221 }
222
223 engine.compile_body_regex_set()?;
225
226 Ok(engine)
227 }
228
229 pub fn load_from_str(&mut self, toml_content: &str) -> Result<(), DetectRulesError> {
233 let raw: RawRuleDb = toml::from_str(toml_content)
234 .map_err(|e| DetectRulesError::Parse(format!("embedded rules: {e}")))?;
235 for waf in raw.waf {
236 let compiled = Self::compile_waf(waf)
237 .map_err(|e| DetectRulesError::Parse(format!("embedded rules: {e}")))?;
238 let key = compiled.name.clone();
239 if !self.rules.contains_key(&key) {
240 self.names.push(key.clone());
241 }
242 self.rules.insert(key, compiled);
243 }
244 Ok(())
245 }
246
247 pub fn load_directory(&mut self, path: &std::path::Path) -> Result<(), DetectRulesError> {
249 let mut entries: Vec<_> = std::fs::read_dir(path)?
250 .filter_map(std::result::Result::ok)
251 .filter(|e| {
252 e.path()
253 .extension()
254 .is_some_and(|ext| ext.eq_ignore_ascii_case("toml"))
255 })
256 .map(|e| e.path())
257 .collect();
258 entries.sort();
259
260 for entry in entries {
261 let content = std::fs::read_to_string(&entry)?;
262 let raw: RawRuleDb = toml::from_str(&content)
263 .map_err(|e| DetectRulesError::Parse(format!("{}: {e}", entry.display())))?;
264 for waf in raw.waf {
265 let compiled = Self::compile_waf(waf)
266 .map_err(|e| DetectRulesError::Parse(format!("{}: {e}", entry.display())))?;
267 let key = compiled.name.clone();
268 if !self.rules.contains_key(&key) {
269 self.names.push(key.clone());
270 }
271 self.rules.insert(key, compiled);
272 }
273 }
274 Ok(())
275 }
276
277 fn compile_waf(raw: RawWafRule) -> Result<CompiledWafRule, String> {
278 let mut signatures = Vec::with_capacity(raw.signature.len());
279 for sig in raw.signature {
280 let header_regex = sig
281 .header_regex
282 .as_ref()
283 .filter(|p| {
284 if p.len() > MAX_REGEX_PATTERN_LEN {
285 tracing::warn!(
286 waf = %raw.name,
287 pattern_len = p.len(),
288 max = MAX_REGEX_PATTERN_LEN,
289 "skipping oversized header regex"
290 );
291 false
292 } else {
293 true
294 }
295 })
296 .map(|p| Regex::new(p).map_err(|e| format!("bad header regex '{p}': {e}")))
297 .transpose()?;
298 let cookie_regex = sig
299 .cookie_regex
300 .as_ref()
301 .filter(|p| {
302 if p.len() > MAX_REGEX_PATTERN_LEN {
303 tracing::warn!(
304 waf = %raw.name,
305 pattern_len = p.len(),
306 max = MAX_REGEX_PATTERN_LEN,
307 "skipping oversized cookie regex"
308 );
309 false
310 } else {
311 true
312 }
313 })
314 .map(|p| Regex::new(p).map_err(|e| format!("bad cookie regex '{p}': {e}")))
315 .transpose()?;
316 let body_regex = sig
317 .body_regex
318 .as_ref()
319 .filter(|p| {
320 if p.len() > MAX_REGEX_PATTERN_LEN {
321 tracing::warn!(
322 waf = %raw.name,
323 pattern_len = p.len(),
324 max = MAX_REGEX_PATTERN_LEN,
325 "skipping oversized body regex"
326 );
327 false
328 } else {
329 true
330 }
331 })
332 .map(|p| Regex::new(p).map_err(|e| format!("bad body regex '{p}': {e}")))
333 .transpose()?;
334 signatures.push(CompiledSignature {
335 header_name: sig.header_name.map(|s| s.to_ascii_lowercase()),
336 header_regex,
337 cookie_regex,
338 body_regex,
339 status_code: sig.status_code,
340 weight: sig.weight,
341 });
342 }
343 Ok(CompiledWafRule {
344 name: raw.name,
345 vendor: raw.vendor,
346 confidence_threshold: raw.confidence_threshold,
347 evasions: raw.evasions,
348 source: raw.source,
349 signatures,
350 })
351 }
352
353 pub fn compile_body_regex_set(&mut self) -> Result<(), DetectRulesError> {
359 let mut patterns: Vec<String> = Vec::new();
360 let mut map: Vec<BodyPatternRef> = Vec::new();
361 let mut regexes: Vec<Regex> = Vec::new();
362
363 for name in &self.names {
364 let rule = &self.rules[name];
365 for (sig_idx, sig) in rule.signatures.iter().enumerate() {
366 if let Some(ref re) = sig.body_regex {
367 if patterns.len() >= MAX_BODY_REGEX_PATTERNS {
368 tracing::warn!(
369 limit = MAX_BODY_REGEX_PATTERNS,
370 "truncating body regex set; some WAF signatures will not match on body text"
371 );
372 break;
373 }
374 patterns.push(re.as_str().to_string());
375 map.push(BodyPatternRef {
376 waf_name: name.clone(),
377 sig_index: sig_idx,
378 weight: sig.weight,
379 });
380 regexes.push(re.clone());
381 }
382 }
383 if patterns.len() >= MAX_BODY_REGEX_PATTERNS {
384 break;
385 }
386 }
387
388 if !patterns.is_empty() {
389 let set = RegexSet::new(&patterns).map_err(|e| {
390 DetectRulesError::Parse(format!("failed to compile body RegexSet: {e}"))
391 })?;
392 self.body_regex_set = Some(set);
393 }
394
395 self.body_pattern_map = map;
396 self.body_regexes = regexes;
397 Ok(())
398 }
399
400 pub fn detect(
406 &self,
407 status: u16,
408 headers: &[(String, String)],
409 body: &str,
410 ) -> Vec<DetectedWaf> {
411 let body_hits: Vec<usize> = self
416 .body_regex_set
417 .as_ref()
418 .map(|set| set.matches(body).into_iter().collect())
419 .unwrap_or_default();
420
421 let mut waf_scores: HashMap<&str, (f64, Vec<String>)> = HashMap::new();
423
424 for &pattern_idx in &body_hits {
425 let pref = &self.body_pattern_map[pattern_idx];
426 let entry = waf_scores
427 .entry(&pref.waf_name)
428 .or_insert_with(|| (0.0, Vec::new()));
429 entry.0 += pref.weight;
430
431 if let Some(m) = self.body_regexes[pattern_idx].find(body) {
433 let snippet = &body[m.start()..m.end().min(m.start() + 40)];
434 entry.1.push(format!("body: {snippet}"));
435 }
436 }
437
438 for name in &self.names {
442 let rule = &self.rules[name];
443 for sig in &rule.signatures {
444 if sig.header_regex.is_none()
446 && sig.cookie_regex.is_none()
447 && sig.status_code.is_none()
448 {
449 continue;
450 }
451
452 let mut matched = false;
453 let entry = waf_scores.entry(name).or_insert_with(|| (0.0, Vec::new()));
454
455 if let Some(expected) = sig.status_code
456 && status == expected
457 {
458 matched = true;
459 entry.1.push(format!("status: {status}"));
460 }
461
462 if let Some(ref re) = sig.header_regex {
463 let hname = sig.header_name.as_deref().unwrap_or("");
464 for (k, v) in headers {
465 if (hname.is_empty() || k.eq_ignore_ascii_case(hname))
466 && let Some(m) = re.find(v)
467 {
468 matched = true;
469 entry.1.push(format!(
470 "header {k}: {}",
471 &v[m.start()..m.end().min(m.start() + 40)]
472 ));
473 break;
474 }
475 }
476 }
477
478 if let Some(ref re) = sig.cookie_regex {
479 for (k, v) in headers {
480 if k.eq_ignore_ascii_case("set-cookie") && re.is_match(v) {
481 matched = true;
482 entry.1.push(format!("cookie: {k}"));
483 break;
484 }
485 }
486 }
487
488 if matched {
489 entry.0 += sig.weight;
490 }
491 }
492 }
493
494 let mut results: Vec<DetectedWaf> = waf_scores
496 .into_iter()
497 .filter_map(|(name, (score, indicators))| {
498 let rule = &self.rules[name];
499 let has_non_body_indicator = indicators
500 .iter()
501 .any(|indicator| !indicator.starts_with("body: "));
502 let effective_threshold = if has_non_body_indicator {
503 rule.confidence_threshold
504 } else {
505 rule.confidence_threshold.max(BODY_ONLY_MIN_CONFIDENCE)
506 };
507 if score >= effective_threshold {
508 Some(DetectedWaf {
509 name: name.to_string(),
510 confidence: score.min(1.0),
511 indicators,
512 })
513 } else {
514 None
515 }
516 })
517 .collect();
518
519 results.sort_by(|a, b| {
520 b.confidence
521 .partial_cmp(&a.confidence)
522 .unwrap_or(std::cmp::Ordering::Equal)
523 .then_with(|| a.name.cmp(&b.name))
524 });
525 results
526 }
527
528 #[must_use]
530 pub fn evasions_for(&self, name: &str) -> Vec<&str> {
531 self.rules
532 .get(name)
533 .map(|r| r.evasions.iter().map(String::as_str).collect())
534 .unwrap_or_default()
535 }
536
537 #[must_use]
539 pub fn len(&self) -> usize {
540 self.rules.len()
541 }
542
543 #[must_use]
544 pub fn is_empty(&self) -> bool {
545 self.rules.is_empty()
546 }
547}
548
549#[derive(Debug, Clone)]
551pub struct DetectedWaf {
552 pub name: String,
553 pub confidence: f64,
554 pub indicators: Vec<String>,
555}
556
557#[derive(Debug, thiserror::Error)]
559pub enum DetectRulesError {
560 #[error("io error: {0}")]
561 Io(#[from] std::io::Error),
562 #[error("parse error: {0}")]
563 Parse(String),
564}
565
566pub fn with_engine<F, R>(f: F) -> R
568where
569 F: FnOnce(&RuleEngine) -> R,
570{
571 let guard = RULE_DB.read().unwrap_or_else(std::sync::PoisonError::into_inner);
572 f(&guard)
573}
574
575pub fn reload() -> Result<(), DetectRulesError> {
577 let new_engine = RuleEngine::load_embedded()?;
578 let mut guard = RULE_DB
579 .write()
580 .map_err(|e| DetectRulesError::Parse(format!("RULE_DB poisoned: {e}")))?;
581 *guard = new_engine;
582 Ok(())
583}
584
585#[must_use]
587pub fn detect(status: u16, headers: &[(String, String)], body: &str) -> Vec<DetectedWaf> {
588 with_engine(|engine| engine.detect(status, headers, body))
589}
590
591#[must_use]
593pub fn supported_wafs() -> Vec<String> {
594 with_engine(|engine| engine.names.clone())
595}
596
597#[must_use]
606pub fn suggest_evasion(waf_name: &str) -> Vec<String> {
607 with_engine(|engine| {
608 engine
609 .rules
610 .get(waf_name).map_or_else(|| {
611 vec![
612 "CaseAlternation".into(),
613 "SqlCommentInsertion".into(),
614 "DoubleUrlEncode".into(),
615 "ContentTypeSwitch".into(),
616 ]
617 }, |r| r.evasions.clone())
618 })
619}
620
621#[derive(Debug, Clone, Copy)]
623pub struct DetectConfig {
624 pub threshold: f64,
626 pub ambiguity_delta: f64,
628}
629
630impl Default for DetectConfig {
631 fn default() -> Self {
632 Self {
633 threshold: 0.3,
634 ambiguity_delta: 0.15,
635 }
636 }
637}
638
639#[must_use]
641pub fn detect_with_config(
642 status: u16,
643 headers: &[(String, String)],
644 body: &str,
645 config: DetectConfig,
646) -> Vec<DetectedWaf> {
647 let mut results = detect(status, headers, body);
648 results.retain(|r| r.confidence >= config.threshold);
649
650 if results.len() >= 2 {
651 let delta = results[0].confidence - results[1].confidence;
652 if delta < config.ambiguity_delta {
653 let mut keep = 2;
655 for window in results.windows(2) {
656 if window[0].confidence - window[1].confidence < config.ambiguity_delta {
657 keep += 1;
658 } else {
659 break;
660 }
661 }
662 results.truncate(keep);
663 } else {
664 results.truncate(1);
665 }
666 }
667 results
668}
669
670#[cfg(test)]
671mod tests {
672 use super::*;
673
674 const TEST_TOML: &str = r#"
675[[waf]]
676name = "TestWAF"
677vendor = "test"
678confidence_threshold = 0.3
679evasions = ["CaseAlternation", "SqlCommentInsertion"]
680
681[[waf.signature]]
682header_name = "x-test-waf"
683header_regex = "active"
684weight = 0.9
685
686[[waf.signature]]
687body_regex = "blocked by test"
688weight = 0.95
689
690[[waf.signature]]
691status_code = 403
692weight = 0.5
693
694[[waf]]
695name = "AnotherWAF"
696vendor = "another"
697confidence_threshold = 0.5
698evasions = ["DoubleUrlEncode"]
699
700[[waf.signature]]
701body_regex = "another waf"
702weight = 0.6
703"#;
704
705 fn test_engine() -> RuleEngine {
706 let mut engine = RuleEngine::default();
707 engine.load_from_str(TEST_TOML).expect("load test toml");
708 engine.compile_body_regex_set().expect("compile regex set");
709 engine
710 }
711
712 #[test]
713 fn load_from_str_populates_rules() {
714 let engine = test_engine();
715 assert_eq!(engine.len(), 2);
716 assert!(!engine.is_empty());
717 }
718
719 #[test]
720 fn detect_by_header() {
721 let engine = test_engine();
722 let headers = vec![("x-test-waf".into(), "active".into())];
723 let results = engine.detect(200, &headers, "OK");
724 assert_eq!(results.len(), 1);
725 assert_eq!(results[0].name, "TestWAF");
726 assert!(results[0].confidence >= 0.9);
727 }
728
729 #[test]
730 fn detect_by_body() {
731 let engine = test_engine();
732 let headers: Vec<(String, String)> = vec![];
733 let results = engine.detect(200, &headers, "you are blocked by test engine");
734 assert_eq!(results.len(), 1);
735 assert_eq!(results[0].name, "TestWAF");
736 assert!(results[0].confidence >= 0.95);
737 }
738
739 #[test]
740 fn detect_by_status() {
741 let engine = test_engine();
742 let headers: Vec<(String, String)> = vec![];
743 let results = engine.detect(403, &headers, "");
744 assert_eq!(results.len(), 1);
745 assert_eq!(results[0].name, "TestWAF");
746 }
747
748 #[test]
749 fn detect_no_match() {
750 let engine = test_engine();
751 let headers = vec![("server".into(), "nginx".into())];
752 let results = engine.detect(200, &headers, "Welcome");
753 assert!(results.is_empty());
754 }
755
756 #[test]
757 fn detect_confidence_threshold_filters_body_only() {
758 let engine = test_engine();
759 let results = engine.detect(200, &[], "another waf detected");
761 assert_eq!(results.len(), 1);
762 assert_eq!(results[0].name, "AnotherWAF");
763 }
764
765 #[test]
766 fn evasions_for_known_waf() {
767 let engine = test_engine();
768 let evasions = engine.evasions_for("TestWAF");
769 assert_eq!(evasions.len(), 2);
770 assert!(evasions.contains(&"CaseAlternation"));
771 }
772
773 #[test]
774 fn evasions_for_unknown_waf_empty() {
775 let engine = test_engine();
776 assert!(engine.evasions_for("Unknown").is_empty());
777 }
778
779 #[test]
780 fn detect_body_only_needs_higher_threshold() {
781 let mut engine = RuleEngine::default();
782 engine
783 .load_from_str(
784 r#"
785[[waf]]
786name = "LowConfWAF"
787vendor = "test"
788confidence_threshold = 0.1
789
790[[waf.signature]]
791body_regex = "blocked"
792weight = 0.4
793"#,
794 )
795 .expect("load");
796 engine.compile_body_regex_set().expect("compile");
797
798 let results = engine.detect(200, &[], "blocked");
800 assert!(results.is_empty());
801 }
802
803 #[test]
804 fn empty_engine_returns_empty() {
805 let engine = RuleEngine::default();
806 assert!(engine.is_empty());
807 assert_eq!(engine.len(), 0);
808 let results = engine.detect(200, &[], "body");
809 assert!(results.is_empty());
810 }
811
812 #[test]
813 fn detect_sorts_by_confidence_desc() {
814 let engine = test_engine();
815 let headers = vec![("x-test-waf".into(), "active".into())];
818 let results = engine.detect(200, &headers, "blocked by test and another waf");
819 assert!(!results.is_empty());
820 assert_eq!(results[0].name, "TestWAF");
821 }
822}