structured_proxy/shield/
matcher.rs1use globset::GlobMatcher;
4
5use super::rate::Rate;
6use crate::config::{EndpointClassConfig, IdentifierEndpointConfig};
7use std::time::Duration;
8
9pub struct EndpointClass {
12 pub matcher: GlobMatcher,
13 pub class: String,
14 pub rate: Rate,
15}
16
17pub struct IdentifierEndpoint {
20 pub matcher: GlobMatcher,
21 pub body_field: String,
22 pub rate: Rate,
23}
24
25fn path_glob(pattern: &str) -> Result<GlobMatcher, String> {
28 globset::GlobBuilder::new(pattern)
29 .literal_separator(true)
30 .build()
31 .map(|g| g.compile_matcher())
32 .map_err(|e| format!("invalid glob pattern {pattern:?}: {e}"))
33}
34
35pub fn compile_endpoint_classes(
37 configs: &[EndpointClassConfig],
38 default_window: Duration,
39) -> Result<Vec<EndpointClass>, String> {
40 configs
41 .iter()
42 .map(|c| {
43 Ok(EndpointClass {
44 matcher: path_glob(&c.pattern)?,
45 class: c.class.clone(),
46 rate: Rate::parse(&c.rate, default_window)?,
47 })
48 })
49 .collect()
50}
51
52pub fn compile_identifier_endpoints(
54 configs: &[IdentifierEndpointConfig],
55 default_window: Duration,
56) -> Result<Vec<IdentifierEndpoint>, String> {
57 configs
58 .iter()
59 .map(|c| {
60 Ok(IdentifierEndpoint {
61 matcher: path_glob(&c.path)?,
62 body_field: c.body_field.clone(),
63 rate: Rate::parse(&c.rate, default_window)?,
64 })
65 })
66 .collect()
67}
68
69#[cfg(test)]
70mod tests {
71 use super::*;
72
73 fn ec(pattern: &str, class: &str, rate: &str) -> EndpointClassConfig {
74 EndpointClassConfig {
75 pattern: pattern.to_string(),
76 class: class.to_string(),
77 rate: rate.to_string(),
78 }
79 }
80
81 #[test]
82 fn endpoint_class_glob_respects_segments() {
83 let classes = compile_endpoint_classes(
84 &[ec("/api/v1/heavy-*", "heavy", "10/min")],
85 Duration::from_secs(60),
86 )
87 .unwrap();
88 let m = &classes[0].matcher;
89 assert!(m.is_match("/api/v1/heavy-export"));
90 assert!(!m.is_match("/api/v1/heavy-export/sub"));
92 assert!(!m.is_match("/api/v1/light"));
93 }
94
95 #[test]
96 fn double_star_spans_segments() {
97 let classes = compile_endpoint_classes(
98 &[ec("/v1/auth/**", "auth", "20/min")],
99 Duration::from_secs(60),
100 )
101 .unwrap();
102 let m = &classes[0].matcher;
103 assert!(m.is_match("/v1/auth/login"));
104 assert!(m.is_match("/v1/auth/opaque/start"));
105 }
106
107 #[test]
108 fn invalid_rate_fails_compilation() {
109 let err = compile_endpoint_classes(&[ec("/x", "c", "nonsense")], Duration::from_secs(60));
110 assert!(err.is_err());
111 }
112}