1use std::collections::{HashMap, HashSet};
4use std::sync::Arc;
5
6use crate::waf::rule::{boolean_operands, MatchCondition, MatchValue, WafRule};
7
8pub const METHOD_GET: u8 = 1 << 0;
10pub const METHOD_POST: u8 = 1 << 1;
11pub const METHOD_HEAD: u8 = 1 << 2;
12pub const METHOD_PUT: u8 = 1 << 3;
13pub const METHOD_PATCH: u8 = 1 << 4;
14
15pub const REQ_ARGS: u16 = 1 << 0;
17pub const REQ_ARG_ENTRIES: u16 = 1 << 1;
18pub const REQ_BODY: u16 = 1 << 2;
19pub const REQ_JSON: u16 = 1 << 3;
20pub const REQ_RESPONSE: u16 = 1 << 4;
21pub const REQ_RESPONSE_BODY: u16 = 1 << 5;
22pub const REQ_MULTIPART: u16 = 1 << 6;
23
24#[derive(Default)]
26pub struct RuleIndex {
27 pub header_bits: Vec<String>,
28 pub rules: Vec<IndexedRule>,
29}
30
31#[derive(Clone, Debug, Default)]
33pub struct IndexedRule {
34 pub method_mask: Option<u8>,
35 pub uri_anchors: Vec<UriAnchor>,
36 pub requirements: RuleRequirements,
37}
38
39#[derive(Clone, Debug, Eq, PartialEq, Hash)]
41pub struct UriAnchor {
42 pub kind: UriAnchorKind,
43 pub transform: UriTransform,
44 pub pattern: String,
45}
46
47#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
48pub enum UriAnchorKind {
49 Contains,
50 Prefix,
51}
52
53#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
54pub enum UriTransform {
55 Raw,
56 Lower,
57 PercentDecoded,
58 PercentDecodedLower,
59}
60
61impl UriTransform {
62 pub fn apply_lower(self) -> Self {
63 match self {
64 UriTransform::Raw => UriTransform::Lower,
65 UriTransform::Lower => UriTransform::Lower,
66 UriTransform::PercentDecoded => UriTransform::PercentDecodedLower,
67 UriTransform::PercentDecodedLower => UriTransform::PercentDecodedLower,
68 }
69 }
70
71 pub fn apply_percent_decode(self) -> Self {
72 match self {
73 UriTransform::Raw => UriTransform::PercentDecoded,
74 UriTransform::Lower => UriTransform::PercentDecodedLower,
75 UriTransform::PercentDecoded => UriTransform::PercentDecoded,
76 UriTransform::PercentDecodedLower => UriTransform::PercentDecodedLower,
77 }
78 }
79}
80
81#[derive(Clone, Debug, Default)]
83pub struct RuleRequirements {
84 pub features: u16,
85 pub static_required: Option<bool>,
86 pub required_headers_mask: u64,
87}
88
89#[derive(Default)]
91pub struct CandidateCache {
92 max_entries: usize,
93 tick: u64,
94 len: usize,
95 by_key: HashMap<CandidateCacheKey, HashMap<String, CandidateCacheEntry>>,
96}
97
98#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
99pub struct CandidateCacheKey {
100 pub method_bit: u8,
101 pub available_features: u16,
102 pub is_static: bool,
103 pub header_mask: u64,
104}
105
106#[derive(Clone, Debug)]
107struct CandidateCacheEntry {
108 candidates: Arc<[usize]>,
109 last_used: u64,
110}
111
112impl CandidateCache {
113 pub fn new(max_entries: usize) -> Self {
114 Self {
115 max_entries: max_entries.min(65_536),
116 ..Default::default()
117 }
118 }
119
120 pub fn clear(&mut self) {
121 self.by_key.clear();
122 self.len = 0;
123 self.tick = 0;
124 }
125
126 pub fn get(&mut self, key: &CandidateCacheKey, uri: &str) -> Option<Arc<[usize]>> {
127 if self.max_entries == 0 {
128 return None;
129 }
130 self.tick = self.tick.wrapping_add(1);
131 let inner = self.by_key.get_mut(key)?;
132 let entry = inner.get_mut(uri)?;
133 entry.last_used = self.tick;
134 Some(entry.candidates.clone())
135 }
136
137 pub fn insert(&mut self, key: CandidateCacheKey, uri: String, candidates: Arc<[usize]>) {
138 if self.max_entries == 0 {
139 return;
140 }
141 self.tick = self.tick.wrapping_add(1);
142 let inner = self.by_key.entry(key).or_default();
143 if let Some(existing) = inner.get_mut(uri.as_str()) {
144 existing.candidates = candidates;
145 existing.last_used = self.tick;
146 return;
147 }
148 inner.insert(
149 uri,
150 CandidateCacheEntry {
151 candidates,
152 last_used: self.tick,
153 },
154 );
155 self.len += 1;
156 self.evict_if_needed();
157 }
158
159 fn evict_if_needed(&mut self) {
160 while self.len > self.max_entries {
161 let mut oldest_key: Option<CandidateCacheKey> = None;
162 let mut oldest_uri: Option<String> = None;
163 let mut oldest_tick = u64::MAX;
164
165 for (key, inner) in &self.by_key {
166 for (uri, entry) in inner {
167 if entry.last_used < oldest_tick {
168 oldest_tick = entry.last_used;
169 oldest_key = Some(*key);
170 oldest_uri = Some(uri.clone());
171 }
172 }
173 }
174
175 let (Some(key), Some(uri)) = (oldest_key, oldest_uri) else {
176 break;
177 };
178
179 if let Some(inner) = self.by_key.get_mut(&key) {
180 if inner.remove(uri.as_str()).is_some() {
181 self.len = self.len.saturating_sub(1);
182 }
183 if inner.is_empty() {
184 self.by_key.remove(&key);
185 }
186 }
187 }
188 }
189}
190
191pub fn method_to_mask(method: &str) -> Option<u8> {
193 if method.eq_ignore_ascii_case("GET") {
194 return Some(METHOD_GET);
195 }
196 if method.eq_ignore_ascii_case("POST") {
197 return Some(METHOD_POST);
198 }
199 if method.eq_ignore_ascii_case("HEAD") {
200 return Some(METHOD_HEAD);
201 }
202 if method.eq_ignore_ascii_case("PUT") {
203 return Some(METHOD_PUT);
204 }
205 if method.eq_ignore_ascii_case("PATCH") {
206 return Some(METHOD_PATCH);
207 }
208 None
209}
210
211pub fn build_rule_index(rules: &[WafRule]) -> RuleIndex {
213 let mut index = RuleIndex::default();
214
215 let mut header_names = HashSet::<String>::new();
217 for rule in rules {
218 for cond in &rule.matches {
219 collect_header_fields(cond, &mut header_names);
220 }
221 }
222
223 let mut header_bits: Vec<String> = header_names.into_iter().collect();
224 header_bits.sort();
225 if header_bits.len() > 64 {
226 header_bits.truncate(64);
227 }
228
229 let header_to_bit: HashMap<String, u8> = header_bits
230 .iter()
231 .enumerate()
232 .map(|(idx, header)| (header.clone(), idx as u8))
233 .collect();
234
235 index.header_bits = header_bits;
236 index.rules.reserve(rules.len());
237
238 for rule in rules {
239 let method_mask = extract_rule_method_mask(rule);
240 let mut uri_anchors = extract_rule_uri_anchors(rule);
241 if !uri_anchors.is_empty() {
242 let mut seen = HashSet::new();
243 uri_anchors.retain(|a| seen.insert(a.clone()));
244 }
245 let requirements = extract_rule_requirements(rule, &header_to_bit);
246 index.rules.push(IndexedRule {
247 method_mask,
248 uri_anchors,
249 requirements,
250 });
251 }
252
253 index
254}
255
256pub fn get_candidate_rule_indices(
258 index: &RuleIndex,
259 method_bit: u8,
260 uri: &str,
261 available_features: u16,
262 is_static: bool,
263 header_mask: u64,
264 rule_count: usize,
265 percent_decode: impl Fn(&str) -> String,
266) -> Vec<usize> {
267 let mut out = Vec::new();
268 let req_method_mask = if method_bit == 0 {
269 None
270 } else {
271 Some(method_bit)
272 };
273
274 let mut uri_lower: Option<String> = None;
275 let mut uri_percent_decoded: Option<String> = None;
276 let mut uri_percent_decoded_lower: Option<String> = None;
277
278 let count = rule_count.min(index.rules.len());
279 for (idx, rule) in index.rules.iter().enumerate().take(count) {
280 if let Some(rule_method_mask) = rule.method_mask {
282 let Some(req_method_mask) = req_method_mask else {
283 continue;
284 };
285 if (rule_method_mask & req_method_mask) == 0 {
286 continue;
287 }
288 }
289
290 let requirements = &rule.requirements;
292 if (requirements.features & !available_features) != 0 {
293 continue;
294 }
295 if requirements.static_required == Some(true) && !is_static {
296 continue;
297 }
298 if requirements.static_required == Some(false) && is_static {
299 continue;
300 }
301 if (requirements.required_headers_mask & !header_mask) != 0 {
302 continue;
303 }
304
305 if !rule.uri_anchors.is_empty() {
307 let mut matched = false;
308 for anchor in &rule.uri_anchors {
309 let haystack: &str = match anchor.transform {
310 UriTransform::Raw => uri,
311 UriTransform::Lower => {
312 if uri_lower.is_none() {
313 uri_lower = Some(uri.to_lowercase());
314 }
315 uri_lower.as_deref().unwrap_or(uri)
316 }
317 UriTransform::PercentDecoded => {
318 if uri_percent_decoded.is_none() {
319 uri_percent_decoded = Some(percent_decode(uri));
320 }
321 uri_percent_decoded.as_deref().unwrap_or(uri)
322 }
323 UriTransform::PercentDecodedLower => {
324 if uri_percent_decoded_lower.is_none() {
325 if uri_percent_decoded.is_none() {
326 uri_percent_decoded = Some(percent_decode(uri));
327 }
328 uri_percent_decoded_lower =
329 Some(uri_percent_decoded.as_deref().unwrap_or(uri).to_lowercase());
330 }
331 uri_percent_decoded_lower.as_deref().unwrap_or(uri)
332 }
333 };
334
335 matched = match anchor.kind {
336 UriAnchorKind::Contains => haystack.contains(anchor.pattern.as_str()),
337 UriAnchorKind::Prefix => haystack.starts_with(anchor.pattern.as_str()),
338 };
339 if matched {
340 break;
341 }
342 }
343
344 if !matched {
345 continue;
346 }
347 }
348
349 out.push(idx);
350 }
351
352 out
353}
354
355fn method_mask_from_match_value(match_value: &MatchValue) -> Option<u8> {
358 match match_value {
359 MatchValue::Str(s) => method_to_mask(s),
360 MatchValue::Arr(items) => {
361 let mut mask = 0u8;
362 for item in items {
363 let Some(s) = item.as_str() else { continue };
364 let Some(bit) = method_to_mask(s) else {
365 return None;
366 };
367 mask |= bit;
368 }
369 if mask == 0 {
370 None
371 } else {
372 Some(mask)
373 }
374 }
375 _ => None,
376 }
377}
378
379fn possible_method_mask(condition: &MatchCondition) -> Option<u8> {
380 match condition.kind.as_str() {
381 "method" => condition
382 .match_value
383 .as_ref()
384 .and_then(method_mask_from_match_value),
385 "boolean" => {
386 let op = condition.op.as_deref().unwrap_or("and");
387 let operands = boolean_operands(condition);
388 if operands.is_empty() {
389 return None;
390 }
391
392 match op {
393 "and" => {
394 let mut out: Option<u8> = None;
395 for operand in operands {
396 let Some(mask) = possible_method_mask(operand) else {
397 continue;
398 };
399 out = Some(match out {
400 None => mask,
401 Some(existing) => existing & mask,
402 });
403 }
404 out.filter(|m| *m != 0)
405 }
406 "or" => {
407 let mut mask = 0u8;
408 for operand in operands {
409 let Some(child_mask) = possible_method_mask(operand) else {
410 return None;
411 };
412 mask |= child_mask;
413 }
414 if mask == 0 {
415 None
416 } else {
417 Some(mask)
418 }
419 }
420 _ => None,
421 }
422 }
423 _ => None,
424 }
425}
426
427fn extract_rule_method_mask(rule: &WafRule) -> Option<u8> {
428 let mut out: Option<u8> = None;
429 for condition in &rule.matches {
430 let Some(mask) = possible_method_mask(condition) else {
431 continue;
432 };
433 out = Some(match out {
434 None => mask,
435 Some(existing) => existing & mask,
436 });
437 }
438 out.filter(|m| *m != 0)
439}
440
441fn extract_rule_uri_anchors(rule: &WafRule) -> Vec<UriAnchor> {
442 let mut out = Vec::new();
443 for condition in &rule.matches {
444 if let Some(mut anchors) = implied_uri_anchors(condition) {
445 out.append(&mut anchors);
446 }
447 }
448 out.retain(|a| !a.pattern.is_empty());
449 out
450}
451
452fn implied_uri_anchors(condition: &MatchCondition) -> Option<Vec<UriAnchor>> {
453 match condition.kind.as_str() {
454 "uri" => {
455 uri_anchors_from_uri_match_value(condition.match_value.as_ref(), UriTransform::Raw)
456 }
457 "boolean" => {
458 let op = condition.op.as_deref().unwrap_or("and");
459 let operands = boolean_operands(condition);
460 if operands.is_empty() {
461 return None;
462 }
463 match op {
464 "and" => {
465 let mut out = Vec::new();
466 for operand in operands {
467 if let Some(mut anchors) = implied_uri_anchors(operand) {
468 out.append(&mut anchors);
469 }
470 }
471 if out.is_empty() {
472 None
473 } else {
474 Some(out)
475 }
476 }
477 "or" => {
478 let mut out = Vec::new();
479 for operand in operands {
480 let Some(mut anchors) = implied_uri_anchors(operand) else {
481 return None;
482 };
483 out.append(&mut anchors);
484 }
485 if out.is_empty() {
486 None
487 } else {
488 Some(out)
489 }
490 }
491 _ => None,
492 }
493 }
494 _ => None,
495 }
496}
497
498fn uri_anchors_from_uri_match_value(
499 match_value: Option<&MatchValue>,
500 transform: UriTransform,
501) -> Option<Vec<UriAnchor>> {
502 match match_value {
503 Some(MatchValue::Str(s)) => Some(vec![UriAnchor {
504 kind: UriAnchorKind::Contains,
505 transform,
506 pattern: s.clone(),
507 }]),
508 Some(MatchValue::Cond(child)) => uri_anchors_from_uri_match(child, transform),
509 _ => None,
510 }
511}
512
513fn uri_anchors_from_uri_match(
514 condition: &MatchCondition,
515 transform: UriTransform,
516) -> Option<Vec<UriAnchor>> {
517 match condition.kind.as_str() {
518 "contains" => condition
519 .match_value
520 .as_ref()
521 .and_then(|m| m.as_str())
522 .map(|pattern| {
523 vec![UriAnchor {
524 kind: UriAnchorKind::Contains,
525 transform,
526 pattern: pattern.to_string(),
527 }]
528 }),
529 "starts_with" => condition
530 .match_value
531 .as_ref()
532 .and_then(|m| m.as_str())
533 .map(|prefix| {
534 vec![UriAnchor {
535 kind: UriAnchorKind::Prefix,
536 transform,
537 pattern: prefix.to_string(),
538 }]
539 }),
540 "equals" => condition
541 .match_value
542 .as_ref()
543 .and_then(|m| m.as_str())
544 .map(|pattern| {
545 vec![UriAnchor {
546 kind: UriAnchorKind::Contains,
547 transform,
548 pattern: pattern.to_string(),
549 }]
550 }),
551 "to_lowercase" => {
552 let child = condition.match_value.as_ref()?.as_cond()?;
553 uri_anchors_from_uri_match(child, transform.apply_lower())
554 }
555 "percent_decode" => {
556 let child = condition.match_value.as_ref()?.as_cond()?;
557 uri_anchors_from_uri_match(child, transform.apply_percent_decode())
558 }
559 "boolean" => {
560 let op = condition.op.as_deref().unwrap_or("and");
561 let operands = boolean_operands(condition);
562 if operands.is_empty() {
563 return None;
564 }
565 match op {
566 "and" => {
567 let mut out = Vec::new();
568 for operand in operands {
569 if let Some(mut anchors) = uri_anchors_from_uri_match(operand, transform) {
570 out.append(&mut anchors);
571 }
572 }
573 if out.is_empty() {
574 None
575 } else {
576 Some(out)
577 }
578 }
579 "or" => {
580 let mut out = Vec::new();
581 for operand in operands {
582 let Some(mut anchors) = uri_anchors_from_uri_match(operand, transform)
583 else {
584 return None;
585 };
586 out.append(&mut anchors);
587 }
588 if out.is_empty() {
589 None
590 } else {
591 Some(out)
592 }
593 }
594 _ => None,
595 }
596 }
597 _ => None,
598 }
599}
600
601#[derive(Clone, Debug, Default)]
602struct RequirementsSet {
603 features: u16,
604 static_required: Option<bool>,
605 required_headers: HashSet<String>,
606}
607
608fn merge_and_static(a: Option<bool>, b: Option<bool>) -> Option<bool> {
609 match (a, b) {
610 (Some(left), Some(right)) if left == right => Some(left),
611 (None, Some(value)) => Some(value),
612 (Some(value), None) => Some(value),
613 _ => None,
614 }
615}
616
617fn req_and(mut left: RequirementsSet, right: RequirementsSet) -> RequirementsSet {
618 left.features |= right.features;
619 left.required_headers.extend(right.required_headers);
620 left.static_required = merge_and_static(left.static_required, right.static_required);
621 left
622}
623
624fn req_or(left: RequirementsSet, right: RequirementsSet) -> RequirementsSet {
625 let mut out = RequirementsSet::default();
626 out.features = left.features & right.features;
627 out.static_required = match (left.static_required, right.static_required) {
628 (Some(l), Some(r)) if l == r => Some(l),
629 _ => None,
630 };
631 out.required_headers = left
632 .required_headers
633 .intersection(&right.required_headers)
634 .cloned()
635 .collect();
636 out
637}
638
639fn extract_rule_requirements(
640 rule: &WafRule,
641 header_to_bit: &HashMap<String, u8>,
642) -> RuleRequirements {
643 let mut req = RequirementsSet::default();
644 for condition in &rule.matches {
645 req = req_and(req, requirements_for_condition(condition));
646 }
647 let mut required_headers_mask: u64 = 0;
648 for header in req.required_headers {
649 if let Some(bit) = header_to_bit.get(header.as_str()).copied() {
650 if bit < 64 {
651 required_headers_mask |= 1u64 << bit;
652 }
653 }
654 }
655 RuleRequirements {
656 features: req.features,
657 static_required: req.static_required,
658 required_headers_mask,
659 }
660}
661
662fn requirements_for_condition(condition: &MatchCondition) -> RequirementsSet {
663 match condition.kind.as_str() {
664 "boolean" => {
665 let op = condition.op.as_deref().unwrap_or("and");
666 let operands = boolean_operands(condition);
667 if operands.is_empty() {
668 return RequirementsSet::default();
669 }
670 match op {
671 "and" => {
672 let mut out = RequirementsSet::default();
673 for operand in operands {
674 out = req_and(out, requirements_for_condition(operand));
675 }
676 out
677 }
678 "or" => {
679 let mut iter = operands.into_iter();
680 let mut out = requirements_for_condition(iter.next().unwrap());
681 for operand in iter {
682 out = req_or(out, requirements_for_condition(operand));
683 }
684 out
685 }
686 _ => RequirementsSet::default(),
687 }
688 }
689 "args" => {
690 let mut out = RequirementsSet {
691 features: REQ_ARGS,
692 ..Default::default()
693 };
694 if let Some(child) = condition.match_value.as_ref().and_then(|m| m.as_cond()) {
695 out = req_and(out, requirements_for_condition(child));
696 }
697 out
698 }
699 "named_argument" | "extract_argument" => {
700 let mut out = RequirementsSet {
701 features: REQ_ARG_ENTRIES,
702 ..Default::default()
703 };
704 if let Some(child) = condition.match_value.as_ref().and_then(|m| m.as_cond()) {
705 out = req_and(out, requirements_for_condition(child));
706 }
707 out
708 }
709 "header" => {
710 let mut out = RequirementsSet::default();
711 if let Some(field) = condition.field.as_deref() {
712 out.required_headers.insert(field.to_ascii_lowercase());
713 }
714 if let Some(child) = condition.match_value.as_ref().and_then(|m| m.as_cond()) {
715 out = req_and(out, requirements_for_condition(child));
716 }
717 out
718 }
719 "request_json" => {
720 let mut out = RequirementsSet {
721 features: REQ_JSON,
722 ..Default::default()
723 };
724 if let Some(child) = condition.match_value.as_ref().and_then(|m| m.as_cond()) {
725 out = req_and(out, requirements_for_condition(child));
726 }
727 out
728 }
729 "response_code" => RequirementsSet {
730 features: REQ_RESPONSE,
731 ..Default::default()
732 },
733 "response" => {
734 let mut out = RequirementsSet {
735 features: REQ_RESPONSE_BODY,
736 ..Default::default()
737 };
738 if let Some(child) = condition.match_value.as_ref().and_then(|m| m.as_cond()) {
739 out = req_and(out, requirements_for_condition(child));
740 }
741 out
742 }
743 "parse_multipart" => {
744 let mut out = RequirementsSet {
745 features: REQ_BODY | REQ_MULTIPART,
746 ..Default::default()
747 };
748 if let Some(child) = condition.match_value.as_ref().and_then(|m| m.as_cond()) {
749 out = req_and(out, requirements_for_condition(child));
750 }
751 out
752 }
753 "static_content" => {
754 let mut out = RequirementsSet::default();
755 if let Some(target) = condition.match_value.as_ref().and_then(|m| m.as_bool()) {
756 out.static_required = Some(target);
757 }
758 out
759 }
760 _ => {
761 let mut out = RequirementsSet::default();
762 if let Some(child) = condition.match_value.as_ref().and_then(|m| m.as_cond()) {
763 out = req_and(out, requirements_for_condition(child));
764 }
765 out
766 }
767 }
768}
769
770fn collect_header_fields(condition: &MatchCondition, out: &mut HashSet<String>) {
771 if condition.kind == "header" {
772 if let Some(field) = condition.field.as_deref() {
773 out.insert(field.to_ascii_lowercase());
774 }
775 }
776
777 if let Some(mv) = condition.match_value.as_ref() {
778 if let Some(child) = mv.as_cond() {
779 collect_header_fields(child, out);
780 } else if let Some(arr) = mv.as_arr() {
781 for item in arr {
782 if let Some(child) = item.as_cond() {
783 collect_header_fields(child, out);
784 }
785 }
786 }
787 }
788
789 if let Some(selector) = condition.selector.as_ref() {
790 collect_header_fields(selector, out);
791 }
792}
793
794#[cfg(test)]
795mod tests {
796 use super::*;
797 use crate::waf::rule::{MatchCondition, MatchValue, WafRule};
798
799 fn rule_with_method(id: u32, methods: &[&str]) -> WafRule {
801 let match_value = if methods.len() == 1 {
802 MatchValue::Str(methods[0].to_string())
803 } else {
804 MatchValue::Arr(
805 methods
806 .iter()
807 .map(|m| MatchValue::Str(m.to_string()))
808 .collect(),
809 )
810 };
811
812 WafRule {
813 id,
814 description: format!("rule-{}", id),
815 contributing_score: None,
816 risk: Some(5.0),
817 blocking: None,
818 matches: vec![MatchCondition {
819 kind: "method".to_string(),
820 match_value: Some(match_value),
821 op: None,
822 field: None,
823 direction: None,
824 field_type: None,
825 name: None,
826 selector: None,
827 cleanup_after: None,
828 count: None,
829 timeframe: None,
830 }],
831 }
832 }
833
834 fn rule_with_uri_contains(id: u32, pattern: &str) -> WafRule {
836 WafRule {
837 id,
838 description: format!("rule-{}", id),
839 contributing_score: None,
840 risk: Some(5.0),
841 blocking: None,
842 matches: vec![MatchCondition {
843 kind: "uri".to_string(),
844 match_value: Some(MatchValue::Cond(Box::new(MatchCondition {
845 kind: "contains".to_string(),
846 match_value: Some(MatchValue::Str(pattern.to_string())),
847 op: None,
848 field: None,
849 direction: None,
850 field_type: None,
851 name: None,
852 selector: None,
853 cleanup_after: None,
854 count: None,
855 timeframe: None,
856 }))),
857 op: None,
858 field: None,
859 direction: None,
860 field_type: None,
861 name: None,
862 selector: None,
863 cleanup_after: None,
864 count: None,
865 timeframe: None,
866 }],
867 }
868 }
869
870 fn noop_percent_decode(s: &str) -> String {
871 s.to_string()
872 }
873
874 #[test]
875 fn test_method_to_mask_known_methods() {
876 assert_eq!(method_to_mask("GET"), Some(METHOD_GET));
877 assert_eq!(method_to_mask("POST"), Some(METHOD_POST));
878 assert_eq!(method_to_mask("HEAD"), Some(METHOD_HEAD));
879 assert_eq!(method_to_mask("PUT"), Some(METHOD_PUT));
880 assert_eq!(method_to_mask("PATCH"), Some(METHOD_PATCH));
881 }
882
883 #[test]
884 fn test_method_to_mask_case_insensitive() {
885 assert_eq!(method_to_mask("get"), Some(METHOD_GET));
886 assert_eq!(method_to_mask("Post"), Some(METHOD_POST));
887 }
888
889 #[test]
890 fn test_method_to_mask_unknown_returns_none() {
891 assert_eq!(method_to_mask("DELETE"), None);
892 assert_eq!(method_to_mask("OPTIONS"), None);
893 assert_eq!(method_to_mask("CONNECT"), None);
894 }
895
896 #[test]
897 fn test_build_rule_index_method_filtering() {
898 let rules = vec![
899 rule_with_method(1, &["GET"]),
900 rule_with_method(2, &["POST"]),
901 rule_with_method(3, &["GET", "POST"]),
902 ];
903
904 let index = build_rule_index(&rules);
905 assert_eq!(index.rules.len(), 3);
906
907 assert_eq!(index.rules[0].method_mask, Some(METHOD_GET));
909 assert_eq!(index.rules[1].method_mask, Some(METHOD_POST));
911 assert_eq!(index.rules[2].method_mask, Some(METHOD_GET | METHOD_POST));
913 }
914
915 #[test]
916 fn test_get_candidates_get_method_returns_only_get_rules() {
917 let rules = vec![
918 rule_with_method(1, &["GET"]),
919 rule_with_method(2, &["POST"]),
920 rule_with_method(3, &["GET", "POST"]),
921 ];
922
923 let index = build_rule_index(&rules);
924
925 let candidates = get_candidate_rule_indices(
926 &index,
927 METHOD_GET,
928 "/any-path",
929 0, false, 0, rules.len(),
933 noop_percent_decode,
934 );
935
936 assert!(candidates.contains(&0), "GET rule should be a candidate");
938 assert!(
939 !candidates.contains(&1),
940 "POST-only rule should NOT be a candidate for GET"
941 );
942 assert!(
943 candidates.contains(&2),
944 "GET|POST rule should be a candidate for GET"
945 );
946 }
947
948 #[test]
949 fn test_get_candidates_post_method_returns_only_post_rules() {
950 let rules = vec![
951 rule_with_method(1, &["GET"]),
952 rule_with_method(2, &["POST"]),
953 rule_with_method(3, &["GET", "POST"]),
954 ];
955
956 let index = build_rule_index(&rules);
957
958 let candidates = get_candidate_rule_indices(
959 &index,
960 METHOD_POST,
961 "/any-path",
962 0,
963 false,
964 0,
965 rules.len(),
966 noop_percent_decode,
967 );
968
969 assert!(
970 !candidates.contains(&0),
971 "GET-only rule should NOT be a candidate for POST"
972 );
973 assert!(candidates.contains(&1), "POST rule should be a candidate");
974 assert!(
975 candidates.contains(&2),
976 "GET|POST rule should be a candidate for POST"
977 );
978 }
979
980 #[test]
981 fn test_get_candidates_uri_anchor_filtering() {
982 let rules = vec![
983 rule_with_uri_contains(1, "/admin"),
984 rule_with_uri_contains(2, "/api"),
985 ];
986
987 let index = build_rule_index(&rules);
988
989 let candidates = get_candidate_rule_indices(
991 &index,
992 0, "/admin/dashboard",
994 0,
995 false,
996 0,
997 rules.len(),
998 noop_percent_decode,
999 );
1000 assert!(
1001 candidates.contains(&0),
1002 "/admin rule should match /admin/dashboard"
1003 );
1004 assert!(
1005 !candidates.contains(&1),
1006 "/api rule should NOT match /admin/dashboard"
1007 );
1008
1009 let candidates = get_candidate_rule_indices(
1011 &index,
1012 0,
1013 "/api/v1/users",
1014 0,
1015 false,
1016 0,
1017 rules.len(),
1018 noop_percent_decode,
1019 );
1020 assert!(
1021 !candidates.contains(&0),
1022 "/admin rule should NOT match /api/v1/users"
1023 );
1024 assert!(
1025 candidates.contains(&1),
1026 "/api rule should match /api/v1/users"
1027 );
1028 }
1029
1030 #[test]
1031 fn test_get_candidates_no_method_constraint_matches_all() {
1032 let rules = vec![rule_with_uri_contains(1, "/health")];
1034
1035 let index = build_rule_index(&rules);
1036 assert!(index.rules[0].method_mask.is_none());
1038
1039 let candidates = get_candidate_rule_indices(
1040 &index,
1041 METHOD_GET,
1042 "/health",
1043 0,
1044 false,
1045 0,
1046 rules.len(),
1047 noop_percent_decode,
1048 );
1049 assert!(
1050 candidates.contains(&0),
1051 "rule without method constraint should match GET"
1052 );
1053
1054 let candidates = get_candidate_rule_indices(
1055 &index,
1056 METHOD_POST,
1057 "/health",
1058 0,
1059 false,
1060 0,
1061 rules.len(),
1062 noop_percent_decode,
1063 );
1064 assert!(
1065 candidates.contains(&0),
1066 "rule without method constraint should match POST"
1067 );
1068 }
1069
1070 #[test]
1071 fn test_candidate_cache_insert_and_get() {
1072 let mut cache = CandidateCache::new(10);
1073 let key = CandidateCacheKey {
1074 method_bit: METHOD_GET,
1075 available_features: 0,
1076 is_static: false,
1077 header_mask: 0,
1078 };
1079 let candidates: Arc<[usize]> = Arc::from(vec![0, 2, 5].as_slice());
1080 cache.insert(key, "/test".to_string(), candidates.clone());
1081
1082 let result = cache.get(&key, "/test");
1083 assert!(result.is_some());
1084 assert_eq!(result.unwrap().as_ref(), &[0, 2, 5]);
1085 }
1086
1087 #[test]
1088 fn test_candidate_cache_eviction() {
1089 let mut cache = CandidateCache::new(2);
1090 let key = CandidateCacheKey {
1091 method_bit: METHOD_GET,
1092 available_features: 0,
1093 is_static: false,
1094 header_mask: 0,
1095 };
1096
1097 cache.insert(key, "/a".to_string(), Arc::from(vec![0].as_slice()));
1098 cache.insert(key, "/b".to_string(), Arc::from(vec![1].as_slice()));
1099 cache.insert(key, "/c".to_string(), Arc::from(vec![2].as_slice()));
1100
1101 assert!(cache.get(&key, "/a").is_none());
1103 assert!(cache.get(&key, "/b").is_some());
1105 assert!(cache.get(&key, "/c").is_some());
1106 }
1107}