1use std::hash::{Hash, Hasher};
2use std::net::IpAddr;
3use std::sync::Arc;
4
5use bytes::Bytes;
6use ipnet::IpNet;
7
8use crate::body::Request;
9use crate::conn_context::ConnContext;
10
11#[derive(Clone, Eq, PartialEq, Hash, Debug, serde::Serialize, serde::Deserialize)]
12#[serde(rename_all = "snake_case")]
13pub enum FieldPath {
14 Transport,
15 RemoteIp,
16 RemotePort,
17 LocalIp,
18 LocalPort,
19 Peek,
20 TlsSni,
21 TlsAlpn,
22 TlsVersion,
23 TlsPeerCertPresent,
26 TlsPeerCertSubjectCn,
27 TlsPeerCertSanDns,
32 TlsPeerCertFingerprintSha256,
34 TlsPeerCertSpkiSha256,
37 TlsPeerCertIssuerCn,
40 TlsPeerCertSerial,
43 HttpMethod,
44 HttpUriPath,
45 HttpUriQuery,
46 HttpHeader(Arc<str>),
47 HttpBody,
48}
49
50#[derive(Copy, Clone, Eq, PartialEq, Debug)]
55pub enum FieldValueType {
56 Str,
57 Bytes,
58 Int,
59 IpAddr,
60 Enum,
61 Bool,
62 VecStr,
63}
64
65impl FieldValueType {
66 #[must_use]
67 pub fn name(self) -> &'static str {
68 match self {
69 Self::Str => "Str",
70 Self::Bytes => "Bytes",
71 Self::Int => "Int",
72 Self::IpAddr => "IpAddr",
73 Self::Enum => "enum",
74 Self::Bool => "Bool",
75 Self::VecStr => "Vec<Str>",
76 }
77 }
78}
79
80impl FieldPath {
81 #[must_use]
85 pub fn value_type(&self) -> FieldValueType {
86 match self {
87 Self::Transport | Self::TlsVersion | Self::HttpMethod => FieldValueType::Enum,
88 Self::RemoteIp | Self::LocalIp => FieldValueType::IpAddr,
89 Self::RemotePort | Self::LocalPort => FieldValueType::Int,
90 Self::Peek | Self::TlsAlpn | Self::HttpBody => FieldValueType::Bytes,
91 Self::TlsPeerCertPresent => FieldValueType::Bool,
92 Self::TlsPeerCertSanDns => FieldValueType::VecStr,
93 Self::TlsSni
94 | Self::TlsPeerCertSubjectCn
95 | Self::TlsPeerCertFingerprintSha256
96 | Self::TlsPeerCertSpkiSha256
97 | Self::TlsPeerCertIssuerCn
98 | Self::TlsPeerCertSerial
99 | Self::HttpUriPath
100 | Self::HttpUriQuery
101 | Self::HttpHeader(_) => FieldValueType::Str,
102 }
103 }
104
105 #[must_use]
107 pub fn display_name(&self) -> String {
108 match self {
109 Self::Transport => "transport".to_string(),
110 Self::RemoteIp => "remote.ip".to_string(),
111 Self::RemotePort => "remote.port".to_string(),
112 Self::LocalIp => "local.ip".to_string(),
113 Self::LocalPort => "local.port".to_string(),
114 Self::Peek => "peek".to_string(),
115 Self::TlsSni => "tls.sni".to_string(),
116 Self::TlsAlpn => "tls.alpn".to_string(),
117 Self::TlsVersion => "tls.version".to_string(),
118 Self::TlsPeerCertPresent => "tls.peer_cert.present".to_string(),
119 Self::TlsPeerCertSubjectCn => "tls.peer_cert.subject_cn".to_string(),
120 Self::TlsPeerCertSanDns => "tls.peer_cert.san_dns".to_string(),
121 Self::TlsPeerCertFingerprintSha256 => "tls.peer_cert.fingerprint_sha256".to_string(),
122 Self::TlsPeerCertSpkiSha256 => "tls.peer_cert.spki_sha256".to_string(),
123 Self::TlsPeerCertIssuerCn => "tls.peer_cert.issuer_cn".to_string(),
124 Self::TlsPeerCertSerial => "tls.peer_cert.serial".to_string(),
125 Self::HttpMethod => "http.method".to_string(),
126 Self::HttpUriPath => "http.uri.path".to_string(),
127 Self::HttpUriQuery => "http.uri.query".to_string(),
128 Self::HttpHeader(name) => format!("http.header.{name}"),
129 Self::HttpBody => "http.body".to_string(),
130 }
131 }
132}
133
134#[derive(Copy, Clone, Eq, PartialEq, Debug)]
139pub enum OperatorFamily {
140 Equality,
141 StringSubstr,
142 StringPrefSuf,
143 RegexMatches,
144 InList,
145 NumericCmp,
146 CidrMatch,
147}
148
149impl Operator {
150 #[must_use]
151 pub fn family(&self) -> OperatorFamily {
152 match self {
153 Self::Equals(_) | Self::NotEquals(_) => OperatorFamily::Equality,
154 Self::Contains(_) | Self::NotContains(_) => OperatorFamily::StringSubstr,
155 Self::Prefix(_) | Self::Suffix(_) => OperatorFamily::StringPrefSuf,
156 Self::Matches(_) => OperatorFamily::RegexMatches,
157 Self::In(_) | Self::NotIn(_) => OperatorFamily::InList,
158 Self::Gt(_) | Self::Gte(_) | Self::Lt(_) | Self::Lte(_) => OperatorFamily::NumericCmp,
159 Self::Cidr(_) => OperatorFamily::CidrMatch,
160 }
161 }
162
163 #[must_use]
164 pub fn name(&self) -> &'static str {
165 match self {
166 Self::Equals(_) => "equals",
167 Self::NotEquals(_) => "not_equals",
168 Self::Contains(_) => "contains",
169 Self::NotContains(_) => "not_contains",
170 Self::Prefix(_) => "prefix",
171 Self::Suffix(_) => "suffix",
172 Self::Matches(_) => "matches",
173 Self::In(_) => "in",
174 Self::NotIn(_) => "not_in",
175 Self::Gt(_) => "gt",
176 Self::Gte(_) => "gte",
177 Self::Lt(_) => "lt",
178 Self::Lte(_) => "lte",
179 Self::Cidr(_) => "cidr",
180 }
181 }
182}
183
184impl OperatorFamily {
185 #[must_use]
190 pub fn accepts(self, vt: FieldValueType) -> bool {
191 use FieldValueType as V;
192 use OperatorFamily as F;
193 matches!(
194 (self, vt),
195 (F::Equality, V::Str | V::Bytes | V::Int | V::IpAddr | V::Enum | V::Bool)
200 | (F::InList, V::Str | V::Bytes | V::Int | V::IpAddr | V::Enum)
203 | (F::StringSubstr, V::Str | V::Bytes | V::VecStr)
204 | (F::StringPrefSuf, V::Str | V::Bytes)
205 | (F::RegexMatches, V::Str)
206 | (F::NumericCmp, V::Int)
207 | (F::CidrMatch, V::IpAddr),
208 )
209 }
210
211 #[must_use]
213 pub fn family_expectation(self) -> &'static str {
214 match self {
215 Self::Equality => "any of Str/Bytes/Int/IpAddr/enum/Bool",
216 Self::InList => "any of Str/Bytes/Int/IpAddr/enum",
217 Self::StringSubstr => "Str, Bytes, or Vec<Str>",
218 Self::StringPrefSuf => "Str or Bytes",
219 Self::RegexMatches => "Str",
220 Self::NumericCmp => "numeric",
221 Self::CidrMatch => "IpAddr",
222 }
223 }
224}
225
226#[derive(Clone, Debug)]
227pub enum CompiledValue {
228 Str(Arc<str>),
229 Bytes(Bytes),
230 Int(i64),
231 Bool(bool),
232 Addr(IpAddr),
233}
234
235impl PartialEq for CompiledValue {
236 fn eq(&self, other: &Self) -> bool {
237 match (self, other) {
238 (Self::Str(a), Self::Str(b)) => a.as_ref() == b.as_ref(),
239 (Self::Bytes(a), Self::Bytes(b)) => a == b,
240 (Self::Int(a), Self::Int(b)) => a == b,
241 (Self::Bool(a), Self::Bool(b)) => a == b,
242 (Self::Addr(a), Self::Addr(b)) => a == b,
243 _ => false,
244 }
245 }
246}
247
248impl Eq for CompiledValue {}
249
250impl Hash for CompiledValue {
251 fn hash<H: Hasher>(&self, state: &mut H) {
252 std::mem::discriminant(self).hash(state);
253 match self {
254 Self::Str(s) => s.as_ref().hash(state),
255 Self::Bytes(b) => b.hash(state),
256 Self::Int(i) => i.hash(state),
257 Self::Bool(b) => b.hash(state),
258 Self::Addr(a) => a.hash(state),
259 }
260 }
261}
262
263#[derive(Clone, Debug)]
264pub enum CompiledOperator {
265 Equals(CompiledValue),
266 NotEquals(CompiledValue),
267 Contains(Bytes),
268 NotContains(Bytes),
269 Prefix(Bytes),
270 Suffix(Bytes),
271 Matches(fancy_regex::Regex),
272 In(Vec<CompiledValue>),
273 NotIn(Vec<CompiledValue>),
274 Gt(i64),
275 Gte(i64),
276 Lt(i64),
277 Lte(i64),
278 Cidr(IpNet),
279}
280
281impl PartialEq for CompiledOperator {
282 fn eq(&self, other: &Self) -> bool {
283 match (self, other) {
284 (Self::Equals(a), Self::Equals(b)) | (Self::NotEquals(a), Self::NotEquals(b)) => a == b,
285 (Self::Contains(a), Self::Contains(b))
286 | (Self::NotContains(a), Self::NotContains(b))
287 | (Self::Prefix(a), Self::Prefix(b))
288 | (Self::Suffix(a), Self::Suffix(b)) => a == b,
289 (Self::Matches(a), Self::Matches(b)) => a.as_str() == b.as_str(),
290 (Self::In(a), Self::In(b)) | (Self::NotIn(a), Self::NotIn(b)) => a == b,
291 (Self::Gt(a), Self::Gt(b))
292 | (Self::Gte(a), Self::Gte(b))
293 | (Self::Lt(a), Self::Lt(b))
294 | (Self::Lte(a), Self::Lte(b)) => a == b,
295 (Self::Cidr(a), Self::Cidr(b)) => a == b,
296 _ => false,
297 }
298 }
299}
300
301impl Eq for CompiledOperator {}
302
303impl Hash for CompiledOperator {
304 fn hash<H: Hasher>(&self, state: &mut H) {
305 std::mem::discriminant(self).hash(state);
306 match self {
307 Self::Equals(v) | Self::NotEquals(v) => v.hash(state),
308 Self::Contains(b) | Self::NotContains(b) | Self::Prefix(b) | Self::Suffix(b) => {
309 b.hash(state);
310 }
311 Self::Matches(r) => r.as_str().hash(state),
312 Self::In(v) | Self::NotIn(v) => v.hash(state),
313 Self::Gt(i) | Self::Gte(i) | Self::Lt(i) | Self::Lte(i) => i.hash(state),
314 Self::Cidr(n) => n.hash(state),
315 }
316 }
317}
318
319#[derive(Clone, Debug, Eq, PartialEq, Hash, serde::Serialize, serde::Deserialize)]
320pub struct PredicateInst {
321 pub path: FieldPath,
322 pub op: CompiledOperator,
323}
324
325pub enum PredicateView<'a> {
326 L4 { conn: &'a Arc<ConnContext>, peek: Option<&'a [u8]> },
327 L7Req { conn: &'a Arc<ConnContext>, req: &'a Request },
328}
329
330impl<'a> PredicateView<'a> {
331 #[must_use]
341 pub fn build(
342 conn: &'a Arc<ConnContext>,
343 req: Option<&'a Request>,
344 _l4: Option<&'a crate::l4::L4Conn>,
345 peek: Option<&'a [u8]>,
346 ) -> Self {
347 match req {
348 Some(r) => Self::L7Req { conn, req: r },
349 None => Self::L4 { conn, peek },
350 }
351 }
352
353 fn conn(&self) -> &Arc<ConnContext> {
354 match self {
355 Self::L4 { conn, .. } | Self::L7Req { conn, .. } => conn,
356 }
357 }
358
359 fn request(&self) -> Option<&Request> {
360 match self {
361 Self::L7Req { req, .. } => Some(req),
362 Self::L4 { .. } => None,
363 }
364 }
365
366 fn peek_buffer(&self) -> Option<&[u8]> {
367 match self {
368 Self::L4 { peek, .. } => *peek,
369 Self::L7Req { .. } => None,
370 }
371 }
372}
373
374impl PredicateInst {
375 #[must_use]
398 #[allow(
399 clippy::too_many_lines,
400 reason = "truth-table dispatch over the full FieldPath enum: each arm reads one connection / request field per spec/crates/core.md § Predicate. Splitting by sub-arm scatters the field-by-field reading across helpers and adds nothing beyond satisfying the lint"
401 )]
402 pub fn test(&self, view: &PredicateView<'_>) -> bool {
403 match &self.path {
404 FieldPath::Transport => {
405 let s = match view.conn().transport {
406 crate::conn_context::Transport::Tcp => "tcp",
407 crate::conn_context::Transport::Udp => "udp",
408 };
409 test_str(&self.op, s)
410 }
411 FieldPath::RemoteIp => test_addr(&self.op, view.conn().remote.ip()),
412 FieldPath::RemotePort => test_int(&self.op, i64::from(view.conn().remote.port())),
413 FieldPath::LocalIp => test_addr(&self.op, view.conn().local.ip()),
414 FieldPath::LocalPort => test_int(&self.op, i64::from(view.conn().local.port())),
415 FieldPath::Peek => view.peek_buffer().is_some_and(|b| test_bytes(&self.op, b)),
416 FieldPath::TlsSni => view
417 .conn()
418 .tls
419 .lock()
420 .as_ref()
421 .and_then(|t| t.sni.clone())
422 .is_some_and(|got| test_str(&self.op, &got)),
423 FieldPath::TlsAlpn => view
424 .conn()
425 .tls
426 .lock()
427 .as_ref()
428 .and_then(|t| t.alpn.clone())
429 .is_some_and(|got| test_bytes(&self.op, &got)),
430 FieldPath::TlsVersion => view
431 .conn()
432 .tls
433 .lock()
434 .as_ref()
435 .and_then(|t| t.version)
436 .is_some_and(|v| test_str(&self.op, tls_version_str(v))),
437 FieldPath::TlsPeerCertPresent => {
447 let present = view.conn().tls.lock().as_ref().is_some_and(|t| t.peer_cert.is_some());
448 test_bool(&self.op, present)
449 }
450 FieldPath::TlsPeerCertSubjectCn => view
451 .conn()
452 .tls
453 .lock()
454 .as_ref()
455 .and_then(|t| t.peer_cert.as_ref().and_then(|p| p.subject_cn.clone()))
456 .is_some_and(|cn| test_str(&self.op, &cn)),
457 FieldPath::TlsPeerCertSanDns => {
458 let dns: Option<Arc<[Arc<str>]>> = view
462 .conn()
463 .tls
464 .lock()
465 .as_ref()
466 .and_then(|t| t.peer_cert.as_ref().map(|p| Arc::clone(&p.san_dns)));
467 match dns {
468 Some(list) => test_vec_str(&self.op, &list),
469 None => test_vec_str::<&str>(&self.op, &[]),
470 }
471 }
472 FieldPath::TlsPeerCertFingerprintSha256 => view
473 .conn()
474 .tls
475 .lock()
476 .as_ref()
477 .and_then(|t| t.peer_cert.as_ref().map(|p| Arc::clone(&p.fingerprint_sha256)))
478 .is_some_and(|s| test_str(&self.op, &s)),
479 FieldPath::TlsPeerCertSpkiSha256 => view
480 .conn()
481 .tls
482 .lock()
483 .as_ref()
484 .and_then(|t| t.peer_cert.as_ref().map(|p| Arc::clone(&p.spki_sha256)))
485 .is_some_and(|s| test_str(&self.op, &s)),
486 FieldPath::TlsPeerCertIssuerCn => view
487 .conn()
488 .tls
489 .lock()
490 .as_ref()
491 .and_then(|t| t.peer_cert.as_ref().and_then(|p| p.issuer_cn.clone()))
492 .is_some_and(|s| test_str(&self.op, &s)),
493 FieldPath::TlsPeerCertSerial => view
494 .conn()
495 .tls
496 .lock()
497 .as_ref()
498 .and_then(|t| t.peer_cert.as_ref().map(|p| Arc::clone(&p.serial)))
499 .is_some_and(|s| test_str(&self.op, &s)),
500 FieldPath::HttpMethod => {
501 let Some(req) = view.request() else { return false };
502 test_str(&self.op, req.method().as_str())
503 }
504 FieldPath::HttpUriPath => {
505 let Some(req) = view.request() else { return false };
506 test_str(&self.op, req.uri().path())
507 }
508 FieldPath::HttpUriQuery => {
509 let Some(req) = view.request() else { return false };
510 test_str(&self.op, req.uri().query().unwrap_or(""))
511 }
512 FieldPath::HttpHeader(name) => {
520 let Some(req) = view.request() else { return false };
521 let Some(value) = req.headers().get(name.as_ref()) else { return false };
522 let Ok(s) = value.to_str() else {
523 return false;
528 };
529 test_str(&self.op, s)
530 }
531 FieldPath::HttpBody => {
540 let Some(req) = view.request() else { return false };
541 let bytes = req.body().as_static().expect("lazy-buffer invariant");
542 test_bytes(&self.op, bytes.as_ref())
543 }
544 }
545 }
546}
547
548fn tls_version_str(v: crate::conn_context::TlsVersion) -> &'static str {
549 match v {
550 crate::conn_context::TlsVersion::Tls12 => "1.2",
551 crate::conn_context::TlsVersion::Tls13 => "1.3",
552 }
553}
554
555fn test_bool(op: &CompiledOperator, value: bool) -> bool {
567 match op {
568 CompiledOperator::Equals(CompiledValue::Bool(expected)) => value == *expected,
569 CompiledOperator::NotEquals(CompiledValue::Bool(expected)) => value != *expected,
570 _ => false,
571 }
572}
573
574fn test_vec_str<S: AsRef<str>>(op: &CompiledOperator, values: &[S]) -> bool {
580 match op {
581 CompiledOperator::Contains(needle) => {
582 values.iter().any(|v| v.as_ref().as_bytes() == needle.as_ref())
583 }
584 CompiledOperator::NotContains(needle) => {
585 !values.iter().any(|v| v.as_ref().as_bytes() == needle.as_ref())
586 }
587 _ => false,
588 }
589}
590
591fn test_str(op: &CompiledOperator, value: &str) -> bool {
596 match op {
597 CompiledOperator::Equals(CompiledValue::Str(expected)) => value == expected.as_ref(),
598 CompiledOperator::NotEquals(CompiledValue::Str(expected)) => value != expected.as_ref(),
599 CompiledOperator::Contains(b) => contains_bytes(value.as_bytes(), b),
600 CompiledOperator::NotContains(b) => !contains_bytes(value.as_bytes(), b),
601 CompiledOperator::Prefix(b) => value.as_bytes().starts_with(b.as_ref()),
602 CompiledOperator::Suffix(b) => value.as_bytes().ends_with(b.as_ref()),
603 CompiledOperator::Matches(re) => re.is_match(value).unwrap_or(false),
604 CompiledOperator::In(values) => {
605 values.iter().any(|v| matches!(v, CompiledValue::Str(s) if value == s.as_ref()))
606 }
607 CompiledOperator::NotIn(values) => {
608 !values.iter().any(|v| matches!(v, CompiledValue::Str(s) if value == s.as_ref()))
609 }
610 _ => false,
611 }
612}
613
614fn test_bytes(op: &CompiledOperator, value: &[u8]) -> bool {
618 match op {
619 CompiledOperator::Equals(CompiledValue::Bytes(expected)) => value == expected.as_ref(),
620 CompiledOperator::NotEquals(CompiledValue::Bytes(expected)) => value != expected.as_ref(),
621 CompiledOperator::Contains(b) => contains_bytes(value, b),
622 CompiledOperator::NotContains(b) => !contains_bytes(value, b),
623 CompiledOperator::Prefix(b) => value.starts_with(b.as_ref()),
624 CompiledOperator::Suffix(b) => value.ends_with(b.as_ref()),
625 CompiledOperator::In(values) => {
626 values.iter().any(|v| matches!(v, CompiledValue::Bytes(b) if value == b.as_ref()))
627 }
628 CompiledOperator::NotIn(values) => {
629 !values.iter().any(|v| matches!(v, CompiledValue::Bytes(b) if value == b.as_ref()))
630 }
631 _ => false,
632 }
633}
634
635fn test_int(op: &CompiledOperator, value: i64) -> bool {
638 match op {
639 CompiledOperator::Equals(CompiledValue::Int(expected)) => value == *expected,
640 CompiledOperator::NotEquals(CompiledValue::Int(expected)) => value != *expected,
641 CompiledOperator::Gt(n) => value > *n,
642 CompiledOperator::Gte(n) => value >= *n,
643 CompiledOperator::Lt(n) => value < *n,
644 CompiledOperator::Lte(n) => value <= *n,
645 CompiledOperator::In(values) => {
646 values.iter().any(|v| matches!(v, CompiledValue::Int(i) if value == *i))
647 }
648 CompiledOperator::NotIn(values) => {
649 !values.iter().any(|v| matches!(v, CompiledValue::Int(i) if value == *i))
650 }
651 _ => false,
652 }
653}
654
655fn test_addr(op: &CompiledOperator, value: std::net::IpAddr) -> bool {
659 match op {
660 CompiledOperator::Equals(CompiledValue::Addr(expected)) => value == *expected,
661 CompiledOperator::NotEquals(CompiledValue::Addr(expected)) => value != *expected,
662 CompiledOperator::Cidr(net) => net.contains(&value),
663 CompiledOperator::In(values) => {
664 values.iter().any(|v| matches!(v, CompiledValue::Addr(a) if value == *a))
665 }
666 CompiledOperator::NotIn(values) => {
667 !values.iter().any(|v| matches!(v, CompiledValue::Addr(a) if value == *a))
668 }
669 _ => false,
670 }
671}
672
673fn contains_bytes(haystack: &[u8], needle: &[u8]) -> bool {
674 if needle.is_empty() {
675 return true;
676 }
677 if needle.len() > haystack.len() {
678 return false;
679 }
680 haystack.windows(needle.len()).any(|w| w == needle)
681}
682
683pub const REGEX_PATTERN_MAX_BYTES: usize = 4 * 1024;
684
685pub const REGEX_BACKTRACK_LIMIT: usize = 1_000_000;
693
694pub const REGEX_DELEGATE_SIZE_LIMIT: usize = 4 * 1024 * 1024;
698
699pub const REGEX_SMOKE_TEST_INPUT_LEN: usize = 64;
704
705pub const MAX_PREDICATE_DEPTH: usize = 64;
712
713#[derive(Debug, Clone, serde::Serialize)]
714pub enum Predicate {
715 AnyOf(AnyOfP),
716 AllOf(AllOfP),
717 Not(NotP),
718 Check(CheckMap),
719}
720
721pub fn check_max_depth(pred: &Predicate) -> Result<(), crate::error::Error> {
729 let mut stack: Vec<(&Predicate, usize)> = vec![(pred, 1)];
730 while let Some((p, depth)) = stack.pop() {
731 if depth > MAX_PREDICATE_DEPTH {
732 return Err(crate::error::Error::compile(format!(
733 "predicate nests deeper than {MAX_PREDICATE_DEPTH} levels (MAX_PREDICATE_DEPTH)"
734 )));
735 }
736 match p {
737 Predicate::Check(_) => {}
738 Predicate::AnyOf(a) => {
739 for child in &a.any_of {
740 stack.push((child, depth + 1));
741 }
742 }
743 Predicate::AllOf(a) => {
744 for child in &a.all_of {
745 stack.push((child, depth + 1));
746 }
747 }
748 Predicate::Not(n) => stack.push((n.not.as_ref(), depth + 1)),
749 }
750 }
751 Ok(())
752}
753
754#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
755#[serde(deny_unknown_fields)]
756pub struct AnyOfP {
757 pub any_of: Vec<Predicate>,
758}
759
760#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
761#[serde(deny_unknown_fields)]
762pub struct AllOfP {
763 pub all_of: Vec<Predicate>,
764}
765
766#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
767#[serde(deny_unknown_fields)]
768pub struct NotP {
769 pub not: Box<Predicate>,
770}
771
772#[derive(Debug, Clone, serde::Serialize)]
773pub struct CheckMap {
774 pub path: FieldPath,
775 pub op: Operator,
776}
777
778#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
779#[serde(rename_all = "snake_case")]
780pub enum Operator {
781 Equals(Value),
782 NotEquals(Value),
783 Contains(Value),
784 NotContains(Value),
785 Prefix(Value),
786 Suffix(Value),
787 Matches(String),
788 In(Vec<Value>),
789 NotIn(Vec<Value>),
790 Gt(i64),
791 Gte(i64),
792 Lt(i64),
793 Lte(i64),
794 Cidr(String),
795}
796
797#[derive(Debug, Clone, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
798#[serde(untagged)]
799pub enum Value {
800 Bool(bool),
801 Int(i64),
802 Str(String),
803}
804
805impl<'de> serde::Deserialize<'de> for Predicate {
806 fn deserialize<D: serde::Deserializer<'de>>(de: D) -> Result<Self, D::Error> {
807 let v = serde_json::Value::deserialize(de)?;
808 let serde_json::Value::Object(ref map) = v else {
809 return Err(serde::de::Error::custom("predicate must be a JSON object"));
810 };
811 if map.len() == 1 {
812 let (key, _) = map.iter().next().expect("len == 1");
813 match key.as_str() {
814 "any_of" => {
815 return serde_json::from_value::<AnyOfP>(v)
816 .map(Predicate::AnyOf)
817 .map_err(serde::de::Error::custom);
818 }
819 "all_of" => {
820 return serde_json::from_value::<AllOfP>(v)
821 .map(Predicate::AllOf)
822 .map_err(serde::de::Error::custom);
823 }
824 "not" => {
825 return serde_json::from_value::<NotP>(v)
826 .map(Predicate::Not)
827 .map_err(serde::de::Error::custom);
828 }
829 _ => {}
830 }
831 }
832 serde_json::from_value::<CheckMap>(v).map(Predicate::Check).map_err(serde::de::Error::custom)
833 }
834}
835
836impl<'de> serde::Deserialize<'de> for CheckMap {
837 fn deserialize<D: serde::Deserializer<'de>>(de: D) -> Result<Self, D::Error> {
838 struct Visitor;
839
840 impl<'de> serde::de::Visitor<'de> for Visitor {
841 type Value = CheckMap;
842
843 fn expecting(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
844 f.write_str("a single-key object of the form {\"<field-path>\": {\"<operator>\": <value>}}")
845 }
846
847 fn visit_map<M: serde::de::MapAccess<'de>>(self, mut map: M) -> Result<CheckMap, M::Error> {
848 let Some(key) = map.next_key::<String>()? else {
849 return Err(serde::de::Error::invalid_length(0, &"exactly one key"));
850 };
851 let path = parse_field_path(&key).map_err(serde::de::Error::custom)?;
852 let op: Operator = map.next_value()?;
853 if map.next_key::<serde::de::IgnoredAny>()?.is_some() {
854 return Err(serde::de::Error::custom("check object must have exactly one key"));
855 }
856 validate_operator(&op).map_err(serde::de::Error::custom)?;
857 Ok(CheckMap { path, op })
858 }
859 }
860
861 de.deserialize_map(Visitor)
862 }
863}
864
865fn parse_field_path(s: &str) -> Result<FieldPath, String> {
866 if s.chars().any(|c| c.is_ascii_uppercase()) {
867 return Err(format!(
868 "field path must be lowercase: {:?} — did you mean {:?}?",
869 s,
870 s.to_ascii_lowercase(),
871 ));
872 }
873 match s {
874 "transport" => Ok(FieldPath::Transport),
875 "remote.ip" => Ok(FieldPath::RemoteIp),
876 "remote.port" => Ok(FieldPath::RemotePort),
877 "local.ip" => Ok(FieldPath::LocalIp),
878 "local.port" => Ok(FieldPath::LocalPort),
879 "peek" => Ok(FieldPath::Peek),
880 "tls.sni" => Ok(FieldPath::TlsSni),
881 "tls.alpn" => Ok(FieldPath::TlsAlpn),
882 "tls.version" => Ok(FieldPath::TlsVersion),
883 "tls.peer_cert.present" => Ok(FieldPath::TlsPeerCertPresent),
884 "tls.peer_cert.subject_cn" => Ok(FieldPath::TlsPeerCertSubjectCn),
885 "tls.peer_cert.san_dns" => Ok(FieldPath::TlsPeerCertSanDns),
886 "tls.peer_cert.fingerprint_sha256" => Ok(FieldPath::TlsPeerCertFingerprintSha256),
887 "tls.peer_cert.spki_sha256" => Ok(FieldPath::TlsPeerCertSpkiSha256),
888 "tls.peer_cert.issuer_cn" => Ok(FieldPath::TlsPeerCertIssuerCn),
889 "tls.peer_cert.serial" => Ok(FieldPath::TlsPeerCertSerial),
890 "http.method" => Ok(FieldPath::HttpMethod),
891 "http.uri.path" => Ok(FieldPath::HttpUriPath),
892 "http.uri.query" => Ok(FieldPath::HttpUriQuery),
893 "http.body" => Ok(FieldPath::HttpBody),
894 other if other.starts_with("http.header.") => {
895 let name = &other["http.header.".len()..];
896 if name.is_empty() {
897 return Err(format!("http.header.* requires a header name: {other:?}"));
898 }
899 Ok(FieldPath::HttpHeader(Arc::from(name)))
900 }
901 other => Err(format!("unknown field path: {other:?}")),
902 }
903}
904
905fn validate_operator(op: &Operator) -> Result<(), String> {
906 if let Operator::Matches(pattern) = op
907 && pattern.len() > REGEX_PATTERN_MAX_BYTES
908 {
909 return Err(format!(
910 "regex pattern source exceeds {REGEX_PATTERN_MAX_BYTES}-byte limit: got {} bytes",
911 pattern.len(),
912 ));
913 }
914 Ok(())
915}
916
917mod serde_impls {
918 use base64::Engine as _;
919 use base64::engine::general_purpose::STANDARD as B64;
920 use bytes::Bytes;
921 use std::net::IpAddr;
922 use std::sync::Arc;
923
924 use super::{CompiledOperator, CompiledValue};
925
926 pub(super) fn ser_bytes<S: serde::Serializer>(b: &Bytes, s: S) -> Result<S::Ok, S::Error> {
927 s.serialize_str(&B64.encode(b))
928 }
929
930 pub(super) fn de_bytes<'de, D: serde::Deserializer<'de>>(d: D) -> Result<Bytes, D::Error> {
931 use serde::Deserialize as _;
932 let s = String::deserialize(d)?;
933 B64.decode(s.as_bytes()).map(Bytes::from).map_err(serde::de::Error::custom)
934 }
935
936 pub(super) fn ser_regex<S: serde::Serializer>(
937 r: &fancy_regex::Regex,
938 s: S,
939 ) -> Result<S::Ok, S::Error> {
940 s.serialize_str(r.as_str())
941 }
942
943 pub(super) fn de_regex<'de, D: serde::Deserializer<'de>>(
944 d: D,
945 ) -> Result<fancy_regex::Regex, D::Error> {
946 use serde::Deserialize as _;
947 let s = String::deserialize(d)?;
948 fancy_regex::Regex::new(&s)
949 .map_err(|e| serde::de::Error::custom(format!("invalid regex {s:?}: {e}")))
950 }
951
952 #[derive(serde::Serialize, serde::Deserialize)]
954 #[serde(rename_all = "snake_case")]
955 pub(super) enum ValueShadow {
956 Str(Arc<str>),
957 #[serde(serialize_with = "ser_bytes", deserialize_with = "de_bytes")]
958 Bytes(Bytes),
959 Int(i64),
960 Bool(bool),
961 Addr(IpAddr),
962 }
963
964 impl From<&CompiledValue> for ValueShadow {
965 fn from(v: &CompiledValue) -> Self {
966 match v {
967 CompiledValue::Str(s) => Self::Str(Arc::clone(s)),
968 CompiledValue::Bytes(b) => Self::Bytes(b.clone()),
969 CompiledValue::Int(i) => Self::Int(*i),
970 CompiledValue::Bool(b) => Self::Bool(*b),
971 CompiledValue::Addr(a) => Self::Addr(*a),
972 }
973 }
974 }
975
976 impl From<ValueShadow> for CompiledValue {
977 fn from(v: ValueShadow) -> Self {
978 match v {
979 ValueShadow::Str(s) => Self::Str(s),
980 ValueShadow::Bytes(b) => Self::Bytes(b),
981 ValueShadow::Int(i) => Self::Int(i),
982 ValueShadow::Bool(b) => Self::Bool(b),
983 ValueShadow::Addr(a) => Self::Addr(a),
984 }
985 }
986 }
987
988 #[derive(serde::Serialize, serde::Deserialize)]
991 #[serde(rename_all = "snake_case")]
992 pub(super) enum OperatorShadow {
993 Equals(CompiledValue),
994 NotEquals(CompiledValue),
995 #[serde(serialize_with = "ser_bytes", deserialize_with = "de_bytes")]
996 Contains(Bytes),
997 #[serde(serialize_with = "ser_bytes", deserialize_with = "de_bytes")]
998 NotContains(Bytes),
999 #[serde(serialize_with = "ser_bytes", deserialize_with = "de_bytes")]
1000 Prefix(Bytes),
1001 #[serde(serialize_with = "ser_bytes", deserialize_with = "de_bytes")]
1002 Suffix(Bytes),
1003 #[serde(serialize_with = "ser_regex", deserialize_with = "de_regex")]
1004 Matches(fancy_regex::Regex),
1005 In(Vec<CompiledValue>),
1006 NotIn(Vec<CompiledValue>),
1007 Gt(i64),
1008 Gte(i64),
1009 Lt(i64),
1010 Lte(i64),
1011 Cidr(ipnet::IpNet),
1012 }
1013
1014 impl From<&CompiledOperator> for OperatorShadow {
1015 fn from(op: &CompiledOperator) -> Self {
1016 match op {
1017 CompiledOperator::Equals(v) => Self::Equals(v.clone()),
1018 CompiledOperator::NotEquals(v) => Self::NotEquals(v.clone()),
1019 CompiledOperator::Contains(b) => Self::Contains(b.clone()),
1020 CompiledOperator::NotContains(b) => Self::NotContains(b.clone()),
1021 CompiledOperator::Prefix(b) => Self::Prefix(b.clone()),
1022 CompiledOperator::Suffix(b) => Self::Suffix(b.clone()),
1023 CompiledOperator::Matches(r) => {
1024 Self::Matches(fancy_regex::Regex::new(r.as_str()).expect("round-trippable"))
1025 }
1026 CompiledOperator::In(vs) => Self::In(vs.clone()),
1027 CompiledOperator::NotIn(vs) => Self::NotIn(vs.clone()),
1028 CompiledOperator::Gt(i) => Self::Gt(*i),
1029 CompiledOperator::Gte(i) => Self::Gte(*i),
1030 CompiledOperator::Lt(i) => Self::Lt(*i),
1031 CompiledOperator::Lte(i) => Self::Lte(*i),
1032 CompiledOperator::Cidr(n) => Self::Cidr(*n),
1033 }
1034 }
1035 }
1036
1037 impl From<OperatorShadow> for CompiledOperator {
1038 fn from(op: OperatorShadow) -> Self {
1039 match op {
1040 OperatorShadow::Equals(v) => Self::Equals(v),
1041 OperatorShadow::NotEquals(v) => Self::NotEquals(v),
1042 OperatorShadow::Contains(b) => Self::Contains(b),
1043 OperatorShadow::NotContains(b) => Self::NotContains(b),
1044 OperatorShadow::Prefix(b) => Self::Prefix(b),
1045 OperatorShadow::Suffix(b) => Self::Suffix(b),
1046 OperatorShadow::Matches(r) => Self::Matches(r),
1047 OperatorShadow::In(vs) => Self::In(vs),
1048 OperatorShadow::NotIn(vs) => Self::NotIn(vs),
1049 OperatorShadow::Gt(i) => Self::Gt(i),
1050 OperatorShadow::Gte(i) => Self::Gte(i),
1051 OperatorShadow::Lt(i) => Self::Lt(i),
1052 OperatorShadow::Lte(i) => Self::Lte(i),
1053 OperatorShadow::Cidr(n) => Self::Cidr(n),
1054 }
1055 }
1056 }
1057}
1058
1059impl serde::Serialize for CompiledValue {
1060 fn serialize<S: serde::Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
1061 serde_impls::ValueShadow::from(self).serialize(s)
1062 }
1063}
1064
1065impl<'de> serde::Deserialize<'de> for CompiledValue {
1066 fn deserialize<D: serde::Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
1067 serde_impls::ValueShadow::deserialize(d).map(Self::from)
1068 }
1069}
1070
1071impl serde::Serialize for CompiledOperator {
1072 fn serialize<S: serde::Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
1073 serde_impls::OperatorShadow::from(self).serialize(s)
1074 }
1075}
1076
1077impl<'de> serde::Deserialize<'de> for CompiledOperator {
1078 fn deserialize<D: serde::Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
1079 serde_impls::OperatorShadow::deserialize(d).map(Self::from)
1080 }
1081}
1082
1083#[cfg(test)]
1084mod tests {
1085 use std::collections::hash_map::DefaultHasher;
1086 use std::hash::Hash;
1087 use std::net::{Ipv4Addr, Ipv6Addr};
1088 use std::str::FromStr;
1089 use std::sync::OnceLock;
1090 use std::time::Instant;
1091
1092 use bytes::Bytes;
1093 use fancy_regex::Regex;
1094 use ipnet::IpNet;
1095 use parking_lot::Mutex;
1096
1097 use super::*;
1098 use crate::body::{Body, Request};
1099 use crate::conn_context::{ConnId, Transport};
1100
1101 fn hash_of<T: Hash>(v: &T) -> u64 {
1106 let mut h = DefaultHasher::new();
1107 v.hash(&mut h);
1108 h.finish()
1109 }
1110
1111 fn make_conn() -> Arc<ConnContext> {
1112 Arc::new(ConnContext {
1113 id: ConnId(1),
1114 remote: "127.0.0.1:0".parse().expect("parse remote"),
1115 local: "127.0.0.1:0".parse().expect("parse local"),
1116 transport: Transport::Tcp,
1117 entered_at: Instant::now(),
1118 tls: Mutex::new(None),
1119 http_version: OnceLock::new(),
1120 user: Mutex::new(http::Extensions::new()),
1121 })
1122 }
1123
1124 #[test]
1125 fn field_path_http_header_is_equal_by_string_content_not_arc_identity() {
1126 let a = FieldPath::HttpHeader(Arc::from("host"));
1127 let b = FieldPath::HttpHeader(Arc::from("host"));
1128 assert_eq!(a, b);
1129 assert_eq!(hash_of(&a), hash_of(&b));
1130 let upper = FieldPath::HttpHeader(Arc::from("Host"));
1135 assert_ne!(a, upper);
1136 }
1137
1138 #[test]
1139 fn field_path_simple_variants_are_self_equal_and_mutually_distinct() {
1140 let paths = [
1141 FieldPath::Transport,
1142 FieldPath::RemoteIp,
1143 FieldPath::RemotePort,
1144 FieldPath::LocalIp,
1145 FieldPath::LocalPort,
1146 FieldPath::Peek,
1147 FieldPath::TlsSni,
1148 FieldPath::TlsAlpn,
1149 FieldPath::TlsVersion,
1150 FieldPath::TlsPeerCertSubjectCn,
1151 FieldPath::HttpMethod,
1152 FieldPath::HttpUriPath,
1153 FieldPath::HttpUriQuery,
1154 FieldPath::HttpBody,
1155 ];
1156 for (i, a) in paths.iter().enumerate() {
1157 for (j, b) in paths.iter().enumerate() {
1158 if i == j {
1159 assert_eq!(a, b);
1160 } else {
1161 assert_ne!(a, b);
1162 }
1163 }
1164 }
1165 }
1166
1167 #[test]
1168 fn compiled_value_str_is_equal_by_content_not_arc_identity() {
1169 let a = CompiledValue::Str(Arc::<str>::from("x"));
1170 let b = CompiledValue::Str(Arc::<str>::from("x"));
1171 assert_eq!(a, b);
1172 assert_eq!(hash_of(&a), hash_of(&b));
1173 let c = CompiledValue::Str(Arc::<str>::from("y"));
1174 assert_ne!(a, c);
1175 }
1176
1177 #[test]
1178 fn compiled_value_cross_variant_inequality() {
1179 let s = CompiledValue::Str(Arc::<str>::from("42"));
1180 let i = CompiledValue::Int(42);
1181 assert_ne!(s, i);
1182 }
1183
1184 #[test]
1185 fn compiled_value_bytes_int_bool_addr_self_equal() {
1186 assert_eq!(
1187 CompiledValue::Bytes(Bytes::from_static(b"abc")),
1188 CompiledValue::Bytes(Bytes::copy_from_slice(b"abc")),
1189 );
1190 assert_eq!(CompiledValue::Int(7), CompiledValue::Int(7));
1191 assert_ne!(CompiledValue::Int(7), CompiledValue::Int(8));
1192 assert_eq!(CompiledValue::Bool(true), CompiledValue::Bool(true));
1193 assert_ne!(CompiledValue::Bool(true), CompiledValue::Bool(false));
1194 assert_eq!(
1195 CompiledValue::Addr(Ipv4Addr::new(10, 0, 0, 1).into()),
1196 CompiledValue::Addr(Ipv4Addr::new(10, 0, 0, 1).into()),
1197 );
1198 assert_ne!(
1199 CompiledValue::Addr(Ipv4Addr::new(10, 0, 0, 1).into()),
1200 CompiledValue::Addr(Ipv6Addr::LOCALHOST.into()),
1201 );
1202 }
1203
1204 #[test]
1205 fn compiled_operator_matches_equal_by_pattern_source() {
1206 let a = CompiledOperator::Matches(Regex::new("^/api").expect("compile a"));
1207 let b = CompiledOperator::Matches(Regex::new("^/api").expect("compile b"));
1208 assert_eq!(a, b);
1209 assert_eq!(hash_of(&a), hash_of(&b));
1210 }
1211
1212 #[test]
1213 fn compiled_operator_matches_distinct_patterns_unequal() {
1214 let a = CompiledOperator::Matches(Regex::new("a|b").expect("compile a"));
1217 let b = CompiledOperator::Matches(Regex::new("b|a").expect("compile b"));
1218 assert_ne!(a, b);
1219 }
1220
1221 #[test]
1222 fn compiled_operator_cidr_equal_by_canonical_form() {
1223 let a = CompiledOperator::Cidr(IpNet::from_str("10.0.0.0/8").expect("parse a"));
1224 let b = CompiledOperator::Cidr(IpNet::from_str("10.0.0.0/8").expect("parse b"));
1225 assert_eq!(a, b);
1226 assert_eq!(hash_of(&a), hash_of(&b));
1227 }
1228
1229 #[test]
1230 fn compiled_operator_cidr_distinct_networks_unequal() {
1231 let a = CompiledOperator::Cidr(IpNet::from_str("10.0.0.0/8").expect("parse a"));
1232 let b = CompiledOperator::Cidr(IpNet::from_str("10.0.0.0/16").expect("parse b"));
1233 assert_ne!(a, b);
1234 }
1235
1236 #[test]
1237 fn compiled_operator_in_is_order_sensitive() {
1238 let xs =
1239 vec![CompiledValue::Str(Arc::<str>::from("a")), CompiledValue::Str(Arc::<str>::from("b"))];
1240 let ys =
1241 vec![CompiledValue::Str(Arc::<str>::from("b")), CompiledValue::Str(Arc::<str>::from("a"))];
1242 assert_ne!(CompiledOperator::In(xs.clone()), CompiledOperator::In(ys.clone()));
1243 assert_ne!(CompiledOperator::NotIn(xs), CompiledOperator::NotIn(ys));
1244 }
1245
1246 #[test]
1247 fn compiled_operator_numeric_comparisons_distinct_per_variant() {
1248 let ops = [
1250 CompiledOperator::Gt(10),
1251 CompiledOperator::Gte(10),
1252 CompiledOperator::Lt(10),
1253 CompiledOperator::Lte(10),
1254 ];
1255 for (i, a) in ops.iter().enumerate() {
1256 for (j, b) in ops.iter().enumerate() {
1257 if i == j {
1258 assert_eq!(a, b);
1259 } else {
1260 assert_ne!(a, b);
1261 }
1262 }
1263 }
1264 }
1265
1266 #[test]
1267 fn compiled_operator_bytes_variants_distinguished() {
1268 let payload = Bytes::from_static(b"abc");
1269 let ops = [
1270 CompiledOperator::Contains(payload.clone()),
1271 CompiledOperator::NotContains(payload.clone()),
1272 CompiledOperator::Prefix(payload.clone()),
1273 CompiledOperator::Suffix(payload),
1274 ];
1275 for (i, a) in ops.iter().enumerate() {
1276 for (j, b) in ops.iter().enumerate() {
1277 if i == j {
1278 assert_eq!(a, b);
1279 } else {
1280 assert_ne!(a, b);
1281 }
1282 }
1283 }
1284 }
1285
1286 #[test]
1287 fn predicate_inst_equal_across_independent_construction_paths() {
1288 let lhs = PredicateInst {
1289 path: FieldPath::HttpHeader(Arc::from("host")),
1290 op: CompiledOperator::Equals(CompiledValue::Str(Arc::<str>::from("example.com"))),
1291 };
1292 let rhs = PredicateInst {
1293 path: FieldPath::HttpHeader(Arc::from("host")),
1294 op: CompiledOperator::Equals(CompiledValue::Str(Arc::<str>::from("example.com"))),
1295 };
1296 assert_eq!(lhs, rhs);
1297 assert_eq!(hash_of(&lhs), hash_of(&rhs));
1298 }
1299
1300 #[test]
1301 fn predicate_inst_equal_with_regex_operator_from_separate_compiles() {
1302 let lhs = PredicateInst {
1303 path: FieldPath::HttpUriPath,
1304 op: CompiledOperator::Matches(Regex::new("^/").expect("compile a")),
1305 };
1306 let rhs = PredicateInst {
1307 path: FieldPath::HttpUriPath,
1308 op: CompiledOperator::Matches(Regex::new("^/").expect("compile b")),
1309 };
1310 assert_eq!(lhs, rhs);
1311 assert_eq!(hash_of(&lhs), hash_of(&rhs));
1312 }
1313
1314 #[test]
1315 fn predicate_inst_unequal_on_path_difference() {
1316 let value = CompiledValue::Str(Arc::<str>::from("x"));
1317 let a =
1318 PredicateInst { path: FieldPath::HttpUriPath, op: CompiledOperator::Equals(value.clone()) };
1319 let b = PredicateInst { path: FieldPath::HttpUriQuery, op: CompiledOperator::Equals(value) };
1320 assert_ne!(a, b);
1321 }
1322
1323 #[test]
1324 fn predicate_view_variants_construct() {
1325 let conn = make_conn();
1326 let peek_bytes: &[u8] = b"\x16\x03\x01";
1327 let l4 = PredicateView::L4 { conn: &conn, peek: Some(peek_bytes) };
1328 match l4 {
1329 PredicateView::L4 { peek, .. } => assert_eq!(peek.map(<[u8]>::len), Some(3)),
1330 PredicateView::L7Req { .. } => panic!("wrong variant"),
1331 }
1332
1333 let conn2 = make_conn();
1334 let req: Request =
1335 http::Request::builder().method("GET").uri("/").body(Body::Empty).expect("build request");
1336 let l7 = PredicateView::L7Req { conn: &conn2, req: &req };
1337 match l7 {
1338 PredicateView::L7Req { .. } => {}
1339 PredicateView::L4 { .. } => panic!("wrong variant"),
1340 }
1341 }
1342
1343 fn parse_predicate(v: serde_json::Value) -> Result<Predicate, serde_json::Error> {
1347 serde_json::from_value(v)
1348 }
1349
1350 fn expect_check(p: &Predicate) -> &CheckMap {
1351 match p {
1352 Predicate::Check(c) => c,
1353 other => panic!("expected Predicate::Check, got {other:?}"),
1354 }
1355 }
1356
1357 #[test]
1358 fn parse_any_of_happy_path() {
1359 let raw = serde_json::json!({
1360 "any_of": [
1361 { "tls.sni": { "equals": "a" } },
1362 { "tls.sni": { "equals": "b" } },
1363 ],
1364 });
1365 let p = parse_predicate(raw).expect("parse any_of");
1366 let Predicate::AnyOf(AnyOfP { any_of }) = p else {
1367 panic!("expected AnyOf");
1368 };
1369 assert_eq!(any_of.len(), 2);
1370 let c0 = expect_check(&any_of[0]);
1371 let c1 = expect_check(&any_of[1]);
1372 assert_eq!(c0.path, FieldPath::TlsSni);
1373 assert_eq!(c1.path, FieldPath::TlsSni);
1374 match (&c0.op, &c1.op) {
1375 (Operator::Equals(Value::Str(a)), Operator::Equals(Value::Str(b))) => {
1376 assert_eq!(a, "a");
1377 assert_eq!(b, "b");
1378 }
1379 (a, b) => panic!("unexpected ops: {a:?} / {b:?}"),
1380 }
1381 }
1382
1383 #[test]
1384 fn parse_not_happy_path() {
1385 let raw = serde_json::json!({
1386 "not": { "tls.sni": { "equals": "internal" } },
1387 });
1388 let p = parse_predicate(raw).expect("parse not");
1389 let Predicate::Not(NotP { not }) = p else {
1390 panic!("expected Not");
1391 };
1392 let inner = expect_check(¬);
1393 assert_eq!(inner.path, FieldPath::TlsSni);
1394 match &inner.op {
1395 Operator::Equals(Value::Str(s)) => assert_eq!(s, "internal"),
1396 other => panic!("unexpected op: {other:?}"),
1397 }
1398 }
1399
1400 #[test]
1401 fn parse_all_of_happy_path() {
1402 let raw = serde_json::json!({
1403 "all_of": [
1404 { "http.header.upgrade": { "equals": "websocket" } },
1405 { "http.uri.path": { "prefix": "/ws" } },
1406 ],
1407 });
1408 let p = parse_predicate(raw).expect("parse all_of");
1409 let Predicate::AllOf(AllOfP { all_of }) = p else {
1410 panic!("expected AllOf");
1411 };
1412 assert_eq!(all_of.len(), 2);
1413 let c0 = expect_check(&all_of[0]);
1414 let c1 = expect_check(&all_of[1]);
1415 assert_eq!(c0.path, FieldPath::HttpHeader(Arc::from("upgrade")));
1416 assert_eq!(c1.path, FieldPath::HttpUriPath);
1417 }
1418
1419 #[test]
1420 fn parse_all_of_empty_array_parses() {
1421 let raw = serde_json::json!({ "all_of": [] });
1424 let p = parse_predicate(raw).expect("empty all_of parses");
1425 let Predicate::AllOf(AllOfP { all_of }) = p else {
1426 panic!("expected AllOf");
1427 };
1428 assert!(all_of.is_empty());
1429 }
1430
1431 #[test]
1432 fn parse_all_of_nested_with_check_and_any_of() {
1433 let raw = serde_json::json!({
1434 "all_of": [
1435 { "tls.sni": { "equals": "api.example.com" } },
1436 { "any_of": [
1437 { "remote.ip": { "cidr": "10.0.0.0/8" } },
1438 { "remote.ip": { "cidr": "192.168.0.0/16" } },
1439 ]},
1440 ],
1441 });
1442 let p = parse_predicate(raw).expect("parse nested all_of/any_of");
1443 let Predicate::AllOf(AllOfP { all_of }) = p else {
1444 panic!("expected AllOf");
1445 };
1446 assert_eq!(all_of.len(), 2);
1447 assert!(matches!(all_of[0], Predicate::Check(_)));
1448 assert!(matches!(all_of[1], Predicate::AnyOf(_)));
1449 }
1450
1451 #[test]
1452 fn parse_all_of_with_extra_key_is_rejected() {
1453 let raw = serde_json::json!({
1455 "all_of": [ { "tls.sni": { "equals": "a" } } ],
1456 "extra": "unwanted",
1457 });
1458 let err = parse_predicate(raw).expect_err("must reject extra key on all_of");
1459 let _ = err.to_string();
1460 }
1461
1462 #[test]
1463 fn parse_http_header_all_of_is_a_check_not_combinator() {
1464 let raw = serde_json::json!({ "http.header.all_of": { "equals": "x" } });
1467 let p = parse_predicate(raw).expect("parse http.header.all_of");
1468 let c = expect_check(&p);
1469 assert_eq!(c.path, FieldPath::HttpHeader(Arc::from("all_of")));
1470 }
1471
1472 #[test]
1473 fn parse_check_across_representative_paths() {
1474 let cases = [
1475 (serde_json::json!({ "tls.sni": { "equals": "api.example.com" } }), FieldPath::TlsSni),
1476 (serde_json::json!({ "remote.port": { "gt": 1024 } }), FieldPath::RemotePort),
1477 (serde_json::json!({ "http.method": { "equals": "GET" } }), FieldPath::HttpMethod),
1478 (serde_json::json!({ "http.uri.path": { "prefix": "/api" } }), FieldPath::HttpUriPath),
1479 (
1480 serde_json::json!({ "http.header.host": { "equals": "a.example.com" } }),
1481 FieldPath::HttpHeader(Arc::from("host")),
1482 ),
1483 (serde_json::json!({ "http.body": { "contains": "hello" } }), FieldPath::HttpBody),
1484 ];
1485 for (raw, expected_path) in cases {
1486 let p = parse_predicate(raw.clone()).unwrap_or_else(|e| panic!("parse {raw}: {e}"));
1487 let c = expect_check(&p);
1488 assert_eq!(c.path, expected_path, "input: {raw}");
1489 }
1490 }
1491
1492 #[test]
1493 fn parse_any_of_with_extra_key_is_rejected() {
1494 let raw = serde_json::json!({
1497 "any_of": [ { "tls.sni": { "equals": "a" } } ],
1498 "extra": true,
1499 });
1500 let err = parse_predicate(raw).expect_err("must reject extra key on any_of");
1501 let _ = err.to_string();
1502 }
1503
1504 #[test]
1505 fn parse_http_header_any_of_is_a_check_not_combinator() {
1506 let raw = serde_json::json!({ "http.header.any_of": { "equals": "x" } });
1509 let p = parse_predicate(raw).expect("parse");
1510 let c = expect_check(&p);
1511 assert_eq!(c.path, FieldPath::HttpHeader(Arc::from("any_of")));
1512 }
1513
1514 #[test]
1515 fn parse_uppercase_field_path_suggests_lowercase() {
1516 let raw = serde_json::json!({ "http.header.Host": { "equals": "x" } });
1517 let err = parse_predicate(raw).expect_err("uppercase must fail");
1518 let msg = err.to_string();
1519 assert!(msg.contains("http.header.Host"), "error mentions offending input: {msg}");
1520 assert!(msg.contains("did you mean"), "error includes suggestion phrase: {msg}");
1521 assert!(msg.contains("http.header.host"), "error contains lowercased form: {msg}");
1522 }
1523
1524 #[test]
1525 fn parse_multi_key_check_is_rejected() {
1526 let raw = serde_json::json!({
1527 "http.uri.path": { "matches": "^/" },
1528 "http.method": { "equals": "GET" },
1529 });
1530 let err = parse_predicate(raw).expect_err("multi-key check must fail");
1531 let _ = err.to_string();
1532 }
1533
1534 #[test]
1535 fn parse_empty_http_header_name_is_rejected() {
1536 let raw = serde_json::json!({ "http.header.": { "equals": "x" } });
1537 let err = parse_predicate(raw).expect_err("empty header name must fail");
1538 let _ = err.to_string();
1539 }
1540
1541 #[test]
1542 fn parse_unknown_field_path_is_rejected_with_name() {
1543 let raw = serde_json::json!({ "http.nope": { "equals": "x" } });
1544 let err = parse_predicate(raw).expect_err("unknown path must fail");
1545 let msg = err.to_string();
1546 assert!(msg.contains("http.nope"), "error mentions offending path: {msg}");
1547 }
1548
1549 fn parse_op(v: serde_json::Value) -> Operator {
1550 let mut map = serde_json::Map::new();
1551 map.insert("tls.sni".to_string(), v);
1552 let raw = serde_json::Value::Object(map);
1553 match parse_predicate(raw).expect("parse check") {
1554 Predicate::Check(c) => c.op,
1555 other => panic!("expected Check, got {other:?}"),
1556 }
1557 }
1558
1559 #[test]
1560 fn operator_equals_and_not_equals_on_string() {
1561 let eq = parse_op(serde_json::json!({ "equals": "api" }));
1562 match eq {
1563 Operator::Equals(Value::Str(s)) => assert_eq!(s, "api"),
1564 other => panic!("expected equals/str: {other:?}"),
1565 }
1566 let neq = parse_op(serde_json::json!({ "not_equals": "api" }));
1567 match neq {
1568 Operator::NotEquals(Value::Str(s)) => assert_eq!(s, "api"),
1569 other => panic!("expected not_equals/str: {other:?}"),
1570 }
1571 }
1572
1573 #[test]
1574 fn operator_contains_and_not_contains_on_string() {
1575 let c = parse_op(serde_json::json!({ "contains": "foo" }));
1576 match c {
1577 Operator::Contains(Value::Str(s)) => assert_eq!(s, "foo"),
1578 other => panic!("expected contains/str: {other:?}"),
1579 }
1580 let nc = parse_op(serde_json::json!({ "not_contains": "foo" }));
1581 match nc {
1582 Operator::NotContains(Value::Str(s)) => assert_eq!(s, "foo"),
1583 other => panic!("expected not_contains/str: {other:?}"),
1584 }
1585 }
1586
1587 #[test]
1588 fn operator_prefix_and_suffix_on_string() {
1589 let p = parse_op(serde_json::json!({ "prefix": "/api" }));
1590 match p {
1591 Operator::Prefix(Value::Str(s)) => assert_eq!(s, "/api"),
1592 other => panic!("expected prefix/str: {other:?}"),
1593 }
1594 let s = parse_op(serde_json::json!({ "suffix": ".json" }));
1595 match s {
1596 Operator::Suffix(Value::Str(v)) => assert_eq!(v, ".json"),
1597 other => panic!("expected suffix/str: {other:?}"),
1598 }
1599 }
1600
1601 #[test]
1602 fn operator_matches_carries_pattern_source() {
1603 let op = parse_op(serde_json::json!({ "matches": "^/api/v\\d+" }));
1604 match op {
1605 Operator::Matches(pattern) => assert_eq!(pattern, "^/api/v\\d+"),
1606 other => panic!("expected matches: {other:?}"),
1607 }
1608 }
1609
1610 #[test]
1611 fn operator_in_and_not_in_accept_mixed_scalar_types() {
1612 let op = parse_op(serde_json::json!({ "in": ["foo", 42] }));
1613 let Operator::In(xs) = op else {
1614 panic!("expected in");
1615 };
1616 assert_eq!(xs.len(), 2);
1617 assert_eq!(xs[0], Value::Str("foo".into()));
1618 assert_eq!(xs[1], Value::Int(42));
1619 let op2 = parse_op(serde_json::json!({ "not_in": ["bar", 7] }));
1620 let Operator::NotIn(ys) = op2 else {
1621 panic!("expected not_in");
1622 };
1623 assert_eq!(ys.len(), 2);
1624 assert_eq!(ys[0], Value::Str("bar".into()));
1625 assert_eq!(ys[1], Value::Int(7));
1626 }
1627
1628 #[test]
1629 fn operator_numeric_comparisons() {
1630 assert!(matches!(parse_op(serde_json::json!({ "gt": 10 })), Operator::Gt(10)));
1631 assert!(matches!(parse_op(serde_json::json!({ "gte": 10 })), Operator::Gte(10)));
1632 assert!(matches!(parse_op(serde_json::json!({ "lt": 10 })), Operator::Lt(10)));
1633 assert!(matches!(parse_op(serde_json::json!({ "lte": 10 })), Operator::Lte(10)));
1634 }
1635
1636 #[test]
1637 fn operator_cidr_carries_source_string() {
1638 let op = parse_op(serde_json::json!({ "cidr": "10.0.0.0/8" }));
1639 match op {
1640 Operator::Cidr(s) => assert_eq!(s, "10.0.0.0/8"),
1641 other => panic!("expected cidr: {other:?}"),
1642 }
1643 }
1644
1645 #[test]
1646 fn value_untagged_priority_bool_before_str() {
1647 let op_t = parse_op(serde_json::json!({ "equals": true }));
1650 assert!(matches!(op_t, Operator::Equals(Value::Bool(true))));
1651 let op_f = parse_op(serde_json::json!({ "equals": false }));
1652 assert!(matches!(op_f, Operator::Equals(Value::Bool(false))));
1653 }
1654
1655 #[test]
1656 fn value_untagged_priority_int_before_str() {
1657 let op = parse_op(serde_json::json!({ "equals": 42 }));
1659 assert!(matches!(op, Operator::Equals(Value::Int(42))));
1660 }
1661
1662 #[test]
1663 fn value_untagged_json_string_stays_str() {
1664 let op = parse_op(serde_json::json!({ "equals": "42" }));
1667 match op {
1668 Operator::Equals(Value::Str(s)) => assert_eq!(s, "42"),
1669 other => panic!("expected equals/str(\"42\"): {other:?}"),
1670 }
1671 }
1672
1673 #[test]
1674 fn regex_pattern_exactly_at_limit_parses() {
1675 assert_eq!(REGEX_PATTERN_MAX_BYTES, 4 * 1024);
1677 let pattern = "a".repeat(REGEX_PATTERN_MAX_BYTES);
1678 let raw = serde_json::json!({ "http.uri.path": { "matches": pattern } });
1679 let p = parse_predicate(raw).expect("4 KiB pattern parses");
1680 let c = expect_check(&p);
1681 match &c.op {
1682 Operator::Matches(src) => assert_eq!(src.len(), REGEX_PATTERN_MAX_BYTES),
1683 other => panic!("expected matches: {other:?}"),
1684 }
1685 }
1686
1687 #[test]
1688 fn regex_pattern_over_limit_rejected_with_limit_in_message() {
1689 let pattern = "a".repeat(REGEX_PATTERN_MAX_BYTES + 1);
1690 let raw = serde_json::json!({ "http.uri.path": { "matches": pattern } });
1691 let err = parse_predicate(raw).expect_err("over-limit pattern must fail");
1692 let msg = err.to_string();
1693 assert!(
1694 msg.contains(®EX_PATTERN_MAX_BYTES.to_string()),
1695 "error mentions the limit ({REGEX_PATTERN_MAX_BYTES}): {msg}",
1696 );
1697 }
1698
1699 fn value_round_trip(v: &CompiledValue) -> CompiledValue {
1706 let encoded = serde_json::to_string(v).expect("serialize value");
1707 serde_json::from_str(&encoded).expect("deserialize value")
1708 }
1709
1710 #[test]
1711 fn compiled_value_str_round_trip_including_empty() {
1712 let non_empty = CompiledValue::Str(Arc::<str>::from("x"));
1713 assert_eq!(value_round_trip(&non_empty), non_empty);
1714 let empty = CompiledValue::Str(Arc::<str>::from(""));
1715 assert_eq!(value_round_trip(&empty), empty);
1716 }
1717
1718 #[test]
1719 fn compiled_value_bytes_round_trip_including_empty_and_binary() {
1720 let hello = CompiledValue::Bytes(Bytes::from_static(b"hello"));
1721 assert_eq!(value_round_trip(&hello), hello);
1722 let empty = CompiledValue::Bytes(Bytes::new());
1723 assert_eq!(value_round_trip(&empty), empty);
1724 let binary = CompiledValue::Bytes(Bytes::from_static(&[0xff, 0x00, 0x13]));
1725 assert_eq!(value_round_trip(&binary), binary);
1726 }
1727
1728 #[test]
1729 fn compiled_value_int_round_trip_including_extremes() {
1730 for i in [0_i64, i64::MIN, i64::MAX] {
1731 let v = CompiledValue::Int(i);
1732 assert_eq!(value_round_trip(&v), v);
1733 }
1734 }
1735
1736 #[test]
1737 fn compiled_value_bool_round_trip_both_variants() {
1738 for b in [true, false] {
1739 let v = CompiledValue::Bool(b);
1740 assert_eq!(value_round_trip(&v), v);
1741 }
1742 }
1743
1744 #[test]
1745 fn compiled_value_addr_round_trip_v4_and_v6() {
1746 let v4 = CompiledValue::Addr(Ipv4Addr::LOCALHOST.into());
1747 assert_eq!(value_round_trip(&v4), v4);
1748 let v6 = CompiledValue::Addr(Ipv6Addr::LOCALHOST.into());
1749 assert_eq!(value_round_trip(&v6), v6);
1750 }
1751
1752 #[test]
1753 fn compiled_value_bytes_emits_standard_base64_literal() {
1754 let v = CompiledValue::Bytes(Bytes::from_static(b"hello"));
1758 let encoded = serde_json::to_string(&v).expect("serialize");
1759 assert_eq!(encoded, r#"{"bytes":"aGVsbG8="}"#);
1760 }
1761
1762 fn op_round_trip(op: &CompiledOperator) -> CompiledOperator {
1763 let encoded = serde_json::to_string(op).expect("serialize op");
1764 serde_json::from_str(&encoded).expect("deserialize op")
1765 }
1766
1767 #[test]
1768 fn compiled_operator_equals_and_not_equals_round_trip() {
1769 let eq = CompiledOperator::Equals(CompiledValue::Str(Arc::<str>::from("x")));
1770 assert_eq!(op_round_trip(&eq), eq);
1771 let neq = CompiledOperator::NotEquals(CompiledValue::Str(Arc::<str>::from("x")));
1772 assert_eq!(op_round_trip(&neq), neq);
1773 }
1774
1775 #[test]
1776 fn compiled_operator_bytes_variants_round_trip() {
1777 let payload = Bytes::from_static(b"hello");
1778 let ops = [
1779 CompiledOperator::Contains(payload.clone()),
1780 CompiledOperator::NotContains(payload.clone()),
1781 CompiledOperator::Prefix(payload.clone()),
1782 CompiledOperator::Suffix(payload),
1783 ];
1784 for op in ops {
1785 assert_eq!(op_round_trip(&op), op);
1786 }
1787 }
1788
1789 #[test]
1790 fn compiled_operator_matches_round_trip_preserves_pattern_source() {
1791 let op = CompiledOperator::Matches(Regex::new("^/api/v[0-9]+").expect("compile"));
1792 let decoded = op_round_trip(&op);
1793 assert_eq!(decoded, op);
1795 match decoded {
1796 CompiledOperator::Matches(r) => assert_eq!(r.as_str(), "^/api/v[0-9]+"),
1797 other => panic!("expected matches, got {other:?}"),
1798 }
1799 }
1800
1801 #[test]
1802 fn compiled_operator_in_and_not_in_round_trip_mixed_values() {
1803 let xs = vec![CompiledValue::Str(Arc::<str>::from("a")), CompiledValue::Int(42)];
1804 let in_op = CompiledOperator::In(xs.clone());
1805 assert_eq!(op_round_trip(&in_op), in_op);
1806 let not_in_op = CompiledOperator::NotIn(xs);
1807 assert_eq!(op_round_trip(¬_in_op), not_in_op);
1808 }
1809
1810 #[test]
1811 fn compiled_operator_numeric_comparisons_round_trip() {
1812 let ops = [
1813 CompiledOperator::Gt(100),
1814 CompiledOperator::Gte(100),
1815 CompiledOperator::Lt(100),
1816 CompiledOperator::Lte(100),
1817 ];
1818 for op in ops {
1819 assert_eq!(op_round_trip(&op), op);
1820 }
1821 }
1822
1823 #[test]
1824 fn compiled_operator_cidr_round_trip_preserves_canonical_form() {
1825 let op = CompiledOperator::Cidr(IpNet::from_str("10.0.0.0/8").expect("parse"));
1826 assert_eq!(op_round_trip(&op), op);
1827 }
1828
1829 #[test]
1830 fn compiled_operator_matches_with_invalid_regex_is_rejected() {
1831 let raw = r#"{"matches":"["}"#;
1835 let err = serde_json::from_str::<CompiledOperator>(raw)
1836 .expect_err("invalid regex must fail to deserialize");
1837 let msg = err.to_string();
1838 assert!(msg.contains('['), "error mentions offending regex source: {msg}");
1839 }
1840
1841 #[test]
1842 fn predicate_inst_pins_exact_wire_shape_for_http_header_equals() {
1843 let inst = PredicateInst {
1844 path: FieldPath::HttpHeader(Arc::from("host")),
1845 op: CompiledOperator::Equals(CompiledValue::Str(Arc::<str>::from("example.com"))),
1846 };
1847 let encoded = serde_json::to_string(&inst).expect("serialize");
1848 assert_eq!(encoded, r#"{"path":{"http_header":"host"},"op":{"equals":{"str":"example.com"}}}"#,);
1849 let decoded: PredicateInst = serde_json::from_str(&encoded).expect("deserialize");
1850 assert_eq!(decoded, inst);
1851 }
1852
1853 #[test]
1854 fn predicate_inst_round_trip_with_regex_operator() {
1855 let inst = PredicateInst {
1856 path: FieldPath::HttpUriPath,
1857 op: CompiledOperator::Matches(Regex::new("^/api").expect("compile")),
1858 };
1859 let encoded = serde_json::to_string(&inst).expect("serialize");
1860 let decoded: PredicateInst = serde_json::from_str(&encoded).expect("deserialize");
1861 assert_eq!(decoded, inst);
1862 }
1863
1864 fn http_header_equals(name: &str, value: &str) -> PredicateInst {
1871 PredicateInst {
1872 path: FieldPath::HttpHeader(Arc::from(name)),
1873 op: CompiledOperator::Equals(CompiledValue::Str(Arc::<str>::from(value))),
1874 }
1875 }
1876
1877 fn http_uri_path_equals(value: &str) -> PredicateInst {
1878 PredicateInst {
1879 path: FieldPath::HttpUriPath,
1880 op: CompiledOperator::Equals(CompiledValue::Str(Arc::<str>::from(value))),
1881 }
1882 }
1883
1884 fn http_uri_path_prefix(value: &str) -> PredicateInst {
1885 PredicateInst {
1886 path: FieldPath::HttpUriPath,
1887 op: CompiledOperator::Prefix(Bytes::copy_from_slice(value.as_bytes())),
1888 }
1889 }
1890
1891 fn tls_sni_equals(value: &str) -> PredicateInst {
1892 PredicateInst {
1893 path: FieldPath::TlsSni,
1894 op: CompiledOperator::Equals(CompiledValue::Str(Arc::<str>::from(value))),
1895 }
1896 }
1897
1898 fn conn_with_sni(sni: &str) -> Arc<ConnContext> {
1899 let conn = make_conn();
1900 *conn.tls.lock() = Some(crate::conn_context::TlsInfo {
1901 sni: Some(Arc::from(sni)),
1902 alpn: None,
1903 version: None,
1904 peer_cert: None,
1905 zero_rtt_used: false,
1906 });
1907 conn
1908 }
1909
1910 fn req_with_header(name: &str, value: &str) -> Request {
1911 http::Request::builder()
1912 .method("GET")
1913 .uri("/")
1914 .header(name, value)
1915 .body(Body::Empty)
1916 .expect("build req")
1917 }
1918
1919 fn req_with_uri(uri: &str) -> Request {
1920 http::Request::builder().method("GET").uri(uri).body(Body::Empty).expect("build req")
1921 }
1922
1923 #[test]
1924 fn predicate_test_http_header_equals_matches_when_present_and_equal() {
1925 let conn = make_conn();
1926 let req = req_with_header("upgrade", "websocket");
1927 let view = PredicateView::L7Req { conn: &conn, req: &req };
1928 assert!(http_header_equals("upgrade", "websocket").test(&view));
1929 }
1930
1931 #[test]
1932 fn predicate_test_http_header_equals_misses_when_header_absent() {
1933 let conn = make_conn();
1934 let req = req_with_header("host", "example.com");
1935 let view = PredicateView::L7Req { conn: &conn, req: &req };
1936 assert!(!http_header_equals("upgrade", "websocket").test(&view));
1937 }
1938
1939 #[test]
1940 fn predicate_test_http_header_equals_value_is_case_sensitive() {
1941 let conn = make_conn();
1946 let req = req_with_header("upgrade", "WebSocket");
1947 let view = PredicateView::L7Req { conn: &conn, req: &req };
1948 assert!(!http_header_equals("upgrade", "websocket").test(&view));
1949 }
1950
1951 #[test]
1952 fn predicate_test_http_header_equals_name_lookup_is_case_insensitive() {
1953 let conn = make_conn();
1959 let req = req_with_header("Upgrade", "websocket");
1960 let view = PredicateView::L7Req { conn: &conn, req: &req };
1961 assert!(http_header_equals("upgrade", "websocket").test(&view));
1962 }
1963
1964 #[test]
1965 fn predicate_test_http_header_equals_misses_on_l4_view() {
1966 let conn = make_conn();
1970 let view = PredicateView::L4 { conn: &conn, peek: None };
1971 assert!(!http_header_equals("upgrade", "websocket").test(&view));
1972 }
1973
1974 #[test]
1975 fn predicate_test_http_uri_path_equals_matches_exact() {
1976 let conn = make_conn();
1977 let req = req_with_uri("/api/v1/users");
1978 let view = PredicateView::L7Req { conn: &conn, req: &req };
1979 assert!(http_uri_path_equals("/api/v1/users").test(&view));
1980 }
1981
1982 #[test]
1983 fn predicate_test_http_uri_path_equals_misses_on_substring() {
1984 let conn = make_conn();
1988 let req = req_with_uri("/api/v1/users");
1989 let view = PredicateView::L7Req { conn: &conn, req: &req };
1990 assert!(!http_uri_path_equals("/api").test(&view));
1991 }
1992
1993 #[test]
1994 fn predicate_test_http_uri_path_prefix_matches_when_path_starts_with() {
1995 let conn = make_conn();
1996 let req = req_with_uri("/api/v1/users");
1997 let view = PredicateView::L7Req { conn: &conn, req: &req };
1998 assert!(http_uri_path_prefix("/api").test(&view));
1999 }
2000
2001 #[test]
2002 fn predicate_test_http_uri_path_prefix_misses_when_no_prefix() {
2003 let conn = make_conn();
2004 let req = req_with_uri("/admin");
2005 let view = PredicateView::L7Req { conn: &conn, req: &req };
2006 assert!(!http_uri_path_prefix("/api").test(&view));
2007 }
2008
2009 #[test]
2010 fn predicate_test_tls_sni_equals_matches_when_set() {
2011 let conn = conn_with_sni("api.example.com");
2015 let req = req_with_uri("/");
2016 let view = PredicateView::L7Req { conn: &conn, req: &req };
2017 assert!(tls_sni_equals("api.example.com").test(&view));
2018 }
2019
2020 #[test]
2021 fn predicate_test_tls_sni_equals_misses_when_unset() {
2022 let conn = make_conn();
2025 let req = req_with_uri("/");
2026 let view = PredicateView::L7Req { conn: &conn, req: &req };
2027 assert!(!tls_sni_equals("api.example.com").test(&view));
2028 }
2029
2030 #[test]
2031 fn predicate_test_tls_sni_equals_works_in_l4_view_too() {
2032 let conn = conn_with_sni("api.example.com");
2038 let view = PredicateView::L4 { conn: &conn, peek: None };
2039 assert!(tls_sni_equals("api.example.com").test(&view));
2040 }
2041
2042 fn pred(path: FieldPath, op: CompiledOperator) -> PredicateInst {
2049 PredicateInst { path, op }
2050 }
2051
2052 fn str_val(s: &str) -> CompiledValue {
2053 CompiledValue::Str(Arc::<str>::from(s))
2054 }
2055
2056 fn bytes_val(b: &[u8]) -> CompiledValue {
2057 CompiledValue::Bytes(Bytes::copy_from_slice(b))
2058 }
2059
2060 fn b(b: &[u8]) -> Bytes {
2061 Bytes::copy_from_slice(b)
2062 }
2063
2064 fn make_conn_with(remote: &str, local: &str) -> Arc<ConnContext> {
2065 Arc::new(ConnContext {
2066 id: ConnId(1),
2067 remote: remote.parse().expect("parse remote"),
2068 local: local.parse().expect("parse local"),
2069 transport: Transport::Tcp,
2070 entered_at: Instant::now(),
2071 tls: Mutex::new(None),
2072 http_version: OnceLock::new(),
2073 user: Mutex::new(http::Extensions::new()),
2074 })
2075 }
2076
2077 fn make_conn_with_transport(t: Transport) -> Arc<ConnContext> {
2078 Arc::new(ConnContext {
2079 id: ConnId(1),
2080 remote: "127.0.0.1:0".parse().expect("remote"),
2081 local: "127.0.0.1:0".parse().expect("local"),
2082 transport: t,
2083 entered_at: Instant::now(),
2084 tls: Mutex::new(None),
2085 http_version: OnceLock::new(),
2086 user: Mutex::new(http::Extensions::new()),
2087 })
2088 }
2089
2090 fn conn_with_tls_alpn(alpn: &[u8]) -> Arc<ConnContext> {
2091 let conn = make_conn();
2092 *conn.tls.lock() = Some(crate::conn_context::TlsInfo {
2093 sni: None,
2094 alpn: Some(Arc::from(alpn)),
2095 version: None,
2096 peer_cert: None,
2097 zero_rtt_used: false,
2098 });
2099 conn
2100 }
2101
2102 fn conn_with_tls_version(v: crate::conn_context::TlsVersion) -> Arc<ConnContext> {
2103 let conn = make_conn();
2104 *conn.tls.lock() = Some(crate::conn_context::TlsInfo {
2105 sni: None,
2106 alpn: None,
2107 version: Some(v),
2108 peer_cert: None,
2109 zero_rtt_used: false,
2110 });
2111 conn
2112 }
2113
2114 #[test]
2116 fn matrix_equality_str_happy_and_miss() {
2117 let conn = conn_with_sni("api.example.com");
2119 let v = PredicateView::L4 { conn: &conn, peek: None };
2120 assert!(pred(FieldPath::TlsSni, CompiledOperator::Equals(str_val("api.example.com"))).test(&v));
2121 assert!(
2122 !pred(FieldPath::TlsSni, CompiledOperator::Equals(str_val("other.example.com"))).test(&v)
2123 );
2124 assert!(
2125 pred(FieldPath::TlsSni, CompiledOperator::NotEquals(str_val("other.example.com"))).test(&v)
2126 );
2127 assert!(
2128 !pred(FieldPath::TlsSni, CompiledOperator::NotEquals(str_val("api.example.com"))).test(&v)
2129 );
2130 }
2131
2132 #[test]
2133 fn matrix_equality_bytes_happy_and_miss() {
2134 let conn = conn_with_tls_alpn(b"h2");
2136 let v = PredicateView::L4 { conn: &conn, peek: None };
2137 assert!(pred(FieldPath::TlsAlpn, CompiledOperator::Equals(bytes_val(b"h2"))).test(&v));
2138 assert!(!pred(FieldPath::TlsAlpn, CompiledOperator::Equals(bytes_val(b"http/1.1"))).test(&v));
2139 assert!(pred(FieldPath::TlsAlpn, CompiledOperator::NotEquals(bytes_val(b"http/1.1"))).test(&v));
2140 assert!(!pred(FieldPath::TlsAlpn, CompiledOperator::NotEquals(bytes_val(b"h2"))).test(&v));
2141 }
2142
2143 #[test]
2144 fn matrix_equality_int_happy_and_miss() {
2145 let conn = make_conn_with("127.0.0.1:9090", "127.0.0.1:80");
2146 let v = PredicateView::L4 { conn: &conn, peek: None };
2147 assert!(
2148 pred(FieldPath::RemotePort, CompiledOperator::Equals(CompiledValue::Int(9090))).test(&v)
2149 );
2150 assert!(
2151 !pred(FieldPath::RemotePort, CompiledOperator::Equals(CompiledValue::Int(81))).test(&v)
2152 );
2153 assert!(
2154 pred(FieldPath::RemotePort, CompiledOperator::NotEquals(CompiledValue::Int(81))).test(&v)
2155 );
2156 assert!(
2157 !pred(FieldPath::RemotePort, CompiledOperator::NotEquals(CompiledValue::Int(9090))).test(&v)
2158 );
2159 }
2160
2161 #[test]
2162 fn matrix_equality_addr_happy_and_miss() {
2163 let conn = make_conn_with("10.0.0.5:55555", "127.0.0.1:80");
2164 let v = PredicateView::L4 { conn: &conn, peek: None };
2165 let ten: std::net::IpAddr = "10.0.0.5".parse().unwrap();
2166 let other: std::net::IpAddr = "10.0.0.6".parse().unwrap();
2167 assert!(pred(FieldPath::RemoteIp, CompiledOperator::Equals(CompiledValue::Addr(ten))).test(&v));
2168 assert!(
2169 !pred(FieldPath::RemoteIp, CompiledOperator::Equals(CompiledValue::Addr(other))).test(&v)
2170 );
2171 assert!(
2172 pred(FieldPath::RemoteIp, CompiledOperator::NotEquals(CompiledValue::Addr(other))).test(&v)
2173 );
2174 assert!(
2175 !pred(FieldPath::RemoteIp, CompiledOperator::NotEquals(CompiledValue::Addr(ten))).test(&v)
2176 );
2177 }
2178
2179 #[test]
2180 fn matrix_equality_enum_transport_happy_and_miss() {
2181 let tcp = make_conn_with_transport(Transport::Tcp);
2182 let udp = make_conn_with_transport(Transport::Udp);
2183 let v_tcp = PredicateView::L4 { conn: &tcp, peek: None };
2184 let v_udp = PredicateView::L4 { conn: &udp, peek: None };
2185 assert!(pred(FieldPath::Transport, CompiledOperator::Equals(str_val("tcp"))).test(&v_tcp));
2186 assert!(!pred(FieldPath::Transport, CompiledOperator::Equals(str_val("udp"))).test(&v_tcp));
2187 assert!(pred(FieldPath::Transport, CompiledOperator::Equals(str_val("udp"))).test(&v_udp));
2188 }
2189
2190 #[test]
2191 fn matrix_equality_enum_tls_version_happy_and_miss() {
2192 let conn = conn_with_tls_version(crate::conn_context::TlsVersion::Tls13);
2193 let v = PredicateView::L4 { conn: &conn, peek: None };
2194 assert!(pred(FieldPath::TlsVersion, CompiledOperator::Equals(str_val("1.3"))).test(&v));
2195 assert!(!pred(FieldPath::TlsVersion, CompiledOperator::Equals(str_val("1.2"))).test(&v));
2196 assert!(pred(FieldPath::TlsVersion, CompiledOperator::NotEquals(str_val("1.2"))).test(&v));
2197 }
2198
2199 #[test]
2200 fn matrix_equality_enum_tls_version_misses_when_absent() {
2201 let conn = make_conn();
2203 let v = PredicateView::L4 { conn: &conn, peek: None };
2204 assert!(!pred(FieldPath::TlsVersion, CompiledOperator::Equals(str_val("1.3"))).test(&v));
2205 assert!(!pred(FieldPath::TlsVersion, CompiledOperator::NotEquals(str_val("1.3"))).test(&v));
2207 }
2208
2209 #[test]
2210 fn matrix_equality_enum_http_method_happy_and_miss() {
2211 let conn = make_conn();
2212 let req = http::Request::builder().method("POST").uri("/").body(Body::Empty).unwrap();
2213 let v = PredicateView::L7Req { conn: &conn, req: &req };
2214 assert!(pred(FieldPath::HttpMethod, CompiledOperator::Equals(str_val("POST"))).test(&v));
2215 assert!(!pred(FieldPath::HttpMethod, CompiledOperator::Equals(str_val("GET"))).test(&v));
2216 assert!(pred(FieldPath::HttpMethod, CompiledOperator::NotEquals(str_val("GET"))).test(&v));
2217 }
2218
2219 #[test]
2221 fn matrix_in_list_str_happy_and_miss() {
2222 let conn = conn_with_sni("api.example.com");
2223 let v = PredicateView::L4 { conn: &conn, peek: None };
2224 let list = vec![str_val("a.example.com"), str_val("api.example.com")];
2225 assert!(pred(FieldPath::TlsSni, CompiledOperator::In(list.clone())).test(&v));
2226 let list_miss = vec![str_val("a.example.com"), str_val("b.example.com")];
2227 assert!(!pred(FieldPath::TlsSni, CompiledOperator::In(list_miss.clone())).test(&v));
2228 assert!(pred(FieldPath::TlsSni, CompiledOperator::NotIn(list_miss)).test(&v));
2229 assert!(!pred(FieldPath::TlsSni, CompiledOperator::NotIn(list)).test(&v));
2230 }
2231
2232 #[test]
2233 fn matrix_in_list_bytes_happy_and_miss() {
2234 let conn = conn_with_tls_alpn(b"h2");
2235 let v = PredicateView::L4 { conn: &conn, peek: None };
2236 let list = vec![bytes_val(b"http/1.1"), bytes_val(b"h2")];
2237 assert!(pred(FieldPath::TlsAlpn, CompiledOperator::In(list.clone())).test(&v));
2238 let list_miss = vec![bytes_val(b"http/1.0"), bytes_val(b"http/1.1")];
2239 assert!(!pred(FieldPath::TlsAlpn, CompiledOperator::In(list_miss.clone())).test(&v));
2240 assert!(pred(FieldPath::TlsAlpn, CompiledOperator::NotIn(list_miss)).test(&v));
2241 }
2242
2243 #[test]
2244 fn matrix_in_list_int_happy_and_miss() {
2245 let conn = make_conn_with("127.0.0.1:443", "127.0.0.1:80");
2246 let v = PredicateView::L4 { conn: &conn, peek: None };
2247 let in_list = vec![CompiledValue::Int(80), CompiledValue::Int(443)];
2248 assert!(pred(FieldPath::RemotePort, CompiledOperator::In(in_list.clone())).test(&v));
2249 let miss_list = vec![CompiledValue::Int(80), CompiledValue::Int(81)];
2250 assert!(!pred(FieldPath::RemotePort, CompiledOperator::In(miss_list.clone())).test(&v));
2251 assert!(pred(FieldPath::RemotePort, CompiledOperator::NotIn(miss_list)).test(&v));
2252 }
2253
2254 #[test]
2255 fn matrix_in_list_addr_happy_and_miss_mixed_family() {
2256 let conn = make_conn_with("10.0.0.5:55555", "127.0.0.1:80");
2257 let v = PredicateView::L4 { conn: &conn, peek: None };
2258 let v4: std::net::IpAddr = "10.0.0.5".parse().unwrap();
2259 let v6: std::net::IpAddr = "::1".parse().unwrap();
2260 let list = vec![CompiledValue::Addr(v6), CompiledValue::Addr(v4)];
2261 assert!(pred(FieldPath::RemoteIp, CompiledOperator::In(list.clone())).test(&v));
2262 let miss = vec![CompiledValue::Addr(v6)];
2263 assert!(!pred(FieldPath::RemoteIp, CompiledOperator::In(miss.clone())).test(&v));
2264 assert!(pred(FieldPath::RemoteIp, CompiledOperator::NotIn(miss)).test(&v));
2265 }
2266
2267 #[test]
2268 fn matrix_in_list_enum_transport_happy_and_miss() {
2269 let conn = make_conn_with_transport(Transport::Udp);
2270 let v = PredicateView::L4 { conn: &conn, peek: None };
2271 let list = vec![str_val("tcp"), str_val("udp")];
2272 assert!(pred(FieldPath::Transport, CompiledOperator::In(list)).test(&v));
2273 let miss = vec![str_val("tcp")];
2274 assert!(!pred(FieldPath::Transport, CompiledOperator::In(miss.clone())).test(&v));
2275 assert!(pred(FieldPath::Transport, CompiledOperator::NotIn(miss)).test(&v));
2276 }
2277
2278 #[test]
2280 fn matrix_substring_on_str_happy_and_miss() {
2281 let conn = make_conn();
2282 let req =
2283 http::Request::builder().method("GET").uri("/api/v1/users").body(Body::Empty).unwrap();
2284 let v = PredicateView::L7Req { conn: &conn, req: &req };
2285 assert!(pred(FieldPath::HttpUriPath, CompiledOperator::Contains(b(b"/v1/"))).test(&v));
2286 assert!(!pred(FieldPath::HttpUriPath, CompiledOperator::Contains(b(b"/v2/"))).test(&v));
2287 assert!(pred(FieldPath::HttpUriPath, CompiledOperator::NotContains(b(b"/v2/"))).test(&v));
2288 assert!(!pred(FieldPath::HttpUriPath, CompiledOperator::NotContains(b(b"/v1/"))).test(&v));
2289 }
2290
2291 #[test]
2292 fn matrix_substring_on_bytes_happy_and_miss() {
2293 let conn = conn_with_tls_alpn(b"http/1.1");
2294 let v = PredicateView::L4 { conn: &conn, peek: None };
2295 assert!(pred(FieldPath::TlsAlpn, CompiledOperator::Contains(b(b"/1."))).test(&v));
2296 assert!(!pred(FieldPath::TlsAlpn, CompiledOperator::Contains(b(b"/2."))).test(&v));
2297 assert!(pred(FieldPath::TlsAlpn, CompiledOperator::NotContains(b(b"/2."))).test(&v));
2298 }
2299
2300 #[test]
2302 fn matrix_prefix_suffix_on_str_happy_and_miss() {
2303 let conn = make_conn();
2304 let req =
2305 http::Request::builder().method("GET").uri("/api/file.json?q=1").body(Body::Empty).unwrap();
2306 let v = PredicateView::L7Req { conn: &conn, req: &req };
2307 assert!(pred(FieldPath::HttpUriPath, CompiledOperator::Prefix(b(b"/api"))).test(&v));
2308 assert!(!pred(FieldPath::HttpUriPath, CompiledOperator::Prefix(b(b"/admin"))).test(&v));
2309 assert!(pred(FieldPath::HttpUriPath, CompiledOperator::Suffix(b(b".json"))).test(&v));
2310 assert!(!pred(FieldPath::HttpUriPath, CompiledOperator::Suffix(b(b".html"))).test(&v));
2311 }
2312
2313 #[test]
2314 fn matrix_prefix_suffix_on_bytes_happy_and_miss() {
2315 let conn = conn_with_tls_alpn(b"http/1.1");
2316 let v = PredicateView::L4 { conn: &conn, peek: None };
2317 assert!(pred(FieldPath::TlsAlpn, CompiledOperator::Prefix(b(b"http"))).test(&v));
2318 assert!(!pred(FieldPath::TlsAlpn, CompiledOperator::Prefix(b(b"h2"))).test(&v));
2319 assert!(pred(FieldPath::TlsAlpn, CompiledOperator::Suffix(b(b"1.1"))).test(&v));
2320 assert!(!pred(FieldPath::TlsAlpn, CompiledOperator::Suffix(b(b"2.0"))).test(&v));
2321 }
2322
2323 #[test]
2325 fn matrix_regex_matches_on_str_happy_and_miss() {
2326 let conn = make_conn();
2327 let req =
2328 http::Request::builder().method("GET").uri("/api/v3/orders").body(Body::Empty).unwrap();
2329 let v = PredicateView::L7Req { conn: &conn, req: &req };
2330 let re = Regex::new(r"^/api/v\d+/orders").expect("compile regex");
2331 assert!(pred(FieldPath::HttpUriPath, CompiledOperator::Matches(re)).test(&v));
2332 let re_miss = Regex::new(r"^/admin").expect("compile regex");
2333 assert!(!pred(FieldPath::HttpUriPath, CompiledOperator::Matches(re_miss)).test(&v));
2334 }
2335
2336 #[test]
2337 fn matrix_regex_matches_on_header_happy_and_miss() {
2338 let conn = make_conn();
2339 let req = http::Request::builder()
2340 .method("GET")
2341 .uri("/")
2342 .header("user-agent", "Mozilla/5.0 (Macintosh; Intel)")
2343 .body(Body::Empty)
2344 .unwrap();
2345 let v = PredicateView::L7Req { conn: &conn, req: &req };
2346 let re = Regex::new(r"(?i)mozilla").expect("compile");
2347 assert!(
2348 pred(FieldPath::HttpHeader(Arc::from("user-agent")), CompiledOperator::Matches(re)).test(&v)
2349 );
2350 let re_miss = Regex::new(r"^curl").expect("compile");
2351 assert!(
2352 !pred(FieldPath::HttpHeader(Arc::from("user-agent")), CompiledOperator::Matches(re_miss))
2353 .test(&v)
2354 );
2355 }
2356
2357 #[test]
2359 fn matrix_numeric_cmp_gt_gte_lt_lte_happy_and_miss() {
2360 let conn = make_conn_with("127.0.0.1:1024", "127.0.0.1:443");
2361 let v = PredicateView::L4 { conn: &conn, peek: None };
2362 assert!(pred(FieldPath::RemotePort, CompiledOperator::Gt(1023)).test(&v));
2364 assert!(!pred(FieldPath::RemotePort, CompiledOperator::Gt(1024)).test(&v));
2365 assert!(pred(FieldPath::RemotePort, CompiledOperator::Gte(1024)).test(&v));
2367 assert!(!pred(FieldPath::RemotePort, CompiledOperator::Gte(1025)).test(&v));
2368 assert!(pred(FieldPath::RemotePort, CompiledOperator::Lt(1025)).test(&v));
2370 assert!(!pred(FieldPath::RemotePort, CompiledOperator::Lt(1024)).test(&v));
2371 assert!(pred(FieldPath::RemotePort, CompiledOperator::Lte(1024)).test(&v));
2373 assert!(!pred(FieldPath::RemotePort, CompiledOperator::Lte(1023)).test(&v));
2374 }
2375
2376 #[test]
2377 fn matrix_numeric_cmp_local_port_too() {
2378 let conn = make_conn_with("127.0.0.1:0", "127.0.0.1:8443");
2380 let v = PredicateView::L4 { conn: &conn, peek: None };
2381 assert!(pred(FieldPath::LocalPort, CompiledOperator::Gt(8000)).test(&v));
2382 assert!(!pred(FieldPath::LocalPort, CompiledOperator::Gt(9000)).test(&v));
2383 }
2384
2385 #[test]
2387 fn matrix_cidr_v4_happy_and_miss() {
2388 let conn = make_conn_with("10.0.5.7:0", "127.0.0.1:0");
2389 let v = PredicateView::L4 { conn: &conn, peek: None };
2390 let ten = IpNet::from_str("10.0.0.0/8").unwrap();
2391 let nineteen2 = IpNet::from_str("192.168.0.0/16").unwrap();
2392 assert!(pred(FieldPath::RemoteIp, CompiledOperator::Cidr(ten)).test(&v));
2393 assert!(!pred(FieldPath::RemoteIp, CompiledOperator::Cidr(nineteen2)).test(&v));
2394 }
2395
2396 #[test]
2397 fn matrix_cidr_v6_happy_and_miss() {
2398 let conn = make_conn_with("[2001:db8::5]:0", "127.0.0.1:0");
2399 let v = PredicateView::L4 { conn: &conn, peek: None };
2400 let net = IpNet::from_str("2001:db8::/32").unwrap();
2401 let other = IpNet::from_str("2001:dead::/32").unwrap();
2402 assert!(pred(FieldPath::RemoteIp, CompiledOperator::Cidr(net)).test(&v));
2403 assert!(!pred(FieldPath::RemoteIp, CompiledOperator::Cidr(other)).test(&v));
2404 }
2405
2406 #[test]
2407 fn matrix_cidr_v4_against_v6_addr_misses() {
2408 let conn = make_conn_with("[2001:db8::5]:0", "127.0.0.1:0");
2410 let v = PredicateView::L4 { conn: &conn, peek: None };
2411 let v4 = IpNet::from_str("0.0.0.0/0").unwrap();
2412 assert!(!pred(FieldPath::RemoteIp, CompiledOperator::Cidr(v4)).test(&v));
2413 }
2414
2415 #[test]
2419 fn http_uri_query_reader_returns_empty_when_query_absent() {
2420 let conn = make_conn();
2423 let req = http::Request::builder().method("GET").uri("/no-q").body(Body::Empty).unwrap();
2424 let v = PredicateView::L7Req { conn: &conn, req: &req };
2425 assert!(pred(FieldPath::HttpUriQuery, CompiledOperator::Equals(str_val(""))).test(&v));
2426 assert!(!pred(FieldPath::HttpUriQuery, CompiledOperator::Equals(str_val("q=1"))).test(&v));
2427 }
2428
2429 #[test]
2430 fn http_uri_query_reader_matches_present_query() {
2431 let conn = make_conn();
2432 let req = http::Request::builder().method("GET").uri("/x?a=1&b=2").body(Body::Empty).unwrap();
2433 let v = PredicateView::L7Req { conn: &conn, req: &req };
2434 assert!(pred(FieldPath::HttpUriQuery, CompiledOperator::Equals(str_val("a=1&b=2"))).test(&v));
2435 assert!(pred(FieldPath::HttpUriQuery, CompiledOperator::Contains(b(b"b=2"))).test(&v));
2436 }
2437
2438 #[test]
2439 fn local_ip_reader_uses_local_socket() {
2440 let conn = make_conn_with("10.0.0.5:0", "127.0.0.1:8443");
2441 let v = PredicateView::L4 { conn: &conn, peek: None };
2442 let local: std::net::IpAddr = "127.0.0.1".parse().unwrap();
2443 assert!(
2444 pred(FieldPath::LocalIp, CompiledOperator::Equals(CompiledValue::Addr(local))).test(&v)
2445 );
2446 }
2447
2448 #[test]
2449 fn http_header_lookup_misses_for_non_utf8_value() {
2450 let conn = make_conn();
2453 let bad =
2454 http::HeaderValue::from_bytes(&[0xff, 0xfe, 0xfd]).expect("non-utf8 header value parses");
2455 let mut builder = http::Request::builder().method("GET").uri("/");
2456 builder.headers_mut().expect("headers").insert("x-bad", bad);
2457 let req: Request = builder.body(Body::Empty).expect("build request");
2458 let v = PredicateView::L7Req { conn: &conn, req: &req };
2459 assert!(
2460 !pred(
2461 FieldPath::HttpHeader(Arc::from("x-bad")),
2462 CompiledOperator::Equals(str_val("anything")),
2463 )
2464 .test(&v)
2465 );
2466 }
2467
2468 fn rcgen_cert_with_cn(cn: &str) -> rustls_pki_types::CertificateDer<'static> {
2470 let mut params = rcgen::CertificateParams::default();
2471 params.distinguished_name = rcgen::DistinguishedName::new();
2472 params.distinguished_name.push(rcgen::DnType::CommonName, cn);
2473 let key = rcgen::KeyPair::generate().expect("rcgen keypair");
2474 let cert = params.self_signed(&key).expect("self-sign cert");
2475 cert.der().clone()
2476 }
2477
2478 fn rcgen_cert_no_cn() -> rustls_pki_types::CertificateDer<'static> {
2479 let params = rcgen::CertificateParams::default();
2482 let mut params = params;
2485 params.distinguished_name = rcgen::DistinguishedName::new();
2486 let key = rcgen::KeyPair::generate().expect("rcgen keypair");
2487 let cert = params.self_signed(&key).expect("self-sign cert");
2488 cert.der().clone()
2489 }
2490
2491 fn conn_with_peer_cert(cert: &rustls_pki_types::CertificateDer<'static>) -> Arc<ConnContext> {
2492 let pc = crate::conn_context::PeerCertificate::from_der(cert)
2493 .expect("rcgen-issued cert must parse via PeerCertificate::from_der");
2494 let conn = make_conn();
2495 *conn.tls.lock() = Some(crate::conn_context::TlsInfo {
2496 sni: None,
2497 alpn: None,
2498 version: None,
2499 peer_cert: Some(Arc::new(pc)),
2500 zero_rtt_used: false,
2501 });
2502 conn
2503 }
2504
2505 #[test]
2506 fn peer_cert_from_der_extracts_cn() {
2507 let cert = rcgen_cert_with_cn("client.internal");
2508 let pc = crate::conn_context::PeerCertificate::from_der(&cert).expect("parse");
2509 assert_eq!(pc.subject_cn.as_deref(), Some("client.internal"));
2510 }
2511
2512 #[test]
2513 fn peer_cert_from_der_returns_none_for_malformed_der() {
2514 let raw = rustls_pki_types::CertificateDer::from(vec![0x30, 0x80, 0x00, 0x00]);
2515 assert!(crate::conn_context::PeerCertificate::from_der(&raw).is_none());
2516 let raw = rustls_pki_types::CertificateDer::from(b"not a cert at all".to_vec());
2517 assert!(crate::conn_context::PeerCertificate::from_der(&raw).is_none());
2518 }
2519
2520 #[test]
2521 fn peer_cert_from_der_returns_some_with_no_cn_when_dn_has_no_cn() {
2522 let cert = rcgen_cert_no_cn();
2524 let pc = crate::conn_context::PeerCertificate::from_der(&cert).expect("parse");
2525 assert!(pc.subject_cn.is_none());
2526 }
2527
2528 #[test]
2529 fn matrix_peer_cert_subject_cn_equals_happy_and_miss() {
2530 let cert = rcgen_cert_with_cn("ops-bot");
2531 let conn = conn_with_peer_cert(&cert);
2532 let v = PredicateView::L4 { conn: &conn, peek: None };
2533 assert!(
2534 pred(FieldPath::TlsPeerCertSubjectCn, CompiledOperator::Equals(str_val("ops-bot"))).test(&v)
2535 );
2536 assert!(
2537 !pred(FieldPath::TlsPeerCertSubjectCn, CompiledOperator::Equals(str_val("attacker")))
2538 .test(&v)
2539 );
2540 }
2541
2542 #[test]
2543 fn matrix_peer_cert_subject_cn_string_ops_happy_and_miss() {
2544 let cert = rcgen_cert_with_cn("svc-payments-prod");
2545 let conn = conn_with_peer_cert(&cert);
2546 let v = PredicateView::L4 { conn: &conn, peek: None };
2547 assert!(pred(FieldPath::TlsPeerCertSubjectCn, CompiledOperator::Prefix(b(b"svc-"))).test(&v));
2549 assert!(
2550 !pred(FieldPath::TlsPeerCertSubjectCn, CompiledOperator::Prefix(b(b"client-"))).test(&v)
2551 );
2552 assert!(pred(FieldPath::TlsPeerCertSubjectCn, CompiledOperator::Suffix(b(b"-prod"))).test(&v));
2554 assert!(
2556 pred(FieldPath::TlsPeerCertSubjectCn, CompiledOperator::Contains(b(b"payments"))).test(&v)
2557 );
2558 let re = Regex::new(r"^svc-[a-z]+-(prod|stg)$").expect("regex");
2560 assert!(pred(FieldPath::TlsPeerCertSubjectCn, CompiledOperator::Matches(re)).test(&v));
2561 let list = vec![str_val("svc-other-prod"), str_val("svc-payments-prod")];
2563 assert!(pred(FieldPath::TlsPeerCertSubjectCn, CompiledOperator::In(list)).test(&v));
2564 }
2565
2566 #[test]
2567 fn peer_cert_subject_cn_misses_when_cert_absent() {
2568 let conn = make_conn();
2571 let v = PredicateView::L4 { conn: &conn, peek: None };
2572 assert!(
2573 !pred(FieldPath::TlsPeerCertSubjectCn, CompiledOperator::Equals(str_val("anything")))
2574 .test(&v)
2575 );
2576 }
2577
2578 #[test]
2579 fn peer_cert_subject_cn_misses_when_cert_has_no_cn() {
2580 let cert = rcgen_cert_no_cn();
2583 let conn = conn_with_peer_cert(&cert);
2584 let v = PredicateView::L4 { conn: &conn, peek: None };
2585 assert!(
2586 !pred(FieldPath::TlsPeerCertSubjectCn, CompiledOperator::Equals(str_val("ops-bot"))).test(&v)
2587 );
2588 }
2589
2590 fn rcgen_cert_with_san_dns(cn: &str, dns: &[&str]) -> rustls_pki_types::CertificateDer<'static> {
2592 let san: Vec<String> = dns.iter().map(|s| (*s).to_owned()).collect();
2593 let mut params = rcgen::CertificateParams::new(san).expect("rcgen params");
2594 params.distinguished_name = rcgen::DistinguishedName::new();
2595 params.distinguished_name.push(rcgen::DnType::CommonName, cn);
2596 let key = rcgen::KeyPair::generate().expect("rcgen keypair");
2597 let cert = params.self_signed(&key).expect("self-sign cert");
2598 cert.der().clone()
2599 }
2600
2601 #[test]
2602 fn each_new_field_path_parses_from_string_form() {
2603 use super::parse_field_path;
2604 assert_eq!(parse_field_path("tls.peer_cert.present"), Ok(FieldPath::TlsPeerCertPresent));
2605 assert_eq!(parse_field_path("tls.peer_cert.san_dns"), Ok(FieldPath::TlsPeerCertSanDns));
2606 assert_eq!(
2607 parse_field_path("tls.peer_cert.fingerprint_sha256"),
2608 Ok(FieldPath::TlsPeerCertFingerprintSha256),
2609 );
2610 assert_eq!(parse_field_path("tls.peer_cert.spki_sha256"), Ok(FieldPath::TlsPeerCertSpkiSha256),);
2611 assert_eq!(parse_field_path("tls.peer_cert.issuer_cn"), Ok(FieldPath::TlsPeerCertIssuerCn));
2612 assert_eq!(parse_field_path("tls.peer_cert.serial"), Ok(FieldPath::TlsPeerCertSerial));
2613 }
2614
2615 #[test]
2616 fn peer_cert_present_true_when_cert_attached() {
2617 let cert = rcgen_cert_with_cn("client.internal");
2618 let conn = conn_with_peer_cert(&cert);
2619 let v = PredicateView::L4 { conn: &conn, peek: None };
2620 assert!(
2621 pred(FieldPath::TlsPeerCertPresent, CompiledOperator::Equals(CompiledValue::Bool(true)))
2622 .test(&v)
2623 );
2624 assert!(
2625 !pred(FieldPath::TlsPeerCertPresent, CompiledOperator::Equals(CompiledValue::Bool(false)))
2626 .test(&v)
2627 );
2628 }
2629
2630 #[test]
2631 fn peer_cert_present_false_when_cert_absent() {
2632 let conn = make_conn();
2635 let v = PredicateView::L4 { conn: &conn, peek: None };
2636 assert!(
2637 pred(FieldPath::TlsPeerCertPresent, CompiledOperator::Equals(CompiledValue::Bool(false)))
2638 .test(&v)
2639 );
2640 assert!(
2641 !pred(FieldPath::TlsPeerCertPresent, CompiledOperator::Equals(CompiledValue::Bool(true)))
2642 .test(&v)
2643 );
2644 }
2645
2646 #[test]
2647 fn peer_cert_san_dns_contains_matches_listed_element() {
2648 let cert = rcgen_cert_with_san_dns("svc-a", &["svc-a.internal", "svc-b.internal"]);
2649 let conn = conn_with_peer_cert(&cert);
2650 let v = PredicateView::L4 { conn: &conn, peek: None };
2651 assert!(
2652 pred(FieldPath::TlsPeerCertSanDns, CompiledOperator::Contains(b(b"svc-a.internal"))).test(&v)
2653 );
2654 assert!(
2655 !pred(FieldPath::TlsPeerCertSanDns, CompiledOperator::Contains(b(b"svc-c.internal")))
2656 .test(&v),
2657 );
2658 assert!(
2659 pred(FieldPath::TlsPeerCertSanDns, CompiledOperator::NotContains(b(b"svc-c.internal")))
2660 .test(&v),
2661 );
2662 }
2663
2664 #[test]
2665 fn peer_cert_san_dns_misses_when_cert_absent() {
2666 let conn = make_conn();
2667 let v = PredicateView::L4 { conn: &conn, peek: None };
2668 assert!(
2669 !pred(FieldPath::TlsPeerCertSanDns, CompiledOperator::Contains(b(b"anything"))).test(&v)
2670 );
2671 }
2672
2673 #[test]
2674 fn peer_cert_fingerprint_sha256_is_lowercase_hex_of_full_der() {
2675 use sha2::{Digest, Sha256};
2676 let cert = rcgen_cert_with_cn("fingerprinted");
2677 let mut h = Sha256::new();
2678 h.update(cert.as_ref());
2679 let want = h.finalize().iter().fold(String::new(), |mut s, b| {
2680 use std::fmt::Write as _;
2681 let _ = write!(s, "{b:02x}");
2682 s
2683 });
2684
2685 let conn = conn_with_peer_cert(&cert);
2686 let v = PredicateView::L4 { conn: &conn, peek: None };
2687 assert!(
2688 pred(FieldPath::TlsPeerCertFingerprintSha256, CompiledOperator::Equals(str_val(&want)),)
2689 .test(&v),
2690 );
2691 }
2692
2693 #[test]
2694 fn peer_cert_issuer_and_serial_present_for_self_signed_cert() {
2695 let cert = rcgen_cert_with_cn("issuer-test");
2698 let conn = conn_with_peer_cert(&cert);
2699 let v = PredicateView::L4 { conn: &conn, peek: None };
2700 assert!(
2702 pred(FieldPath::TlsPeerCertIssuerCn, CompiledOperator::Equals(str_val("issuer-test")))
2703 .test(&v)
2704 );
2705 let pc = conn.tls.lock().as_ref().unwrap().peer_cert.as_ref().unwrap().clone();
2709 assert!(!pc.serial.is_empty(), "serial extracted");
2710 assert!(pc.serial.chars().all(|c| c.is_ascii_hexdigit() && !c.is_ascii_uppercase()));
2711 }
2712
2713 #[test]
2714 fn peer_cert_present_value_type_is_bool() {
2715 assert_eq!(FieldPath::TlsPeerCertPresent.value_type(), FieldValueType::Bool);
2716 }
2717
2718 #[test]
2719 fn peer_cert_san_dns_value_type_is_vec_str() {
2720 assert_eq!(FieldPath::TlsPeerCertSanDns.value_type(), FieldValueType::VecStr);
2721 }
2722
2723 #[test]
2724 fn matrix_rejects_string_pref_suf_on_bool_field() {
2725 assert!(!OperatorFamily::StringPrefSuf.accepts(FieldValueType::Bool));
2728 assert!(!OperatorFamily::StringSubstr.accepts(FieldValueType::Bool));
2729 assert!(!OperatorFamily::RegexMatches.accepts(FieldValueType::Bool));
2730 assert!(OperatorFamily::Equality.accepts(FieldValueType::Bool));
2732 }
2733
2734 #[test]
2735 fn matrix_rejects_equals_on_vec_str_field() {
2736 assert!(!OperatorFamily::Equality.accepts(FieldValueType::VecStr));
2739 assert!(!OperatorFamily::InList.accepts(FieldValueType::VecStr));
2740 assert!(!OperatorFamily::StringPrefSuf.accepts(FieldValueType::VecStr));
2741 assert!(!OperatorFamily::RegexMatches.accepts(FieldValueType::VecStr));
2742 assert!(OperatorFamily::StringSubstr.accepts(FieldValueType::VecStr));
2743 }
2744
2745 fn req_with_body(body_bytes: &[u8]) -> Request {
2752 http::Request::builder()
2753 .method("POST")
2754 .uri("/upload")
2755 .body(Body::Static(Bytes::copy_from_slice(body_bytes)))
2756 .expect("build req with body")
2757 }
2758
2759 #[test]
2760 fn matrix_http_body_equality_happy_and_miss() {
2761 let conn = make_conn();
2762 let req = req_with_body(b"hello world");
2763 let v = PredicateView::L7Req { conn: &conn, req: &req };
2764 assert!(
2765 pred(FieldPath::HttpBody, CompiledOperator::Equals(bytes_val(b"hello world"))).test(&v)
2766 );
2767 assert!(!pred(FieldPath::HttpBody, CompiledOperator::Equals(bytes_val(b"wrong"))).test(&v));
2768 assert!(pred(FieldPath::HttpBody, CompiledOperator::NotEquals(bytes_val(b"wrong"))).test(&v));
2769 }
2770
2771 #[test]
2772 fn matrix_http_body_substring_happy_and_miss() {
2773 let conn = make_conn();
2774 let req = req_with_body(b"prelude payload trailer");
2775 let v = PredicateView::L7Req { conn: &conn, req: &req };
2776 assert!(pred(FieldPath::HttpBody, CompiledOperator::Contains(b(b"payload"))).test(&v));
2777 assert!(!pred(FieldPath::HttpBody, CompiledOperator::Contains(b(b"missing"))).test(&v));
2778 assert!(pred(FieldPath::HttpBody, CompiledOperator::NotContains(b(b"missing"))).test(&v));
2779 }
2780
2781 #[test]
2782 fn matrix_http_body_prefix_suffix_happy_and_miss() {
2783 let conn = make_conn();
2784 let req = req_with_body(b"START middle END");
2785 let v = PredicateView::L7Req { conn: &conn, req: &req };
2786 assert!(pred(FieldPath::HttpBody, CompiledOperator::Prefix(b(b"START"))).test(&v));
2787 assert!(!pred(FieldPath::HttpBody, CompiledOperator::Prefix(b(b"BEGIN"))).test(&v));
2788 assert!(pred(FieldPath::HttpBody, CompiledOperator::Suffix(b(b"END"))).test(&v));
2789 assert!(!pred(FieldPath::HttpBody, CompiledOperator::Suffix(b(b"FIN"))).test(&v));
2790 }
2791
2792 #[test]
2793 fn matrix_http_body_in_list_happy_and_miss() {
2794 let conn = make_conn();
2795 let req = req_with_body(b"one");
2796 let v = PredicateView::L7Req { conn: &conn, req: &req };
2797 let list = vec![bytes_val(b"two"), bytes_val(b"one")];
2798 assert!(pred(FieldPath::HttpBody, CompiledOperator::In(list)).test(&v));
2799 let miss = vec![bytes_val(b"two"), bytes_val(b"three")];
2800 assert!(!pred(FieldPath::HttpBody, CompiledOperator::In(miss.clone())).test(&v));
2801 assert!(pred(FieldPath::HttpBody, CompiledOperator::NotIn(miss)).test(&v));
2802 }
2803
2804 #[test]
2805 fn http_body_misses_on_l4_view() {
2806 let conn = make_conn();
2809 let v = PredicateView::L4 { conn: &conn, peek: None };
2810 assert!(!pred(FieldPath::HttpBody, CompiledOperator::Contains(b(b"x"))).test(&v));
2811 }
2812
2813 #[test]
2814 #[should_panic(expected = "lazy-buffer invariant")]
2815 fn http_body_panics_when_lazy_buffer_invariant_violated() {
2816 let conn = make_conn();
2824 let req = http::Request::builder().method("POST").uri("/").body(Body::Empty).unwrap();
2825 let v = PredicateView::L7Req { conn: &conn, req: &req };
2826 let _ = pred(FieldPath::HttpBody, CompiledOperator::Contains(b(b"x"))).test(&v);
2827 }
2828
2829 #[test]
2837 fn matrix_peek_substring_happy_and_miss() {
2838 let buf: &[u8] = &[0x16, 0x03, 0x01, 0x00, 0x40, 0x01];
2840 let conn = make_conn();
2841 let v = PredicateView::L4 { conn: &conn, peek: Some(buf) };
2842 assert!(pred(FieldPath::Peek, CompiledOperator::Prefix(b(b"\x16\x03"))).test(&v));
2843 assert!(!pred(FieldPath::Peek, CompiledOperator::Prefix(b(b"\x14\x03"))).test(&v));
2844 assert!(pred(FieldPath::Peek, CompiledOperator::Contains(b(b"\x03\x01"))).test(&v));
2845 assert!(!pred(FieldPath::Peek, CompiledOperator::Contains(b(b"\xff\xff"))).test(&v));
2846 }
2847
2848 #[test]
2849 fn matrix_peek_equality_happy_and_miss() {
2850 let buf: &[u8] = b"GET";
2851 let conn = make_conn();
2852 let v = PredicateView::L4 { conn: &conn, peek: Some(buf) };
2853 assert!(pred(FieldPath::Peek, CompiledOperator::Equals(bytes_val(b"GET"))).test(&v));
2854 assert!(!pred(FieldPath::Peek, CompiledOperator::Equals(bytes_val(b"PUT"))).test(&v));
2855 assert!(pred(FieldPath::Peek, CompiledOperator::NotEquals(bytes_val(b"PUT"))).test(&v));
2856 }
2857
2858 #[test]
2859 fn matrix_peek_in_list_happy_and_miss() {
2860 let buf: &[u8] = b"PRI ";
2861 let conn = make_conn();
2862 let v = PredicateView::L4 { conn: &conn, peek: Some(buf) };
2863 let list = vec![bytes_val(b"GET "), bytes_val(b"PRI ")];
2865 assert!(pred(FieldPath::Peek, CompiledOperator::In(list)).test(&v));
2866 let miss = vec![bytes_val(b"POST"), bytes_val(b"HEAD")];
2867 assert!(!pred(FieldPath::Peek, CompiledOperator::In(miss.clone())).test(&v));
2868 assert!(pred(FieldPath::Peek, CompiledOperator::NotIn(miss)).test(&v));
2869 }
2870
2871 #[test]
2872 fn peek_misses_when_buffer_absent_on_l4_view() {
2873 let conn = make_conn();
2876 let v = PredicateView::L4 { conn: &conn, peek: None };
2877 assert!(!pred(FieldPath::Peek, CompiledOperator::Prefix(b(b"\x16"))).test(&v));
2878 let req = http::Request::builder().method("GET").uri("/").body(Body::Empty).unwrap();
2880 let v7 = PredicateView::L7Req { conn: &conn, req: &req };
2881 assert!(!pred(FieldPath::Peek, CompiledOperator::Prefix(b(b"\x16"))).test(&v7));
2882 }
2883}