1use std::hash::{Hash, Hasher};
2use std::net::IpAddr;
3use std::sync::Arc;
4
5use bytes::Bytes;
6use ipnet::IpNet;
7
8use crate::body::Request;
9use crate::conn_context::ConnContext;
10
11#[derive(Clone, Eq, PartialEq, Hash, Debug, serde::Serialize, serde::Deserialize)]
12#[serde(rename_all = "snake_case")]
13pub enum FieldPath {
14 Transport,
15 RemoteIp,
16 RemotePort,
17 LocalIp,
18 LocalPort,
19 Peek,
20 TlsSni,
21 TlsAlpn,
22 TlsVersion,
23 TlsPeerCertPresent,
26 TlsPeerCertSubjectCn,
27 TlsPeerCertSanDns,
33 TlsPeerCertFingerprintSha256,
35 TlsPeerCertSpkiSha256,
38 TlsPeerCertIssuerCn,
41 TlsPeerCertSerial,
44 HttpMethod,
45 HttpUriPath,
46 HttpUriQuery,
47 HttpHeader(Arc<str>),
48 HttpBody,
49}
50
51#[derive(Copy, Clone, Eq, PartialEq, Debug)]
56pub enum FieldValueType {
57 Str,
58 Bytes,
59 Int,
60 IpAddr,
61 Enum,
62 Bool,
63 VecStr,
64}
65
66impl FieldValueType {
67 #[must_use]
68 pub fn name(self) -> &'static str {
69 match self {
70 Self::Str => "Str",
71 Self::Bytes => "Bytes",
72 Self::Int => "Int",
73 Self::IpAddr => "IpAddr",
74 Self::Enum => "enum",
75 Self::Bool => "Bool",
76 Self::VecStr => "Vec<Str>",
77 }
78 }
79}
80
81impl FieldPath {
82 #[must_use]
86 pub fn value_type(&self) -> FieldValueType {
87 match self {
88 Self::Transport | Self::TlsVersion | Self::HttpMethod => FieldValueType::Enum,
89 Self::RemoteIp | Self::LocalIp => FieldValueType::IpAddr,
90 Self::RemotePort | Self::LocalPort => FieldValueType::Int,
91 Self::Peek | Self::TlsAlpn | Self::HttpBody => FieldValueType::Bytes,
92 Self::TlsPeerCertPresent => FieldValueType::Bool,
93 Self::TlsPeerCertSanDns => FieldValueType::VecStr,
94 Self::TlsSni
95 | Self::TlsPeerCertSubjectCn
96 | Self::TlsPeerCertFingerprintSha256
97 | Self::TlsPeerCertSpkiSha256
98 | Self::TlsPeerCertIssuerCn
99 | Self::TlsPeerCertSerial
100 | Self::HttpUriPath
101 | Self::HttpUriQuery
102 | Self::HttpHeader(_) => FieldValueType::Str,
103 }
104 }
105
106 #[must_use]
108 pub fn display_name(&self) -> String {
109 match self {
110 Self::Transport => "transport".to_string(),
111 Self::RemoteIp => "remote.ip".to_string(),
112 Self::RemotePort => "remote.port".to_string(),
113 Self::LocalIp => "local.ip".to_string(),
114 Self::LocalPort => "local.port".to_string(),
115 Self::Peek => "peek".to_string(),
116 Self::TlsSni => "tls.sni".to_string(),
117 Self::TlsAlpn => "tls.alpn".to_string(),
118 Self::TlsVersion => "tls.version".to_string(),
119 Self::TlsPeerCertPresent => "tls.peer_cert.present".to_string(),
120 Self::TlsPeerCertSubjectCn => "tls.peer_cert.subject_cn".to_string(),
121 Self::TlsPeerCertSanDns => "tls.peer_cert.san_dns".to_string(),
122 Self::TlsPeerCertFingerprintSha256 => "tls.peer_cert.fingerprint_sha256".to_string(),
123 Self::TlsPeerCertSpkiSha256 => "tls.peer_cert.spki_sha256".to_string(),
124 Self::TlsPeerCertIssuerCn => "tls.peer_cert.issuer_cn".to_string(),
125 Self::TlsPeerCertSerial => "tls.peer_cert.serial".to_string(),
126 Self::HttpMethod => "http.method".to_string(),
127 Self::HttpUriPath => "http.uri.path".to_string(),
128 Self::HttpUriQuery => "http.uri.query".to_string(),
129 Self::HttpHeader(name) => format!("http.header.{name}"),
130 Self::HttpBody => "http.body".to_string(),
131 }
132 }
133}
134
135#[derive(Copy, Clone, Eq, PartialEq, Debug)]
140pub enum OperatorFamily {
141 Equality,
142 StringSubstr,
143 StringPrefSuf,
144 RegexMatches,
145 InList,
146 NumericCmp,
147 CidrMatch,
148}
149
150impl Operator {
151 #[must_use]
152 pub fn family(&self) -> OperatorFamily {
153 match self {
154 Self::Equals(_) | Self::NotEquals(_) => OperatorFamily::Equality,
155 Self::Contains(_) | Self::NotContains(_) => OperatorFamily::StringSubstr,
156 Self::Prefix(_) | Self::Suffix(_) => OperatorFamily::StringPrefSuf,
157 Self::Matches(_) => OperatorFamily::RegexMatches,
158 Self::In(_) | Self::NotIn(_) => OperatorFamily::InList,
159 Self::Gt(_) | Self::Gte(_) | Self::Lt(_) | Self::Lte(_) => OperatorFamily::NumericCmp,
160 Self::Cidr(_) => OperatorFamily::CidrMatch,
161 }
162 }
163
164 #[must_use]
165 pub fn name(&self) -> &'static str {
166 match self {
167 Self::Equals(_) => "equals",
168 Self::NotEquals(_) => "not_equals",
169 Self::Contains(_) => "contains",
170 Self::NotContains(_) => "not_contains",
171 Self::Prefix(_) => "prefix",
172 Self::Suffix(_) => "suffix",
173 Self::Matches(_) => "matches",
174 Self::In(_) => "in",
175 Self::NotIn(_) => "not_in",
176 Self::Gt(_) => "gt",
177 Self::Gte(_) => "gte",
178 Self::Lt(_) => "lt",
179 Self::Lte(_) => "lte",
180 Self::Cidr(_) => "cidr",
181 }
182 }
183}
184
185impl OperatorFamily {
186 #[must_use]
191 pub fn accepts(self, vt: FieldValueType) -> bool {
192 use FieldValueType as V;
193 use OperatorFamily as F;
194 matches!(
195 (self, vt),
196 (F::Equality, V::Str | V::Bytes | V::Int | V::IpAddr | V::Enum | V::Bool)
201 | (F::InList, V::Str | V::Bytes | V::Int | V::IpAddr | V::Enum)
204 | (F::StringSubstr, V::Str | V::Bytes | V::VecStr)
205 | (F::StringPrefSuf, V::Str | V::Bytes)
206 | (F::RegexMatches, V::Str)
207 | (F::NumericCmp, V::Int)
208 | (F::CidrMatch, V::IpAddr),
209 )
210 }
211
212 #[must_use]
214 pub fn family_expectation(self) -> &'static str {
215 match self {
216 Self::Equality => "any of Str/Bytes/Int/IpAddr/enum/Bool",
217 Self::InList => "any of Str/Bytes/Int/IpAddr/enum",
218 Self::StringSubstr => "Str, Bytes, or Vec<Str>",
219 Self::StringPrefSuf => "Str or Bytes",
220 Self::RegexMatches => "Str",
221 Self::NumericCmp => "numeric",
222 Self::CidrMatch => "IpAddr",
223 }
224 }
225}
226
227#[derive(Clone, Debug)]
228pub enum CompiledValue {
229 Str(Arc<str>),
230 Bytes(Bytes),
231 Int(i64),
232 Bool(bool),
233 Addr(IpAddr),
234}
235
236impl PartialEq for CompiledValue {
237 fn eq(&self, other: &Self) -> bool {
238 match (self, other) {
239 (Self::Str(a), Self::Str(b)) => a.as_ref() == b.as_ref(),
240 (Self::Bytes(a), Self::Bytes(b)) => a == b,
241 (Self::Int(a), Self::Int(b)) => a == b,
242 (Self::Bool(a), Self::Bool(b)) => a == b,
243 (Self::Addr(a), Self::Addr(b)) => a == b,
244 _ => false,
245 }
246 }
247}
248
249impl Eq for CompiledValue {}
250
251impl Hash for CompiledValue {
252 fn hash<H: Hasher>(&self, state: &mut H) {
253 std::mem::discriminant(self).hash(state);
254 match self {
255 Self::Str(s) => s.as_ref().hash(state),
256 Self::Bytes(b) => b.hash(state),
257 Self::Int(i) => i.hash(state),
258 Self::Bool(b) => b.hash(state),
259 Self::Addr(a) => a.hash(state),
260 }
261 }
262}
263
264#[derive(Clone, Debug)]
265pub enum CompiledOperator {
266 Equals(CompiledValue),
267 NotEquals(CompiledValue),
268 Contains(Bytes),
269 NotContains(Bytes),
270 Prefix(Bytes),
271 Suffix(Bytes),
272 Matches(fancy_regex::Regex),
273 In(Vec<CompiledValue>),
274 NotIn(Vec<CompiledValue>),
275 Gt(i64),
276 Gte(i64),
277 Lt(i64),
278 Lte(i64),
279 Cidr(IpNet),
280}
281
282impl PartialEq for CompiledOperator {
283 fn eq(&self, other: &Self) -> bool {
284 match (self, other) {
285 (Self::Equals(a), Self::Equals(b)) | (Self::NotEquals(a), Self::NotEquals(b)) => a == b,
286 (Self::Contains(a), Self::Contains(b))
287 | (Self::NotContains(a), Self::NotContains(b))
288 | (Self::Prefix(a), Self::Prefix(b))
289 | (Self::Suffix(a), Self::Suffix(b)) => a == b,
290 (Self::Matches(a), Self::Matches(b)) => a.as_str() == b.as_str(),
291 (Self::In(a), Self::In(b)) | (Self::NotIn(a), Self::NotIn(b)) => a == b,
292 (Self::Gt(a), Self::Gt(b))
293 | (Self::Gte(a), Self::Gte(b))
294 | (Self::Lt(a), Self::Lt(b))
295 | (Self::Lte(a), Self::Lte(b)) => a == b,
296 (Self::Cidr(a), Self::Cidr(b)) => a == b,
297 _ => false,
298 }
299 }
300}
301
302impl Eq for CompiledOperator {}
303
304impl Hash for CompiledOperator {
305 fn hash<H: Hasher>(&self, state: &mut H) {
306 std::mem::discriminant(self).hash(state);
307 match self {
308 Self::Equals(v) | Self::NotEquals(v) => v.hash(state),
309 Self::Contains(b) | Self::NotContains(b) | Self::Prefix(b) | Self::Suffix(b) => {
310 b.hash(state);
311 }
312 Self::Matches(r) => r.as_str().hash(state),
313 Self::In(v) | Self::NotIn(v) => v.hash(state),
314 Self::Gt(i) | Self::Gte(i) | Self::Lt(i) | Self::Lte(i) => i.hash(state),
315 Self::Cidr(n) => n.hash(state),
316 }
317 }
318}
319
320#[derive(Clone, Debug, Eq, PartialEq, Hash, serde::Serialize, serde::Deserialize)]
321pub struct PredicateInst {
322 pub path: FieldPath,
323 pub op: CompiledOperator,
324}
325
326pub enum PredicateView<'a> {
327 L4 { conn: &'a Arc<ConnContext>, peek: Option<&'a [u8]> },
328 L7Req { conn: &'a Arc<ConnContext>, req: &'a Request },
329}
330
331impl<'a> PredicateView<'a> {
332 #[must_use]
342 pub fn build(
343 conn: &'a Arc<ConnContext>,
344 req: Option<&'a Request>,
345 _l4: Option<&'a crate::l4::L4Conn>,
346 peek: Option<&'a [u8]>,
347 ) -> Self {
348 match req {
349 Some(r) => Self::L7Req { conn, req: r },
350 None => Self::L4 { conn, peek },
351 }
352 }
353
354 fn conn(&self) -> &Arc<ConnContext> {
355 match self {
356 Self::L4 { conn, .. } | Self::L7Req { conn, .. } => conn,
357 }
358 }
359
360 fn request(&self) -> Option<&Request> {
361 match self {
362 Self::L7Req { req, .. } => Some(req),
363 Self::L4 { .. } => None,
364 }
365 }
366
367 fn peek_buffer(&self) -> Option<&[u8]> {
368 match self {
369 Self::L4 { peek, .. } => *peek,
370 Self::L7Req { .. } => None,
371 }
372 }
373}
374
375impl PredicateInst {
376 #[must_use]
399 #[allow(clippy::too_many_lines)]
400 pub fn test(&self, view: &PredicateView<'_>) -> bool {
401 match &self.path {
402 FieldPath::Transport => {
403 let s = match view.conn().transport {
404 crate::conn_context::Transport::Tcp => "tcp",
405 crate::conn_context::Transport::Udp => "udp",
406 };
407 test_str(&self.op, s)
408 }
409 FieldPath::RemoteIp => test_addr(&self.op, view.conn().remote.ip()),
410 FieldPath::RemotePort => test_int(&self.op, i64::from(view.conn().remote.port())),
411 FieldPath::LocalIp => test_addr(&self.op, view.conn().local.ip()),
412 FieldPath::LocalPort => test_int(&self.op, i64::from(view.conn().local.port())),
413 FieldPath::Peek => view.peek_buffer().is_some_and(|b| test_bytes(&self.op, b)),
414 FieldPath::TlsSni => view
415 .conn()
416 .tls
417 .lock()
418 .as_ref()
419 .and_then(|t| t.sni.clone())
420 .is_some_and(|got| test_str(&self.op, got.as_str())),
421 FieldPath::TlsAlpn => view
422 .conn()
423 .tls
424 .lock()
425 .as_ref()
426 .and_then(|t| t.alpn.clone())
427 .is_some_and(|got| test_bytes(&self.op, got.as_slice())),
428 FieldPath::TlsVersion => view
429 .conn()
430 .tls
431 .lock()
432 .as_ref()
433 .and_then(|t| t.version)
434 .is_some_and(|v| test_str(&self.op, tls_version_str(v))),
435 FieldPath::TlsPeerCertPresent => {
445 let present = view.conn().tls.lock().as_ref().is_some_and(|t| t.peer_cert.is_some());
446 test_bool(&self.op, present)
447 }
448 FieldPath::TlsPeerCertSubjectCn => view
449 .conn()
450 .tls
451 .lock()
452 .as_ref()
453 .and_then(|t| t.peer_cert.as_ref().and_then(|p| p.subject_cn.clone()))
454 .is_some_and(|cn| test_str(&self.op, cn.as_str())),
455 FieldPath::TlsPeerCertSanDns => {
456 let dns_list: Vec<String> = view
457 .conn()
458 .tls
459 .lock()
460 .as_ref()
461 .and_then(|t| t.peer_cert.as_ref().map(|p| p.san_dns.clone()))
462 .unwrap_or_default();
463 test_vec_str(&self.op, &dns_list)
464 }
465 FieldPath::TlsPeerCertFingerprintSha256 => view
466 .conn()
467 .tls
468 .lock()
469 .as_ref()
470 .and_then(|t| t.peer_cert.as_ref().map(|p| p.fingerprint_sha256.clone()))
471 .is_some_and(|s| test_str(&self.op, s.as_str())),
472 FieldPath::TlsPeerCertSpkiSha256 => view
473 .conn()
474 .tls
475 .lock()
476 .as_ref()
477 .and_then(|t| t.peer_cert.as_ref().map(|p| p.spki_sha256.clone()))
478 .is_some_and(|s| test_str(&self.op, s.as_str())),
479 FieldPath::TlsPeerCertIssuerCn => view
480 .conn()
481 .tls
482 .lock()
483 .as_ref()
484 .and_then(|t| t.peer_cert.as_ref().and_then(|p| p.issuer_cn.clone()))
485 .is_some_and(|s| test_str(&self.op, s.as_str())),
486 FieldPath::TlsPeerCertSerial => view
487 .conn()
488 .tls
489 .lock()
490 .as_ref()
491 .and_then(|t| t.peer_cert.as_ref().map(|p| p.serial.clone()))
492 .is_some_and(|s| test_str(&self.op, s.as_str())),
493 FieldPath::HttpMethod => {
494 let Some(req) = view.request() else { return false };
495 test_str(&self.op, req.method().as_str())
496 }
497 FieldPath::HttpUriPath => {
498 let Some(req) = view.request() else { return false };
499 test_str(&self.op, req.uri().path())
500 }
501 FieldPath::HttpUriQuery => {
502 let Some(req) = view.request() else { return false };
503 test_str(&self.op, req.uri().query().unwrap_or(""))
504 }
505 FieldPath::HttpHeader(name) => {
513 let Some(req) = view.request() else { return false };
514 let Some(value) = req.headers().get(name.as_ref()) else { return false };
515 let Ok(s) = value.to_str() else {
516 return false;
521 };
522 test_str(&self.op, s)
523 }
524 FieldPath::HttpBody => {
533 let Some(req) = view.request() else { return false };
534 let bytes = req.body().as_static().expect("lazy-buffer invariant");
535 test_bytes(&self.op, bytes.as_ref())
536 }
537 }
538 }
539}
540
541fn tls_version_str(v: crate::conn_context::TlsVersion) -> &'static str {
542 match v {
543 crate::conn_context::TlsVersion::Tls12 => "1.2",
544 crate::conn_context::TlsVersion::Tls13 => "1.3",
545 }
546}
547
548fn test_bool(op: &CompiledOperator, value: bool) -> bool {
560 match op {
561 CompiledOperator::Equals(CompiledValue::Bool(expected)) => value == *expected,
562 CompiledOperator::NotEquals(CompiledValue::Bool(expected)) => value != *expected,
563 _ => false,
564 }
565}
566
567fn test_vec_str(op: &CompiledOperator, values: &[String]) -> bool {
573 match op {
574 CompiledOperator::Contains(needle) => values.iter().any(|v| v.as_bytes() == needle.as_ref()),
575 CompiledOperator::NotContains(needle) => {
576 !values.iter().any(|v| v.as_bytes() == needle.as_ref())
577 }
578 _ => false,
579 }
580}
581
582fn test_str(op: &CompiledOperator, value: &str) -> bool {
587 match op {
588 CompiledOperator::Equals(CompiledValue::Str(expected)) => value == expected.as_ref(),
589 CompiledOperator::NotEquals(CompiledValue::Str(expected)) => value != expected.as_ref(),
590 CompiledOperator::Contains(b) => contains_bytes(value.as_bytes(), b),
591 CompiledOperator::NotContains(b) => !contains_bytes(value.as_bytes(), b),
592 CompiledOperator::Prefix(b) => value.as_bytes().starts_with(b.as_ref()),
593 CompiledOperator::Suffix(b) => value.as_bytes().ends_with(b.as_ref()),
594 CompiledOperator::Matches(re) => re.is_match(value).unwrap_or(false),
595 CompiledOperator::In(values) => {
596 values.iter().any(|v| matches!(v, CompiledValue::Str(s) if value == s.as_ref()))
597 }
598 CompiledOperator::NotIn(values) => {
599 !values.iter().any(|v| matches!(v, CompiledValue::Str(s) if value == s.as_ref()))
600 }
601 _ => false,
602 }
603}
604
605fn test_bytes(op: &CompiledOperator, value: &[u8]) -> bool {
609 match op {
610 CompiledOperator::Equals(CompiledValue::Bytes(expected)) => value == expected.as_ref(),
611 CompiledOperator::NotEquals(CompiledValue::Bytes(expected)) => value != expected.as_ref(),
612 CompiledOperator::Contains(b) => contains_bytes(value, b),
613 CompiledOperator::NotContains(b) => !contains_bytes(value, b),
614 CompiledOperator::Prefix(b) => value.starts_with(b.as_ref()),
615 CompiledOperator::Suffix(b) => value.ends_with(b.as_ref()),
616 CompiledOperator::In(values) => {
617 values.iter().any(|v| matches!(v, CompiledValue::Bytes(b) if value == b.as_ref()))
618 }
619 CompiledOperator::NotIn(values) => {
620 !values.iter().any(|v| matches!(v, CompiledValue::Bytes(b) if value == b.as_ref()))
621 }
622 _ => false,
623 }
624}
625
626fn test_int(op: &CompiledOperator, value: i64) -> bool {
629 match op {
630 CompiledOperator::Equals(CompiledValue::Int(expected)) => value == *expected,
631 CompiledOperator::NotEquals(CompiledValue::Int(expected)) => value != *expected,
632 CompiledOperator::Gt(n) => value > *n,
633 CompiledOperator::Gte(n) => value >= *n,
634 CompiledOperator::Lt(n) => value < *n,
635 CompiledOperator::Lte(n) => value <= *n,
636 CompiledOperator::In(values) => {
637 values.iter().any(|v| matches!(v, CompiledValue::Int(i) if value == *i))
638 }
639 CompiledOperator::NotIn(values) => {
640 !values.iter().any(|v| matches!(v, CompiledValue::Int(i) if value == *i))
641 }
642 _ => false,
643 }
644}
645
646fn test_addr(op: &CompiledOperator, value: std::net::IpAddr) -> bool {
650 match op {
651 CompiledOperator::Equals(CompiledValue::Addr(expected)) => value == *expected,
652 CompiledOperator::NotEquals(CompiledValue::Addr(expected)) => value != *expected,
653 CompiledOperator::Cidr(net) => net.contains(&value),
654 CompiledOperator::In(values) => {
655 values.iter().any(|v| matches!(v, CompiledValue::Addr(a) if value == *a))
656 }
657 CompiledOperator::NotIn(values) => {
658 !values.iter().any(|v| matches!(v, CompiledValue::Addr(a) if value == *a))
659 }
660 _ => false,
661 }
662}
663
664fn contains_bytes(haystack: &[u8], needle: &[u8]) -> bool {
665 if needle.is_empty() {
666 return true;
667 }
668 if needle.len() > haystack.len() {
669 return false;
670 }
671 haystack.windows(needle.len()).any(|w| w == needle)
672}
673
674pub const REGEX_PATTERN_MAX_BYTES: usize = 4 * 1024;
675
676#[derive(Debug, Clone, serde::Serialize)]
677pub enum Predicate {
678 AnyOf(AnyOfP),
679 AllOf(AllOfP),
680 Not(NotP),
681 Check(CheckMap),
682}
683
684#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
685#[serde(deny_unknown_fields)]
686pub struct AnyOfP {
687 pub any_of: Vec<Predicate>,
688}
689
690#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
691#[serde(deny_unknown_fields)]
692pub struct AllOfP {
693 pub all_of: Vec<Predicate>,
694}
695
696#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
697#[serde(deny_unknown_fields)]
698pub struct NotP {
699 pub not: Box<Predicate>,
700}
701
702#[derive(Debug, Clone, serde::Serialize)]
703pub struct CheckMap {
704 pub path: FieldPath,
705 pub op: Operator,
706}
707
708#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
709#[serde(rename_all = "snake_case")]
710pub enum Operator {
711 Equals(Value),
712 NotEquals(Value),
713 Contains(Value),
714 NotContains(Value),
715 Prefix(Value),
716 Suffix(Value),
717 Matches(String),
718 In(Vec<Value>),
719 NotIn(Vec<Value>),
720 Gt(i64),
721 Gte(i64),
722 Lt(i64),
723 Lte(i64),
724 Cidr(String),
725}
726
727#[derive(Debug, Clone, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
728#[serde(untagged)]
729pub enum Value {
730 Bool(bool),
731 Int(i64),
732 Str(String),
733}
734
735impl<'de> serde::Deserialize<'de> for Predicate {
736 fn deserialize<D: serde::Deserializer<'de>>(de: D) -> Result<Self, D::Error> {
737 let v = serde_json::Value::deserialize(de)?;
738 let serde_json::Value::Object(ref map) = v else {
739 return Err(serde::de::Error::custom("predicate must be a JSON object"));
740 };
741 if map.len() == 1 {
742 let (key, _) = map.iter().next().expect("len == 1");
743 match key.as_str() {
744 "any_of" => {
745 return serde_json::from_value::<AnyOfP>(v)
746 .map(Predicate::AnyOf)
747 .map_err(serde::de::Error::custom);
748 }
749 "all_of" => {
750 return serde_json::from_value::<AllOfP>(v)
751 .map(Predicate::AllOf)
752 .map_err(serde::de::Error::custom);
753 }
754 "not" => {
755 return serde_json::from_value::<NotP>(v)
756 .map(Predicate::Not)
757 .map_err(serde::de::Error::custom);
758 }
759 _ => {}
760 }
761 }
762 serde_json::from_value::<CheckMap>(v).map(Predicate::Check).map_err(serde::de::Error::custom)
763 }
764}
765
766impl<'de> serde::Deserialize<'de> for CheckMap {
767 fn deserialize<D: serde::Deserializer<'de>>(de: D) -> Result<Self, D::Error> {
768 struct Visitor;
769
770 impl<'de> serde::de::Visitor<'de> for Visitor {
771 type Value = CheckMap;
772
773 fn expecting(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
774 f.write_str("a single-key object of the form {\"<field-path>\": {\"<operator>\": <value>}}")
775 }
776
777 fn visit_map<M: serde::de::MapAccess<'de>>(self, mut map: M) -> Result<CheckMap, M::Error> {
778 let Some(key) = map.next_key::<String>()? else {
779 return Err(serde::de::Error::invalid_length(0, &"exactly one key"));
780 };
781 let path = parse_field_path(&key).map_err(serde::de::Error::custom)?;
782 let op: Operator = map.next_value()?;
783 if map.next_key::<serde::de::IgnoredAny>()?.is_some() {
784 return Err(serde::de::Error::custom("check object must have exactly one key"));
785 }
786 validate_operator(&op).map_err(serde::de::Error::custom)?;
787 Ok(CheckMap { path, op })
788 }
789 }
790
791 de.deserialize_map(Visitor)
792 }
793}
794
795fn parse_field_path(s: &str) -> Result<FieldPath, String> {
796 if s.chars().any(|c| c.is_ascii_uppercase()) {
797 return Err(format!(
798 "field path must be lowercase: {:?} — did you mean {:?}?",
799 s,
800 s.to_ascii_lowercase(),
801 ));
802 }
803 match s {
804 "transport" => Ok(FieldPath::Transport),
805 "remote.ip" => Ok(FieldPath::RemoteIp),
806 "remote.port" => Ok(FieldPath::RemotePort),
807 "local.ip" => Ok(FieldPath::LocalIp),
808 "local.port" => Ok(FieldPath::LocalPort),
809 "peek" => Ok(FieldPath::Peek),
810 "tls.sni" => Ok(FieldPath::TlsSni),
811 "tls.alpn" => Ok(FieldPath::TlsAlpn),
812 "tls.version" => Ok(FieldPath::TlsVersion),
813 "tls.peer_cert.present" => Ok(FieldPath::TlsPeerCertPresent),
814 "tls.peer_cert.subject_cn" => Ok(FieldPath::TlsPeerCertSubjectCn),
815 "tls.peer_cert.san_dns" => Ok(FieldPath::TlsPeerCertSanDns),
816 "tls.peer_cert.fingerprint_sha256" => Ok(FieldPath::TlsPeerCertFingerprintSha256),
817 "tls.peer_cert.spki_sha256" => Ok(FieldPath::TlsPeerCertSpkiSha256),
818 "tls.peer_cert.issuer_cn" => Ok(FieldPath::TlsPeerCertIssuerCn),
819 "tls.peer_cert.serial" => Ok(FieldPath::TlsPeerCertSerial),
820 "http.method" => Ok(FieldPath::HttpMethod),
821 "http.uri.path" => Ok(FieldPath::HttpUriPath),
822 "http.uri.query" => Ok(FieldPath::HttpUriQuery),
823 "http.body" => Ok(FieldPath::HttpBody),
824 other if other.starts_with("http.header.") => {
825 let name = &other["http.header.".len()..];
826 if name.is_empty() {
827 return Err(format!("http.header.* requires a header name: {other:?}"));
828 }
829 Ok(FieldPath::HttpHeader(Arc::from(name)))
830 }
831 other => Err(format!("unknown field path: {other:?}")),
832 }
833}
834
835fn validate_operator(op: &Operator) -> Result<(), String> {
836 if let Operator::Matches(pattern) = op
837 && pattern.len() > REGEX_PATTERN_MAX_BYTES
838 {
839 return Err(format!(
840 "regex pattern source exceeds {REGEX_PATTERN_MAX_BYTES}-byte limit: got {} bytes",
841 pattern.len(),
842 ));
843 }
844 Ok(())
845}
846
847mod serde_impls {
848 use base64::Engine as _;
849 use base64::engine::general_purpose::STANDARD as B64;
850 use bytes::Bytes;
851 use std::net::IpAddr;
852 use std::sync::Arc;
853
854 use super::{CompiledOperator, CompiledValue};
855
856 pub(super) fn ser_bytes<S: serde::Serializer>(b: &Bytes, s: S) -> Result<S::Ok, S::Error> {
857 s.serialize_str(&B64.encode(b))
858 }
859
860 pub(super) fn de_bytes<'de, D: serde::Deserializer<'de>>(d: D) -> Result<Bytes, D::Error> {
861 use serde::Deserialize as _;
862 let s = String::deserialize(d)?;
863 B64.decode(s.as_bytes()).map(Bytes::from).map_err(serde::de::Error::custom)
864 }
865
866 pub(super) fn ser_regex<S: serde::Serializer>(
867 r: &fancy_regex::Regex,
868 s: S,
869 ) -> Result<S::Ok, S::Error> {
870 s.serialize_str(r.as_str())
871 }
872
873 pub(super) fn de_regex<'de, D: serde::Deserializer<'de>>(
874 d: D,
875 ) -> Result<fancy_regex::Regex, D::Error> {
876 use serde::Deserialize as _;
877 let s = String::deserialize(d)?;
878 fancy_regex::Regex::new(&s)
879 .map_err(|e| serde::de::Error::custom(format!("invalid regex {s:?}: {e}")))
880 }
881
882 #[derive(serde::Serialize, serde::Deserialize)]
884 #[serde(rename_all = "snake_case")]
885 pub(super) enum ValueShadow {
886 Str(Arc<str>),
887 #[serde(serialize_with = "ser_bytes", deserialize_with = "de_bytes")]
888 Bytes(Bytes),
889 Int(i64),
890 Bool(bool),
891 Addr(IpAddr),
892 }
893
894 impl From<&CompiledValue> for ValueShadow {
895 fn from(v: &CompiledValue) -> Self {
896 match v {
897 CompiledValue::Str(s) => Self::Str(Arc::clone(s)),
898 CompiledValue::Bytes(b) => Self::Bytes(b.clone()),
899 CompiledValue::Int(i) => Self::Int(*i),
900 CompiledValue::Bool(b) => Self::Bool(*b),
901 CompiledValue::Addr(a) => Self::Addr(*a),
902 }
903 }
904 }
905
906 impl From<ValueShadow> for CompiledValue {
907 fn from(v: ValueShadow) -> Self {
908 match v {
909 ValueShadow::Str(s) => Self::Str(s),
910 ValueShadow::Bytes(b) => Self::Bytes(b),
911 ValueShadow::Int(i) => Self::Int(i),
912 ValueShadow::Bool(b) => Self::Bool(b),
913 ValueShadow::Addr(a) => Self::Addr(a),
914 }
915 }
916 }
917
918 #[derive(serde::Serialize, serde::Deserialize)]
921 #[serde(rename_all = "snake_case")]
922 pub(super) enum OperatorShadow {
923 Equals(CompiledValue),
924 NotEquals(CompiledValue),
925 #[serde(serialize_with = "ser_bytes", deserialize_with = "de_bytes")]
926 Contains(Bytes),
927 #[serde(serialize_with = "ser_bytes", deserialize_with = "de_bytes")]
928 NotContains(Bytes),
929 #[serde(serialize_with = "ser_bytes", deserialize_with = "de_bytes")]
930 Prefix(Bytes),
931 #[serde(serialize_with = "ser_bytes", deserialize_with = "de_bytes")]
932 Suffix(Bytes),
933 #[serde(serialize_with = "ser_regex", deserialize_with = "de_regex")]
934 Matches(fancy_regex::Regex),
935 In(Vec<CompiledValue>),
936 NotIn(Vec<CompiledValue>),
937 Gt(i64),
938 Gte(i64),
939 Lt(i64),
940 Lte(i64),
941 Cidr(ipnet::IpNet),
942 }
943
944 impl From<&CompiledOperator> for OperatorShadow {
945 fn from(op: &CompiledOperator) -> Self {
946 match op {
947 CompiledOperator::Equals(v) => Self::Equals(v.clone()),
948 CompiledOperator::NotEquals(v) => Self::NotEquals(v.clone()),
949 CompiledOperator::Contains(b) => Self::Contains(b.clone()),
950 CompiledOperator::NotContains(b) => Self::NotContains(b.clone()),
951 CompiledOperator::Prefix(b) => Self::Prefix(b.clone()),
952 CompiledOperator::Suffix(b) => Self::Suffix(b.clone()),
953 CompiledOperator::Matches(r) => {
954 Self::Matches(fancy_regex::Regex::new(r.as_str()).expect("round-trippable"))
955 }
956 CompiledOperator::In(vs) => Self::In(vs.clone()),
957 CompiledOperator::NotIn(vs) => Self::NotIn(vs.clone()),
958 CompiledOperator::Gt(i) => Self::Gt(*i),
959 CompiledOperator::Gte(i) => Self::Gte(*i),
960 CompiledOperator::Lt(i) => Self::Lt(*i),
961 CompiledOperator::Lte(i) => Self::Lte(*i),
962 CompiledOperator::Cidr(n) => Self::Cidr(*n),
963 }
964 }
965 }
966
967 impl From<OperatorShadow> for CompiledOperator {
968 fn from(op: OperatorShadow) -> Self {
969 match op {
970 OperatorShadow::Equals(v) => Self::Equals(v),
971 OperatorShadow::NotEquals(v) => Self::NotEquals(v),
972 OperatorShadow::Contains(b) => Self::Contains(b),
973 OperatorShadow::NotContains(b) => Self::NotContains(b),
974 OperatorShadow::Prefix(b) => Self::Prefix(b),
975 OperatorShadow::Suffix(b) => Self::Suffix(b),
976 OperatorShadow::Matches(r) => Self::Matches(r),
977 OperatorShadow::In(vs) => Self::In(vs),
978 OperatorShadow::NotIn(vs) => Self::NotIn(vs),
979 OperatorShadow::Gt(i) => Self::Gt(i),
980 OperatorShadow::Gte(i) => Self::Gte(i),
981 OperatorShadow::Lt(i) => Self::Lt(i),
982 OperatorShadow::Lte(i) => Self::Lte(i),
983 OperatorShadow::Cidr(n) => Self::Cidr(n),
984 }
985 }
986 }
987}
988
989impl serde::Serialize for CompiledValue {
990 fn serialize<S: serde::Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
991 serde_impls::ValueShadow::from(self).serialize(s)
992 }
993}
994
995impl<'de> serde::Deserialize<'de> for CompiledValue {
996 fn deserialize<D: serde::Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
997 serde_impls::ValueShadow::deserialize(d).map(Self::from)
998 }
999}
1000
1001impl serde::Serialize for CompiledOperator {
1002 fn serialize<S: serde::Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
1003 serde_impls::OperatorShadow::from(self).serialize(s)
1004 }
1005}
1006
1007impl<'de> serde::Deserialize<'de> for CompiledOperator {
1008 fn deserialize<D: serde::Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
1009 serde_impls::OperatorShadow::deserialize(d).map(Self::from)
1010 }
1011}
1012
1013#[cfg(test)]
1014mod tests {
1015 use std::collections::hash_map::DefaultHasher;
1016 use std::hash::Hash;
1017 use std::net::{Ipv4Addr, Ipv6Addr};
1018 use std::str::FromStr;
1019 use std::sync::OnceLock;
1020 use std::time::Instant;
1021
1022 use bytes::Bytes;
1023 use fancy_regex::Regex;
1024 use ipnet::IpNet;
1025 use parking_lot::Mutex;
1026
1027 use super::*;
1028 use crate::body::{Body, Request};
1029 use crate::conn_context::{ConnId, Transport};
1030
1031 fn hash_of<T: Hash>(v: &T) -> u64 {
1035 let mut h = DefaultHasher::new();
1036 v.hash(&mut h);
1037 h.finish()
1038 }
1039
1040 fn make_conn() -> Arc<ConnContext> {
1041 Arc::new(ConnContext {
1042 id: ConnId(1),
1043 remote: "127.0.0.1:0".parse().expect("parse remote"),
1044 local: "127.0.0.1:0".parse().expect("parse local"),
1045 transport: Transport::Tcp,
1046 entered_at: Instant::now(),
1047 tls: Mutex::new(None),
1048 http_version: OnceLock::new(),
1049 user: Mutex::new(http::Extensions::new()),
1050 })
1051 }
1052
1053 #[test]
1054 fn field_path_http_header_is_equal_by_string_content_not_arc_identity() {
1055 let a = FieldPath::HttpHeader(Arc::from("host"));
1056 let b = FieldPath::HttpHeader(Arc::from("host"));
1057 assert_eq!(a, b);
1058 assert_eq!(hash_of(&a), hash_of(&b));
1059 let upper = FieldPath::HttpHeader(Arc::from("Host"));
1064 assert_ne!(a, upper);
1065 }
1066
1067 #[test]
1068 fn field_path_simple_variants_are_self_equal_and_mutually_distinct() {
1069 let paths = [
1070 FieldPath::Transport,
1071 FieldPath::RemoteIp,
1072 FieldPath::RemotePort,
1073 FieldPath::LocalIp,
1074 FieldPath::LocalPort,
1075 FieldPath::Peek,
1076 FieldPath::TlsSni,
1077 FieldPath::TlsAlpn,
1078 FieldPath::TlsVersion,
1079 FieldPath::TlsPeerCertSubjectCn,
1080 FieldPath::HttpMethod,
1081 FieldPath::HttpUriPath,
1082 FieldPath::HttpUriQuery,
1083 FieldPath::HttpBody,
1084 ];
1085 for (i, a) in paths.iter().enumerate() {
1086 for (j, b) in paths.iter().enumerate() {
1087 if i == j {
1088 assert_eq!(a, b);
1089 } else {
1090 assert_ne!(a, b);
1091 }
1092 }
1093 }
1094 }
1095
1096 #[test]
1097 fn compiled_value_str_is_equal_by_content_not_arc_identity() {
1098 let a = CompiledValue::Str(Arc::<str>::from("x"));
1099 let b = CompiledValue::Str(Arc::<str>::from("x"));
1100 assert_eq!(a, b);
1101 assert_eq!(hash_of(&a), hash_of(&b));
1102 let c = CompiledValue::Str(Arc::<str>::from("y"));
1103 assert_ne!(a, c);
1104 }
1105
1106 #[test]
1107 fn compiled_value_cross_variant_inequality() {
1108 let s = CompiledValue::Str(Arc::<str>::from("42"));
1109 let i = CompiledValue::Int(42);
1110 assert_ne!(s, i);
1111 }
1112
1113 #[test]
1114 fn compiled_value_bytes_int_bool_addr_self_equal() {
1115 assert_eq!(
1116 CompiledValue::Bytes(Bytes::from_static(b"abc")),
1117 CompiledValue::Bytes(Bytes::copy_from_slice(b"abc")),
1118 );
1119 assert_eq!(CompiledValue::Int(7), CompiledValue::Int(7));
1120 assert_ne!(CompiledValue::Int(7), CompiledValue::Int(8));
1121 assert_eq!(CompiledValue::Bool(true), CompiledValue::Bool(true));
1122 assert_ne!(CompiledValue::Bool(true), CompiledValue::Bool(false));
1123 assert_eq!(
1124 CompiledValue::Addr(Ipv4Addr::new(10, 0, 0, 1).into()),
1125 CompiledValue::Addr(Ipv4Addr::new(10, 0, 0, 1).into()),
1126 );
1127 assert_ne!(
1128 CompiledValue::Addr(Ipv4Addr::new(10, 0, 0, 1).into()),
1129 CompiledValue::Addr(Ipv6Addr::LOCALHOST.into()),
1130 );
1131 }
1132
1133 #[test]
1134 fn compiled_operator_matches_equal_by_pattern_source() {
1135 let a = CompiledOperator::Matches(Regex::new("^/api").expect("compile a"));
1136 let b = CompiledOperator::Matches(Regex::new("^/api").expect("compile b"));
1137 assert_eq!(a, b);
1138 assert_eq!(hash_of(&a), hash_of(&b));
1139 }
1140
1141 #[test]
1142 fn compiled_operator_matches_distinct_patterns_unequal() {
1143 let a = CompiledOperator::Matches(Regex::new("a|b").expect("compile a"));
1146 let b = CompiledOperator::Matches(Regex::new("b|a").expect("compile b"));
1147 assert_ne!(a, b);
1148 }
1149
1150 #[test]
1151 fn compiled_operator_cidr_equal_by_canonical_form() {
1152 let a = CompiledOperator::Cidr(IpNet::from_str("10.0.0.0/8").expect("parse a"));
1153 let b = CompiledOperator::Cidr(IpNet::from_str("10.0.0.0/8").expect("parse b"));
1154 assert_eq!(a, b);
1155 assert_eq!(hash_of(&a), hash_of(&b));
1156 }
1157
1158 #[test]
1159 fn compiled_operator_cidr_distinct_networks_unequal() {
1160 let a = CompiledOperator::Cidr(IpNet::from_str("10.0.0.0/8").expect("parse a"));
1161 let b = CompiledOperator::Cidr(IpNet::from_str("10.0.0.0/16").expect("parse b"));
1162 assert_ne!(a, b);
1163 }
1164
1165 #[test]
1166 fn compiled_operator_in_is_order_sensitive() {
1167 let xs =
1168 vec![CompiledValue::Str(Arc::<str>::from("a")), CompiledValue::Str(Arc::<str>::from("b"))];
1169 let ys =
1170 vec![CompiledValue::Str(Arc::<str>::from("b")), CompiledValue::Str(Arc::<str>::from("a"))];
1171 assert_ne!(CompiledOperator::In(xs.clone()), CompiledOperator::In(ys.clone()));
1172 assert_ne!(CompiledOperator::NotIn(xs), CompiledOperator::NotIn(ys));
1173 }
1174
1175 #[test]
1176 fn compiled_operator_numeric_comparisons_distinct_per_variant() {
1177 let ops = [
1179 CompiledOperator::Gt(10),
1180 CompiledOperator::Gte(10),
1181 CompiledOperator::Lt(10),
1182 CompiledOperator::Lte(10),
1183 ];
1184 for (i, a) in ops.iter().enumerate() {
1185 for (j, b) in ops.iter().enumerate() {
1186 if i == j {
1187 assert_eq!(a, b);
1188 } else {
1189 assert_ne!(a, b);
1190 }
1191 }
1192 }
1193 }
1194
1195 #[test]
1196 fn compiled_operator_bytes_variants_distinguished() {
1197 let payload = Bytes::from_static(b"abc");
1198 let ops = [
1199 CompiledOperator::Contains(payload.clone()),
1200 CompiledOperator::NotContains(payload.clone()),
1201 CompiledOperator::Prefix(payload.clone()),
1202 CompiledOperator::Suffix(payload),
1203 ];
1204 for (i, a) in ops.iter().enumerate() {
1205 for (j, b) in ops.iter().enumerate() {
1206 if i == j {
1207 assert_eq!(a, b);
1208 } else {
1209 assert_ne!(a, b);
1210 }
1211 }
1212 }
1213 }
1214
1215 #[test]
1216 fn predicate_inst_equal_across_independent_construction_paths() {
1217 let lhs = PredicateInst {
1218 path: FieldPath::HttpHeader(Arc::from("host")),
1219 op: CompiledOperator::Equals(CompiledValue::Str(Arc::<str>::from("example.com"))),
1220 };
1221 let rhs = PredicateInst {
1222 path: FieldPath::HttpHeader(Arc::from("host")),
1223 op: CompiledOperator::Equals(CompiledValue::Str(Arc::<str>::from("example.com"))),
1224 };
1225 assert_eq!(lhs, rhs);
1226 assert_eq!(hash_of(&lhs), hash_of(&rhs));
1227 }
1228
1229 #[test]
1230 fn predicate_inst_equal_with_regex_operator_from_separate_compiles() {
1231 let lhs = PredicateInst {
1232 path: FieldPath::HttpUriPath,
1233 op: CompiledOperator::Matches(Regex::new("^/").expect("compile a")),
1234 };
1235 let rhs = PredicateInst {
1236 path: FieldPath::HttpUriPath,
1237 op: CompiledOperator::Matches(Regex::new("^/").expect("compile b")),
1238 };
1239 assert_eq!(lhs, rhs);
1240 assert_eq!(hash_of(&lhs), hash_of(&rhs));
1241 }
1242
1243 #[test]
1244 fn predicate_inst_unequal_on_path_difference() {
1245 let value = CompiledValue::Str(Arc::<str>::from("x"));
1246 let a =
1247 PredicateInst { path: FieldPath::HttpUriPath, op: CompiledOperator::Equals(value.clone()) };
1248 let b = PredicateInst { path: FieldPath::HttpUriQuery, op: CompiledOperator::Equals(value) };
1249 assert_ne!(a, b);
1250 }
1251
1252 #[test]
1253 fn predicate_view_variants_construct() {
1254 let conn = make_conn();
1255 let peek_bytes: &[u8] = b"\x16\x03\x01";
1256 let l4 = PredicateView::L4 { conn: &conn, peek: Some(peek_bytes) };
1257 match l4 {
1258 PredicateView::L4 { peek, .. } => assert_eq!(peek.map(<[u8]>::len), Some(3)),
1259 PredicateView::L7Req { .. } => panic!("wrong variant"),
1260 }
1261
1262 let conn2 = make_conn();
1263 let req: Request =
1264 http::Request::builder().method("GET").uri("/").body(Body::Empty).expect("build request");
1265 let l7 = PredicateView::L7Req { conn: &conn2, req: &req };
1266 match l7 {
1267 PredicateView::L7Req { .. } => {}
1268 PredicateView::L4 { .. } => panic!("wrong variant"),
1269 }
1270 }
1271
1272 fn parse_predicate(v: serde_json::Value) -> Result<Predicate, serde_json::Error> {
1276 serde_json::from_value(v)
1277 }
1278
1279 fn expect_check(p: &Predicate) -> &CheckMap {
1280 match p {
1281 Predicate::Check(c) => c,
1282 other => panic!("expected Predicate::Check, got {other:?}"),
1283 }
1284 }
1285
1286 #[test]
1287 fn parse_any_of_happy_path() {
1288 let raw = serde_json::json!({
1289 "any_of": [
1290 { "tls.sni": { "equals": "a" } },
1291 { "tls.sni": { "equals": "b" } },
1292 ],
1293 });
1294 let p = parse_predicate(raw).expect("parse any_of");
1295 let Predicate::AnyOf(AnyOfP { any_of }) = p else {
1296 panic!("expected AnyOf");
1297 };
1298 assert_eq!(any_of.len(), 2);
1299 let c0 = expect_check(&any_of[0]);
1300 let c1 = expect_check(&any_of[1]);
1301 assert_eq!(c0.path, FieldPath::TlsSni);
1302 assert_eq!(c1.path, FieldPath::TlsSni);
1303 match (&c0.op, &c1.op) {
1304 (Operator::Equals(Value::Str(a)), Operator::Equals(Value::Str(b))) => {
1305 assert_eq!(a, "a");
1306 assert_eq!(b, "b");
1307 }
1308 (a, b) => panic!("unexpected ops: {a:?} / {b:?}"),
1309 }
1310 }
1311
1312 #[test]
1313 fn parse_not_happy_path() {
1314 let raw = serde_json::json!({
1315 "not": { "tls.sni": { "equals": "internal" } },
1316 });
1317 let p = parse_predicate(raw).expect("parse not");
1318 let Predicate::Not(NotP { not }) = p else {
1319 panic!("expected Not");
1320 };
1321 let inner = expect_check(¬);
1322 assert_eq!(inner.path, FieldPath::TlsSni);
1323 match &inner.op {
1324 Operator::Equals(Value::Str(s)) => assert_eq!(s, "internal"),
1325 other => panic!("unexpected op: {other:?}"),
1326 }
1327 }
1328
1329 #[test]
1330 fn parse_all_of_happy_path() {
1331 let raw = serde_json::json!({
1332 "all_of": [
1333 { "http.header.upgrade": { "equals": "websocket" } },
1334 { "http.uri.path": { "prefix": "/ws" } },
1335 ],
1336 });
1337 let p = parse_predicate(raw).expect("parse all_of");
1338 let Predicate::AllOf(AllOfP { all_of }) = p else {
1339 panic!("expected AllOf");
1340 };
1341 assert_eq!(all_of.len(), 2);
1342 let c0 = expect_check(&all_of[0]);
1343 let c1 = expect_check(&all_of[1]);
1344 assert_eq!(c0.path, FieldPath::HttpHeader(Arc::from("upgrade")));
1345 assert_eq!(c1.path, FieldPath::HttpUriPath);
1346 }
1347
1348 #[test]
1349 fn parse_all_of_empty_array_parses() {
1350 let raw = serde_json::json!({ "all_of": [] });
1353 let p = parse_predicate(raw).expect("empty all_of parses");
1354 let Predicate::AllOf(AllOfP { all_of }) = p else {
1355 panic!("expected AllOf");
1356 };
1357 assert!(all_of.is_empty());
1358 }
1359
1360 #[test]
1361 fn parse_all_of_nested_with_check_and_any_of() {
1362 let raw = serde_json::json!({
1363 "all_of": [
1364 { "tls.sni": { "equals": "api.example.com" } },
1365 { "any_of": [
1366 { "remote.ip": { "cidr": "10.0.0.0/8" } },
1367 { "remote.ip": { "cidr": "192.168.0.0/16" } },
1368 ]},
1369 ],
1370 });
1371 let p = parse_predicate(raw).expect("parse nested all_of/any_of");
1372 let Predicate::AllOf(AllOfP { all_of }) = p else {
1373 panic!("expected AllOf");
1374 };
1375 assert_eq!(all_of.len(), 2);
1376 assert!(matches!(all_of[0], Predicate::Check(_)));
1377 assert!(matches!(all_of[1], Predicate::AnyOf(_)));
1378 }
1379
1380 #[test]
1381 fn parse_all_of_with_extra_key_is_rejected() {
1382 let raw = serde_json::json!({
1384 "all_of": [ { "tls.sni": { "equals": "a" } } ],
1385 "extra": "unwanted",
1386 });
1387 let err = parse_predicate(raw).expect_err("must reject extra key on all_of");
1388 let _ = err.to_string();
1389 }
1390
1391 #[test]
1392 fn parse_http_header_all_of_is_a_check_not_combinator() {
1393 let raw = serde_json::json!({ "http.header.all_of": { "equals": "x" } });
1396 let p = parse_predicate(raw).expect("parse http.header.all_of");
1397 let c = expect_check(&p);
1398 assert_eq!(c.path, FieldPath::HttpHeader(Arc::from("all_of")));
1399 }
1400
1401 #[test]
1402 fn parse_check_across_representative_paths() {
1403 let cases = [
1404 (serde_json::json!({ "tls.sni": { "equals": "api.example.com" } }), FieldPath::TlsSni),
1405 (serde_json::json!({ "remote.port": { "gt": 1024 } }), FieldPath::RemotePort),
1406 (serde_json::json!({ "http.method": { "equals": "GET" } }), FieldPath::HttpMethod),
1407 (serde_json::json!({ "http.uri.path": { "prefix": "/api" } }), FieldPath::HttpUriPath),
1408 (
1409 serde_json::json!({ "http.header.host": { "equals": "a.example.com" } }),
1410 FieldPath::HttpHeader(Arc::from("host")),
1411 ),
1412 (serde_json::json!({ "http.body": { "contains": "hello" } }), FieldPath::HttpBody),
1413 ];
1414 for (raw, expected_path) in cases {
1415 let p = parse_predicate(raw.clone()).unwrap_or_else(|e| panic!("parse {raw}: {e}"));
1416 let c = expect_check(&p);
1417 assert_eq!(c.path, expected_path, "input: {raw}");
1418 }
1419 }
1420
1421 #[test]
1422 fn parse_any_of_with_extra_key_is_rejected() {
1423 let raw = serde_json::json!({
1426 "any_of": [ { "tls.sni": { "equals": "a" } } ],
1427 "extra": true,
1428 });
1429 let err = parse_predicate(raw).expect_err("must reject extra key on any_of");
1430 let _ = err.to_string();
1431 }
1432
1433 #[test]
1434 fn parse_http_header_any_of_is_a_check_not_combinator() {
1435 let raw = serde_json::json!({ "http.header.any_of": { "equals": "x" } });
1438 let p = parse_predicate(raw).expect("parse");
1439 let c = expect_check(&p);
1440 assert_eq!(c.path, FieldPath::HttpHeader(Arc::from("any_of")));
1441 }
1442
1443 #[test]
1444 fn parse_uppercase_field_path_suggests_lowercase() {
1445 let raw = serde_json::json!({ "http.header.Host": { "equals": "x" } });
1446 let err = parse_predicate(raw).expect_err("uppercase must fail");
1447 let msg = err.to_string();
1448 assert!(msg.contains("http.header.Host"), "error mentions offending input: {msg}");
1449 assert!(msg.contains("did you mean"), "error includes suggestion phrase: {msg}");
1450 assert!(msg.contains("http.header.host"), "error contains lowercased form: {msg}");
1451 }
1452
1453 #[test]
1454 fn parse_multi_key_check_is_rejected() {
1455 let raw = serde_json::json!({
1456 "http.uri.path": { "matches": "^/" },
1457 "http.method": { "equals": "GET" },
1458 });
1459 let err = parse_predicate(raw).expect_err("multi-key check must fail");
1460 let _ = err.to_string();
1461 }
1462
1463 #[test]
1464 fn parse_empty_http_header_name_is_rejected() {
1465 let raw = serde_json::json!({ "http.header.": { "equals": "x" } });
1466 let err = parse_predicate(raw).expect_err("empty header name must fail");
1467 let _ = err.to_string();
1468 }
1469
1470 #[test]
1471 fn parse_unknown_field_path_is_rejected_with_name() {
1472 let raw = serde_json::json!({ "http.nope": { "equals": "x" } });
1473 let err = parse_predicate(raw).expect_err("unknown path must fail");
1474 let msg = err.to_string();
1475 assert!(msg.contains("http.nope"), "error mentions offending path: {msg}");
1476 }
1477
1478 fn parse_op(v: serde_json::Value) -> Operator {
1479 let mut map = serde_json::Map::new();
1480 map.insert("tls.sni".to_string(), v);
1481 let raw = serde_json::Value::Object(map);
1482 match parse_predicate(raw).expect("parse check") {
1483 Predicate::Check(c) => c.op,
1484 other => panic!("expected Check, got {other:?}"),
1485 }
1486 }
1487
1488 #[test]
1489 fn operator_equals_and_not_equals_on_string() {
1490 let eq = parse_op(serde_json::json!({ "equals": "api" }));
1491 match eq {
1492 Operator::Equals(Value::Str(s)) => assert_eq!(s, "api"),
1493 other => panic!("expected equals/str: {other:?}"),
1494 }
1495 let neq = parse_op(serde_json::json!({ "not_equals": "api" }));
1496 match neq {
1497 Operator::NotEquals(Value::Str(s)) => assert_eq!(s, "api"),
1498 other => panic!("expected not_equals/str: {other:?}"),
1499 }
1500 }
1501
1502 #[test]
1503 fn operator_contains_and_not_contains_on_string() {
1504 let c = parse_op(serde_json::json!({ "contains": "foo" }));
1505 match c {
1506 Operator::Contains(Value::Str(s)) => assert_eq!(s, "foo"),
1507 other => panic!("expected contains/str: {other:?}"),
1508 }
1509 let nc = parse_op(serde_json::json!({ "not_contains": "foo" }));
1510 match nc {
1511 Operator::NotContains(Value::Str(s)) => assert_eq!(s, "foo"),
1512 other => panic!("expected not_contains/str: {other:?}"),
1513 }
1514 }
1515
1516 #[test]
1517 fn operator_prefix_and_suffix_on_string() {
1518 let p = parse_op(serde_json::json!({ "prefix": "/api" }));
1519 match p {
1520 Operator::Prefix(Value::Str(s)) => assert_eq!(s, "/api"),
1521 other => panic!("expected prefix/str: {other:?}"),
1522 }
1523 let s = parse_op(serde_json::json!({ "suffix": ".json" }));
1524 match s {
1525 Operator::Suffix(Value::Str(v)) => assert_eq!(v, ".json"),
1526 other => panic!("expected suffix/str: {other:?}"),
1527 }
1528 }
1529
1530 #[test]
1531 fn operator_matches_carries_pattern_source() {
1532 let op = parse_op(serde_json::json!({ "matches": "^/api/v\\d+" }));
1533 match op {
1534 Operator::Matches(pattern) => assert_eq!(pattern, "^/api/v\\d+"),
1535 other => panic!("expected matches: {other:?}"),
1536 }
1537 }
1538
1539 #[test]
1540 fn operator_in_and_not_in_accept_mixed_scalar_types() {
1541 let op = parse_op(serde_json::json!({ "in": ["foo", 42] }));
1542 let Operator::In(xs) = op else {
1543 panic!("expected in");
1544 };
1545 assert_eq!(xs.len(), 2);
1546 assert_eq!(xs[0], Value::Str("foo".into()));
1547 assert_eq!(xs[1], Value::Int(42));
1548 let op2 = parse_op(serde_json::json!({ "not_in": ["bar", 7] }));
1549 let Operator::NotIn(ys) = op2 else {
1550 panic!("expected not_in");
1551 };
1552 assert_eq!(ys.len(), 2);
1553 assert_eq!(ys[0], Value::Str("bar".into()));
1554 assert_eq!(ys[1], Value::Int(7));
1555 }
1556
1557 #[test]
1558 fn operator_numeric_comparisons() {
1559 assert!(matches!(parse_op(serde_json::json!({ "gt": 10 })), Operator::Gt(10)));
1560 assert!(matches!(parse_op(serde_json::json!({ "gte": 10 })), Operator::Gte(10)));
1561 assert!(matches!(parse_op(serde_json::json!({ "lt": 10 })), Operator::Lt(10)));
1562 assert!(matches!(parse_op(serde_json::json!({ "lte": 10 })), Operator::Lte(10)));
1563 }
1564
1565 #[test]
1566 fn operator_cidr_carries_source_string() {
1567 let op = parse_op(serde_json::json!({ "cidr": "10.0.0.0/8" }));
1568 match op {
1569 Operator::Cidr(s) => assert_eq!(s, "10.0.0.0/8"),
1570 other => panic!("expected cidr: {other:?}"),
1571 }
1572 }
1573
1574 #[test]
1575 fn value_untagged_priority_bool_before_str() {
1576 let op_t = parse_op(serde_json::json!({ "equals": true }));
1579 assert!(matches!(op_t, Operator::Equals(Value::Bool(true))));
1580 let op_f = parse_op(serde_json::json!({ "equals": false }));
1581 assert!(matches!(op_f, Operator::Equals(Value::Bool(false))));
1582 }
1583
1584 #[test]
1585 fn value_untagged_priority_int_before_str() {
1586 let op = parse_op(serde_json::json!({ "equals": 42 }));
1588 assert!(matches!(op, Operator::Equals(Value::Int(42))));
1589 }
1590
1591 #[test]
1592 fn value_untagged_json_string_stays_str() {
1593 let op = parse_op(serde_json::json!({ "equals": "42" }));
1596 match op {
1597 Operator::Equals(Value::Str(s)) => assert_eq!(s, "42"),
1598 other => panic!("expected equals/str(\"42\"): {other:?}"),
1599 }
1600 }
1601
1602 #[test]
1603 fn regex_pattern_exactly_at_limit_parses() {
1604 assert_eq!(REGEX_PATTERN_MAX_BYTES, 4 * 1024);
1606 let pattern = "a".repeat(REGEX_PATTERN_MAX_BYTES);
1607 let raw = serde_json::json!({ "http.uri.path": { "matches": pattern } });
1608 let p = parse_predicate(raw).expect("4 KiB pattern parses");
1609 let c = expect_check(&p);
1610 match &c.op {
1611 Operator::Matches(src) => assert_eq!(src.len(), REGEX_PATTERN_MAX_BYTES),
1612 other => panic!("expected matches: {other:?}"),
1613 }
1614 }
1615
1616 #[test]
1617 fn regex_pattern_over_limit_rejected_with_limit_in_message() {
1618 let pattern = "a".repeat(REGEX_PATTERN_MAX_BYTES + 1);
1619 let raw = serde_json::json!({ "http.uri.path": { "matches": pattern } });
1620 let err = parse_predicate(raw).expect_err("over-limit pattern must fail");
1621 let msg = err.to_string();
1622 assert!(
1623 msg.contains(®EX_PATTERN_MAX_BYTES.to_string()),
1624 "error mentions the limit ({REGEX_PATTERN_MAX_BYTES}): {msg}",
1625 );
1626 }
1627
1628 fn value_round_trip(v: &CompiledValue) -> CompiledValue {
1637 let encoded = serde_json::to_string(v).expect("serialize value");
1638 serde_json::from_str(&encoded).expect("deserialize value")
1639 }
1640
1641 #[test]
1642 fn compiled_value_str_round_trip_including_empty() {
1643 let non_empty = CompiledValue::Str(Arc::<str>::from("x"));
1644 assert_eq!(value_round_trip(&non_empty), non_empty);
1645 let empty = CompiledValue::Str(Arc::<str>::from(""));
1646 assert_eq!(value_round_trip(&empty), empty);
1647 }
1648
1649 #[test]
1650 fn compiled_value_bytes_round_trip_including_empty_and_binary() {
1651 let hello = CompiledValue::Bytes(Bytes::from_static(b"hello"));
1652 assert_eq!(value_round_trip(&hello), hello);
1653 let empty = CompiledValue::Bytes(Bytes::new());
1654 assert_eq!(value_round_trip(&empty), empty);
1655 let binary = CompiledValue::Bytes(Bytes::from_static(&[0xff, 0x00, 0x13]));
1656 assert_eq!(value_round_trip(&binary), binary);
1657 }
1658
1659 #[test]
1660 fn compiled_value_int_round_trip_including_extremes() {
1661 for i in [0_i64, i64::MIN, i64::MAX] {
1662 let v = CompiledValue::Int(i);
1663 assert_eq!(value_round_trip(&v), v);
1664 }
1665 }
1666
1667 #[test]
1668 fn compiled_value_bool_round_trip_both_variants() {
1669 for b in [true, false] {
1670 let v = CompiledValue::Bool(b);
1671 assert_eq!(value_round_trip(&v), v);
1672 }
1673 }
1674
1675 #[test]
1676 fn compiled_value_addr_round_trip_v4_and_v6() {
1677 let v4 = CompiledValue::Addr(Ipv4Addr::LOCALHOST.into());
1678 assert_eq!(value_round_trip(&v4), v4);
1679 let v6 = CompiledValue::Addr(Ipv6Addr::LOCALHOST.into());
1680 assert_eq!(value_round_trip(&v6), v6);
1681 }
1682
1683 #[test]
1684 fn compiled_value_bytes_emits_standard_base64_literal() {
1685 let v = CompiledValue::Bytes(Bytes::from_static(b"hello"));
1689 let encoded = serde_json::to_string(&v).expect("serialize");
1690 assert_eq!(encoded, r#"{"bytes":"aGVsbG8="}"#);
1691 }
1692
1693 fn op_round_trip(op: &CompiledOperator) -> CompiledOperator {
1694 let encoded = serde_json::to_string(op).expect("serialize op");
1695 serde_json::from_str(&encoded).expect("deserialize op")
1696 }
1697
1698 #[test]
1699 fn compiled_operator_equals_and_not_equals_round_trip() {
1700 let eq = CompiledOperator::Equals(CompiledValue::Str(Arc::<str>::from("x")));
1701 assert_eq!(op_round_trip(&eq), eq);
1702 let neq = CompiledOperator::NotEquals(CompiledValue::Str(Arc::<str>::from("x")));
1703 assert_eq!(op_round_trip(&neq), neq);
1704 }
1705
1706 #[test]
1707 fn compiled_operator_bytes_variants_round_trip() {
1708 let payload = Bytes::from_static(b"hello");
1709 let ops = [
1710 CompiledOperator::Contains(payload.clone()),
1711 CompiledOperator::NotContains(payload.clone()),
1712 CompiledOperator::Prefix(payload.clone()),
1713 CompiledOperator::Suffix(payload),
1714 ];
1715 for op in ops {
1716 assert_eq!(op_round_trip(&op), op);
1717 }
1718 }
1719
1720 #[test]
1721 fn compiled_operator_matches_round_trip_preserves_pattern_source() {
1722 let op = CompiledOperator::Matches(Regex::new("^/api/v[0-9]+").expect("compile"));
1723 let decoded = op_round_trip(&op);
1724 assert_eq!(decoded, op);
1726 match decoded {
1727 CompiledOperator::Matches(r) => assert_eq!(r.as_str(), "^/api/v[0-9]+"),
1728 other => panic!("expected matches, got {other:?}"),
1729 }
1730 }
1731
1732 #[test]
1733 fn compiled_operator_in_and_not_in_round_trip_mixed_values() {
1734 let xs = vec![CompiledValue::Str(Arc::<str>::from("a")), CompiledValue::Int(42)];
1735 let in_op = CompiledOperator::In(xs.clone());
1736 assert_eq!(op_round_trip(&in_op), in_op);
1737 let not_in_op = CompiledOperator::NotIn(xs);
1738 assert_eq!(op_round_trip(¬_in_op), not_in_op);
1739 }
1740
1741 #[test]
1742 fn compiled_operator_numeric_comparisons_round_trip() {
1743 let ops = [
1744 CompiledOperator::Gt(100),
1745 CompiledOperator::Gte(100),
1746 CompiledOperator::Lt(100),
1747 CompiledOperator::Lte(100),
1748 ];
1749 for op in ops {
1750 assert_eq!(op_round_trip(&op), op);
1751 }
1752 }
1753
1754 #[test]
1755 fn compiled_operator_cidr_round_trip_preserves_canonical_form() {
1756 let op = CompiledOperator::Cidr(IpNet::from_str("10.0.0.0/8").expect("parse"));
1757 assert_eq!(op_round_trip(&op), op);
1758 }
1759
1760 #[test]
1761 fn compiled_operator_matches_with_invalid_regex_is_rejected() {
1762 let raw = r#"{"matches":"["}"#;
1766 let err = serde_json::from_str::<CompiledOperator>(raw)
1767 .expect_err("invalid regex must fail to deserialize");
1768 let msg = err.to_string();
1769 assert!(msg.contains('['), "error mentions offending regex source: {msg}");
1770 }
1771
1772 #[test]
1773 fn predicate_inst_pins_exact_wire_shape_for_http_header_equals() {
1774 let inst = PredicateInst {
1775 path: FieldPath::HttpHeader(Arc::from("host")),
1776 op: CompiledOperator::Equals(CompiledValue::Str(Arc::<str>::from("example.com"))),
1777 };
1778 let encoded = serde_json::to_string(&inst).expect("serialize");
1779 assert_eq!(encoded, r#"{"path":{"http_header":"host"},"op":{"equals":{"str":"example.com"}}}"#,);
1780 let decoded: PredicateInst = serde_json::from_str(&encoded).expect("deserialize");
1781 assert_eq!(decoded, inst);
1782 }
1783
1784 #[test]
1785 fn predicate_inst_round_trip_with_regex_operator() {
1786 let inst = PredicateInst {
1787 path: FieldPath::HttpUriPath,
1788 op: CompiledOperator::Matches(Regex::new("^/api").expect("compile")),
1789 };
1790 let encoded = serde_json::to_string(&inst).expect("serialize");
1791 let decoded: PredicateInst = serde_json::from_str(&encoded).expect("deserialize");
1792 assert_eq!(decoded, inst);
1793 }
1794
1795 fn http_header_equals(name: &str, value: &str) -> PredicateInst {
1803 PredicateInst {
1804 path: FieldPath::HttpHeader(Arc::from(name)),
1805 op: CompiledOperator::Equals(CompiledValue::Str(Arc::<str>::from(value))),
1806 }
1807 }
1808
1809 fn http_uri_path_equals(value: &str) -> PredicateInst {
1810 PredicateInst {
1811 path: FieldPath::HttpUriPath,
1812 op: CompiledOperator::Equals(CompiledValue::Str(Arc::<str>::from(value))),
1813 }
1814 }
1815
1816 fn http_uri_path_prefix(value: &str) -> PredicateInst {
1817 PredicateInst {
1818 path: FieldPath::HttpUriPath,
1819 op: CompiledOperator::Prefix(Bytes::copy_from_slice(value.as_bytes())),
1820 }
1821 }
1822
1823 fn tls_sni_equals(value: &str) -> PredicateInst {
1824 PredicateInst {
1825 path: FieldPath::TlsSni,
1826 op: CompiledOperator::Equals(CompiledValue::Str(Arc::<str>::from(value))),
1827 }
1828 }
1829
1830 fn conn_with_sni(sni: &str) -> Arc<ConnContext> {
1831 let conn = make_conn();
1832 *conn.tls.lock() = Some(crate::conn_context::TlsInfo {
1833 sni: Some(sni.to_string()),
1834 alpn: None,
1835 version: None,
1836 peer_cert: None,
1837 });
1838 conn
1839 }
1840
1841 fn req_with_header(name: &str, value: &str) -> Request {
1842 http::Request::builder()
1843 .method("GET")
1844 .uri("/")
1845 .header(name, value)
1846 .body(Body::Empty)
1847 .expect("build req")
1848 }
1849
1850 fn req_with_uri(uri: &str) -> Request {
1851 http::Request::builder().method("GET").uri(uri).body(Body::Empty).expect("build req")
1852 }
1853
1854 #[test]
1855 fn predicate_test_http_header_equals_matches_when_present_and_equal() {
1856 let conn = make_conn();
1857 let req = req_with_header("upgrade", "websocket");
1858 let view = PredicateView::L7Req { conn: &conn, req: &req };
1859 assert!(http_header_equals("upgrade", "websocket").test(&view));
1860 }
1861
1862 #[test]
1863 fn predicate_test_http_header_equals_misses_when_header_absent() {
1864 let conn = make_conn();
1865 let req = req_with_header("host", "example.com");
1866 let view = PredicateView::L7Req { conn: &conn, req: &req };
1867 assert!(!http_header_equals("upgrade", "websocket").test(&view));
1868 }
1869
1870 #[test]
1871 fn predicate_test_http_header_equals_value_is_case_sensitive() {
1872 let conn = make_conn();
1877 let req = req_with_header("upgrade", "WebSocket");
1878 let view = PredicateView::L7Req { conn: &conn, req: &req };
1879 assert!(!http_header_equals("upgrade", "websocket").test(&view));
1880 }
1881
1882 #[test]
1883 fn predicate_test_http_header_equals_name_lookup_is_case_insensitive() {
1884 let conn = make_conn();
1890 let req = req_with_header("Upgrade", "websocket");
1891 let view = PredicateView::L7Req { conn: &conn, req: &req };
1892 assert!(http_header_equals("upgrade", "websocket").test(&view));
1893 }
1894
1895 #[test]
1896 fn predicate_test_http_header_equals_misses_on_l4_view() {
1897 let conn = make_conn();
1901 let view = PredicateView::L4 { conn: &conn, peek: None };
1902 assert!(!http_header_equals("upgrade", "websocket").test(&view));
1903 }
1904
1905 #[test]
1906 fn predicate_test_http_uri_path_equals_matches_exact() {
1907 let conn = make_conn();
1908 let req = req_with_uri("/api/v1/users");
1909 let view = PredicateView::L7Req { conn: &conn, req: &req };
1910 assert!(http_uri_path_equals("/api/v1/users").test(&view));
1911 }
1912
1913 #[test]
1914 fn predicate_test_http_uri_path_equals_misses_on_substring() {
1915 let conn = make_conn();
1919 let req = req_with_uri("/api/v1/users");
1920 let view = PredicateView::L7Req { conn: &conn, req: &req };
1921 assert!(!http_uri_path_equals("/api").test(&view));
1922 }
1923
1924 #[test]
1925 fn predicate_test_http_uri_path_prefix_matches_when_path_starts_with() {
1926 let conn = make_conn();
1927 let req = req_with_uri("/api/v1/users");
1928 let view = PredicateView::L7Req { conn: &conn, req: &req };
1929 assert!(http_uri_path_prefix("/api").test(&view));
1930 }
1931
1932 #[test]
1933 fn predicate_test_http_uri_path_prefix_misses_when_no_prefix() {
1934 let conn = make_conn();
1935 let req = req_with_uri("/admin");
1936 let view = PredicateView::L7Req { conn: &conn, req: &req };
1937 assert!(!http_uri_path_prefix("/api").test(&view));
1938 }
1939
1940 #[test]
1941 fn predicate_test_tls_sni_equals_matches_when_set() {
1942 let conn = conn_with_sni("api.example.com");
1946 let req = req_with_uri("/");
1947 let view = PredicateView::L7Req { conn: &conn, req: &req };
1948 assert!(tls_sni_equals("api.example.com").test(&view));
1949 }
1950
1951 #[test]
1952 fn predicate_test_tls_sni_equals_misses_when_unset() {
1953 let conn = make_conn();
1956 let req = req_with_uri("/");
1957 let view = PredicateView::L7Req { conn: &conn, req: &req };
1958 assert!(!tls_sni_equals("api.example.com").test(&view));
1959 }
1960
1961 #[test]
1962 fn predicate_test_tls_sni_equals_works_in_l4_view_too() {
1963 let conn = conn_with_sni("api.example.com");
1969 let view = PredicateView::L4 { conn: &conn, peek: None };
1970 assert!(tls_sni_equals("api.example.com").test(&view));
1971 }
1972
1973 fn pred(path: FieldPath, op: CompiledOperator) -> PredicateInst {
1983 PredicateInst { path, op }
1984 }
1985
1986 fn str_val(s: &str) -> CompiledValue {
1987 CompiledValue::Str(Arc::<str>::from(s))
1988 }
1989
1990 fn bytes_val(b: &[u8]) -> CompiledValue {
1991 CompiledValue::Bytes(Bytes::copy_from_slice(b))
1992 }
1993
1994 fn b(b: &[u8]) -> Bytes {
1995 Bytes::copy_from_slice(b)
1996 }
1997
1998 fn make_conn_with(remote: &str, local: &str) -> Arc<ConnContext> {
1999 Arc::new(ConnContext {
2000 id: ConnId(1),
2001 remote: remote.parse().expect("parse remote"),
2002 local: local.parse().expect("parse local"),
2003 transport: Transport::Tcp,
2004 entered_at: Instant::now(),
2005 tls: Mutex::new(None),
2006 http_version: OnceLock::new(),
2007 user: Mutex::new(http::Extensions::new()),
2008 })
2009 }
2010
2011 fn make_conn_with_transport(t: Transport) -> Arc<ConnContext> {
2012 Arc::new(ConnContext {
2013 id: ConnId(1),
2014 remote: "127.0.0.1:0".parse().expect("remote"),
2015 local: "127.0.0.1:0".parse().expect("local"),
2016 transport: t,
2017 entered_at: Instant::now(),
2018 tls: Mutex::new(None),
2019 http_version: OnceLock::new(),
2020 user: Mutex::new(http::Extensions::new()),
2021 })
2022 }
2023
2024 fn conn_with_tls_alpn(alpn: &[u8]) -> Arc<ConnContext> {
2025 let conn = make_conn();
2026 *conn.tls.lock() = Some(crate::conn_context::TlsInfo {
2027 sni: None,
2028 alpn: Some(alpn.to_vec()),
2029 version: None,
2030 peer_cert: None,
2031 });
2032 conn
2033 }
2034
2035 fn conn_with_tls_version(v: crate::conn_context::TlsVersion) -> Arc<ConnContext> {
2036 let conn = make_conn();
2037 *conn.tls.lock() = Some(crate::conn_context::TlsInfo {
2038 sni: None,
2039 alpn: None,
2040 version: Some(v),
2041 peer_cert: None,
2042 });
2043 conn
2044 }
2045
2046 #[test]
2049 fn matrix_equality_str_happy_and_miss() {
2050 let conn = conn_with_sni("api.example.com");
2052 let v = PredicateView::L4 { conn: &conn, peek: None };
2053 assert!(pred(FieldPath::TlsSni, CompiledOperator::Equals(str_val("api.example.com"))).test(&v));
2054 assert!(
2055 !pred(FieldPath::TlsSni, CompiledOperator::Equals(str_val("other.example.com"))).test(&v)
2056 );
2057 assert!(
2058 pred(FieldPath::TlsSni, CompiledOperator::NotEquals(str_val("other.example.com"))).test(&v)
2059 );
2060 assert!(
2061 !pred(FieldPath::TlsSni, CompiledOperator::NotEquals(str_val("api.example.com"))).test(&v)
2062 );
2063 }
2064
2065 #[test]
2066 fn matrix_equality_bytes_happy_and_miss() {
2067 let conn = conn_with_tls_alpn(b"h2");
2069 let v = PredicateView::L4 { conn: &conn, peek: None };
2070 assert!(pred(FieldPath::TlsAlpn, CompiledOperator::Equals(bytes_val(b"h2"))).test(&v));
2071 assert!(!pred(FieldPath::TlsAlpn, CompiledOperator::Equals(bytes_val(b"http/1.1"))).test(&v));
2072 assert!(pred(FieldPath::TlsAlpn, CompiledOperator::NotEquals(bytes_val(b"http/1.1"))).test(&v));
2073 assert!(!pred(FieldPath::TlsAlpn, CompiledOperator::NotEquals(bytes_val(b"h2"))).test(&v));
2074 }
2075
2076 #[test]
2077 fn matrix_equality_int_happy_and_miss() {
2078 let conn = make_conn_with("127.0.0.1:9090", "127.0.0.1:80");
2079 let v = PredicateView::L4 { conn: &conn, peek: None };
2080 assert!(
2081 pred(FieldPath::RemotePort, CompiledOperator::Equals(CompiledValue::Int(9090))).test(&v)
2082 );
2083 assert!(
2084 !pred(FieldPath::RemotePort, CompiledOperator::Equals(CompiledValue::Int(81))).test(&v)
2085 );
2086 assert!(
2087 pred(FieldPath::RemotePort, CompiledOperator::NotEquals(CompiledValue::Int(81))).test(&v)
2088 );
2089 assert!(
2090 !pred(FieldPath::RemotePort, CompiledOperator::NotEquals(CompiledValue::Int(9090))).test(&v)
2091 );
2092 }
2093
2094 #[test]
2095 fn matrix_equality_addr_happy_and_miss() {
2096 let conn = make_conn_with("10.0.0.5:55555", "127.0.0.1:80");
2097 let v = PredicateView::L4 { conn: &conn, peek: None };
2098 let ten: std::net::IpAddr = "10.0.0.5".parse().unwrap();
2099 let other: std::net::IpAddr = "10.0.0.6".parse().unwrap();
2100 assert!(pred(FieldPath::RemoteIp, CompiledOperator::Equals(CompiledValue::Addr(ten))).test(&v));
2101 assert!(
2102 !pred(FieldPath::RemoteIp, CompiledOperator::Equals(CompiledValue::Addr(other))).test(&v)
2103 );
2104 assert!(
2105 pred(FieldPath::RemoteIp, CompiledOperator::NotEquals(CompiledValue::Addr(other))).test(&v)
2106 );
2107 assert!(
2108 !pred(FieldPath::RemoteIp, CompiledOperator::NotEquals(CompiledValue::Addr(ten))).test(&v)
2109 );
2110 }
2111
2112 #[test]
2113 fn matrix_equality_enum_transport_happy_and_miss() {
2114 let tcp = make_conn_with_transport(Transport::Tcp);
2115 let udp = make_conn_with_transport(Transport::Udp);
2116 let v_tcp = PredicateView::L4 { conn: &tcp, peek: None };
2117 let v_udp = PredicateView::L4 { conn: &udp, peek: None };
2118 assert!(pred(FieldPath::Transport, CompiledOperator::Equals(str_val("tcp"))).test(&v_tcp));
2119 assert!(!pred(FieldPath::Transport, CompiledOperator::Equals(str_val("udp"))).test(&v_tcp));
2120 assert!(pred(FieldPath::Transport, CompiledOperator::Equals(str_val("udp"))).test(&v_udp));
2121 }
2122
2123 #[test]
2124 fn matrix_equality_enum_tls_version_happy_and_miss() {
2125 let conn = conn_with_tls_version(crate::conn_context::TlsVersion::Tls13);
2126 let v = PredicateView::L4 { conn: &conn, peek: None };
2127 assert!(pred(FieldPath::TlsVersion, CompiledOperator::Equals(str_val("1.3"))).test(&v));
2128 assert!(!pred(FieldPath::TlsVersion, CompiledOperator::Equals(str_val("1.2"))).test(&v));
2129 assert!(pred(FieldPath::TlsVersion, CompiledOperator::NotEquals(str_val("1.2"))).test(&v));
2130 }
2131
2132 #[test]
2133 fn matrix_equality_enum_tls_version_misses_when_absent() {
2134 let conn = make_conn();
2136 let v = PredicateView::L4 { conn: &conn, peek: None };
2137 assert!(!pred(FieldPath::TlsVersion, CompiledOperator::Equals(str_val("1.3"))).test(&v));
2138 assert!(!pred(FieldPath::TlsVersion, CompiledOperator::NotEquals(str_val("1.3"))).test(&v));
2140 }
2141
2142 #[test]
2143 fn matrix_equality_enum_http_method_happy_and_miss() {
2144 let conn = make_conn();
2145 let req = http::Request::builder().method("POST").uri("/").body(Body::Empty).unwrap();
2146 let v = PredicateView::L7Req { conn: &conn, req: &req };
2147 assert!(pred(FieldPath::HttpMethod, CompiledOperator::Equals(str_val("POST"))).test(&v));
2148 assert!(!pred(FieldPath::HttpMethod, CompiledOperator::Equals(str_val("GET"))).test(&v));
2149 assert!(pred(FieldPath::HttpMethod, CompiledOperator::NotEquals(str_val("GET"))).test(&v));
2150 }
2151
2152 #[test]
2155 fn matrix_in_list_str_happy_and_miss() {
2156 let conn = conn_with_sni("api.example.com");
2157 let v = PredicateView::L4 { conn: &conn, peek: None };
2158 let list = vec![str_val("a.example.com"), str_val("api.example.com")];
2159 assert!(pred(FieldPath::TlsSni, CompiledOperator::In(list.clone())).test(&v));
2160 let list_miss = vec![str_val("a.example.com"), str_val("b.example.com")];
2161 assert!(!pred(FieldPath::TlsSni, CompiledOperator::In(list_miss.clone())).test(&v));
2162 assert!(pred(FieldPath::TlsSni, CompiledOperator::NotIn(list_miss)).test(&v));
2163 assert!(!pred(FieldPath::TlsSni, CompiledOperator::NotIn(list)).test(&v));
2164 }
2165
2166 #[test]
2167 fn matrix_in_list_bytes_happy_and_miss() {
2168 let conn = conn_with_tls_alpn(b"h2");
2169 let v = PredicateView::L4 { conn: &conn, peek: None };
2170 let list = vec![bytes_val(b"http/1.1"), bytes_val(b"h2")];
2171 assert!(pred(FieldPath::TlsAlpn, CompiledOperator::In(list.clone())).test(&v));
2172 let list_miss = vec![bytes_val(b"http/1.0"), bytes_val(b"http/1.1")];
2173 assert!(!pred(FieldPath::TlsAlpn, CompiledOperator::In(list_miss.clone())).test(&v));
2174 assert!(pred(FieldPath::TlsAlpn, CompiledOperator::NotIn(list_miss)).test(&v));
2175 }
2176
2177 #[test]
2178 fn matrix_in_list_int_happy_and_miss() {
2179 let conn = make_conn_with("127.0.0.1:443", "127.0.0.1:80");
2180 let v = PredicateView::L4 { conn: &conn, peek: None };
2181 let in_list = vec![CompiledValue::Int(80), CompiledValue::Int(443)];
2182 assert!(pred(FieldPath::RemotePort, CompiledOperator::In(in_list.clone())).test(&v));
2183 let miss_list = vec![CompiledValue::Int(80), CompiledValue::Int(81)];
2184 assert!(!pred(FieldPath::RemotePort, CompiledOperator::In(miss_list.clone())).test(&v));
2185 assert!(pred(FieldPath::RemotePort, CompiledOperator::NotIn(miss_list)).test(&v));
2186 }
2187
2188 #[test]
2189 fn matrix_in_list_addr_happy_and_miss_mixed_family() {
2190 let conn = make_conn_with("10.0.0.5:55555", "127.0.0.1:80");
2191 let v = PredicateView::L4 { conn: &conn, peek: None };
2192 let v4: std::net::IpAddr = "10.0.0.5".parse().unwrap();
2193 let v6: std::net::IpAddr = "::1".parse().unwrap();
2194 let list = vec![CompiledValue::Addr(v6), CompiledValue::Addr(v4)];
2195 assert!(pred(FieldPath::RemoteIp, CompiledOperator::In(list.clone())).test(&v));
2196 let miss = vec![CompiledValue::Addr(v6)];
2197 assert!(!pred(FieldPath::RemoteIp, CompiledOperator::In(miss.clone())).test(&v));
2198 assert!(pred(FieldPath::RemoteIp, CompiledOperator::NotIn(miss)).test(&v));
2199 }
2200
2201 #[test]
2202 fn matrix_in_list_enum_transport_happy_and_miss() {
2203 let conn = make_conn_with_transport(Transport::Udp);
2204 let v = PredicateView::L4 { conn: &conn, peek: None };
2205 let list = vec![str_val("tcp"), str_val("udp")];
2206 assert!(pred(FieldPath::Transport, CompiledOperator::In(list)).test(&v));
2207 let miss = vec![str_val("tcp")];
2208 assert!(!pred(FieldPath::Transport, CompiledOperator::In(miss.clone())).test(&v));
2209 assert!(pred(FieldPath::Transport, CompiledOperator::NotIn(miss)).test(&v));
2210 }
2211
2212 #[test]
2215 fn matrix_substring_on_str_happy_and_miss() {
2216 let conn = make_conn();
2217 let req =
2218 http::Request::builder().method("GET").uri("/api/v1/users").body(Body::Empty).unwrap();
2219 let v = PredicateView::L7Req { conn: &conn, req: &req };
2220 assert!(pred(FieldPath::HttpUriPath, CompiledOperator::Contains(b(b"/v1/"))).test(&v));
2221 assert!(!pred(FieldPath::HttpUriPath, CompiledOperator::Contains(b(b"/v2/"))).test(&v));
2222 assert!(pred(FieldPath::HttpUriPath, CompiledOperator::NotContains(b(b"/v2/"))).test(&v));
2223 assert!(!pred(FieldPath::HttpUriPath, CompiledOperator::NotContains(b(b"/v1/"))).test(&v));
2224 }
2225
2226 #[test]
2227 fn matrix_substring_on_bytes_happy_and_miss() {
2228 let conn = conn_with_tls_alpn(b"http/1.1");
2229 let v = PredicateView::L4 { conn: &conn, peek: None };
2230 assert!(pred(FieldPath::TlsAlpn, CompiledOperator::Contains(b(b"/1."))).test(&v));
2231 assert!(!pred(FieldPath::TlsAlpn, CompiledOperator::Contains(b(b"/2."))).test(&v));
2232 assert!(pred(FieldPath::TlsAlpn, CompiledOperator::NotContains(b(b"/2."))).test(&v));
2233 }
2234
2235 #[test]
2238 fn matrix_prefix_suffix_on_str_happy_and_miss() {
2239 let conn = make_conn();
2240 let req =
2241 http::Request::builder().method("GET").uri("/api/file.json?q=1").body(Body::Empty).unwrap();
2242 let v = PredicateView::L7Req { conn: &conn, req: &req };
2243 assert!(pred(FieldPath::HttpUriPath, CompiledOperator::Prefix(b(b"/api"))).test(&v));
2244 assert!(!pred(FieldPath::HttpUriPath, CompiledOperator::Prefix(b(b"/admin"))).test(&v));
2245 assert!(pred(FieldPath::HttpUriPath, CompiledOperator::Suffix(b(b".json"))).test(&v));
2246 assert!(!pred(FieldPath::HttpUriPath, CompiledOperator::Suffix(b(b".html"))).test(&v));
2247 }
2248
2249 #[test]
2250 fn matrix_prefix_suffix_on_bytes_happy_and_miss() {
2251 let conn = conn_with_tls_alpn(b"http/1.1");
2252 let v = PredicateView::L4 { conn: &conn, peek: None };
2253 assert!(pred(FieldPath::TlsAlpn, CompiledOperator::Prefix(b(b"http"))).test(&v));
2254 assert!(!pred(FieldPath::TlsAlpn, CompiledOperator::Prefix(b(b"h2"))).test(&v));
2255 assert!(pred(FieldPath::TlsAlpn, CompiledOperator::Suffix(b(b"1.1"))).test(&v));
2256 assert!(!pred(FieldPath::TlsAlpn, CompiledOperator::Suffix(b(b"2.0"))).test(&v));
2257 }
2258
2259 #[test]
2262 fn matrix_regex_matches_on_str_happy_and_miss() {
2263 let conn = make_conn();
2264 let req =
2265 http::Request::builder().method("GET").uri("/api/v3/orders").body(Body::Empty).unwrap();
2266 let v = PredicateView::L7Req { conn: &conn, req: &req };
2267 let re = Regex::new(r"^/api/v\d+/orders").expect("compile regex");
2268 assert!(pred(FieldPath::HttpUriPath, CompiledOperator::Matches(re)).test(&v));
2269 let re_miss = Regex::new(r"^/admin").expect("compile regex");
2270 assert!(!pred(FieldPath::HttpUriPath, CompiledOperator::Matches(re_miss)).test(&v));
2271 }
2272
2273 #[test]
2274 fn matrix_regex_matches_on_header_happy_and_miss() {
2275 let conn = make_conn();
2276 let req = http::Request::builder()
2277 .method("GET")
2278 .uri("/")
2279 .header("user-agent", "Mozilla/5.0 (Macintosh; Intel)")
2280 .body(Body::Empty)
2281 .unwrap();
2282 let v = PredicateView::L7Req { conn: &conn, req: &req };
2283 let re = Regex::new(r"(?i)mozilla").expect("compile");
2284 assert!(
2285 pred(FieldPath::HttpHeader(Arc::from("user-agent")), CompiledOperator::Matches(re)).test(&v)
2286 );
2287 let re_miss = Regex::new(r"^curl").expect("compile");
2288 assert!(
2289 !pred(FieldPath::HttpHeader(Arc::from("user-agent")), CompiledOperator::Matches(re_miss))
2290 .test(&v)
2291 );
2292 }
2293
2294 #[test]
2297 fn matrix_numeric_cmp_gt_gte_lt_lte_happy_and_miss() {
2298 let conn = make_conn_with("127.0.0.1:1024", "127.0.0.1:443");
2299 let v = PredicateView::L4 { conn: &conn, peek: None };
2300 assert!(pred(FieldPath::RemotePort, CompiledOperator::Gt(1023)).test(&v));
2302 assert!(!pred(FieldPath::RemotePort, CompiledOperator::Gt(1024)).test(&v));
2303 assert!(pred(FieldPath::RemotePort, CompiledOperator::Gte(1024)).test(&v));
2305 assert!(!pred(FieldPath::RemotePort, CompiledOperator::Gte(1025)).test(&v));
2306 assert!(pred(FieldPath::RemotePort, CompiledOperator::Lt(1025)).test(&v));
2308 assert!(!pred(FieldPath::RemotePort, CompiledOperator::Lt(1024)).test(&v));
2309 assert!(pred(FieldPath::RemotePort, CompiledOperator::Lte(1024)).test(&v));
2311 assert!(!pred(FieldPath::RemotePort, CompiledOperator::Lte(1023)).test(&v));
2312 }
2313
2314 #[test]
2315 fn matrix_numeric_cmp_local_port_too() {
2316 let conn = make_conn_with("127.0.0.1:0", "127.0.0.1:8443");
2318 let v = PredicateView::L4 { conn: &conn, peek: None };
2319 assert!(pred(FieldPath::LocalPort, CompiledOperator::Gt(8000)).test(&v));
2320 assert!(!pred(FieldPath::LocalPort, CompiledOperator::Gt(9000)).test(&v));
2321 }
2322
2323 #[test]
2326 fn matrix_cidr_v4_happy_and_miss() {
2327 let conn = make_conn_with("10.0.5.7:0", "127.0.0.1:0");
2328 let v = PredicateView::L4 { conn: &conn, peek: None };
2329 let ten = IpNet::from_str("10.0.0.0/8").unwrap();
2330 let nineteen2 = IpNet::from_str("192.168.0.0/16").unwrap();
2331 assert!(pred(FieldPath::RemoteIp, CompiledOperator::Cidr(ten)).test(&v));
2332 assert!(!pred(FieldPath::RemoteIp, CompiledOperator::Cidr(nineteen2)).test(&v));
2333 }
2334
2335 #[test]
2336 fn matrix_cidr_v6_happy_and_miss() {
2337 let conn = make_conn_with("[2001:db8::5]:0", "127.0.0.1:0");
2338 let v = PredicateView::L4 { conn: &conn, peek: None };
2339 let net = IpNet::from_str("2001:db8::/32").unwrap();
2340 let other = IpNet::from_str("2001:dead::/32").unwrap();
2341 assert!(pred(FieldPath::RemoteIp, CompiledOperator::Cidr(net)).test(&v));
2342 assert!(!pred(FieldPath::RemoteIp, CompiledOperator::Cidr(other)).test(&v));
2343 }
2344
2345 #[test]
2346 fn matrix_cidr_v4_against_v6_addr_misses() {
2347 let conn = make_conn_with("[2001:db8::5]:0", "127.0.0.1:0");
2349 let v = PredicateView::L4 { conn: &conn, peek: None };
2350 let v4 = IpNet::from_str("0.0.0.0/0").unwrap();
2351 assert!(!pred(FieldPath::RemoteIp, CompiledOperator::Cidr(v4)).test(&v));
2352 }
2353
2354 #[test]
2358 fn http_uri_query_reader_returns_empty_when_query_absent() {
2359 let conn = make_conn();
2362 let req = http::Request::builder().method("GET").uri("/no-q").body(Body::Empty).unwrap();
2363 let v = PredicateView::L7Req { conn: &conn, req: &req };
2364 assert!(pred(FieldPath::HttpUriQuery, CompiledOperator::Equals(str_val(""))).test(&v));
2365 assert!(!pred(FieldPath::HttpUriQuery, CompiledOperator::Equals(str_val("q=1"))).test(&v));
2366 }
2367
2368 #[test]
2369 fn http_uri_query_reader_matches_present_query() {
2370 let conn = make_conn();
2371 let req = http::Request::builder().method("GET").uri("/x?a=1&b=2").body(Body::Empty).unwrap();
2372 let v = PredicateView::L7Req { conn: &conn, req: &req };
2373 assert!(pred(FieldPath::HttpUriQuery, CompiledOperator::Equals(str_val("a=1&b=2"))).test(&v));
2374 assert!(pred(FieldPath::HttpUriQuery, CompiledOperator::Contains(b(b"b=2"))).test(&v));
2375 }
2376
2377 #[test]
2378 fn local_ip_reader_uses_local_socket() {
2379 let conn = make_conn_with("10.0.0.5:0", "127.0.0.1:8443");
2380 let v = PredicateView::L4 { conn: &conn, peek: None };
2381 let local: std::net::IpAddr = "127.0.0.1".parse().unwrap();
2382 assert!(
2383 pred(FieldPath::LocalIp, CompiledOperator::Equals(CompiledValue::Addr(local))).test(&v)
2384 );
2385 }
2386
2387 #[test]
2388 fn http_header_lookup_misses_for_non_utf8_value() {
2389 let conn = make_conn();
2392 let bad =
2393 http::HeaderValue::from_bytes(&[0xff, 0xfe, 0xfd]).expect("non-utf8 header value parses");
2394 let mut builder = http::Request::builder().method("GET").uri("/");
2395 builder.headers_mut().expect("headers").insert("x-bad", bad);
2396 let req: Request = builder.body(Body::Empty).expect("build request");
2397 let v = PredicateView::L7Req { conn: &conn, req: &req };
2398 assert!(
2399 !pred(
2400 FieldPath::HttpHeader(Arc::from("x-bad")),
2401 CompiledOperator::Equals(str_val("anything")),
2402 )
2403 .test(&v)
2404 );
2405 }
2406
2407 fn rcgen_cert_with_cn(cn: &str) -> rustls_pki_types::CertificateDer<'static> {
2410 let mut params = rcgen::CertificateParams::default();
2411 params.distinguished_name = rcgen::DistinguishedName::new();
2412 params.distinguished_name.push(rcgen::DnType::CommonName, cn);
2413 let key = rcgen::KeyPair::generate().expect("rcgen keypair");
2414 let cert = params.self_signed(&key).expect("self-sign cert");
2415 cert.der().clone()
2416 }
2417
2418 fn rcgen_cert_no_cn() -> rustls_pki_types::CertificateDer<'static> {
2419 let params = rcgen::CertificateParams::default();
2422 let mut params = params;
2425 params.distinguished_name = rcgen::DistinguishedName::new();
2426 let key = rcgen::KeyPair::generate().expect("rcgen keypair");
2427 let cert = params.self_signed(&key).expect("self-sign cert");
2428 cert.der().clone()
2429 }
2430
2431 fn conn_with_peer_cert(cert: &rustls_pki_types::CertificateDer<'static>) -> Arc<ConnContext> {
2432 let pc = crate::conn_context::PeerCertificate::from_der(cert)
2433 .expect("rcgen-issued cert must parse via PeerCertificate::from_der");
2434 let conn = make_conn();
2435 *conn.tls.lock() = Some(crate::conn_context::TlsInfo {
2436 sni: None,
2437 alpn: None,
2438 version: None,
2439 peer_cert: Some(Arc::new(pc)),
2440 });
2441 conn
2442 }
2443
2444 #[test]
2445 fn peer_cert_from_der_extracts_cn() {
2446 let cert = rcgen_cert_with_cn("client.internal");
2447 let pc = crate::conn_context::PeerCertificate::from_der(&cert).expect("parse");
2448 assert_eq!(pc.subject_cn.as_deref(), Some("client.internal"));
2449 }
2450
2451 #[test]
2452 fn peer_cert_from_der_returns_none_for_malformed_der() {
2453 let raw = rustls_pki_types::CertificateDer::from(vec![0x30, 0x80, 0x00, 0x00]);
2454 assert!(crate::conn_context::PeerCertificate::from_der(&raw).is_none());
2455 let raw = rustls_pki_types::CertificateDer::from(b"not a cert at all".to_vec());
2456 assert!(crate::conn_context::PeerCertificate::from_der(&raw).is_none());
2457 }
2458
2459 #[test]
2460 fn peer_cert_from_der_returns_some_with_no_cn_when_dn_has_no_cn() {
2461 let cert = rcgen_cert_no_cn();
2463 let pc = crate::conn_context::PeerCertificate::from_der(&cert).expect("parse");
2464 assert!(pc.subject_cn.is_none());
2465 }
2466
2467 #[test]
2468 fn matrix_peer_cert_subject_cn_equals_happy_and_miss() {
2469 let cert = rcgen_cert_with_cn("ops-bot");
2470 let conn = conn_with_peer_cert(&cert);
2471 let v = PredicateView::L4 { conn: &conn, peek: None };
2472 assert!(
2473 pred(FieldPath::TlsPeerCertSubjectCn, CompiledOperator::Equals(str_val("ops-bot"))).test(&v)
2474 );
2475 assert!(
2476 !pred(FieldPath::TlsPeerCertSubjectCn, CompiledOperator::Equals(str_val("attacker")))
2477 .test(&v)
2478 );
2479 }
2480
2481 #[test]
2482 fn matrix_peer_cert_subject_cn_string_ops_happy_and_miss() {
2483 let cert = rcgen_cert_with_cn("svc-payments-prod");
2484 let conn = conn_with_peer_cert(&cert);
2485 let v = PredicateView::L4 { conn: &conn, peek: None };
2486 assert!(pred(FieldPath::TlsPeerCertSubjectCn, CompiledOperator::Prefix(b(b"svc-"))).test(&v));
2488 assert!(
2489 !pred(FieldPath::TlsPeerCertSubjectCn, CompiledOperator::Prefix(b(b"client-"))).test(&v)
2490 );
2491 assert!(pred(FieldPath::TlsPeerCertSubjectCn, CompiledOperator::Suffix(b(b"-prod"))).test(&v));
2493 assert!(
2495 pred(FieldPath::TlsPeerCertSubjectCn, CompiledOperator::Contains(b(b"payments"))).test(&v)
2496 );
2497 let re = Regex::new(r"^svc-[a-z]+-(prod|stg)$").expect("regex");
2499 assert!(pred(FieldPath::TlsPeerCertSubjectCn, CompiledOperator::Matches(re)).test(&v));
2500 let list = vec![str_val("svc-other-prod"), str_val("svc-payments-prod")];
2502 assert!(pred(FieldPath::TlsPeerCertSubjectCn, CompiledOperator::In(list)).test(&v));
2503 }
2504
2505 #[test]
2506 fn peer_cert_subject_cn_misses_when_cert_absent() {
2507 let conn = make_conn();
2510 let v = PredicateView::L4 { conn: &conn, peek: None };
2511 assert!(
2512 !pred(FieldPath::TlsPeerCertSubjectCn, CompiledOperator::Equals(str_val("anything")))
2513 .test(&v)
2514 );
2515 }
2516
2517 #[test]
2518 fn peer_cert_subject_cn_misses_when_cert_has_no_cn() {
2519 let cert = rcgen_cert_no_cn();
2522 let conn = conn_with_peer_cert(&cert);
2523 let v = PredicateView::L4 { conn: &conn, peek: None };
2524 assert!(
2525 !pred(FieldPath::TlsPeerCertSubjectCn, CompiledOperator::Equals(str_val("ops-bot"))).test(&v)
2526 );
2527 }
2528
2529 fn rcgen_cert_with_san_dns(cn: &str, dns: &[&str]) -> rustls_pki_types::CertificateDer<'static> {
2532 let san: Vec<String> = dns.iter().map(|s| (*s).to_owned()).collect();
2533 let mut params = rcgen::CertificateParams::new(san).expect("rcgen params");
2534 params.distinguished_name = rcgen::DistinguishedName::new();
2535 params.distinguished_name.push(rcgen::DnType::CommonName, cn);
2536 let key = rcgen::KeyPair::generate().expect("rcgen keypair");
2537 let cert = params.self_signed(&key).expect("self-sign cert");
2538 cert.der().clone()
2539 }
2540
2541 #[test]
2542 fn each_new_field_path_parses_from_string_form() {
2543 use super::parse_field_path;
2544 assert_eq!(parse_field_path("tls.peer_cert.present"), Ok(FieldPath::TlsPeerCertPresent));
2545 assert_eq!(parse_field_path("tls.peer_cert.san_dns"), Ok(FieldPath::TlsPeerCertSanDns));
2546 assert_eq!(
2547 parse_field_path("tls.peer_cert.fingerprint_sha256"),
2548 Ok(FieldPath::TlsPeerCertFingerprintSha256),
2549 );
2550 assert_eq!(parse_field_path("tls.peer_cert.spki_sha256"), Ok(FieldPath::TlsPeerCertSpkiSha256),);
2551 assert_eq!(parse_field_path("tls.peer_cert.issuer_cn"), Ok(FieldPath::TlsPeerCertIssuerCn));
2552 assert_eq!(parse_field_path("tls.peer_cert.serial"), Ok(FieldPath::TlsPeerCertSerial));
2553 }
2554
2555 #[test]
2556 fn peer_cert_present_true_when_cert_attached() {
2557 let cert = rcgen_cert_with_cn("client.internal");
2558 let conn = conn_with_peer_cert(&cert);
2559 let v = PredicateView::L4 { conn: &conn, peek: None };
2560 assert!(
2561 pred(FieldPath::TlsPeerCertPresent, CompiledOperator::Equals(CompiledValue::Bool(true)))
2562 .test(&v)
2563 );
2564 assert!(
2565 !pred(FieldPath::TlsPeerCertPresent, CompiledOperator::Equals(CompiledValue::Bool(false)))
2566 .test(&v)
2567 );
2568 }
2569
2570 #[test]
2571 fn peer_cert_present_false_when_cert_absent() {
2572 let conn = make_conn();
2575 let v = PredicateView::L4 { conn: &conn, peek: None };
2576 assert!(
2577 pred(FieldPath::TlsPeerCertPresent, CompiledOperator::Equals(CompiledValue::Bool(false)))
2578 .test(&v)
2579 );
2580 assert!(
2581 !pred(FieldPath::TlsPeerCertPresent, CompiledOperator::Equals(CompiledValue::Bool(true)))
2582 .test(&v)
2583 );
2584 }
2585
2586 #[test]
2587 fn peer_cert_san_dns_contains_matches_listed_element() {
2588 let cert = rcgen_cert_with_san_dns("svc-a", &["svc-a.internal", "svc-b.internal"]);
2589 let conn = conn_with_peer_cert(&cert);
2590 let v = PredicateView::L4 { conn: &conn, peek: None };
2591 assert!(
2592 pred(FieldPath::TlsPeerCertSanDns, CompiledOperator::Contains(b(b"svc-a.internal"))).test(&v)
2593 );
2594 assert!(
2595 !pred(FieldPath::TlsPeerCertSanDns, CompiledOperator::Contains(b(b"svc-c.internal")))
2596 .test(&v),
2597 );
2598 assert!(
2599 pred(FieldPath::TlsPeerCertSanDns, CompiledOperator::NotContains(b(b"svc-c.internal")))
2600 .test(&v),
2601 );
2602 }
2603
2604 #[test]
2605 fn peer_cert_san_dns_misses_when_cert_absent() {
2606 let conn = make_conn();
2607 let v = PredicateView::L4 { conn: &conn, peek: None };
2608 assert!(
2609 !pred(FieldPath::TlsPeerCertSanDns, CompiledOperator::Contains(b(b"anything"))).test(&v)
2610 );
2611 }
2612
2613 #[test]
2614 fn peer_cert_fingerprint_sha256_is_lowercase_hex_of_full_der() {
2615 use sha2::{Digest, Sha256};
2616 let cert = rcgen_cert_with_cn("fingerprinted");
2617 let mut h = Sha256::new();
2618 h.update(cert.as_ref());
2619 let want = h.finalize().iter().fold(String::new(), |mut s, b| {
2620 use std::fmt::Write as _;
2621 let _ = write!(s, "{b:02x}");
2622 s
2623 });
2624
2625 let conn = conn_with_peer_cert(&cert);
2626 let v = PredicateView::L4 { conn: &conn, peek: None };
2627 assert!(
2628 pred(FieldPath::TlsPeerCertFingerprintSha256, CompiledOperator::Equals(str_val(&want)),)
2629 .test(&v),
2630 );
2631 }
2632
2633 #[test]
2634 fn peer_cert_issuer_and_serial_present_for_self_signed_cert() {
2635 let cert = rcgen_cert_with_cn("issuer-test");
2638 let conn = conn_with_peer_cert(&cert);
2639 let v = PredicateView::L4 { conn: &conn, peek: None };
2640 assert!(
2642 pred(FieldPath::TlsPeerCertIssuerCn, CompiledOperator::Equals(str_val("issuer-test")))
2643 .test(&v)
2644 );
2645 let pc = conn.tls.lock().as_ref().unwrap().peer_cert.as_ref().unwrap().clone();
2649 assert!(!pc.serial.is_empty(), "serial extracted");
2650 assert!(pc.serial.chars().all(|c| c.is_ascii_hexdigit() && !c.is_ascii_uppercase()));
2651 }
2652
2653 #[test]
2654 fn peer_cert_present_value_type_is_bool() {
2655 assert_eq!(FieldPath::TlsPeerCertPresent.value_type(), FieldValueType::Bool);
2656 }
2657
2658 #[test]
2659 fn peer_cert_san_dns_value_type_is_vec_str() {
2660 assert_eq!(FieldPath::TlsPeerCertSanDns.value_type(), FieldValueType::VecStr);
2661 }
2662
2663 #[test]
2664 fn matrix_rejects_string_pref_suf_on_bool_field() {
2665 assert!(!OperatorFamily::StringPrefSuf.accepts(FieldValueType::Bool));
2668 assert!(!OperatorFamily::StringSubstr.accepts(FieldValueType::Bool));
2669 assert!(!OperatorFamily::RegexMatches.accepts(FieldValueType::Bool));
2670 assert!(OperatorFamily::Equality.accepts(FieldValueType::Bool));
2672 }
2673
2674 #[test]
2675 fn matrix_rejects_equals_on_vec_str_field() {
2676 assert!(!OperatorFamily::Equality.accepts(FieldValueType::VecStr));
2679 assert!(!OperatorFamily::InList.accepts(FieldValueType::VecStr));
2680 assert!(!OperatorFamily::StringPrefSuf.accepts(FieldValueType::VecStr));
2681 assert!(!OperatorFamily::RegexMatches.accepts(FieldValueType::VecStr));
2682 assert!(OperatorFamily::StringSubstr.accepts(FieldValueType::VecStr));
2683 }
2684
2685 fn req_with_body(body_bytes: &[u8]) -> Request {
2693 http::Request::builder()
2694 .method("POST")
2695 .uri("/upload")
2696 .body(Body::Static(Bytes::copy_from_slice(body_bytes)))
2697 .expect("build req with body")
2698 }
2699
2700 #[test]
2701 fn matrix_http_body_equality_happy_and_miss() {
2702 let conn = make_conn();
2703 let req = req_with_body(b"hello world");
2704 let v = PredicateView::L7Req { conn: &conn, req: &req };
2705 assert!(
2706 pred(FieldPath::HttpBody, CompiledOperator::Equals(bytes_val(b"hello world"))).test(&v)
2707 );
2708 assert!(!pred(FieldPath::HttpBody, CompiledOperator::Equals(bytes_val(b"wrong"))).test(&v));
2709 assert!(pred(FieldPath::HttpBody, CompiledOperator::NotEquals(bytes_val(b"wrong"))).test(&v));
2710 }
2711
2712 #[test]
2713 fn matrix_http_body_substring_happy_and_miss() {
2714 let conn = make_conn();
2715 let req = req_with_body(b"prelude payload trailer");
2716 let v = PredicateView::L7Req { conn: &conn, req: &req };
2717 assert!(pred(FieldPath::HttpBody, CompiledOperator::Contains(b(b"payload"))).test(&v));
2718 assert!(!pred(FieldPath::HttpBody, CompiledOperator::Contains(b(b"missing"))).test(&v));
2719 assert!(pred(FieldPath::HttpBody, CompiledOperator::NotContains(b(b"missing"))).test(&v));
2720 }
2721
2722 #[test]
2723 fn matrix_http_body_prefix_suffix_happy_and_miss() {
2724 let conn = make_conn();
2725 let req = req_with_body(b"START middle END");
2726 let v = PredicateView::L7Req { conn: &conn, req: &req };
2727 assert!(pred(FieldPath::HttpBody, CompiledOperator::Prefix(b(b"START"))).test(&v));
2728 assert!(!pred(FieldPath::HttpBody, CompiledOperator::Prefix(b(b"BEGIN"))).test(&v));
2729 assert!(pred(FieldPath::HttpBody, CompiledOperator::Suffix(b(b"END"))).test(&v));
2730 assert!(!pred(FieldPath::HttpBody, CompiledOperator::Suffix(b(b"FIN"))).test(&v));
2731 }
2732
2733 #[test]
2734 fn matrix_http_body_in_list_happy_and_miss() {
2735 let conn = make_conn();
2736 let req = req_with_body(b"one");
2737 let v = PredicateView::L7Req { conn: &conn, req: &req };
2738 let list = vec![bytes_val(b"two"), bytes_val(b"one")];
2739 assert!(pred(FieldPath::HttpBody, CompiledOperator::In(list)).test(&v));
2740 let miss = vec![bytes_val(b"two"), bytes_val(b"three")];
2741 assert!(!pred(FieldPath::HttpBody, CompiledOperator::In(miss.clone())).test(&v));
2742 assert!(pred(FieldPath::HttpBody, CompiledOperator::NotIn(miss)).test(&v));
2743 }
2744
2745 #[test]
2746 fn http_body_misses_on_l4_view() {
2747 let conn = make_conn();
2750 let v = PredicateView::L4 { conn: &conn, peek: None };
2751 assert!(!pred(FieldPath::HttpBody, CompiledOperator::Contains(b(b"x"))).test(&v));
2752 }
2753
2754 #[test]
2755 #[should_panic(expected = "lazy-buffer invariant")]
2756 fn http_body_panics_when_lazy_buffer_invariant_violated() {
2757 let conn = make_conn();
2765 let req = http::Request::builder().method("POST").uri("/").body(Body::Empty).unwrap();
2766 let v = PredicateView::L7Req { conn: &conn, req: &req };
2767 let _ = pred(FieldPath::HttpBody, CompiledOperator::Contains(b(b"x"))).test(&v);
2768 }
2769
2770 #[test]
2779 fn matrix_peek_substring_happy_and_miss() {
2780 let buf: &[u8] = &[0x16, 0x03, 0x01, 0x00, 0x40, 0x01];
2782 let conn = make_conn();
2783 let v = PredicateView::L4 { conn: &conn, peek: Some(buf) };
2784 assert!(pred(FieldPath::Peek, CompiledOperator::Prefix(b(b"\x16\x03"))).test(&v));
2785 assert!(!pred(FieldPath::Peek, CompiledOperator::Prefix(b(b"\x14\x03"))).test(&v));
2786 assert!(pred(FieldPath::Peek, CompiledOperator::Contains(b(b"\x03\x01"))).test(&v));
2787 assert!(!pred(FieldPath::Peek, CompiledOperator::Contains(b(b"\xff\xff"))).test(&v));
2788 }
2789
2790 #[test]
2791 fn matrix_peek_equality_happy_and_miss() {
2792 let buf: &[u8] = b"GET";
2793 let conn = make_conn();
2794 let v = PredicateView::L4 { conn: &conn, peek: Some(buf) };
2795 assert!(pred(FieldPath::Peek, CompiledOperator::Equals(bytes_val(b"GET"))).test(&v));
2796 assert!(!pred(FieldPath::Peek, CompiledOperator::Equals(bytes_val(b"PUT"))).test(&v));
2797 assert!(pred(FieldPath::Peek, CompiledOperator::NotEquals(bytes_val(b"PUT"))).test(&v));
2798 }
2799
2800 #[test]
2801 fn matrix_peek_in_list_happy_and_miss() {
2802 let buf: &[u8] = b"PRI ";
2803 let conn = make_conn();
2804 let v = PredicateView::L4 { conn: &conn, peek: Some(buf) };
2805 let list = vec![bytes_val(b"GET "), bytes_val(b"PRI ")];
2807 assert!(pred(FieldPath::Peek, CompiledOperator::In(list)).test(&v));
2808 let miss = vec![bytes_val(b"POST"), bytes_val(b"HEAD")];
2809 assert!(!pred(FieldPath::Peek, CompiledOperator::In(miss.clone())).test(&v));
2810 assert!(pred(FieldPath::Peek, CompiledOperator::NotIn(miss)).test(&v));
2811 }
2812
2813 #[test]
2814 fn peek_misses_when_buffer_absent_on_l4_view() {
2815 let conn = make_conn();
2818 let v = PredicateView::L4 { conn: &conn, peek: None };
2819 assert!(!pred(FieldPath::Peek, CompiledOperator::Prefix(b(b"\x16"))).test(&v));
2820 let req = http::Request::builder().method("GET").uri("/").body(Body::Empty).unwrap();
2822 let v7 = PredicateView::L7Req { conn: &conn, req: &req };
2823 assert!(!pred(FieldPath::Peek, CompiledOperator::Prefix(b(b"\x16"))).test(&v7));
2824 }
2825}