Skip to main content

synapse_pingora/waf/
index.rs

1//! Rule indexing for fast candidate selection.
2
3use std::collections::{HashMap, HashSet};
4use std::sync::Arc;
5
6use crate::waf::rule::{boolean_operands, MatchCondition, MatchValue, WafRule};
7
8/// Method bit masks.
9pub const METHOD_GET: u8 = 1 << 0;
10pub const METHOD_POST: u8 = 1 << 1;
11pub const METHOD_HEAD: u8 = 1 << 2;
12pub const METHOD_PUT: u8 = 1 << 3;
13pub const METHOD_PATCH: u8 = 1 << 4;
14
15/// Feature requirement flags.
16pub const REQ_ARGS: u16 = 1 << 0;
17pub const REQ_ARG_ENTRIES: u16 = 1 << 1;
18pub const REQ_BODY: u16 = 1 << 2;
19pub const REQ_JSON: u16 = 1 << 3;
20pub const REQ_RESPONSE: u16 = 1 << 4;
21pub const REQ_RESPONSE_BODY: u16 = 1 << 5;
22pub const REQ_MULTIPART: u16 = 1 << 6;
23
24/// Rule index for fast candidate selection.
25#[derive(Default)]
26pub struct RuleIndex {
27    pub header_bits: Vec<String>,
28    pub rules: Vec<IndexedRule>,
29}
30
31/// Indexed rule metadata.
32#[derive(Clone, Debug, Default)]
33pub struct IndexedRule {
34    pub method_mask: Option<u8>,
35    pub uri_anchors: Vec<UriAnchor>,
36    pub requirements: RuleRequirements,
37}
38
39/// URI anchor for prefix/contains matching.
40#[derive(Clone, Debug, Eq, PartialEq, Hash)]
41pub struct UriAnchor {
42    pub kind: UriAnchorKind,
43    pub transform: UriTransform,
44    pub pattern: String,
45}
46
47#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
48pub enum UriAnchorKind {
49    Contains,
50    Prefix,
51}
52
53#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
54pub enum UriTransform {
55    Raw,
56    Lower,
57    PercentDecoded,
58    PercentDecodedLower,
59}
60
61impl UriTransform {
62    pub fn apply_lower(self) -> Self {
63        match self {
64            UriTransform::Raw => UriTransform::Lower,
65            UriTransform::Lower => UriTransform::Lower,
66            UriTransform::PercentDecoded => UriTransform::PercentDecodedLower,
67            UriTransform::PercentDecodedLower => UriTransform::PercentDecodedLower,
68        }
69    }
70
71    pub fn apply_percent_decode(self) -> Self {
72        match self {
73            UriTransform::Raw => UriTransform::PercentDecoded,
74            UriTransform::Lower => UriTransform::PercentDecodedLower,
75            UriTransform::PercentDecoded => UriTransform::PercentDecoded,
76            UriTransform::PercentDecodedLower => UriTransform::PercentDecodedLower,
77        }
78    }
79}
80
81/// Rule requirements for feature filtering.
82#[derive(Clone, Debug, Default)]
83pub struct RuleRequirements {
84    pub features: u16,
85    pub static_required: Option<bool>,
86    pub required_headers_mask: u64,
87}
88
89/// Candidate cache for repeated URIs.
90#[derive(Default)]
91pub struct CandidateCache {
92    max_entries: usize,
93    tick: u64,
94    len: usize,
95    by_key: HashMap<CandidateCacheKey, HashMap<String, CandidateCacheEntry>>,
96}
97
98#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
99pub struct CandidateCacheKey {
100    pub method_bit: u8,
101    pub available_features: u16,
102    pub is_static: bool,
103    pub header_mask: u64,
104}
105
106#[derive(Clone, Debug)]
107struct CandidateCacheEntry {
108    candidates: Arc<[usize]>,
109    last_used: u64,
110}
111
112impl CandidateCache {
113    pub fn new(max_entries: usize) -> Self {
114        Self {
115            max_entries: max_entries.min(65_536),
116            ..Default::default()
117        }
118    }
119
120    pub fn clear(&mut self) {
121        self.by_key.clear();
122        self.len = 0;
123        self.tick = 0;
124    }
125
126    pub fn get(&mut self, key: &CandidateCacheKey, uri: &str) -> Option<Arc<[usize]>> {
127        if self.max_entries == 0 {
128            return None;
129        }
130        self.tick = self.tick.wrapping_add(1);
131        let inner = self.by_key.get_mut(key)?;
132        let entry = inner.get_mut(uri)?;
133        entry.last_used = self.tick;
134        Some(entry.candidates.clone())
135    }
136
137    pub fn insert(&mut self, key: CandidateCacheKey, uri: String, candidates: Arc<[usize]>) {
138        if self.max_entries == 0 {
139            return;
140        }
141        self.tick = self.tick.wrapping_add(1);
142        let inner = self.by_key.entry(key).or_default();
143        if let Some(existing) = inner.get_mut(uri.as_str()) {
144            existing.candidates = candidates;
145            existing.last_used = self.tick;
146            return;
147        }
148        inner.insert(
149            uri,
150            CandidateCacheEntry {
151                candidates,
152                last_used: self.tick,
153            },
154        );
155        self.len += 1;
156        self.evict_if_needed();
157    }
158
159    fn evict_if_needed(&mut self) {
160        while self.len > self.max_entries {
161            let mut oldest_key: Option<CandidateCacheKey> = None;
162            let mut oldest_uri: Option<String> = None;
163            let mut oldest_tick = u64::MAX;
164
165            for (key, inner) in &self.by_key {
166                for (uri, entry) in inner {
167                    if entry.last_used < oldest_tick {
168                        oldest_tick = entry.last_used;
169                        oldest_key = Some(*key);
170                        oldest_uri = Some(uri.clone());
171                    }
172                }
173            }
174
175            let (Some(key), Some(uri)) = (oldest_key, oldest_uri) else {
176                break;
177            };
178
179            if let Some(inner) = self.by_key.get_mut(&key) {
180                if inner.remove(uri.as_str()).is_some() {
181                    self.len = self.len.saturating_sub(1);
182                }
183                if inner.is_empty() {
184                    self.by_key.remove(&key);
185                }
186            }
187        }
188    }
189}
190
191/// Convert method string to bit mask.
192pub fn method_to_mask(method: &str) -> Option<u8> {
193    if method.eq_ignore_ascii_case("GET") {
194        return Some(METHOD_GET);
195    }
196    if method.eq_ignore_ascii_case("POST") {
197        return Some(METHOD_POST);
198    }
199    if method.eq_ignore_ascii_case("HEAD") {
200        return Some(METHOD_HEAD);
201    }
202    if method.eq_ignore_ascii_case("PUT") {
203        return Some(METHOD_PUT);
204    }
205    if method.eq_ignore_ascii_case("PATCH") {
206        return Some(METHOD_PATCH);
207    }
208    None
209}
210
211/// Build rule index from rules.
212pub fn build_rule_index(rules: &[WafRule]) -> RuleIndex {
213    let mut index = RuleIndex::default();
214
215    // Collect all header fields
216    let mut header_names = HashSet::<String>::new();
217    for rule in rules {
218        for cond in &rule.matches {
219            collect_header_fields(cond, &mut header_names);
220        }
221    }
222
223    let mut header_bits: Vec<String> = header_names.into_iter().collect();
224    header_bits.sort();
225    if header_bits.len() > 64 {
226        header_bits.truncate(64);
227    }
228
229    let header_to_bit: HashMap<String, u8> = header_bits
230        .iter()
231        .enumerate()
232        .map(|(idx, header)| (header.clone(), idx as u8))
233        .collect();
234
235    index.header_bits = header_bits;
236    index.rules.reserve(rules.len());
237
238    for rule in rules {
239        let method_mask = extract_rule_method_mask(rule);
240        let mut uri_anchors = extract_rule_uri_anchors(rule);
241        if !uri_anchors.is_empty() {
242            let mut seen = HashSet::new();
243            uri_anchors.retain(|a| seen.insert(a.clone()));
244        }
245        let requirements = extract_rule_requirements(rule, &header_to_bit);
246        index.rules.push(IndexedRule {
247            method_mask,
248            uri_anchors,
249            requirements,
250        });
251    }
252
253    index
254}
255
256/// Get candidate rule indices for a request.
257pub fn get_candidate_rule_indices(
258    index: &RuleIndex,
259    method_bit: u8,
260    uri: &str,
261    available_features: u16,
262    is_static: bool,
263    header_mask: u64,
264    rule_count: usize,
265    percent_decode: impl Fn(&str) -> String,
266) -> Vec<usize> {
267    let mut out = Vec::new();
268    let req_method_mask = if method_bit == 0 {
269        None
270    } else {
271        Some(method_bit)
272    };
273
274    let mut uri_lower: Option<String> = None;
275    let mut uri_percent_decoded: Option<String> = None;
276    let mut uri_percent_decoded_lower: Option<String> = None;
277
278    let count = rule_count.min(index.rules.len());
279    for (idx, rule) in index.rules.iter().enumerate().take(count) {
280        // Check method
281        if let Some(rule_method_mask) = rule.method_mask {
282            let Some(req_method_mask) = req_method_mask else {
283                continue;
284            };
285            if (rule_method_mask & req_method_mask) == 0 {
286                continue;
287            }
288        }
289
290        // Check requirements
291        let requirements = &rule.requirements;
292        if (requirements.features & !available_features) != 0 {
293            continue;
294        }
295        if requirements.static_required == Some(true) && !is_static {
296            continue;
297        }
298        if requirements.static_required == Some(false) && is_static {
299            continue;
300        }
301        if (requirements.required_headers_mask & !header_mask) != 0 {
302            continue;
303        }
304
305        // Check URI anchors
306        if !rule.uri_anchors.is_empty() {
307            let mut matched = false;
308            for anchor in &rule.uri_anchors {
309                let haystack: &str = match anchor.transform {
310                    UriTransform::Raw => uri,
311                    UriTransform::Lower => {
312                        if uri_lower.is_none() {
313                            uri_lower = Some(uri.to_lowercase());
314                        }
315                        uri_lower.as_deref().unwrap_or(uri)
316                    }
317                    UriTransform::PercentDecoded => {
318                        if uri_percent_decoded.is_none() {
319                            uri_percent_decoded = Some(percent_decode(uri));
320                        }
321                        uri_percent_decoded.as_deref().unwrap_or(uri)
322                    }
323                    UriTransform::PercentDecodedLower => {
324                        if uri_percent_decoded_lower.is_none() {
325                            if uri_percent_decoded.is_none() {
326                                uri_percent_decoded = Some(percent_decode(uri));
327                            }
328                            uri_percent_decoded_lower =
329                                Some(uri_percent_decoded.as_deref().unwrap_or(uri).to_lowercase());
330                        }
331                        uri_percent_decoded_lower.as_deref().unwrap_or(uri)
332                    }
333                };
334
335                matched = match anchor.kind {
336                    UriAnchorKind::Contains => haystack.contains(anchor.pattern.as_str()),
337                    UriAnchorKind::Prefix => haystack.starts_with(anchor.pattern.as_str()),
338                };
339                if matched {
340                    break;
341                }
342            }
343
344            if !matched {
345                continue;
346            }
347        }
348
349        out.push(idx);
350    }
351
352    out
353}
354
355// Helper functions for index building
356
357fn method_mask_from_match_value(match_value: &MatchValue) -> Option<u8> {
358    match match_value {
359        MatchValue::Str(s) => method_to_mask(s),
360        MatchValue::Arr(items) => {
361            let mut mask = 0u8;
362            for item in items {
363                let Some(s) = item.as_str() else { continue };
364                let Some(bit) = method_to_mask(s) else {
365                    return None;
366                };
367                mask |= bit;
368            }
369            if mask == 0 {
370                None
371            } else {
372                Some(mask)
373            }
374        }
375        _ => None,
376    }
377}
378
379fn possible_method_mask(condition: &MatchCondition) -> Option<u8> {
380    match condition.kind.as_str() {
381        "method" => condition
382            .match_value
383            .as_ref()
384            .and_then(method_mask_from_match_value),
385        "boolean" => {
386            let op = condition.op.as_deref().unwrap_or("and");
387            let operands = boolean_operands(condition);
388            if operands.is_empty() {
389                return None;
390            }
391
392            match op {
393                "and" => {
394                    let mut out: Option<u8> = None;
395                    for operand in operands {
396                        let Some(mask) = possible_method_mask(operand) else {
397                            continue;
398                        };
399                        out = Some(match out {
400                            None => mask,
401                            Some(existing) => existing & mask,
402                        });
403                    }
404                    out.filter(|m| *m != 0)
405                }
406                "or" => {
407                    let mut mask = 0u8;
408                    for operand in operands {
409                        let Some(child_mask) = possible_method_mask(operand) else {
410                            return None;
411                        };
412                        mask |= child_mask;
413                    }
414                    if mask == 0 {
415                        None
416                    } else {
417                        Some(mask)
418                    }
419                }
420                _ => None,
421            }
422        }
423        _ => None,
424    }
425}
426
427fn extract_rule_method_mask(rule: &WafRule) -> Option<u8> {
428    let mut out: Option<u8> = None;
429    for condition in &rule.matches {
430        let Some(mask) = possible_method_mask(condition) else {
431            continue;
432        };
433        out = Some(match out {
434            None => mask,
435            Some(existing) => existing & mask,
436        });
437    }
438    out.filter(|m| *m != 0)
439}
440
441fn extract_rule_uri_anchors(rule: &WafRule) -> Vec<UriAnchor> {
442    let mut out = Vec::new();
443    for condition in &rule.matches {
444        if let Some(mut anchors) = implied_uri_anchors(condition) {
445            out.append(&mut anchors);
446        }
447    }
448    out.retain(|a| !a.pattern.is_empty());
449    out
450}
451
452fn implied_uri_anchors(condition: &MatchCondition) -> Option<Vec<UriAnchor>> {
453    match condition.kind.as_str() {
454        "uri" => {
455            uri_anchors_from_uri_match_value(condition.match_value.as_ref(), UriTransform::Raw)
456        }
457        "boolean" => {
458            let op = condition.op.as_deref().unwrap_or("and");
459            let operands = boolean_operands(condition);
460            if operands.is_empty() {
461                return None;
462            }
463            match op {
464                "and" => {
465                    let mut out = Vec::new();
466                    for operand in operands {
467                        if let Some(mut anchors) = implied_uri_anchors(operand) {
468                            out.append(&mut anchors);
469                        }
470                    }
471                    if out.is_empty() {
472                        None
473                    } else {
474                        Some(out)
475                    }
476                }
477                "or" => {
478                    let mut out = Vec::new();
479                    for operand in operands {
480                        let Some(mut anchors) = implied_uri_anchors(operand) else {
481                            return None;
482                        };
483                        out.append(&mut anchors);
484                    }
485                    if out.is_empty() {
486                        None
487                    } else {
488                        Some(out)
489                    }
490                }
491                _ => None,
492            }
493        }
494        _ => None,
495    }
496}
497
498fn uri_anchors_from_uri_match_value(
499    match_value: Option<&MatchValue>,
500    transform: UriTransform,
501) -> Option<Vec<UriAnchor>> {
502    match match_value {
503        Some(MatchValue::Str(s)) => Some(vec![UriAnchor {
504            kind: UriAnchorKind::Contains,
505            transform,
506            pattern: s.clone(),
507        }]),
508        Some(MatchValue::Cond(child)) => uri_anchors_from_uri_match(child, transform),
509        _ => None,
510    }
511}
512
513fn uri_anchors_from_uri_match(
514    condition: &MatchCondition,
515    transform: UriTransform,
516) -> Option<Vec<UriAnchor>> {
517    match condition.kind.as_str() {
518        "contains" => condition
519            .match_value
520            .as_ref()
521            .and_then(|m| m.as_str())
522            .map(|pattern| {
523                vec![UriAnchor {
524                    kind: UriAnchorKind::Contains,
525                    transform,
526                    pattern: pattern.to_string(),
527                }]
528            }),
529        "starts_with" => condition
530            .match_value
531            .as_ref()
532            .and_then(|m| m.as_str())
533            .map(|prefix| {
534                vec![UriAnchor {
535                    kind: UriAnchorKind::Prefix,
536                    transform,
537                    pattern: prefix.to_string(),
538                }]
539            }),
540        "equals" => condition
541            .match_value
542            .as_ref()
543            .and_then(|m| m.as_str())
544            .map(|pattern| {
545                vec![UriAnchor {
546                    kind: UriAnchorKind::Contains,
547                    transform,
548                    pattern: pattern.to_string(),
549                }]
550            }),
551        "to_lowercase" => {
552            let child = condition.match_value.as_ref()?.as_cond()?;
553            uri_anchors_from_uri_match(child, transform.apply_lower())
554        }
555        "percent_decode" => {
556            let child = condition.match_value.as_ref()?.as_cond()?;
557            uri_anchors_from_uri_match(child, transform.apply_percent_decode())
558        }
559        "boolean" => {
560            let op = condition.op.as_deref().unwrap_or("and");
561            let operands = boolean_operands(condition);
562            if operands.is_empty() {
563                return None;
564            }
565            match op {
566                "and" => {
567                    let mut out = Vec::new();
568                    for operand in operands {
569                        if let Some(mut anchors) = uri_anchors_from_uri_match(operand, transform) {
570                            out.append(&mut anchors);
571                        }
572                    }
573                    if out.is_empty() {
574                        None
575                    } else {
576                        Some(out)
577                    }
578                }
579                "or" => {
580                    let mut out = Vec::new();
581                    for operand in operands {
582                        let Some(mut anchors) = uri_anchors_from_uri_match(operand, transform)
583                        else {
584                            return None;
585                        };
586                        out.append(&mut anchors);
587                    }
588                    if out.is_empty() {
589                        None
590                    } else {
591                        Some(out)
592                    }
593                }
594                _ => None,
595            }
596        }
597        _ => None,
598    }
599}
600
601#[derive(Clone, Debug, Default)]
602struct RequirementsSet {
603    features: u16,
604    static_required: Option<bool>,
605    required_headers: HashSet<String>,
606}
607
608fn merge_and_static(a: Option<bool>, b: Option<bool>) -> Option<bool> {
609    match (a, b) {
610        (Some(left), Some(right)) if left == right => Some(left),
611        (None, Some(value)) => Some(value),
612        (Some(value), None) => Some(value),
613        _ => None,
614    }
615}
616
617fn req_and(mut left: RequirementsSet, right: RequirementsSet) -> RequirementsSet {
618    left.features |= right.features;
619    left.required_headers.extend(right.required_headers);
620    left.static_required = merge_and_static(left.static_required, right.static_required);
621    left
622}
623
624fn req_or(left: RequirementsSet, right: RequirementsSet) -> RequirementsSet {
625    let mut out = RequirementsSet::default();
626    out.features = left.features & right.features;
627    out.static_required = match (left.static_required, right.static_required) {
628        (Some(l), Some(r)) if l == r => Some(l),
629        _ => None,
630    };
631    out.required_headers = left
632        .required_headers
633        .intersection(&right.required_headers)
634        .cloned()
635        .collect();
636    out
637}
638
639fn extract_rule_requirements(
640    rule: &WafRule,
641    header_to_bit: &HashMap<String, u8>,
642) -> RuleRequirements {
643    let mut req = RequirementsSet::default();
644    for condition in &rule.matches {
645        req = req_and(req, requirements_for_condition(condition));
646    }
647    let mut required_headers_mask: u64 = 0;
648    for header in req.required_headers {
649        if let Some(bit) = header_to_bit.get(header.as_str()).copied() {
650            if bit < 64 {
651                required_headers_mask |= 1u64 << bit;
652            }
653        }
654    }
655    RuleRequirements {
656        features: req.features,
657        static_required: req.static_required,
658        required_headers_mask,
659    }
660}
661
662fn requirements_for_condition(condition: &MatchCondition) -> RequirementsSet {
663    match condition.kind.as_str() {
664        "boolean" => {
665            let op = condition.op.as_deref().unwrap_or("and");
666            let operands = boolean_operands(condition);
667            if operands.is_empty() {
668                return RequirementsSet::default();
669            }
670            match op {
671                "and" => {
672                    let mut out = RequirementsSet::default();
673                    for operand in operands {
674                        out = req_and(out, requirements_for_condition(operand));
675                    }
676                    out
677                }
678                "or" => {
679                    let mut iter = operands.into_iter();
680                    let mut out = requirements_for_condition(iter.next().unwrap());
681                    for operand in iter {
682                        out = req_or(out, requirements_for_condition(operand));
683                    }
684                    out
685                }
686                _ => RequirementsSet::default(),
687            }
688        }
689        "args" => {
690            let mut out = RequirementsSet {
691                features: REQ_ARGS,
692                ..Default::default()
693            };
694            if let Some(child) = condition.match_value.as_ref().and_then(|m| m.as_cond()) {
695                out = req_and(out, requirements_for_condition(child));
696            }
697            out
698        }
699        "named_argument" | "extract_argument" => {
700            let mut out = RequirementsSet {
701                features: REQ_ARG_ENTRIES,
702                ..Default::default()
703            };
704            if let Some(child) = condition.match_value.as_ref().and_then(|m| m.as_cond()) {
705                out = req_and(out, requirements_for_condition(child));
706            }
707            out
708        }
709        "header" => {
710            let mut out = RequirementsSet::default();
711            if let Some(field) = condition.field.as_deref() {
712                out.required_headers.insert(field.to_ascii_lowercase());
713            }
714            if let Some(child) = condition.match_value.as_ref().and_then(|m| m.as_cond()) {
715                out = req_and(out, requirements_for_condition(child));
716            }
717            out
718        }
719        "request_json" => {
720            let mut out = RequirementsSet {
721                features: REQ_JSON,
722                ..Default::default()
723            };
724            if let Some(child) = condition.match_value.as_ref().and_then(|m| m.as_cond()) {
725                out = req_and(out, requirements_for_condition(child));
726            }
727            out
728        }
729        "response_code" => RequirementsSet {
730            features: REQ_RESPONSE,
731            ..Default::default()
732        },
733        "response" => {
734            let mut out = RequirementsSet {
735                features: REQ_RESPONSE_BODY,
736                ..Default::default()
737            };
738            if let Some(child) = condition.match_value.as_ref().and_then(|m| m.as_cond()) {
739                out = req_and(out, requirements_for_condition(child));
740            }
741            out
742        }
743        "parse_multipart" => {
744            let mut out = RequirementsSet {
745                features: REQ_BODY | REQ_MULTIPART,
746                ..Default::default()
747            };
748            if let Some(child) = condition.match_value.as_ref().and_then(|m| m.as_cond()) {
749                out = req_and(out, requirements_for_condition(child));
750            }
751            out
752        }
753        "static_content" => {
754            let mut out = RequirementsSet::default();
755            if let Some(target) = condition.match_value.as_ref().and_then(|m| m.as_bool()) {
756                out.static_required = Some(target);
757            }
758            out
759        }
760        _ => {
761            let mut out = RequirementsSet::default();
762            if let Some(child) = condition.match_value.as_ref().and_then(|m| m.as_cond()) {
763                out = req_and(out, requirements_for_condition(child));
764            }
765            out
766        }
767    }
768}
769
770fn collect_header_fields(condition: &MatchCondition, out: &mut HashSet<String>) {
771    if condition.kind == "header" {
772        if let Some(field) = condition.field.as_deref() {
773            out.insert(field.to_ascii_lowercase());
774        }
775    }
776
777    if let Some(mv) = condition.match_value.as_ref() {
778        if let Some(child) = mv.as_cond() {
779            collect_header_fields(child, out);
780        } else if let Some(arr) = mv.as_arr() {
781            for item in arr {
782                if let Some(child) = item.as_cond() {
783                    collect_header_fields(child, out);
784                }
785            }
786        }
787    }
788
789    if let Some(selector) = condition.selector.as_ref() {
790        collect_header_fields(selector, out);
791    }
792}
793
794#[cfg(test)]
795mod tests {
796    use super::*;
797    use crate::waf::rule::{MatchCondition, MatchValue, WafRule};
798
799    /// Helper: build a minimal WafRule with a method match condition.
800    fn rule_with_method(id: u32, methods: &[&str]) -> WafRule {
801        let match_value = if methods.len() == 1 {
802            MatchValue::Str(methods[0].to_string())
803        } else {
804            MatchValue::Arr(
805                methods
806                    .iter()
807                    .map(|m| MatchValue::Str(m.to_string()))
808                    .collect(),
809            )
810        };
811
812        WafRule {
813            id,
814            description: format!("rule-{}", id),
815            contributing_score: None,
816            risk: Some(5.0),
817            blocking: None,
818            matches: vec![MatchCondition {
819                kind: "method".to_string(),
820                match_value: Some(match_value),
821                op: None,
822                field: None,
823                direction: None,
824                field_type: None,
825                name: None,
826                selector: None,
827                cleanup_after: None,
828                count: None,
829                timeframe: None,
830            }],
831        }
832    }
833
834    /// Helper: build a WafRule with a URI contains anchor.
835    fn rule_with_uri_contains(id: u32, pattern: &str) -> WafRule {
836        WafRule {
837            id,
838            description: format!("rule-{}", id),
839            contributing_score: None,
840            risk: Some(5.0),
841            blocking: None,
842            matches: vec![MatchCondition {
843                kind: "uri".to_string(),
844                match_value: Some(MatchValue::Cond(Box::new(MatchCondition {
845                    kind: "contains".to_string(),
846                    match_value: Some(MatchValue::Str(pattern.to_string())),
847                    op: None,
848                    field: None,
849                    direction: None,
850                    field_type: None,
851                    name: None,
852                    selector: None,
853                    cleanup_after: None,
854                    count: None,
855                    timeframe: None,
856                }))),
857                op: None,
858                field: None,
859                direction: None,
860                field_type: None,
861                name: None,
862                selector: None,
863                cleanup_after: None,
864                count: None,
865                timeframe: None,
866            }],
867        }
868    }
869
870    fn noop_percent_decode(s: &str) -> String {
871        s.to_string()
872    }
873
874    #[test]
875    fn test_method_to_mask_known_methods() {
876        assert_eq!(method_to_mask("GET"), Some(METHOD_GET));
877        assert_eq!(method_to_mask("POST"), Some(METHOD_POST));
878        assert_eq!(method_to_mask("HEAD"), Some(METHOD_HEAD));
879        assert_eq!(method_to_mask("PUT"), Some(METHOD_PUT));
880        assert_eq!(method_to_mask("PATCH"), Some(METHOD_PATCH));
881    }
882
883    #[test]
884    fn test_method_to_mask_case_insensitive() {
885        assert_eq!(method_to_mask("get"), Some(METHOD_GET));
886        assert_eq!(method_to_mask("Post"), Some(METHOD_POST));
887    }
888
889    #[test]
890    fn test_method_to_mask_unknown_returns_none() {
891        assert_eq!(method_to_mask("DELETE"), None);
892        assert_eq!(method_to_mask("OPTIONS"), None);
893        assert_eq!(method_to_mask("CONNECT"), None);
894    }
895
896    #[test]
897    fn test_build_rule_index_method_filtering() {
898        let rules = vec![
899            rule_with_method(1, &["GET"]),
900            rule_with_method(2, &["POST"]),
901            rule_with_method(3, &["GET", "POST"]),
902        ];
903
904        let index = build_rule_index(&rules);
905        assert_eq!(index.rules.len(), 3);
906
907        // Rule 0 (GET only)
908        assert_eq!(index.rules[0].method_mask, Some(METHOD_GET));
909        // Rule 1 (POST only)
910        assert_eq!(index.rules[1].method_mask, Some(METHOD_POST));
911        // Rule 2 (GET | POST)
912        assert_eq!(index.rules[2].method_mask, Some(METHOD_GET | METHOD_POST));
913    }
914
915    #[test]
916    fn test_get_candidates_get_method_returns_only_get_rules() {
917        let rules = vec![
918            rule_with_method(1, &["GET"]),
919            rule_with_method(2, &["POST"]),
920            rule_with_method(3, &["GET", "POST"]),
921        ];
922
923        let index = build_rule_index(&rules);
924
925        let candidates = get_candidate_rule_indices(
926            &index,
927            METHOD_GET,
928            "/any-path",
929            0,     // no feature requirements
930            false, // not static
931            0,     // no header mask
932            rules.len(),
933            noop_percent_decode,
934        );
935
936        // Should include rule 0 (GET) and rule 2 (GET|POST), but NOT rule 1 (POST)
937        assert!(candidates.contains(&0), "GET rule should be a candidate");
938        assert!(
939            !candidates.contains(&1),
940            "POST-only rule should NOT be a candidate for GET"
941        );
942        assert!(
943            candidates.contains(&2),
944            "GET|POST rule should be a candidate for GET"
945        );
946    }
947
948    #[test]
949    fn test_get_candidates_post_method_returns_only_post_rules() {
950        let rules = vec![
951            rule_with_method(1, &["GET"]),
952            rule_with_method(2, &["POST"]),
953            rule_with_method(3, &["GET", "POST"]),
954        ];
955
956        let index = build_rule_index(&rules);
957
958        let candidates = get_candidate_rule_indices(
959            &index,
960            METHOD_POST,
961            "/any-path",
962            0,
963            false,
964            0,
965            rules.len(),
966            noop_percent_decode,
967        );
968
969        assert!(
970            !candidates.contains(&0),
971            "GET-only rule should NOT be a candidate for POST"
972        );
973        assert!(candidates.contains(&1), "POST rule should be a candidate");
974        assert!(
975            candidates.contains(&2),
976            "GET|POST rule should be a candidate for POST"
977        );
978    }
979
980    #[test]
981    fn test_get_candidates_uri_anchor_filtering() {
982        let rules = vec![
983            rule_with_uri_contains(1, "/admin"),
984            rule_with_uri_contains(2, "/api"),
985        ];
986
987        let index = build_rule_index(&rules);
988
989        // Request to /admin/dashboard
990        let candidates = get_candidate_rule_indices(
991            &index,
992            0, // no method filter (unknown method)
993            "/admin/dashboard",
994            0,
995            false,
996            0,
997            rules.len(),
998            noop_percent_decode,
999        );
1000        assert!(
1001            candidates.contains(&0),
1002            "/admin rule should match /admin/dashboard"
1003        );
1004        assert!(
1005            !candidates.contains(&1),
1006            "/api rule should NOT match /admin/dashboard"
1007        );
1008
1009        // Request to /api/v1/users
1010        let candidates = get_candidate_rule_indices(
1011            &index,
1012            0,
1013            "/api/v1/users",
1014            0,
1015            false,
1016            0,
1017            rules.len(),
1018            noop_percent_decode,
1019        );
1020        assert!(
1021            !candidates.contains(&0),
1022            "/admin rule should NOT match /api/v1/users"
1023        );
1024        assert!(
1025            candidates.contains(&1),
1026            "/api rule should match /api/v1/users"
1027        );
1028    }
1029
1030    #[test]
1031    fn test_get_candidates_no_method_constraint_matches_all() {
1032        // A rule without method constraint should match any request method
1033        let rules = vec![rule_with_uri_contains(1, "/health")];
1034
1035        let index = build_rule_index(&rules);
1036        // Method mask is None for the rule (no method condition)
1037        assert!(index.rules[0].method_mask.is_none());
1038
1039        let candidates = get_candidate_rule_indices(
1040            &index,
1041            METHOD_GET,
1042            "/health",
1043            0,
1044            false,
1045            0,
1046            rules.len(),
1047            noop_percent_decode,
1048        );
1049        assert!(
1050            candidates.contains(&0),
1051            "rule without method constraint should match GET"
1052        );
1053
1054        let candidates = get_candidate_rule_indices(
1055            &index,
1056            METHOD_POST,
1057            "/health",
1058            0,
1059            false,
1060            0,
1061            rules.len(),
1062            noop_percent_decode,
1063        );
1064        assert!(
1065            candidates.contains(&0),
1066            "rule without method constraint should match POST"
1067        );
1068    }
1069
1070    #[test]
1071    fn test_candidate_cache_insert_and_get() {
1072        let mut cache = CandidateCache::new(10);
1073        let key = CandidateCacheKey {
1074            method_bit: METHOD_GET,
1075            available_features: 0,
1076            is_static: false,
1077            header_mask: 0,
1078        };
1079        let candidates: Arc<[usize]> = Arc::from(vec![0, 2, 5].as_slice());
1080        cache.insert(key, "/test".to_string(), candidates.clone());
1081
1082        let result = cache.get(&key, "/test");
1083        assert!(result.is_some());
1084        assert_eq!(result.unwrap().as_ref(), &[0, 2, 5]);
1085    }
1086
1087    #[test]
1088    fn test_candidate_cache_eviction() {
1089        let mut cache = CandidateCache::new(2);
1090        let key = CandidateCacheKey {
1091            method_bit: METHOD_GET,
1092            available_features: 0,
1093            is_static: false,
1094            header_mask: 0,
1095        };
1096
1097        cache.insert(key, "/a".to_string(), Arc::from(vec![0].as_slice()));
1098        cache.insert(key, "/b".to_string(), Arc::from(vec![1].as_slice()));
1099        cache.insert(key, "/c".to_string(), Arc::from(vec![2].as_slice()));
1100
1101        // Cache has capacity 2, so /a should have been evicted
1102        assert!(cache.get(&key, "/a").is_none());
1103        // /b and /c should still exist
1104        assert!(cache.get(&key, "/b").is_some());
1105        assert!(cache.get(&key, "/c").is_some());
1106    }
1107}