1use std::hash::{Hash, Hasher};
2use std::net::IpAddr;
3use std::sync::Arc;
4
5use bytes::Bytes;
6use ipnet::IpNet;
7
8use crate::body::Request;
9use crate::conn_context::ConnContext;
10
11#[derive(Clone, Eq, PartialEq, Hash, Debug, serde::Serialize, serde::Deserialize)]
12#[serde(rename_all = "snake_case")]
13pub enum FieldPath {
14 Transport,
15 RemoteIp,
16 RemotePort,
17 LocalIp,
18 LocalPort,
19 Peek,
20 TlsSni,
21 TlsAlpn,
22 TlsVersion,
23 TlsPeerCertPresent,
26 TlsPeerCertSubjectCn,
27 TlsPeerCertSanDns,
32 TlsPeerCertFingerprintSha256,
34 TlsPeerCertSpkiSha256,
37 TlsPeerCertIssuerCn,
40 TlsPeerCertSerial,
43 HttpMethod,
44 HttpUriPath,
45 HttpUriQuery,
46 HttpHeader(Arc<str>),
47 HttpBody,
48}
49
50#[derive(Copy, Clone, Eq, PartialEq, Debug)]
55pub enum FieldValueType {
56 Str,
57 Bytes,
58 Int,
59 IpAddr,
60 Enum,
61 Bool,
62 VecStr,
63}
64
65impl FieldValueType {
66 #[must_use]
67 pub fn name(self) -> &'static str {
68 match self {
69 Self::Str => "Str",
70 Self::Bytes => "Bytes",
71 Self::Int => "Int",
72 Self::IpAddr => "IpAddr",
73 Self::Enum => "enum",
74 Self::Bool => "Bool",
75 Self::VecStr => "Vec<Str>",
76 }
77 }
78}
79
80impl FieldPath {
81 #[must_use]
85 pub fn value_type(&self) -> FieldValueType {
86 match self {
87 Self::Transport | Self::TlsVersion | Self::HttpMethod => FieldValueType::Enum,
88 Self::RemoteIp | Self::LocalIp => FieldValueType::IpAddr,
89 Self::RemotePort | Self::LocalPort => FieldValueType::Int,
90 Self::Peek | Self::TlsAlpn | Self::HttpBody => FieldValueType::Bytes,
91 Self::TlsPeerCertPresent => FieldValueType::Bool,
92 Self::TlsPeerCertSanDns => FieldValueType::VecStr,
93 Self::TlsSni
94 | Self::TlsPeerCertSubjectCn
95 | Self::TlsPeerCertFingerprintSha256
96 | Self::TlsPeerCertSpkiSha256
97 | Self::TlsPeerCertIssuerCn
98 | Self::TlsPeerCertSerial
99 | Self::HttpUriPath
100 | Self::HttpUriQuery
101 | Self::HttpHeader(_) => FieldValueType::Str,
102 }
103 }
104
105 #[must_use]
107 pub fn display_name(&self) -> String {
108 match self {
109 Self::Transport => "transport".to_string(),
110 Self::RemoteIp => "remote.ip".to_string(),
111 Self::RemotePort => "remote.port".to_string(),
112 Self::LocalIp => "local.ip".to_string(),
113 Self::LocalPort => "local.port".to_string(),
114 Self::Peek => "peek".to_string(),
115 Self::TlsSni => "tls.sni".to_string(),
116 Self::TlsAlpn => "tls.alpn".to_string(),
117 Self::TlsVersion => "tls.version".to_string(),
118 Self::TlsPeerCertPresent => "tls.peer_cert.present".to_string(),
119 Self::TlsPeerCertSubjectCn => "tls.peer_cert.subject_cn".to_string(),
120 Self::TlsPeerCertSanDns => "tls.peer_cert.san_dns".to_string(),
121 Self::TlsPeerCertFingerprintSha256 => "tls.peer_cert.fingerprint_sha256".to_string(),
122 Self::TlsPeerCertSpkiSha256 => "tls.peer_cert.spki_sha256".to_string(),
123 Self::TlsPeerCertIssuerCn => "tls.peer_cert.issuer_cn".to_string(),
124 Self::TlsPeerCertSerial => "tls.peer_cert.serial".to_string(),
125 Self::HttpMethod => "http.method".to_string(),
126 Self::HttpUriPath => "http.uri.path".to_string(),
127 Self::HttpUriQuery => "http.uri.query".to_string(),
128 Self::HttpHeader(name) => format!("http.header.{name}"),
129 Self::HttpBody => "http.body".to_string(),
130 }
131 }
132}
133
134#[derive(Copy, Clone, Eq, PartialEq, Debug)]
139pub enum OperatorFamily {
140 Equality,
141 StringSubstr,
142 StringPrefSuf,
143 RegexMatches,
144 InList,
145 NumericCmp,
146 CidrMatch,
147}
148
149impl Operator {
150 #[must_use]
151 pub fn family(&self) -> OperatorFamily {
152 match self {
153 Self::Equals(_) | Self::NotEquals(_) => OperatorFamily::Equality,
154 Self::Contains(_) | Self::NotContains(_) => OperatorFamily::StringSubstr,
155 Self::Prefix(_) | Self::Suffix(_) => OperatorFamily::StringPrefSuf,
156 Self::Matches(_) => OperatorFamily::RegexMatches,
157 Self::In(_) | Self::NotIn(_) => OperatorFamily::InList,
158 Self::Gt(_) | Self::Gte(_) | Self::Lt(_) | Self::Lte(_) => OperatorFamily::NumericCmp,
159 Self::Cidr(_) => OperatorFamily::CidrMatch,
160 }
161 }
162
163 #[must_use]
164 pub fn name(&self) -> &'static str {
165 match self {
166 Self::Equals(_) => "equals",
167 Self::NotEquals(_) => "not_equals",
168 Self::Contains(_) => "contains",
169 Self::NotContains(_) => "not_contains",
170 Self::Prefix(_) => "prefix",
171 Self::Suffix(_) => "suffix",
172 Self::Matches(_) => "matches",
173 Self::In(_) => "in",
174 Self::NotIn(_) => "not_in",
175 Self::Gt(_) => "gt",
176 Self::Gte(_) => "gte",
177 Self::Lt(_) => "lt",
178 Self::Lte(_) => "lte",
179 Self::Cidr(_) => "cidr",
180 }
181 }
182}
183
184impl OperatorFamily {
185 #[must_use]
190 pub fn accepts(self, vt: FieldValueType) -> bool {
191 use FieldValueType as V;
192 use OperatorFamily as F;
193 matches!(
194 (self, vt),
195 (F::Equality, V::Str | V::Bytes | V::Int | V::IpAddr | V::Enum | V::Bool)
200 | (F::InList, V::Str | V::Bytes | V::Int | V::IpAddr | V::Enum)
203 | (F::StringSubstr, V::Str | V::Bytes | V::VecStr)
204 | (F::StringPrefSuf, V::Str | V::Bytes)
205 | (F::RegexMatches, V::Str)
206 | (F::NumericCmp, V::Int)
207 | (F::CidrMatch, V::IpAddr),
208 )
209 }
210
211 #[must_use]
213 pub fn family_expectation(self) -> &'static str {
214 match self {
215 Self::Equality => "any of Str/Bytes/Int/IpAddr/enum/Bool",
216 Self::InList => "any of Str/Bytes/Int/IpAddr/enum",
217 Self::StringSubstr => "Str, Bytes, or Vec<Str>",
218 Self::StringPrefSuf => "Str or Bytes",
219 Self::RegexMatches => "Str",
220 Self::NumericCmp => "numeric",
221 Self::CidrMatch => "IpAddr",
222 }
223 }
224}
225
226#[derive(Clone, Debug)]
227pub enum CompiledValue {
228 Str(Arc<str>),
229 Bytes(Bytes),
230 Int(i64),
231 Bool(bool),
232 Addr(IpAddr),
233}
234
235impl PartialEq for CompiledValue {
236 fn eq(&self, other: &Self) -> bool {
237 match (self, other) {
238 (Self::Str(a), Self::Str(b)) => a.as_ref() == b.as_ref(),
239 (Self::Bytes(a), Self::Bytes(b)) => a == b,
240 (Self::Int(a), Self::Int(b)) => a == b,
241 (Self::Bool(a), Self::Bool(b)) => a == b,
242 (Self::Addr(a), Self::Addr(b)) => a == b,
243 _ => false,
244 }
245 }
246}
247
248impl Eq for CompiledValue {}
249
250impl Hash for CompiledValue {
251 fn hash<H: Hasher>(&self, state: &mut H) {
252 std::mem::discriminant(self).hash(state);
253 match self {
254 Self::Str(s) => s.as_ref().hash(state),
255 Self::Bytes(b) => b.hash(state),
256 Self::Int(i) => i.hash(state),
257 Self::Bool(b) => b.hash(state),
258 Self::Addr(a) => a.hash(state),
259 }
260 }
261}
262
263#[derive(Clone, Debug)]
264pub enum CompiledOperator {
265 Equals(CompiledValue),
266 NotEquals(CompiledValue),
267 Contains(Bytes),
268 NotContains(Bytes),
269 Prefix(Bytes),
270 Suffix(Bytes),
271 Matches(fancy_regex::Regex),
272 In(Vec<CompiledValue>),
273 NotIn(Vec<CompiledValue>),
274 Gt(i64),
275 Gte(i64),
276 Lt(i64),
277 Lte(i64),
278 Cidr(IpNet),
279}
280
281impl PartialEq for CompiledOperator {
282 fn eq(&self, other: &Self) -> bool {
283 match (self, other) {
284 (Self::Equals(a), Self::Equals(b)) | (Self::NotEquals(a), Self::NotEquals(b)) => a == b,
285 (Self::Contains(a), Self::Contains(b))
286 | (Self::NotContains(a), Self::NotContains(b))
287 | (Self::Prefix(a), Self::Prefix(b))
288 | (Self::Suffix(a), Self::Suffix(b)) => a == b,
289 (Self::Matches(a), Self::Matches(b)) => a.as_str() == b.as_str(),
290 (Self::In(a), Self::In(b)) | (Self::NotIn(a), Self::NotIn(b)) => a == b,
291 (Self::Gt(a), Self::Gt(b))
292 | (Self::Gte(a), Self::Gte(b))
293 | (Self::Lt(a), Self::Lt(b))
294 | (Self::Lte(a), Self::Lte(b)) => a == b,
295 (Self::Cidr(a), Self::Cidr(b)) => a == b,
296 _ => false,
297 }
298 }
299}
300
301impl Eq for CompiledOperator {}
302
303impl Hash for CompiledOperator {
304 fn hash<H: Hasher>(&self, state: &mut H) {
305 std::mem::discriminant(self).hash(state);
306 match self {
307 Self::Equals(v) | Self::NotEquals(v) => v.hash(state),
308 Self::Contains(b) | Self::NotContains(b) | Self::Prefix(b) | Self::Suffix(b) => {
309 b.hash(state);
310 }
311 Self::Matches(r) => r.as_str().hash(state),
312 Self::In(v) | Self::NotIn(v) => v.hash(state),
313 Self::Gt(i) | Self::Gte(i) | Self::Lt(i) | Self::Lte(i) => i.hash(state),
314 Self::Cidr(n) => n.hash(state),
315 }
316 }
317}
318
319#[derive(Clone, Debug, Eq, PartialEq, Hash, serde::Serialize, serde::Deserialize)]
320pub struct PredicateInst {
321 pub path: FieldPath,
322 pub op: CompiledOperator,
323}
324
325pub enum PredicateView<'a> {
326 L4 { conn: &'a Arc<ConnContext>, peek: Option<&'a [u8]> },
327 L7Req { conn: &'a Arc<ConnContext>, req: &'a Request },
328}
329
330impl<'a> PredicateView<'a> {
331 #[must_use]
341 pub fn build(
342 conn: &'a Arc<ConnContext>,
343 req: Option<&'a Request>,
344 _l4: Option<&'a crate::l4::L4Conn>,
345 peek: Option<&'a [u8]>,
346 ) -> Self {
347 match req {
348 Some(r) => Self::L7Req { conn, req: r },
349 None => Self::L4 { conn, peek },
350 }
351 }
352
353 fn conn(&self) -> &Arc<ConnContext> {
354 match self {
355 Self::L4 { conn, .. } | Self::L7Req { conn, .. } => conn,
356 }
357 }
358
359 fn request(&self) -> Option<&Request> {
360 match self {
361 Self::L7Req { req, .. } => Some(req),
362 Self::L4 { .. } => None,
363 }
364 }
365
366 fn peek_buffer(&self) -> Option<&[u8]> {
367 match self {
368 Self::L4 { peek, .. } => *peek,
369 Self::L7Req { .. } => None,
370 }
371 }
372}
373
374impl PredicateInst {
375 #[must_use]
398 #[allow(
399 clippy::too_many_lines,
400 reason = "truth-table dispatch over the full FieldPath enum: each arm reads one connection / request field per spec/crates/core.md § Predicate. Splitting by sub-arm scatters the field-by-field reading across helpers and adds nothing beyond satisfying the lint"
401 )]
402 pub fn test(&self, view: &PredicateView<'_>) -> bool {
403 match &self.path {
404 FieldPath::Transport => {
405 let s = match view.conn().transport {
406 crate::conn_context::Transport::Tcp => "tcp",
407 crate::conn_context::Transport::Udp => "udp",
408 };
409 test_str(&self.op, s)
410 }
411 FieldPath::RemoteIp => test_addr(&self.op, view.conn().remote.ip()),
412 FieldPath::RemotePort => test_int(&self.op, i64::from(view.conn().remote.port())),
413 FieldPath::LocalIp => test_addr(&self.op, view.conn().local.ip()),
414 FieldPath::LocalPort => test_int(&self.op, i64::from(view.conn().local.port())),
415 FieldPath::Peek => view.peek_buffer().is_some_and(|b| test_bytes(&self.op, b)),
416 FieldPath::TlsSni => view
417 .conn()
418 .tls
419 .lock()
420 .as_ref()
421 .and_then(|t| t.sni.clone())
422 .is_some_and(|got| test_str(&self.op, got.as_str())),
423 FieldPath::TlsAlpn => view
424 .conn()
425 .tls
426 .lock()
427 .as_ref()
428 .and_then(|t| t.alpn.clone())
429 .is_some_and(|got| test_bytes(&self.op, got.as_slice())),
430 FieldPath::TlsVersion => view
431 .conn()
432 .tls
433 .lock()
434 .as_ref()
435 .and_then(|t| t.version)
436 .is_some_and(|v| test_str(&self.op, tls_version_str(v))),
437 FieldPath::TlsPeerCertPresent => {
447 let present = view.conn().tls.lock().as_ref().is_some_and(|t| t.peer_cert.is_some());
448 test_bool(&self.op, present)
449 }
450 FieldPath::TlsPeerCertSubjectCn => view
451 .conn()
452 .tls
453 .lock()
454 .as_ref()
455 .and_then(|t| t.peer_cert.as_ref().and_then(|p| p.subject_cn.clone()))
456 .is_some_and(|cn| test_str(&self.op, cn.as_str())),
457 FieldPath::TlsPeerCertSanDns => {
458 let dns_list: Vec<String> = view
459 .conn()
460 .tls
461 .lock()
462 .as_ref()
463 .and_then(|t| t.peer_cert.as_ref().map(|p| p.san_dns.clone()))
464 .unwrap_or_default();
465 test_vec_str(&self.op, &dns_list)
466 }
467 FieldPath::TlsPeerCertFingerprintSha256 => view
468 .conn()
469 .tls
470 .lock()
471 .as_ref()
472 .and_then(|t| t.peer_cert.as_ref().map(|p| p.fingerprint_sha256.clone()))
473 .is_some_and(|s| test_str(&self.op, s.as_str())),
474 FieldPath::TlsPeerCertSpkiSha256 => view
475 .conn()
476 .tls
477 .lock()
478 .as_ref()
479 .and_then(|t| t.peer_cert.as_ref().map(|p| p.spki_sha256.clone()))
480 .is_some_and(|s| test_str(&self.op, s.as_str())),
481 FieldPath::TlsPeerCertIssuerCn => view
482 .conn()
483 .tls
484 .lock()
485 .as_ref()
486 .and_then(|t| t.peer_cert.as_ref().and_then(|p| p.issuer_cn.clone()))
487 .is_some_and(|s| test_str(&self.op, s.as_str())),
488 FieldPath::TlsPeerCertSerial => view
489 .conn()
490 .tls
491 .lock()
492 .as_ref()
493 .and_then(|t| t.peer_cert.as_ref().map(|p| p.serial.clone()))
494 .is_some_and(|s| test_str(&self.op, s.as_str())),
495 FieldPath::HttpMethod => {
496 let Some(req) = view.request() else { return false };
497 test_str(&self.op, req.method().as_str())
498 }
499 FieldPath::HttpUriPath => {
500 let Some(req) = view.request() else { return false };
501 test_str(&self.op, req.uri().path())
502 }
503 FieldPath::HttpUriQuery => {
504 let Some(req) = view.request() else { return false };
505 test_str(&self.op, req.uri().query().unwrap_or(""))
506 }
507 FieldPath::HttpHeader(name) => {
515 let Some(req) = view.request() else { return false };
516 let Some(value) = req.headers().get(name.as_ref()) else { return false };
517 let Ok(s) = value.to_str() else {
518 return false;
523 };
524 test_str(&self.op, s)
525 }
526 FieldPath::HttpBody => {
535 let Some(req) = view.request() else { return false };
536 let bytes = req.body().as_static().expect("lazy-buffer invariant");
537 test_bytes(&self.op, bytes.as_ref())
538 }
539 }
540 }
541}
542
543fn tls_version_str(v: crate::conn_context::TlsVersion) -> &'static str {
544 match v {
545 crate::conn_context::TlsVersion::Tls12 => "1.2",
546 crate::conn_context::TlsVersion::Tls13 => "1.3",
547 }
548}
549
550fn test_bool(op: &CompiledOperator, value: bool) -> bool {
562 match op {
563 CompiledOperator::Equals(CompiledValue::Bool(expected)) => value == *expected,
564 CompiledOperator::NotEquals(CompiledValue::Bool(expected)) => value != *expected,
565 _ => false,
566 }
567}
568
569fn test_vec_str(op: &CompiledOperator, values: &[String]) -> bool {
575 match op {
576 CompiledOperator::Contains(needle) => values.iter().any(|v| v.as_bytes() == needle.as_ref()),
577 CompiledOperator::NotContains(needle) => {
578 !values.iter().any(|v| v.as_bytes() == needle.as_ref())
579 }
580 _ => false,
581 }
582}
583
584fn test_str(op: &CompiledOperator, value: &str) -> bool {
589 match op {
590 CompiledOperator::Equals(CompiledValue::Str(expected)) => value == expected.as_ref(),
591 CompiledOperator::NotEquals(CompiledValue::Str(expected)) => value != expected.as_ref(),
592 CompiledOperator::Contains(b) => contains_bytes(value.as_bytes(), b),
593 CompiledOperator::NotContains(b) => !contains_bytes(value.as_bytes(), b),
594 CompiledOperator::Prefix(b) => value.as_bytes().starts_with(b.as_ref()),
595 CompiledOperator::Suffix(b) => value.as_bytes().ends_with(b.as_ref()),
596 CompiledOperator::Matches(re) => re.is_match(value).unwrap_or(false),
597 CompiledOperator::In(values) => {
598 values.iter().any(|v| matches!(v, CompiledValue::Str(s) if value == s.as_ref()))
599 }
600 CompiledOperator::NotIn(values) => {
601 !values.iter().any(|v| matches!(v, CompiledValue::Str(s) if value == s.as_ref()))
602 }
603 _ => false,
604 }
605}
606
607fn test_bytes(op: &CompiledOperator, value: &[u8]) -> bool {
611 match op {
612 CompiledOperator::Equals(CompiledValue::Bytes(expected)) => value == expected.as_ref(),
613 CompiledOperator::NotEquals(CompiledValue::Bytes(expected)) => value != expected.as_ref(),
614 CompiledOperator::Contains(b) => contains_bytes(value, b),
615 CompiledOperator::NotContains(b) => !contains_bytes(value, b),
616 CompiledOperator::Prefix(b) => value.starts_with(b.as_ref()),
617 CompiledOperator::Suffix(b) => value.ends_with(b.as_ref()),
618 CompiledOperator::In(values) => {
619 values.iter().any(|v| matches!(v, CompiledValue::Bytes(b) if value == b.as_ref()))
620 }
621 CompiledOperator::NotIn(values) => {
622 !values.iter().any(|v| matches!(v, CompiledValue::Bytes(b) if value == b.as_ref()))
623 }
624 _ => false,
625 }
626}
627
628fn test_int(op: &CompiledOperator, value: i64) -> bool {
631 match op {
632 CompiledOperator::Equals(CompiledValue::Int(expected)) => value == *expected,
633 CompiledOperator::NotEquals(CompiledValue::Int(expected)) => value != *expected,
634 CompiledOperator::Gt(n) => value > *n,
635 CompiledOperator::Gte(n) => value >= *n,
636 CompiledOperator::Lt(n) => value < *n,
637 CompiledOperator::Lte(n) => value <= *n,
638 CompiledOperator::In(values) => {
639 values.iter().any(|v| matches!(v, CompiledValue::Int(i) if value == *i))
640 }
641 CompiledOperator::NotIn(values) => {
642 !values.iter().any(|v| matches!(v, CompiledValue::Int(i) if value == *i))
643 }
644 _ => false,
645 }
646}
647
648fn test_addr(op: &CompiledOperator, value: std::net::IpAddr) -> bool {
652 match op {
653 CompiledOperator::Equals(CompiledValue::Addr(expected)) => value == *expected,
654 CompiledOperator::NotEquals(CompiledValue::Addr(expected)) => value != *expected,
655 CompiledOperator::Cidr(net) => net.contains(&value),
656 CompiledOperator::In(values) => {
657 values.iter().any(|v| matches!(v, CompiledValue::Addr(a) if value == *a))
658 }
659 CompiledOperator::NotIn(values) => {
660 !values.iter().any(|v| matches!(v, CompiledValue::Addr(a) if value == *a))
661 }
662 _ => false,
663 }
664}
665
666fn contains_bytes(haystack: &[u8], needle: &[u8]) -> bool {
667 if needle.is_empty() {
668 return true;
669 }
670 if needle.len() > haystack.len() {
671 return false;
672 }
673 haystack.windows(needle.len()).any(|w| w == needle)
674}
675
676pub const REGEX_PATTERN_MAX_BYTES: usize = 4 * 1024;
677
678#[derive(Debug, Clone, serde::Serialize)]
679pub enum Predicate {
680 AnyOf(AnyOfP),
681 AllOf(AllOfP),
682 Not(NotP),
683 Check(CheckMap),
684}
685
686#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
687#[serde(deny_unknown_fields)]
688pub struct AnyOfP {
689 pub any_of: Vec<Predicate>,
690}
691
692#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
693#[serde(deny_unknown_fields)]
694pub struct AllOfP {
695 pub all_of: Vec<Predicate>,
696}
697
698#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
699#[serde(deny_unknown_fields)]
700pub struct NotP {
701 pub not: Box<Predicate>,
702}
703
704#[derive(Debug, Clone, serde::Serialize)]
705pub struct CheckMap {
706 pub path: FieldPath,
707 pub op: Operator,
708}
709
710#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
711#[serde(rename_all = "snake_case")]
712pub enum Operator {
713 Equals(Value),
714 NotEquals(Value),
715 Contains(Value),
716 NotContains(Value),
717 Prefix(Value),
718 Suffix(Value),
719 Matches(String),
720 In(Vec<Value>),
721 NotIn(Vec<Value>),
722 Gt(i64),
723 Gte(i64),
724 Lt(i64),
725 Lte(i64),
726 Cidr(String),
727}
728
729#[derive(Debug, Clone, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
730#[serde(untagged)]
731pub enum Value {
732 Bool(bool),
733 Int(i64),
734 Str(String),
735}
736
737impl<'de> serde::Deserialize<'de> for Predicate {
738 fn deserialize<D: serde::Deserializer<'de>>(de: D) -> Result<Self, D::Error> {
739 let v = serde_json::Value::deserialize(de)?;
740 let serde_json::Value::Object(ref map) = v else {
741 return Err(serde::de::Error::custom("predicate must be a JSON object"));
742 };
743 if map.len() == 1 {
744 let (key, _) = map.iter().next().expect("len == 1");
745 match key.as_str() {
746 "any_of" => {
747 return serde_json::from_value::<AnyOfP>(v)
748 .map(Predicate::AnyOf)
749 .map_err(serde::de::Error::custom);
750 }
751 "all_of" => {
752 return serde_json::from_value::<AllOfP>(v)
753 .map(Predicate::AllOf)
754 .map_err(serde::de::Error::custom);
755 }
756 "not" => {
757 return serde_json::from_value::<NotP>(v)
758 .map(Predicate::Not)
759 .map_err(serde::de::Error::custom);
760 }
761 _ => {}
762 }
763 }
764 serde_json::from_value::<CheckMap>(v).map(Predicate::Check).map_err(serde::de::Error::custom)
765 }
766}
767
768impl<'de> serde::Deserialize<'de> for CheckMap {
769 fn deserialize<D: serde::Deserializer<'de>>(de: D) -> Result<Self, D::Error> {
770 struct Visitor;
771
772 impl<'de> serde::de::Visitor<'de> for Visitor {
773 type Value = CheckMap;
774
775 fn expecting(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
776 f.write_str("a single-key object of the form {\"<field-path>\": {\"<operator>\": <value>}}")
777 }
778
779 fn visit_map<M: serde::de::MapAccess<'de>>(self, mut map: M) -> Result<CheckMap, M::Error> {
780 let Some(key) = map.next_key::<String>()? else {
781 return Err(serde::de::Error::invalid_length(0, &"exactly one key"));
782 };
783 let path = parse_field_path(&key).map_err(serde::de::Error::custom)?;
784 let op: Operator = map.next_value()?;
785 if map.next_key::<serde::de::IgnoredAny>()?.is_some() {
786 return Err(serde::de::Error::custom("check object must have exactly one key"));
787 }
788 validate_operator(&op).map_err(serde::de::Error::custom)?;
789 Ok(CheckMap { path, op })
790 }
791 }
792
793 de.deserialize_map(Visitor)
794 }
795}
796
797fn parse_field_path(s: &str) -> Result<FieldPath, String> {
798 if s.chars().any(|c| c.is_ascii_uppercase()) {
799 return Err(format!(
800 "field path must be lowercase: {:?} — did you mean {:?}?",
801 s,
802 s.to_ascii_lowercase(),
803 ));
804 }
805 match s {
806 "transport" => Ok(FieldPath::Transport),
807 "remote.ip" => Ok(FieldPath::RemoteIp),
808 "remote.port" => Ok(FieldPath::RemotePort),
809 "local.ip" => Ok(FieldPath::LocalIp),
810 "local.port" => Ok(FieldPath::LocalPort),
811 "peek" => Ok(FieldPath::Peek),
812 "tls.sni" => Ok(FieldPath::TlsSni),
813 "tls.alpn" => Ok(FieldPath::TlsAlpn),
814 "tls.version" => Ok(FieldPath::TlsVersion),
815 "tls.peer_cert.present" => Ok(FieldPath::TlsPeerCertPresent),
816 "tls.peer_cert.subject_cn" => Ok(FieldPath::TlsPeerCertSubjectCn),
817 "tls.peer_cert.san_dns" => Ok(FieldPath::TlsPeerCertSanDns),
818 "tls.peer_cert.fingerprint_sha256" => Ok(FieldPath::TlsPeerCertFingerprintSha256),
819 "tls.peer_cert.spki_sha256" => Ok(FieldPath::TlsPeerCertSpkiSha256),
820 "tls.peer_cert.issuer_cn" => Ok(FieldPath::TlsPeerCertIssuerCn),
821 "tls.peer_cert.serial" => Ok(FieldPath::TlsPeerCertSerial),
822 "http.method" => Ok(FieldPath::HttpMethod),
823 "http.uri.path" => Ok(FieldPath::HttpUriPath),
824 "http.uri.query" => Ok(FieldPath::HttpUriQuery),
825 "http.body" => Ok(FieldPath::HttpBody),
826 other if other.starts_with("http.header.") => {
827 let name = &other["http.header.".len()..];
828 if name.is_empty() {
829 return Err(format!("http.header.* requires a header name: {other:?}"));
830 }
831 Ok(FieldPath::HttpHeader(Arc::from(name)))
832 }
833 other => Err(format!("unknown field path: {other:?}")),
834 }
835}
836
837fn validate_operator(op: &Operator) -> Result<(), String> {
838 if let Operator::Matches(pattern) = op
839 && pattern.len() > REGEX_PATTERN_MAX_BYTES
840 {
841 return Err(format!(
842 "regex pattern source exceeds {REGEX_PATTERN_MAX_BYTES}-byte limit: got {} bytes",
843 pattern.len(),
844 ));
845 }
846 Ok(())
847}
848
849mod serde_impls {
850 use base64::Engine as _;
851 use base64::engine::general_purpose::STANDARD as B64;
852 use bytes::Bytes;
853 use std::net::IpAddr;
854 use std::sync::Arc;
855
856 use super::{CompiledOperator, CompiledValue};
857
858 pub(super) fn ser_bytes<S: serde::Serializer>(b: &Bytes, s: S) -> Result<S::Ok, S::Error> {
859 s.serialize_str(&B64.encode(b))
860 }
861
862 pub(super) fn de_bytes<'de, D: serde::Deserializer<'de>>(d: D) -> Result<Bytes, D::Error> {
863 use serde::Deserialize as _;
864 let s = String::deserialize(d)?;
865 B64.decode(s.as_bytes()).map(Bytes::from).map_err(serde::de::Error::custom)
866 }
867
868 pub(super) fn ser_regex<S: serde::Serializer>(
869 r: &fancy_regex::Regex,
870 s: S,
871 ) -> Result<S::Ok, S::Error> {
872 s.serialize_str(r.as_str())
873 }
874
875 pub(super) fn de_regex<'de, D: serde::Deserializer<'de>>(
876 d: D,
877 ) -> Result<fancy_regex::Regex, D::Error> {
878 use serde::Deserialize as _;
879 let s = String::deserialize(d)?;
880 fancy_regex::Regex::new(&s)
881 .map_err(|e| serde::de::Error::custom(format!("invalid regex {s:?}: {e}")))
882 }
883
884 #[derive(serde::Serialize, serde::Deserialize)]
886 #[serde(rename_all = "snake_case")]
887 pub(super) enum ValueShadow {
888 Str(Arc<str>),
889 #[serde(serialize_with = "ser_bytes", deserialize_with = "de_bytes")]
890 Bytes(Bytes),
891 Int(i64),
892 Bool(bool),
893 Addr(IpAddr),
894 }
895
896 impl From<&CompiledValue> for ValueShadow {
897 fn from(v: &CompiledValue) -> Self {
898 match v {
899 CompiledValue::Str(s) => Self::Str(Arc::clone(s)),
900 CompiledValue::Bytes(b) => Self::Bytes(b.clone()),
901 CompiledValue::Int(i) => Self::Int(*i),
902 CompiledValue::Bool(b) => Self::Bool(*b),
903 CompiledValue::Addr(a) => Self::Addr(*a),
904 }
905 }
906 }
907
908 impl From<ValueShadow> for CompiledValue {
909 fn from(v: ValueShadow) -> Self {
910 match v {
911 ValueShadow::Str(s) => Self::Str(s),
912 ValueShadow::Bytes(b) => Self::Bytes(b),
913 ValueShadow::Int(i) => Self::Int(i),
914 ValueShadow::Bool(b) => Self::Bool(b),
915 ValueShadow::Addr(a) => Self::Addr(a),
916 }
917 }
918 }
919
920 #[derive(serde::Serialize, serde::Deserialize)]
923 #[serde(rename_all = "snake_case")]
924 pub(super) enum OperatorShadow {
925 Equals(CompiledValue),
926 NotEquals(CompiledValue),
927 #[serde(serialize_with = "ser_bytes", deserialize_with = "de_bytes")]
928 Contains(Bytes),
929 #[serde(serialize_with = "ser_bytes", deserialize_with = "de_bytes")]
930 NotContains(Bytes),
931 #[serde(serialize_with = "ser_bytes", deserialize_with = "de_bytes")]
932 Prefix(Bytes),
933 #[serde(serialize_with = "ser_bytes", deserialize_with = "de_bytes")]
934 Suffix(Bytes),
935 #[serde(serialize_with = "ser_regex", deserialize_with = "de_regex")]
936 Matches(fancy_regex::Regex),
937 In(Vec<CompiledValue>),
938 NotIn(Vec<CompiledValue>),
939 Gt(i64),
940 Gte(i64),
941 Lt(i64),
942 Lte(i64),
943 Cidr(ipnet::IpNet),
944 }
945
946 impl From<&CompiledOperator> for OperatorShadow {
947 fn from(op: &CompiledOperator) -> Self {
948 match op {
949 CompiledOperator::Equals(v) => Self::Equals(v.clone()),
950 CompiledOperator::NotEquals(v) => Self::NotEquals(v.clone()),
951 CompiledOperator::Contains(b) => Self::Contains(b.clone()),
952 CompiledOperator::NotContains(b) => Self::NotContains(b.clone()),
953 CompiledOperator::Prefix(b) => Self::Prefix(b.clone()),
954 CompiledOperator::Suffix(b) => Self::Suffix(b.clone()),
955 CompiledOperator::Matches(r) => {
956 Self::Matches(fancy_regex::Regex::new(r.as_str()).expect("round-trippable"))
957 }
958 CompiledOperator::In(vs) => Self::In(vs.clone()),
959 CompiledOperator::NotIn(vs) => Self::NotIn(vs.clone()),
960 CompiledOperator::Gt(i) => Self::Gt(*i),
961 CompiledOperator::Gte(i) => Self::Gte(*i),
962 CompiledOperator::Lt(i) => Self::Lt(*i),
963 CompiledOperator::Lte(i) => Self::Lte(*i),
964 CompiledOperator::Cidr(n) => Self::Cidr(*n),
965 }
966 }
967 }
968
969 impl From<OperatorShadow> for CompiledOperator {
970 fn from(op: OperatorShadow) -> Self {
971 match op {
972 OperatorShadow::Equals(v) => Self::Equals(v),
973 OperatorShadow::NotEquals(v) => Self::NotEquals(v),
974 OperatorShadow::Contains(b) => Self::Contains(b),
975 OperatorShadow::NotContains(b) => Self::NotContains(b),
976 OperatorShadow::Prefix(b) => Self::Prefix(b),
977 OperatorShadow::Suffix(b) => Self::Suffix(b),
978 OperatorShadow::Matches(r) => Self::Matches(r),
979 OperatorShadow::In(vs) => Self::In(vs),
980 OperatorShadow::NotIn(vs) => Self::NotIn(vs),
981 OperatorShadow::Gt(i) => Self::Gt(i),
982 OperatorShadow::Gte(i) => Self::Gte(i),
983 OperatorShadow::Lt(i) => Self::Lt(i),
984 OperatorShadow::Lte(i) => Self::Lte(i),
985 OperatorShadow::Cidr(n) => Self::Cidr(n),
986 }
987 }
988 }
989}
990
991impl serde::Serialize for CompiledValue {
992 fn serialize<S: serde::Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
993 serde_impls::ValueShadow::from(self).serialize(s)
994 }
995}
996
997impl<'de> serde::Deserialize<'de> for CompiledValue {
998 fn deserialize<D: serde::Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
999 serde_impls::ValueShadow::deserialize(d).map(Self::from)
1000 }
1001}
1002
1003impl serde::Serialize for CompiledOperator {
1004 fn serialize<S: serde::Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
1005 serde_impls::OperatorShadow::from(self).serialize(s)
1006 }
1007}
1008
1009impl<'de> serde::Deserialize<'de> for CompiledOperator {
1010 fn deserialize<D: serde::Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
1011 serde_impls::OperatorShadow::deserialize(d).map(Self::from)
1012 }
1013}
1014
1015#[cfg(test)]
1016mod tests {
1017 use std::collections::hash_map::DefaultHasher;
1018 use std::hash::Hash;
1019 use std::net::{Ipv4Addr, Ipv6Addr};
1020 use std::str::FromStr;
1021 use std::sync::OnceLock;
1022 use std::time::Instant;
1023
1024 use bytes::Bytes;
1025 use fancy_regex::Regex;
1026 use ipnet::IpNet;
1027 use parking_lot::Mutex;
1028
1029 use super::*;
1030 use crate::body::{Body, Request};
1031 use crate::conn_context::{ConnId, Transport};
1032
1033 fn hash_of<T: Hash>(v: &T) -> u64 {
1038 let mut h = DefaultHasher::new();
1039 v.hash(&mut h);
1040 h.finish()
1041 }
1042
1043 fn make_conn() -> Arc<ConnContext> {
1044 Arc::new(ConnContext {
1045 id: ConnId(1),
1046 remote: "127.0.0.1:0".parse().expect("parse remote"),
1047 local: "127.0.0.1:0".parse().expect("parse local"),
1048 transport: Transport::Tcp,
1049 entered_at: Instant::now(),
1050 tls: Mutex::new(None),
1051 http_version: OnceLock::new(),
1052 user: Mutex::new(http::Extensions::new()),
1053 })
1054 }
1055
1056 #[test]
1057 fn field_path_http_header_is_equal_by_string_content_not_arc_identity() {
1058 let a = FieldPath::HttpHeader(Arc::from("host"));
1059 let b = FieldPath::HttpHeader(Arc::from("host"));
1060 assert_eq!(a, b);
1061 assert_eq!(hash_of(&a), hash_of(&b));
1062 let upper = FieldPath::HttpHeader(Arc::from("Host"));
1067 assert_ne!(a, upper);
1068 }
1069
1070 #[test]
1071 fn field_path_simple_variants_are_self_equal_and_mutually_distinct() {
1072 let paths = [
1073 FieldPath::Transport,
1074 FieldPath::RemoteIp,
1075 FieldPath::RemotePort,
1076 FieldPath::LocalIp,
1077 FieldPath::LocalPort,
1078 FieldPath::Peek,
1079 FieldPath::TlsSni,
1080 FieldPath::TlsAlpn,
1081 FieldPath::TlsVersion,
1082 FieldPath::TlsPeerCertSubjectCn,
1083 FieldPath::HttpMethod,
1084 FieldPath::HttpUriPath,
1085 FieldPath::HttpUriQuery,
1086 FieldPath::HttpBody,
1087 ];
1088 for (i, a) in paths.iter().enumerate() {
1089 for (j, b) in paths.iter().enumerate() {
1090 if i == j {
1091 assert_eq!(a, b);
1092 } else {
1093 assert_ne!(a, b);
1094 }
1095 }
1096 }
1097 }
1098
1099 #[test]
1100 fn compiled_value_str_is_equal_by_content_not_arc_identity() {
1101 let a = CompiledValue::Str(Arc::<str>::from("x"));
1102 let b = CompiledValue::Str(Arc::<str>::from("x"));
1103 assert_eq!(a, b);
1104 assert_eq!(hash_of(&a), hash_of(&b));
1105 let c = CompiledValue::Str(Arc::<str>::from("y"));
1106 assert_ne!(a, c);
1107 }
1108
1109 #[test]
1110 fn compiled_value_cross_variant_inequality() {
1111 let s = CompiledValue::Str(Arc::<str>::from("42"));
1112 let i = CompiledValue::Int(42);
1113 assert_ne!(s, i);
1114 }
1115
1116 #[test]
1117 fn compiled_value_bytes_int_bool_addr_self_equal() {
1118 assert_eq!(
1119 CompiledValue::Bytes(Bytes::from_static(b"abc")),
1120 CompiledValue::Bytes(Bytes::copy_from_slice(b"abc")),
1121 );
1122 assert_eq!(CompiledValue::Int(7), CompiledValue::Int(7));
1123 assert_ne!(CompiledValue::Int(7), CompiledValue::Int(8));
1124 assert_eq!(CompiledValue::Bool(true), CompiledValue::Bool(true));
1125 assert_ne!(CompiledValue::Bool(true), CompiledValue::Bool(false));
1126 assert_eq!(
1127 CompiledValue::Addr(Ipv4Addr::new(10, 0, 0, 1).into()),
1128 CompiledValue::Addr(Ipv4Addr::new(10, 0, 0, 1).into()),
1129 );
1130 assert_ne!(
1131 CompiledValue::Addr(Ipv4Addr::new(10, 0, 0, 1).into()),
1132 CompiledValue::Addr(Ipv6Addr::LOCALHOST.into()),
1133 );
1134 }
1135
1136 #[test]
1137 fn compiled_operator_matches_equal_by_pattern_source() {
1138 let a = CompiledOperator::Matches(Regex::new("^/api").expect("compile a"));
1139 let b = CompiledOperator::Matches(Regex::new("^/api").expect("compile b"));
1140 assert_eq!(a, b);
1141 assert_eq!(hash_of(&a), hash_of(&b));
1142 }
1143
1144 #[test]
1145 fn compiled_operator_matches_distinct_patterns_unequal() {
1146 let a = CompiledOperator::Matches(Regex::new("a|b").expect("compile a"));
1149 let b = CompiledOperator::Matches(Regex::new("b|a").expect("compile b"));
1150 assert_ne!(a, b);
1151 }
1152
1153 #[test]
1154 fn compiled_operator_cidr_equal_by_canonical_form() {
1155 let a = CompiledOperator::Cidr(IpNet::from_str("10.0.0.0/8").expect("parse a"));
1156 let b = CompiledOperator::Cidr(IpNet::from_str("10.0.0.0/8").expect("parse b"));
1157 assert_eq!(a, b);
1158 assert_eq!(hash_of(&a), hash_of(&b));
1159 }
1160
1161 #[test]
1162 fn compiled_operator_cidr_distinct_networks_unequal() {
1163 let a = CompiledOperator::Cidr(IpNet::from_str("10.0.0.0/8").expect("parse a"));
1164 let b = CompiledOperator::Cidr(IpNet::from_str("10.0.0.0/16").expect("parse b"));
1165 assert_ne!(a, b);
1166 }
1167
1168 #[test]
1169 fn compiled_operator_in_is_order_sensitive() {
1170 let xs =
1171 vec![CompiledValue::Str(Arc::<str>::from("a")), CompiledValue::Str(Arc::<str>::from("b"))];
1172 let ys =
1173 vec![CompiledValue::Str(Arc::<str>::from("b")), CompiledValue::Str(Arc::<str>::from("a"))];
1174 assert_ne!(CompiledOperator::In(xs.clone()), CompiledOperator::In(ys.clone()));
1175 assert_ne!(CompiledOperator::NotIn(xs), CompiledOperator::NotIn(ys));
1176 }
1177
1178 #[test]
1179 fn compiled_operator_numeric_comparisons_distinct_per_variant() {
1180 let ops = [
1182 CompiledOperator::Gt(10),
1183 CompiledOperator::Gte(10),
1184 CompiledOperator::Lt(10),
1185 CompiledOperator::Lte(10),
1186 ];
1187 for (i, a) in ops.iter().enumerate() {
1188 for (j, b) in ops.iter().enumerate() {
1189 if i == j {
1190 assert_eq!(a, b);
1191 } else {
1192 assert_ne!(a, b);
1193 }
1194 }
1195 }
1196 }
1197
1198 #[test]
1199 fn compiled_operator_bytes_variants_distinguished() {
1200 let payload = Bytes::from_static(b"abc");
1201 let ops = [
1202 CompiledOperator::Contains(payload.clone()),
1203 CompiledOperator::NotContains(payload.clone()),
1204 CompiledOperator::Prefix(payload.clone()),
1205 CompiledOperator::Suffix(payload),
1206 ];
1207 for (i, a) in ops.iter().enumerate() {
1208 for (j, b) in ops.iter().enumerate() {
1209 if i == j {
1210 assert_eq!(a, b);
1211 } else {
1212 assert_ne!(a, b);
1213 }
1214 }
1215 }
1216 }
1217
1218 #[test]
1219 fn predicate_inst_equal_across_independent_construction_paths() {
1220 let lhs = PredicateInst {
1221 path: FieldPath::HttpHeader(Arc::from("host")),
1222 op: CompiledOperator::Equals(CompiledValue::Str(Arc::<str>::from("example.com"))),
1223 };
1224 let rhs = PredicateInst {
1225 path: FieldPath::HttpHeader(Arc::from("host")),
1226 op: CompiledOperator::Equals(CompiledValue::Str(Arc::<str>::from("example.com"))),
1227 };
1228 assert_eq!(lhs, rhs);
1229 assert_eq!(hash_of(&lhs), hash_of(&rhs));
1230 }
1231
1232 #[test]
1233 fn predicate_inst_equal_with_regex_operator_from_separate_compiles() {
1234 let lhs = PredicateInst {
1235 path: FieldPath::HttpUriPath,
1236 op: CompiledOperator::Matches(Regex::new("^/").expect("compile a")),
1237 };
1238 let rhs = PredicateInst {
1239 path: FieldPath::HttpUriPath,
1240 op: CompiledOperator::Matches(Regex::new("^/").expect("compile b")),
1241 };
1242 assert_eq!(lhs, rhs);
1243 assert_eq!(hash_of(&lhs), hash_of(&rhs));
1244 }
1245
1246 #[test]
1247 fn predicate_inst_unequal_on_path_difference() {
1248 let value = CompiledValue::Str(Arc::<str>::from("x"));
1249 let a =
1250 PredicateInst { path: FieldPath::HttpUriPath, op: CompiledOperator::Equals(value.clone()) };
1251 let b = PredicateInst { path: FieldPath::HttpUriQuery, op: CompiledOperator::Equals(value) };
1252 assert_ne!(a, b);
1253 }
1254
1255 #[test]
1256 fn predicate_view_variants_construct() {
1257 let conn = make_conn();
1258 let peek_bytes: &[u8] = b"\x16\x03\x01";
1259 let l4 = PredicateView::L4 { conn: &conn, peek: Some(peek_bytes) };
1260 match l4 {
1261 PredicateView::L4 { peek, .. } => assert_eq!(peek.map(<[u8]>::len), Some(3)),
1262 PredicateView::L7Req { .. } => panic!("wrong variant"),
1263 }
1264
1265 let conn2 = make_conn();
1266 let req: Request =
1267 http::Request::builder().method("GET").uri("/").body(Body::Empty).expect("build request");
1268 let l7 = PredicateView::L7Req { conn: &conn2, req: &req };
1269 match l7 {
1270 PredicateView::L7Req { .. } => {}
1271 PredicateView::L4 { .. } => panic!("wrong variant"),
1272 }
1273 }
1274
1275 fn parse_predicate(v: serde_json::Value) -> Result<Predicate, serde_json::Error> {
1279 serde_json::from_value(v)
1280 }
1281
1282 fn expect_check(p: &Predicate) -> &CheckMap {
1283 match p {
1284 Predicate::Check(c) => c,
1285 other => panic!("expected Predicate::Check, got {other:?}"),
1286 }
1287 }
1288
1289 #[test]
1290 fn parse_any_of_happy_path() {
1291 let raw = serde_json::json!({
1292 "any_of": [
1293 { "tls.sni": { "equals": "a" } },
1294 { "tls.sni": { "equals": "b" } },
1295 ],
1296 });
1297 let p = parse_predicate(raw).expect("parse any_of");
1298 let Predicate::AnyOf(AnyOfP { any_of }) = p else {
1299 panic!("expected AnyOf");
1300 };
1301 assert_eq!(any_of.len(), 2);
1302 let c0 = expect_check(&any_of[0]);
1303 let c1 = expect_check(&any_of[1]);
1304 assert_eq!(c0.path, FieldPath::TlsSni);
1305 assert_eq!(c1.path, FieldPath::TlsSni);
1306 match (&c0.op, &c1.op) {
1307 (Operator::Equals(Value::Str(a)), Operator::Equals(Value::Str(b))) => {
1308 assert_eq!(a, "a");
1309 assert_eq!(b, "b");
1310 }
1311 (a, b) => panic!("unexpected ops: {a:?} / {b:?}"),
1312 }
1313 }
1314
1315 #[test]
1316 fn parse_not_happy_path() {
1317 let raw = serde_json::json!({
1318 "not": { "tls.sni": { "equals": "internal" } },
1319 });
1320 let p = parse_predicate(raw).expect("parse not");
1321 let Predicate::Not(NotP { not }) = p else {
1322 panic!("expected Not");
1323 };
1324 let inner = expect_check(¬);
1325 assert_eq!(inner.path, FieldPath::TlsSni);
1326 match &inner.op {
1327 Operator::Equals(Value::Str(s)) => assert_eq!(s, "internal"),
1328 other => panic!("unexpected op: {other:?}"),
1329 }
1330 }
1331
1332 #[test]
1333 fn parse_all_of_happy_path() {
1334 let raw = serde_json::json!({
1335 "all_of": [
1336 { "http.header.upgrade": { "equals": "websocket" } },
1337 { "http.uri.path": { "prefix": "/ws" } },
1338 ],
1339 });
1340 let p = parse_predicate(raw).expect("parse all_of");
1341 let Predicate::AllOf(AllOfP { all_of }) = p else {
1342 panic!("expected AllOf");
1343 };
1344 assert_eq!(all_of.len(), 2);
1345 let c0 = expect_check(&all_of[0]);
1346 let c1 = expect_check(&all_of[1]);
1347 assert_eq!(c0.path, FieldPath::HttpHeader(Arc::from("upgrade")));
1348 assert_eq!(c1.path, FieldPath::HttpUriPath);
1349 }
1350
1351 #[test]
1352 fn parse_all_of_empty_array_parses() {
1353 let raw = serde_json::json!({ "all_of": [] });
1356 let p = parse_predicate(raw).expect("empty all_of parses");
1357 let Predicate::AllOf(AllOfP { all_of }) = p else {
1358 panic!("expected AllOf");
1359 };
1360 assert!(all_of.is_empty());
1361 }
1362
1363 #[test]
1364 fn parse_all_of_nested_with_check_and_any_of() {
1365 let raw = serde_json::json!({
1366 "all_of": [
1367 { "tls.sni": { "equals": "api.example.com" } },
1368 { "any_of": [
1369 { "remote.ip": { "cidr": "10.0.0.0/8" } },
1370 { "remote.ip": { "cidr": "192.168.0.0/16" } },
1371 ]},
1372 ],
1373 });
1374 let p = parse_predicate(raw).expect("parse nested all_of/any_of");
1375 let Predicate::AllOf(AllOfP { all_of }) = p else {
1376 panic!("expected AllOf");
1377 };
1378 assert_eq!(all_of.len(), 2);
1379 assert!(matches!(all_of[0], Predicate::Check(_)));
1380 assert!(matches!(all_of[1], Predicate::AnyOf(_)));
1381 }
1382
1383 #[test]
1384 fn parse_all_of_with_extra_key_is_rejected() {
1385 let raw = serde_json::json!({
1387 "all_of": [ { "tls.sni": { "equals": "a" } } ],
1388 "extra": "unwanted",
1389 });
1390 let err = parse_predicate(raw).expect_err("must reject extra key on all_of");
1391 let _ = err.to_string();
1392 }
1393
1394 #[test]
1395 fn parse_http_header_all_of_is_a_check_not_combinator() {
1396 let raw = serde_json::json!({ "http.header.all_of": { "equals": "x" } });
1399 let p = parse_predicate(raw).expect("parse http.header.all_of");
1400 let c = expect_check(&p);
1401 assert_eq!(c.path, FieldPath::HttpHeader(Arc::from("all_of")));
1402 }
1403
1404 #[test]
1405 fn parse_check_across_representative_paths() {
1406 let cases = [
1407 (serde_json::json!({ "tls.sni": { "equals": "api.example.com" } }), FieldPath::TlsSni),
1408 (serde_json::json!({ "remote.port": { "gt": 1024 } }), FieldPath::RemotePort),
1409 (serde_json::json!({ "http.method": { "equals": "GET" } }), FieldPath::HttpMethod),
1410 (serde_json::json!({ "http.uri.path": { "prefix": "/api" } }), FieldPath::HttpUriPath),
1411 (
1412 serde_json::json!({ "http.header.host": { "equals": "a.example.com" } }),
1413 FieldPath::HttpHeader(Arc::from("host")),
1414 ),
1415 (serde_json::json!({ "http.body": { "contains": "hello" } }), FieldPath::HttpBody),
1416 ];
1417 for (raw, expected_path) in cases {
1418 let p = parse_predicate(raw.clone()).unwrap_or_else(|e| panic!("parse {raw}: {e}"));
1419 let c = expect_check(&p);
1420 assert_eq!(c.path, expected_path, "input: {raw}");
1421 }
1422 }
1423
1424 #[test]
1425 fn parse_any_of_with_extra_key_is_rejected() {
1426 let raw = serde_json::json!({
1429 "any_of": [ { "tls.sni": { "equals": "a" } } ],
1430 "extra": true,
1431 });
1432 let err = parse_predicate(raw).expect_err("must reject extra key on any_of");
1433 let _ = err.to_string();
1434 }
1435
1436 #[test]
1437 fn parse_http_header_any_of_is_a_check_not_combinator() {
1438 let raw = serde_json::json!({ "http.header.any_of": { "equals": "x" } });
1441 let p = parse_predicate(raw).expect("parse");
1442 let c = expect_check(&p);
1443 assert_eq!(c.path, FieldPath::HttpHeader(Arc::from("any_of")));
1444 }
1445
1446 #[test]
1447 fn parse_uppercase_field_path_suggests_lowercase() {
1448 let raw = serde_json::json!({ "http.header.Host": { "equals": "x" } });
1449 let err = parse_predicate(raw).expect_err("uppercase must fail");
1450 let msg = err.to_string();
1451 assert!(msg.contains("http.header.Host"), "error mentions offending input: {msg}");
1452 assert!(msg.contains("did you mean"), "error includes suggestion phrase: {msg}");
1453 assert!(msg.contains("http.header.host"), "error contains lowercased form: {msg}");
1454 }
1455
1456 #[test]
1457 fn parse_multi_key_check_is_rejected() {
1458 let raw = serde_json::json!({
1459 "http.uri.path": { "matches": "^/" },
1460 "http.method": { "equals": "GET" },
1461 });
1462 let err = parse_predicate(raw).expect_err("multi-key check must fail");
1463 let _ = err.to_string();
1464 }
1465
1466 #[test]
1467 fn parse_empty_http_header_name_is_rejected() {
1468 let raw = serde_json::json!({ "http.header.": { "equals": "x" } });
1469 let err = parse_predicate(raw).expect_err("empty header name must fail");
1470 let _ = err.to_string();
1471 }
1472
1473 #[test]
1474 fn parse_unknown_field_path_is_rejected_with_name() {
1475 let raw = serde_json::json!({ "http.nope": { "equals": "x" } });
1476 let err = parse_predicate(raw).expect_err("unknown path must fail");
1477 let msg = err.to_string();
1478 assert!(msg.contains("http.nope"), "error mentions offending path: {msg}");
1479 }
1480
1481 fn parse_op(v: serde_json::Value) -> Operator {
1482 let mut map = serde_json::Map::new();
1483 map.insert("tls.sni".to_string(), v);
1484 let raw = serde_json::Value::Object(map);
1485 match parse_predicate(raw).expect("parse check") {
1486 Predicate::Check(c) => c.op,
1487 other => panic!("expected Check, got {other:?}"),
1488 }
1489 }
1490
1491 #[test]
1492 fn operator_equals_and_not_equals_on_string() {
1493 let eq = parse_op(serde_json::json!({ "equals": "api" }));
1494 match eq {
1495 Operator::Equals(Value::Str(s)) => assert_eq!(s, "api"),
1496 other => panic!("expected equals/str: {other:?}"),
1497 }
1498 let neq = parse_op(serde_json::json!({ "not_equals": "api" }));
1499 match neq {
1500 Operator::NotEquals(Value::Str(s)) => assert_eq!(s, "api"),
1501 other => panic!("expected not_equals/str: {other:?}"),
1502 }
1503 }
1504
1505 #[test]
1506 fn operator_contains_and_not_contains_on_string() {
1507 let c = parse_op(serde_json::json!({ "contains": "foo" }));
1508 match c {
1509 Operator::Contains(Value::Str(s)) => assert_eq!(s, "foo"),
1510 other => panic!("expected contains/str: {other:?}"),
1511 }
1512 let nc = parse_op(serde_json::json!({ "not_contains": "foo" }));
1513 match nc {
1514 Operator::NotContains(Value::Str(s)) => assert_eq!(s, "foo"),
1515 other => panic!("expected not_contains/str: {other:?}"),
1516 }
1517 }
1518
1519 #[test]
1520 fn operator_prefix_and_suffix_on_string() {
1521 let p = parse_op(serde_json::json!({ "prefix": "/api" }));
1522 match p {
1523 Operator::Prefix(Value::Str(s)) => assert_eq!(s, "/api"),
1524 other => panic!("expected prefix/str: {other:?}"),
1525 }
1526 let s = parse_op(serde_json::json!({ "suffix": ".json" }));
1527 match s {
1528 Operator::Suffix(Value::Str(v)) => assert_eq!(v, ".json"),
1529 other => panic!("expected suffix/str: {other:?}"),
1530 }
1531 }
1532
1533 #[test]
1534 fn operator_matches_carries_pattern_source() {
1535 let op = parse_op(serde_json::json!({ "matches": "^/api/v\\d+" }));
1536 match op {
1537 Operator::Matches(pattern) => assert_eq!(pattern, "^/api/v\\d+"),
1538 other => panic!("expected matches: {other:?}"),
1539 }
1540 }
1541
1542 #[test]
1543 fn operator_in_and_not_in_accept_mixed_scalar_types() {
1544 let op = parse_op(serde_json::json!({ "in": ["foo", 42] }));
1545 let Operator::In(xs) = op else {
1546 panic!("expected in");
1547 };
1548 assert_eq!(xs.len(), 2);
1549 assert_eq!(xs[0], Value::Str("foo".into()));
1550 assert_eq!(xs[1], Value::Int(42));
1551 let op2 = parse_op(serde_json::json!({ "not_in": ["bar", 7] }));
1552 let Operator::NotIn(ys) = op2 else {
1553 panic!("expected not_in");
1554 };
1555 assert_eq!(ys.len(), 2);
1556 assert_eq!(ys[0], Value::Str("bar".into()));
1557 assert_eq!(ys[1], Value::Int(7));
1558 }
1559
1560 #[test]
1561 fn operator_numeric_comparisons() {
1562 assert!(matches!(parse_op(serde_json::json!({ "gt": 10 })), Operator::Gt(10)));
1563 assert!(matches!(parse_op(serde_json::json!({ "gte": 10 })), Operator::Gte(10)));
1564 assert!(matches!(parse_op(serde_json::json!({ "lt": 10 })), Operator::Lt(10)));
1565 assert!(matches!(parse_op(serde_json::json!({ "lte": 10 })), Operator::Lte(10)));
1566 }
1567
1568 #[test]
1569 fn operator_cidr_carries_source_string() {
1570 let op = parse_op(serde_json::json!({ "cidr": "10.0.0.0/8" }));
1571 match op {
1572 Operator::Cidr(s) => assert_eq!(s, "10.0.0.0/8"),
1573 other => panic!("expected cidr: {other:?}"),
1574 }
1575 }
1576
1577 #[test]
1578 fn value_untagged_priority_bool_before_str() {
1579 let op_t = parse_op(serde_json::json!({ "equals": true }));
1582 assert!(matches!(op_t, Operator::Equals(Value::Bool(true))));
1583 let op_f = parse_op(serde_json::json!({ "equals": false }));
1584 assert!(matches!(op_f, Operator::Equals(Value::Bool(false))));
1585 }
1586
1587 #[test]
1588 fn value_untagged_priority_int_before_str() {
1589 let op = parse_op(serde_json::json!({ "equals": 42 }));
1591 assert!(matches!(op, Operator::Equals(Value::Int(42))));
1592 }
1593
1594 #[test]
1595 fn value_untagged_json_string_stays_str() {
1596 let op = parse_op(serde_json::json!({ "equals": "42" }));
1599 match op {
1600 Operator::Equals(Value::Str(s)) => assert_eq!(s, "42"),
1601 other => panic!("expected equals/str(\"42\"): {other:?}"),
1602 }
1603 }
1604
1605 #[test]
1606 fn regex_pattern_exactly_at_limit_parses() {
1607 assert_eq!(REGEX_PATTERN_MAX_BYTES, 4 * 1024);
1609 let pattern = "a".repeat(REGEX_PATTERN_MAX_BYTES);
1610 let raw = serde_json::json!({ "http.uri.path": { "matches": pattern } });
1611 let p = parse_predicate(raw).expect("4 KiB pattern parses");
1612 let c = expect_check(&p);
1613 match &c.op {
1614 Operator::Matches(src) => assert_eq!(src.len(), REGEX_PATTERN_MAX_BYTES),
1615 other => panic!("expected matches: {other:?}"),
1616 }
1617 }
1618
1619 #[test]
1620 fn regex_pattern_over_limit_rejected_with_limit_in_message() {
1621 let pattern = "a".repeat(REGEX_PATTERN_MAX_BYTES + 1);
1622 let raw = serde_json::json!({ "http.uri.path": { "matches": pattern } });
1623 let err = parse_predicate(raw).expect_err("over-limit pattern must fail");
1624 let msg = err.to_string();
1625 assert!(
1626 msg.contains(®EX_PATTERN_MAX_BYTES.to_string()),
1627 "error mentions the limit ({REGEX_PATTERN_MAX_BYTES}): {msg}",
1628 );
1629 }
1630
1631 fn value_round_trip(v: &CompiledValue) -> CompiledValue {
1638 let encoded = serde_json::to_string(v).expect("serialize value");
1639 serde_json::from_str(&encoded).expect("deserialize value")
1640 }
1641
1642 #[test]
1643 fn compiled_value_str_round_trip_including_empty() {
1644 let non_empty = CompiledValue::Str(Arc::<str>::from("x"));
1645 assert_eq!(value_round_trip(&non_empty), non_empty);
1646 let empty = CompiledValue::Str(Arc::<str>::from(""));
1647 assert_eq!(value_round_trip(&empty), empty);
1648 }
1649
1650 #[test]
1651 fn compiled_value_bytes_round_trip_including_empty_and_binary() {
1652 let hello = CompiledValue::Bytes(Bytes::from_static(b"hello"));
1653 assert_eq!(value_round_trip(&hello), hello);
1654 let empty = CompiledValue::Bytes(Bytes::new());
1655 assert_eq!(value_round_trip(&empty), empty);
1656 let binary = CompiledValue::Bytes(Bytes::from_static(&[0xff, 0x00, 0x13]));
1657 assert_eq!(value_round_trip(&binary), binary);
1658 }
1659
1660 #[test]
1661 fn compiled_value_int_round_trip_including_extremes() {
1662 for i in [0_i64, i64::MIN, i64::MAX] {
1663 let v = CompiledValue::Int(i);
1664 assert_eq!(value_round_trip(&v), v);
1665 }
1666 }
1667
1668 #[test]
1669 fn compiled_value_bool_round_trip_both_variants() {
1670 for b in [true, false] {
1671 let v = CompiledValue::Bool(b);
1672 assert_eq!(value_round_trip(&v), v);
1673 }
1674 }
1675
1676 #[test]
1677 fn compiled_value_addr_round_trip_v4_and_v6() {
1678 let v4 = CompiledValue::Addr(Ipv4Addr::LOCALHOST.into());
1679 assert_eq!(value_round_trip(&v4), v4);
1680 let v6 = CompiledValue::Addr(Ipv6Addr::LOCALHOST.into());
1681 assert_eq!(value_round_trip(&v6), v6);
1682 }
1683
1684 #[test]
1685 fn compiled_value_bytes_emits_standard_base64_literal() {
1686 let v = CompiledValue::Bytes(Bytes::from_static(b"hello"));
1690 let encoded = serde_json::to_string(&v).expect("serialize");
1691 assert_eq!(encoded, r#"{"bytes":"aGVsbG8="}"#);
1692 }
1693
1694 fn op_round_trip(op: &CompiledOperator) -> CompiledOperator {
1695 let encoded = serde_json::to_string(op).expect("serialize op");
1696 serde_json::from_str(&encoded).expect("deserialize op")
1697 }
1698
1699 #[test]
1700 fn compiled_operator_equals_and_not_equals_round_trip() {
1701 let eq = CompiledOperator::Equals(CompiledValue::Str(Arc::<str>::from("x")));
1702 assert_eq!(op_round_trip(&eq), eq);
1703 let neq = CompiledOperator::NotEquals(CompiledValue::Str(Arc::<str>::from("x")));
1704 assert_eq!(op_round_trip(&neq), neq);
1705 }
1706
1707 #[test]
1708 fn compiled_operator_bytes_variants_round_trip() {
1709 let payload = Bytes::from_static(b"hello");
1710 let ops = [
1711 CompiledOperator::Contains(payload.clone()),
1712 CompiledOperator::NotContains(payload.clone()),
1713 CompiledOperator::Prefix(payload.clone()),
1714 CompiledOperator::Suffix(payload),
1715 ];
1716 for op in ops {
1717 assert_eq!(op_round_trip(&op), op);
1718 }
1719 }
1720
1721 #[test]
1722 fn compiled_operator_matches_round_trip_preserves_pattern_source() {
1723 let op = CompiledOperator::Matches(Regex::new("^/api/v[0-9]+").expect("compile"));
1724 let decoded = op_round_trip(&op);
1725 assert_eq!(decoded, op);
1727 match decoded {
1728 CompiledOperator::Matches(r) => assert_eq!(r.as_str(), "^/api/v[0-9]+"),
1729 other => panic!("expected matches, got {other:?}"),
1730 }
1731 }
1732
1733 #[test]
1734 fn compiled_operator_in_and_not_in_round_trip_mixed_values() {
1735 let xs = vec![CompiledValue::Str(Arc::<str>::from("a")), CompiledValue::Int(42)];
1736 let in_op = CompiledOperator::In(xs.clone());
1737 assert_eq!(op_round_trip(&in_op), in_op);
1738 let not_in_op = CompiledOperator::NotIn(xs);
1739 assert_eq!(op_round_trip(¬_in_op), not_in_op);
1740 }
1741
1742 #[test]
1743 fn compiled_operator_numeric_comparisons_round_trip() {
1744 let ops = [
1745 CompiledOperator::Gt(100),
1746 CompiledOperator::Gte(100),
1747 CompiledOperator::Lt(100),
1748 CompiledOperator::Lte(100),
1749 ];
1750 for op in ops {
1751 assert_eq!(op_round_trip(&op), op);
1752 }
1753 }
1754
1755 #[test]
1756 fn compiled_operator_cidr_round_trip_preserves_canonical_form() {
1757 let op = CompiledOperator::Cidr(IpNet::from_str("10.0.0.0/8").expect("parse"));
1758 assert_eq!(op_round_trip(&op), op);
1759 }
1760
1761 #[test]
1762 fn compiled_operator_matches_with_invalid_regex_is_rejected() {
1763 let raw = r#"{"matches":"["}"#;
1767 let err = serde_json::from_str::<CompiledOperator>(raw)
1768 .expect_err("invalid regex must fail to deserialize");
1769 let msg = err.to_string();
1770 assert!(msg.contains('['), "error mentions offending regex source: {msg}");
1771 }
1772
1773 #[test]
1774 fn predicate_inst_pins_exact_wire_shape_for_http_header_equals() {
1775 let inst = PredicateInst {
1776 path: FieldPath::HttpHeader(Arc::from("host")),
1777 op: CompiledOperator::Equals(CompiledValue::Str(Arc::<str>::from("example.com"))),
1778 };
1779 let encoded = serde_json::to_string(&inst).expect("serialize");
1780 assert_eq!(encoded, r#"{"path":{"http_header":"host"},"op":{"equals":{"str":"example.com"}}}"#,);
1781 let decoded: PredicateInst = serde_json::from_str(&encoded).expect("deserialize");
1782 assert_eq!(decoded, inst);
1783 }
1784
1785 #[test]
1786 fn predicate_inst_round_trip_with_regex_operator() {
1787 let inst = PredicateInst {
1788 path: FieldPath::HttpUriPath,
1789 op: CompiledOperator::Matches(Regex::new("^/api").expect("compile")),
1790 };
1791 let encoded = serde_json::to_string(&inst).expect("serialize");
1792 let decoded: PredicateInst = serde_json::from_str(&encoded).expect("deserialize");
1793 assert_eq!(decoded, inst);
1794 }
1795
1796 fn http_header_equals(name: &str, value: &str) -> PredicateInst {
1803 PredicateInst {
1804 path: FieldPath::HttpHeader(Arc::from(name)),
1805 op: CompiledOperator::Equals(CompiledValue::Str(Arc::<str>::from(value))),
1806 }
1807 }
1808
1809 fn http_uri_path_equals(value: &str) -> PredicateInst {
1810 PredicateInst {
1811 path: FieldPath::HttpUriPath,
1812 op: CompiledOperator::Equals(CompiledValue::Str(Arc::<str>::from(value))),
1813 }
1814 }
1815
1816 fn http_uri_path_prefix(value: &str) -> PredicateInst {
1817 PredicateInst {
1818 path: FieldPath::HttpUriPath,
1819 op: CompiledOperator::Prefix(Bytes::copy_from_slice(value.as_bytes())),
1820 }
1821 }
1822
1823 fn tls_sni_equals(value: &str) -> PredicateInst {
1824 PredicateInst {
1825 path: FieldPath::TlsSni,
1826 op: CompiledOperator::Equals(CompiledValue::Str(Arc::<str>::from(value))),
1827 }
1828 }
1829
1830 fn conn_with_sni(sni: &str) -> Arc<ConnContext> {
1831 let conn = make_conn();
1832 *conn.tls.lock() = Some(crate::conn_context::TlsInfo {
1833 sni: Some(sni.to_string()),
1834 alpn: None,
1835 version: None,
1836 peer_cert: None,
1837 zero_rtt_used: false,
1838 });
1839 conn
1840 }
1841
1842 fn req_with_header(name: &str, value: &str) -> Request {
1843 http::Request::builder()
1844 .method("GET")
1845 .uri("/")
1846 .header(name, value)
1847 .body(Body::Empty)
1848 .expect("build req")
1849 }
1850
1851 fn req_with_uri(uri: &str) -> Request {
1852 http::Request::builder().method("GET").uri(uri).body(Body::Empty).expect("build req")
1853 }
1854
1855 #[test]
1856 fn predicate_test_http_header_equals_matches_when_present_and_equal() {
1857 let conn = make_conn();
1858 let req = req_with_header("upgrade", "websocket");
1859 let view = PredicateView::L7Req { conn: &conn, req: &req };
1860 assert!(http_header_equals("upgrade", "websocket").test(&view));
1861 }
1862
1863 #[test]
1864 fn predicate_test_http_header_equals_misses_when_header_absent() {
1865 let conn = make_conn();
1866 let req = req_with_header("host", "example.com");
1867 let view = PredicateView::L7Req { conn: &conn, req: &req };
1868 assert!(!http_header_equals("upgrade", "websocket").test(&view));
1869 }
1870
1871 #[test]
1872 fn predicate_test_http_header_equals_value_is_case_sensitive() {
1873 let conn = make_conn();
1878 let req = req_with_header("upgrade", "WebSocket");
1879 let view = PredicateView::L7Req { conn: &conn, req: &req };
1880 assert!(!http_header_equals("upgrade", "websocket").test(&view));
1881 }
1882
1883 #[test]
1884 fn predicate_test_http_header_equals_name_lookup_is_case_insensitive() {
1885 let conn = make_conn();
1891 let req = req_with_header("Upgrade", "websocket");
1892 let view = PredicateView::L7Req { conn: &conn, req: &req };
1893 assert!(http_header_equals("upgrade", "websocket").test(&view));
1894 }
1895
1896 #[test]
1897 fn predicate_test_http_header_equals_misses_on_l4_view() {
1898 let conn = make_conn();
1902 let view = PredicateView::L4 { conn: &conn, peek: None };
1903 assert!(!http_header_equals("upgrade", "websocket").test(&view));
1904 }
1905
1906 #[test]
1907 fn predicate_test_http_uri_path_equals_matches_exact() {
1908 let conn = make_conn();
1909 let req = req_with_uri("/api/v1/users");
1910 let view = PredicateView::L7Req { conn: &conn, req: &req };
1911 assert!(http_uri_path_equals("/api/v1/users").test(&view));
1912 }
1913
1914 #[test]
1915 fn predicate_test_http_uri_path_equals_misses_on_substring() {
1916 let conn = make_conn();
1920 let req = req_with_uri("/api/v1/users");
1921 let view = PredicateView::L7Req { conn: &conn, req: &req };
1922 assert!(!http_uri_path_equals("/api").test(&view));
1923 }
1924
1925 #[test]
1926 fn predicate_test_http_uri_path_prefix_matches_when_path_starts_with() {
1927 let conn = make_conn();
1928 let req = req_with_uri("/api/v1/users");
1929 let view = PredicateView::L7Req { conn: &conn, req: &req };
1930 assert!(http_uri_path_prefix("/api").test(&view));
1931 }
1932
1933 #[test]
1934 fn predicate_test_http_uri_path_prefix_misses_when_no_prefix() {
1935 let conn = make_conn();
1936 let req = req_with_uri("/admin");
1937 let view = PredicateView::L7Req { conn: &conn, req: &req };
1938 assert!(!http_uri_path_prefix("/api").test(&view));
1939 }
1940
1941 #[test]
1942 fn predicate_test_tls_sni_equals_matches_when_set() {
1943 let conn = conn_with_sni("api.example.com");
1947 let req = req_with_uri("/");
1948 let view = PredicateView::L7Req { conn: &conn, req: &req };
1949 assert!(tls_sni_equals("api.example.com").test(&view));
1950 }
1951
1952 #[test]
1953 fn predicate_test_tls_sni_equals_misses_when_unset() {
1954 let conn = make_conn();
1957 let req = req_with_uri("/");
1958 let view = PredicateView::L7Req { conn: &conn, req: &req };
1959 assert!(!tls_sni_equals("api.example.com").test(&view));
1960 }
1961
1962 #[test]
1963 fn predicate_test_tls_sni_equals_works_in_l4_view_too() {
1964 let conn = conn_with_sni("api.example.com");
1970 let view = PredicateView::L4 { conn: &conn, peek: None };
1971 assert!(tls_sni_equals("api.example.com").test(&view));
1972 }
1973
1974 fn pred(path: FieldPath, op: CompiledOperator) -> PredicateInst {
1981 PredicateInst { path, op }
1982 }
1983
1984 fn str_val(s: &str) -> CompiledValue {
1985 CompiledValue::Str(Arc::<str>::from(s))
1986 }
1987
1988 fn bytes_val(b: &[u8]) -> CompiledValue {
1989 CompiledValue::Bytes(Bytes::copy_from_slice(b))
1990 }
1991
1992 fn b(b: &[u8]) -> Bytes {
1993 Bytes::copy_from_slice(b)
1994 }
1995
1996 fn make_conn_with(remote: &str, local: &str) -> Arc<ConnContext> {
1997 Arc::new(ConnContext {
1998 id: ConnId(1),
1999 remote: remote.parse().expect("parse remote"),
2000 local: local.parse().expect("parse local"),
2001 transport: Transport::Tcp,
2002 entered_at: Instant::now(),
2003 tls: Mutex::new(None),
2004 http_version: OnceLock::new(),
2005 user: Mutex::new(http::Extensions::new()),
2006 })
2007 }
2008
2009 fn make_conn_with_transport(t: Transport) -> Arc<ConnContext> {
2010 Arc::new(ConnContext {
2011 id: ConnId(1),
2012 remote: "127.0.0.1:0".parse().expect("remote"),
2013 local: "127.0.0.1:0".parse().expect("local"),
2014 transport: t,
2015 entered_at: Instant::now(),
2016 tls: Mutex::new(None),
2017 http_version: OnceLock::new(),
2018 user: Mutex::new(http::Extensions::new()),
2019 })
2020 }
2021
2022 fn conn_with_tls_alpn(alpn: &[u8]) -> Arc<ConnContext> {
2023 let conn = make_conn();
2024 *conn.tls.lock() = Some(crate::conn_context::TlsInfo {
2025 sni: None,
2026 alpn: Some(alpn.to_vec()),
2027 version: None,
2028 peer_cert: None,
2029 zero_rtt_used: false,
2030 });
2031 conn
2032 }
2033
2034 fn conn_with_tls_version(v: crate::conn_context::TlsVersion) -> Arc<ConnContext> {
2035 let conn = make_conn();
2036 *conn.tls.lock() = Some(crate::conn_context::TlsInfo {
2037 sni: None,
2038 alpn: None,
2039 version: Some(v),
2040 peer_cert: None,
2041 zero_rtt_used: false,
2042 });
2043 conn
2044 }
2045
2046 #[test]
2048 fn matrix_equality_str_happy_and_miss() {
2049 let conn = conn_with_sni("api.example.com");
2051 let v = PredicateView::L4 { conn: &conn, peek: None };
2052 assert!(pred(FieldPath::TlsSni, CompiledOperator::Equals(str_val("api.example.com"))).test(&v));
2053 assert!(
2054 !pred(FieldPath::TlsSni, CompiledOperator::Equals(str_val("other.example.com"))).test(&v)
2055 );
2056 assert!(
2057 pred(FieldPath::TlsSni, CompiledOperator::NotEquals(str_val("other.example.com"))).test(&v)
2058 );
2059 assert!(
2060 !pred(FieldPath::TlsSni, CompiledOperator::NotEquals(str_val("api.example.com"))).test(&v)
2061 );
2062 }
2063
2064 #[test]
2065 fn matrix_equality_bytes_happy_and_miss() {
2066 let conn = conn_with_tls_alpn(b"h2");
2068 let v = PredicateView::L4 { conn: &conn, peek: None };
2069 assert!(pred(FieldPath::TlsAlpn, CompiledOperator::Equals(bytes_val(b"h2"))).test(&v));
2070 assert!(!pred(FieldPath::TlsAlpn, CompiledOperator::Equals(bytes_val(b"http/1.1"))).test(&v));
2071 assert!(pred(FieldPath::TlsAlpn, CompiledOperator::NotEquals(bytes_val(b"http/1.1"))).test(&v));
2072 assert!(!pred(FieldPath::TlsAlpn, CompiledOperator::NotEquals(bytes_val(b"h2"))).test(&v));
2073 }
2074
2075 #[test]
2076 fn matrix_equality_int_happy_and_miss() {
2077 let conn = make_conn_with("127.0.0.1:9090", "127.0.0.1:80");
2078 let v = PredicateView::L4 { conn: &conn, peek: None };
2079 assert!(
2080 pred(FieldPath::RemotePort, CompiledOperator::Equals(CompiledValue::Int(9090))).test(&v)
2081 );
2082 assert!(
2083 !pred(FieldPath::RemotePort, CompiledOperator::Equals(CompiledValue::Int(81))).test(&v)
2084 );
2085 assert!(
2086 pred(FieldPath::RemotePort, CompiledOperator::NotEquals(CompiledValue::Int(81))).test(&v)
2087 );
2088 assert!(
2089 !pred(FieldPath::RemotePort, CompiledOperator::NotEquals(CompiledValue::Int(9090))).test(&v)
2090 );
2091 }
2092
2093 #[test]
2094 fn matrix_equality_addr_happy_and_miss() {
2095 let conn = make_conn_with("10.0.0.5:55555", "127.0.0.1:80");
2096 let v = PredicateView::L4 { conn: &conn, peek: None };
2097 let ten: std::net::IpAddr = "10.0.0.5".parse().unwrap();
2098 let other: std::net::IpAddr = "10.0.0.6".parse().unwrap();
2099 assert!(pred(FieldPath::RemoteIp, CompiledOperator::Equals(CompiledValue::Addr(ten))).test(&v));
2100 assert!(
2101 !pred(FieldPath::RemoteIp, CompiledOperator::Equals(CompiledValue::Addr(other))).test(&v)
2102 );
2103 assert!(
2104 pred(FieldPath::RemoteIp, CompiledOperator::NotEquals(CompiledValue::Addr(other))).test(&v)
2105 );
2106 assert!(
2107 !pred(FieldPath::RemoteIp, CompiledOperator::NotEquals(CompiledValue::Addr(ten))).test(&v)
2108 );
2109 }
2110
2111 #[test]
2112 fn matrix_equality_enum_transport_happy_and_miss() {
2113 let tcp = make_conn_with_transport(Transport::Tcp);
2114 let udp = make_conn_with_transport(Transport::Udp);
2115 let v_tcp = PredicateView::L4 { conn: &tcp, peek: None };
2116 let v_udp = PredicateView::L4 { conn: &udp, peek: None };
2117 assert!(pred(FieldPath::Transport, CompiledOperator::Equals(str_val("tcp"))).test(&v_tcp));
2118 assert!(!pred(FieldPath::Transport, CompiledOperator::Equals(str_val("udp"))).test(&v_tcp));
2119 assert!(pred(FieldPath::Transport, CompiledOperator::Equals(str_val("udp"))).test(&v_udp));
2120 }
2121
2122 #[test]
2123 fn matrix_equality_enum_tls_version_happy_and_miss() {
2124 let conn = conn_with_tls_version(crate::conn_context::TlsVersion::Tls13);
2125 let v = PredicateView::L4 { conn: &conn, peek: None };
2126 assert!(pred(FieldPath::TlsVersion, CompiledOperator::Equals(str_val("1.3"))).test(&v));
2127 assert!(!pred(FieldPath::TlsVersion, CompiledOperator::Equals(str_val("1.2"))).test(&v));
2128 assert!(pred(FieldPath::TlsVersion, CompiledOperator::NotEquals(str_val("1.2"))).test(&v));
2129 }
2130
2131 #[test]
2132 fn matrix_equality_enum_tls_version_misses_when_absent() {
2133 let conn = make_conn();
2135 let v = PredicateView::L4 { conn: &conn, peek: None };
2136 assert!(!pred(FieldPath::TlsVersion, CompiledOperator::Equals(str_val("1.3"))).test(&v));
2137 assert!(!pred(FieldPath::TlsVersion, CompiledOperator::NotEquals(str_val("1.3"))).test(&v));
2139 }
2140
2141 #[test]
2142 fn matrix_equality_enum_http_method_happy_and_miss() {
2143 let conn = make_conn();
2144 let req = http::Request::builder().method("POST").uri("/").body(Body::Empty).unwrap();
2145 let v = PredicateView::L7Req { conn: &conn, req: &req };
2146 assert!(pred(FieldPath::HttpMethod, CompiledOperator::Equals(str_val("POST"))).test(&v));
2147 assert!(!pred(FieldPath::HttpMethod, CompiledOperator::Equals(str_val("GET"))).test(&v));
2148 assert!(pred(FieldPath::HttpMethod, CompiledOperator::NotEquals(str_val("GET"))).test(&v));
2149 }
2150
2151 #[test]
2153 fn matrix_in_list_str_happy_and_miss() {
2154 let conn = conn_with_sni("api.example.com");
2155 let v = PredicateView::L4 { conn: &conn, peek: None };
2156 let list = vec![str_val("a.example.com"), str_val("api.example.com")];
2157 assert!(pred(FieldPath::TlsSni, CompiledOperator::In(list.clone())).test(&v));
2158 let list_miss = vec![str_val("a.example.com"), str_val("b.example.com")];
2159 assert!(!pred(FieldPath::TlsSni, CompiledOperator::In(list_miss.clone())).test(&v));
2160 assert!(pred(FieldPath::TlsSni, CompiledOperator::NotIn(list_miss)).test(&v));
2161 assert!(!pred(FieldPath::TlsSni, CompiledOperator::NotIn(list)).test(&v));
2162 }
2163
2164 #[test]
2165 fn matrix_in_list_bytes_happy_and_miss() {
2166 let conn = conn_with_tls_alpn(b"h2");
2167 let v = PredicateView::L4 { conn: &conn, peek: None };
2168 let list = vec![bytes_val(b"http/1.1"), bytes_val(b"h2")];
2169 assert!(pred(FieldPath::TlsAlpn, CompiledOperator::In(list.clone())).test(&v));
2170 let list_miss = vec![bytes_val(b"http/1.0"), bytes_val(b"http/1.1")];
2171 assert!(!pred(FieldPath::TlsAlpn, CompiledOperator::In(list_miss.clone())).test(&v));
2172 assert!(pred(FieldPath::TlsAlpn, CompiledOperator::NotIn(list_miss)).test(&v));
2173 }
2174
2175 #[test]
2176 fn matrix_in_list_int_happy_and_miss() {
2177 let conn = make_conn_with("127.0.0.1:443", "127.0.0.1:80");
2178 let v = PredicateView::L4 { conn: &conn, peek: None };
2179 let in_list = vec![CompiledValue::Int(80), CompiledValue::Int(443)];
2180 assert!(pred(FieldPath::RemotePort, CompiledOperator::In(in_list.clone())).test(&v));
2181 let miss_list = vec![CompiledValue::Int(80), CompiledValue::Int(81)];
2182 assert!(!pred(FieldPath::RemotePort, CompiledOperator::In(miss_list.clone())).test(&v));
2183 assert!(pred(FieldPath::RemotePort, CompiledOperator::NotIn(miss_list)).test(&v));
2184 }
2185
2186 #[test]
2187 fn matrix_in_list_addr_happy_and_miss_mixed_family() {
2188 let conn = make_conn_with("10.0.0.5:55555", "127.0.0.1:80");
2189 let v = PredicateView::L4 { conn: &conn, peek: None };
2190 let v4: std::net::IpAddr = "10.0.0.5".parse().unwrap();
2191 let v6: std::net::IpAddr = "::1".parse().unwrap();
2192 let list = vec![CompiledValue::Addr(v6), CompiledValue::Addr(v4)];
2193 assert!(pred(FieldPath::RemoteIp, CompiledOperator::In(list.clone())).test(&v));
2194 let miss = vec![CompiledValue::Addr(v6)];
2195 assert!(!pred(FieldPath::RemoteIp, CompiledOperator::In(miss.clone())).test(&v));
2196 assert!(pred(FieldPath::RemoteIp, CompiledOperator::NotIn(miss)).test(&v));
2197 }
2198
2199 #[test]
2200 fn matrix_in_list_enum_transport_happy_and_miss() {
2201 let conn = make_conn_with_transport(Transport::Udp);
2202 let v = PredicateView::L4 { conn: &conn, peek: None };
2203 let list = vec![str_val("tcp"), str_val("udp")];
2204 assert!(pred(FieldPath::Transport, CompiledOperator::In(list)).test(&v));
2205 let miss = vec![str_val("tcp")];
2206 assert!(!pred(FieldPath::Transport, CompiledOperator::In(miss.clone())).test(&v));
2207 assert!(pred(FieldPath::Transport, CompiledOperator::NotIn(miss)).test(&v));
2208 }
2209
2210 #[test]
2212 fn matrix_substring_on_str_happy_and_miss() {
2213 let conn = make_conn();
2214 let req =
2215 http::Request::builder().method("GET").uri("/api/v1/users").body(Body::Empty).unwrap();
2216 let v = PredicateView::L7Req { conn: &conn, req: &req };
2217 assert!(pred(FieldPath::HttpUriPath, CompiledOperator::Contains(b(b"/v1/"))).test(&v));
2218 assert!(!pred(FieldPath::HttpUriPath, CompiledOperator::Contains(b(b"/v2/"))).test(&v));
2219 assert!(pred(FieldPath::HttpUriPath, CompiledOperator::NotContains(b(b"/v2/"))).test(&v));
2220 assert!(!pred(FieldPath::HttpUriPath, CompiledOperator::NotContains(b(b"/v1/"))).test(&v));
2221 }
2222
2223 #[test]
2224 fn matrix_substring_on_bytes_happy_and_miss() {
2225 let conn = conn_with_tls_alpn(b"http/1.1");
2226 let v = PredicateView::L4 { conn: &conn, peek: None };
2227 assert!(pred(FieldPath::TlsAlpn, CompiledOperator::Contains(b(b"/1."))).test(&v));
2228 assert!(!pred(FieldPath::TlsAlpn, CompiledOperator::Contains(b(b"/2."))).test(&v));
2229 assert!(pred(FieldPath::TlsAlpn, CompiledOperator::NotContains(b(b"/2."))).test(&v));
2230 }
2231
2232 #[test]
2234 fn matrix_prefix_suffix_on_str_happy_and_miss() {
2235 let conn = make_conn();
2236 let req =
2237 http::Request::builder().method("GET").uri("/api/file.json?q=1").body(Body::Empty).unwrap();
2238 let v = PredicateView::L7Req { conn: &conn, req: &req };
2239 assert!(pred(FieldPath::HttpUriPath, CompiledOperator::Prefix(b(b"/api"))).test(&v));
2240 assert!(!pred(FieldPath::HttpUriPath, CompiledOperator::Prefix(b(b"/admin"))).test(&v));
2241 assert!(pred(FieldPath::HttpUriPath, CompiledOperator::Suffix(b(b".json"))).test(&v));
2242 assert!(!pred(FieldPath::HttpUriPath, CompiledOperator::Suffix(b(b".html"))).test(&v));
2243 }
2244
2245 #[test]
2246 fn matrix_prefix_suffix_on_bytes_happy_and_miss() {
2247 let conn = conn_with_tls_alpn(b"http/1.1");
2248 let v = PredicateView::L4 { conn: &conn, peek: None };
2249 assert!(pred(FieldPath::TlsAlpn, CompiledOperator::Prefix(b(b"http"))).test(&v));
2250 assert!(!pred(FieldPath::TlsAlpn, CompiledOperator::Prefix(b(b"h2"))).test(&v));
2251 assert!(pred(FieldPath::TlsAlpn, CompiledOperator::Suffix(b(b"1.1"))).test(&v));
2252 assert!(!pred(FieldPath::TlsAlpn, CompiledOperator::Suffix(b(b"2.0"))).test(&v));
2253 }
2254
2255 #[test]
2257 fn matrix_regex_matches_on_str_happy_and_miss() {
2258 let conn = make_conn();
2259 let req =
2260 http::Request::builder().method("GET").uri("/api/v3/orders").body(Body::Empty).unwrap();
2261 let v = PredicateView::L7Req { conn: &conn, req: &req };
2262 let re = Regex::new(r"^/api/v\d+/orders").expect("compile regex");
2263 assert!(pred(FieldPath::HttpUriPath, CompiledOperator::Matches(re)).test(&v));
2264 let re_miss = Regex::new(r"^/admin").expect("compile regex");
2265 assert!(!pred(FieldPath::HttpUriPath, CompiledOperator::Matches(re_miss)).test(&v));
2266 }
2267
2268 #[test]
2269 fn matrix_regex_matches_on_header_happy_and_miss() {
2270 let conn = make_conn();
2271 let req = http::Request::builder()
2272 .method("GET")
2273 .uri("/")
2274 .header("user-agent", "Mozilla/5.0 (Macintosh; Intel)")
2275 .body(Body::Empty)
2276 .unwrap();
2277 let v = PredicateView::L7Req { conn: &conn, req: &req };
2278 let re = Regex::new(r"(?i)mozilla").expect("compile");
2279 assert!(
2280 pred(FieldPath::HttpHeader(Arc::from("user-agent")), CompiledOperator::Matches(re)).test(&v)
2281 );
2282 let re_miss = Regex::new(r"^curl").expect("compile");
2283 assert!(
2284 !pred(FieldPath::HttpHeader(Arc::from("user-agent")), CompiledOperator::Matches(re_miss))
2285 .test(&v)
2286 );
2287 }
2288
2289 #[test]
2291 fn matrix_numeric_cmp_gt_gte_lt_lte_happy_and_miss() {
2292 let conn = make_conn_with("127.0.0.1:1024", "127.0.0.1:443");
2293 let v = PredicateView::L4 { conn: &conn, peek: None };
2294 assert!(pred(FieldPath::RemotePort, CompiledOperator::Gt(1023)).test(&v));
2296 assert!(!pred(FieldPath::RemotePort, CompiledOperator::Gt(1024)).test(&v));
2297 assert!(pred(FieldPath::RemotePort, CompiledOperator::Gte(1024)).test(&v));
2299 assert!(!pred(FieldPath::RemotePort, CompiledOperator::Gte(1025)).test(&v));
2300 assert!(pred(FieldPath::RemotePort, CompiledOperator::Lt(1025)).test(&v));
2302 assert!(!pred(FieldPath::RemotePort, CompiledOperator::Lt(1024)).test(&v));
2303 assert!(pred(FieldPath::RemotePort, CompiledOperator::Lte(1024)).test(&v));
2305 assert!(!pred(FieldPath::RemotePort, CompiledOperator::Lte(1023)).test(&v));
2306 }
2307
2308 #[test]
2309 fn matrix_numeric_cmp_local_port_too() {
2310 let conn = make_conn_with("127.0.0.1:0", "127.0.0.1:8443");
2312 let v = PredicateView::L4 { conn: &conn, peek: None };
2313 assert!(pred(FieldPath::LocalPort, CompiledOperator::Gt(8000)).test(&v));
2314 assert!(!pred(FieldPath::LocalPort, CompiledOperator::Gt(9000)).test(&v));
2315 }
2316
2317 #[test]
2319 fn matrix_cidr_v4_happy_and_miss() {
2320 let conn = make_conn_with("10.0.5.7:0", "127.0.0.1:0");
2321 let v = PredicateView::L4 { conn: &conn, peek: None };
2322 let ten = IpNet::from_str("10.0.0.0/8").unwrap();
2323 let nineteen2 = IpNet::from_str("192.168.0.0/16").unwrap();
2324 assert!(pred(FieldPath::RemoteIp, CompiledOperator::Cidr(ten)).test(&v));
2325 assert!(!pred(FieldPath::RemoteIp, CompiledOperator::Cidr(nineteen2)).test(&v));
2326 }
2327
2328 #[test]
2329 fn matrix_cidr_v6_happy_and_miss() {
2330 let conn = make_conn_with("[2001:db8::5]:0", "127.0.0.1:0");
2331 let v = PredicateView::L4 { conn: &conn, peek: None };
2332 let net = IpNet::from_str("2001:db8::/32").unwrap();
2333 let other = IpNet::from_str("2001:dead::/32").unwrap();
2334 assert!(pred(FieldPath::RemoteIp, CompiledOperator::Cidr(net)).test(&v));
2335 assert!(!pred(FieldPath::RemoteIp, CompiledOperator::Cidr(other)).test(&v));
2336 }
2337
2338 #[test]
2339 fn matrix_cidr_v4_against_v6_addr_misses() {
2340 let conn = make_conn_with("[2001:db8::5]:0", "127.0.0.1:0");
2342 let v = PredicateView::L4 { conn: &conn, peek: None };
2343 let v4 = IpNet::from_str("0.0.0.0/0").unwrap();
2344 assert!(!pred(FieldPath::RemoteIp, CompiledOperator::Cidr(v4)).test(&v));
2345 }
2346
2347 #[test]
2351 fn http_uri_query_reader_returns_empty_when_query_absent() {
2352 let conn = make_conn();
2355 let req = http::Request::builder().method("GET").uri("/no-q").body(Body::Empty).unwrap();
2356 let v = PredicateView::L7Req { conn: &conn, req: &req };
2357 assert!(pred(FieldPath::HttpUriQuery, CompiledOperator::Equals(str_val(""))).test(&v));
2358 assert!(!pred(FieldPath::HttpUriQuery, CompiledOperator::Equals(str_val("q=1"))).test(&v));
2359 }
2360
2361 #[test]
2362 fn http_uri_query_reader_matches_present_query() {
2363 let conn = make_conn();
2364 let req = http::Request::builder().method("GET").uri("/x?a=1&b=2").body(Body::Empty).unwrap();
2365 let v = PredicateView::L7Req { conn: &conn, req: &req };
2366 assert!(pred(FieldPath::HttpUriQuery, CompiledOperator::Equals(str_val("a=1&b=2"))).test(&v));
2367 assert!(pred(FieldPath::HttpUriQuery, CompiledOperator::Contains(b(b"b=2"))).test(&v));
2368 }
2369
2370 #[test]
2371 fn local_ip_reader_uses_local_socket() {
2372 let conn = make_conn_with("10.0.0.5:0", "127.0.0.1:8443");
2373 let v = PredicateView::L4 { conn: &conn, peek: None };
2374 let local: std::net::IpAddr = "127.0.0.1".parse().unwrap();
2375 assert!(
2376 pred(FieldPath::LocalIp, CompiledOperator::Equals(CompiledValue::Addr(local))).test(&v)
2377 );
2378 }
2379
2380 #[test]
2381 fn http_header_lookup_misses_for_non_utf8_value() {
2382 let conn = make_conn();
2385 let bad =
2386 http::HeaderValue::from_bytes(&[0xff, 0xfe, 0xfd]).expect("non-utf8 header value parses");
2387 let mut builder = http::Request::builder().method("GET").uri("/");
2388 builder.headers_mut().expect("headers").insert("x-bad", bad);
2389 let req: Request = builder.body(Body::Empty).expect("build request");
2390 let v = PredicateView::L7Req { conn: &conn, req: &req };
2391 assert!(
2392 !pred(
2393 FieldPath::HttpHeader(Arc::from("x-bad")),
2394 CompiledOperator::Equals(str_val("anything")),
2395 )
2396 .test(&v)
2397 );
2398 }
2399
2400 fn rcgen_cert_with_cn(cn: &str) -> rustls_pki_types::CertificateDer<'static> {
2402 let mut params = rcgen::CertificateParams::default();
2403 params.distinguished_name = rcgen::DistinguishedName::new();
2404 params.distinguished_name.push(rcgen::DnType::CommonName, cn);
2405 let key = rcgen::KeyPair::generate().expect("rcgen keypair");
2406 let cert = params.self_signed(&key).expect("self-sign cert");
2407 cert.der().clone()
2408 }
2409
2410 fn rcgen_cert_no_cn() -> rustls_pki_types::CertificateDer<'static> {
2411 let params = rcgen::CertificateParams::default();
2414 let mut params = params;
2417 params.distinguished_name = rcgen::DistinguishedName::new();
2418 let key = rcgen::KeyPair::generate().expect("rcgen keypair");
2419 let cert = params.self_signed(&key).expect("self-sign cert");
2420 cert.der().clone()
2421 }
2422
2423 fn conn_with_peer_cert(cert: &rustls_pki_types::CertificateDer<'static>) -> Arc<ConnContext> {
2424 let pc = crate::conn_context::PeerCertificate::from_der(cert)
2425 .expect("rcgen-issued cert must parse via PeerCertificate::from_der");
2426 let conn = make_conn();
2427 *conn.tls.lock() = Some(crate::conn_context::TlsInfo {
2428 sni: None,
2429 alpn: None,
2430 version: None,
2431 peer_cert: Some(Arc::new(pc)),
2432 zero_rtt_used: false,
2433 });
2434 conn
2435 }
2436
2437 #[test]
2438 fn peer_cert_from_der_extracts_cn() {
2439 let cert = rcgen_cert_with_cn("client.internal");
2440 let pc = crate::conn_context::PeerCertificate::from_der(&cert).expect("parse");
2441 assert_eq!(pc.subject_cn.as_deref(), Some("client.internal"));
2442 }
2443
2444 #[test]
2445 fn peer_cert_from_der_returns_none_for_malformed_der() {
2446 let raw = rustls_pki_types::CertificateDer::from(vec![0x30, 0x80, 0x00, 0x00]);
2447 assert!(crate::conn_context::PeerCertificate::from_der(&raw).is_none());
2448 let raw = rustls_pki_types::CertificateDer::from(b"not a cert at all".to_vec());
2449 assert!(crate::conn_context::PeerCertificate::from_der(&raw).is_none());
2450 }
2451
2452 #[test]
2453 fn peer_cert_from_der_returns_some_with_no_cn_when_dn_has_no_cn() {
2454 let cert = rcgen_cert_no_cn();
2456 let pc = crate::conn_context::PeerCertificate::from_der(&cert).expect("parse");
2457 assert!(pc.subject_cn.is_none());
2458 }
2459
2460 #[test]
2461 fn matrix_peer_cert_subject_cn_equals_happy_and_miss() {
2462 let cert = rcgen_cert_with_cn("ops-bot");
2463 let conn = conn_with_peer_cert(&cert);
2464 let v = PredicateView::L4 { conn: &conn, peek: None };
2465 assert!(
2466 pred(FieldPath::TlsPeerCertSubjectCn, CompiledOperator::Equals(str_val("ops-bot"))).test(&v)
2467 );
2468 assert!(
2469 !pred(FieldPath::TlsPeerCertSubjectCn, CompiledOperator::Equals(str_val("attacker")))
2470 .test(&v)
2471 );
2472 }
2473
2474 #[test]
2475 fn matrix_peer_cert_subject_cn_string_ops_happy_and_miss() {
2476 let cert = rcgen_cert_with_cn("svc-payments-prod");
2477 let conn = conn_with_peer_cert(&cert);
2478 let v = PredicateView::L4 { conn: &conn, peek: None };
2479 assert!(pred(FieldPath::TlsPeerCertSubjectCn, CompiledOperator::Prefix(b(b"svc-"))).test(&v));
2481 assert!(
2482 !pred(FieldPath::TlsPeerCertSubjectCn, CompiledOperator::Prefix(b(b"client-"))).test(&v)
2483 );
2484 assert!(pred(FieldPath::TlsPeerCertSubjectCn, CompiledOperator::Suffix(b(b"-prod"))).test(&v));
2486 assert!(
2488 pred(FieldPath::TlsPeerCertSubjectCn, CompiledOperator::Contains(b(b"payments"))).test(&v)
2489 );
2490 let re = Regex::new(r"^svc-[a-z]+-(prod|stg)$").expect("regex");
2492 assert!(pred(FieldPath::TlsPeerCertSubjectCn, CompiledOperator::Matches(re)).test(&v));
2493 let list = vec![str_val("svc-other-prod"), str_val("svc-payments-prod")];
2495 assert!(pred(FieldPath::TlsPeerCertSubjectCn, CompiledOperator::In(list)).test(&v));
2496 }
2497
2498 #[test]
2499 fn peer_cert_subject_cn_misses_when_cert_absent() {
2500 let conn = make_conn();
2503 let v = PredicateView::L4 { conn: &conn, peek: None };
2504 assert!(
2505 !pred(FieldPath::TlsPeerCertSubjectCn, CompiledOperator::Equals(str_val("anything")))
2506 .test(&v)
2507 );
2508 }
2509
2510 #[test]
2511 fn peer_cert_subject_cn_misses_when_cert_has_no_cn() {
2512 let cert = rcgen_cert_no_cn();
2515 let conn = conn_with_peer_cert(&cert);
2516 let v = PredicateView::L4 { conn: &conn, peek: None };
2517 assert!(
2518 !pred(FieldPath::TlsPeerCertSubjectCn, CompiledOperator::Equals(str_val("ops-bot"))).test(&v)
2519 );
2520 }
2521
2522 fn rcgen_cert_with_san_dns(cn: &str, dns: &[&str]) -> rustls_pki_types::CertificateDer<'static> {
2524 let san: Vec<String> = dns.iter().map(|s| (*s).to_owned()).collect();
2525 let mut params = rcgen::CertificateParams::new(san).expect("rcgen params");
2526 params.distinguished_name = rcgen::DistinguishedName::new();
2527 params.distinguished_name.push(rcgen::DnType::CommonName, cn);
2528 let key = rcgen::KeyPair::generate().expect("rcgen keypair");
2529 let cert = params.self_signed(&key).expect("self-sign cert");
2530 cert.der().clone()
2531 }
2532
2533 #[test]
2534 fn each_new_field_path_parses_from_string_form() {
2535 use super::parse_field_path;
2536 assert_eq!(parse_field_path("tls.peer_cert.present"), Ok(FieldPath::TlsPeerCertPresent));
2537 assert_eq!(parse_field_path("tls.peer_cert.san_dns"), Ok(FieldPath::TlsPeerCertSanDns));
2538 assert_eq!(
2539 parse_field_path("tls.peer_cert.fingerprint_sha256"),
2540 Ok(FieldPath::TlsPeerCertFingerprintSha256),
2541 );
2542 assert_eq!(parse_field_path("tls.peer_cert.spki_sha256"), Ok(FieldPath::TlsPeerCertSpkiSha256),);
2543 assert_eq!(parse_field_path("tls.peer_cert.issuer_cn"), Ok(FieldPath::TlsPeerCertIssuerCn));
2544 assert_eq!(parse_field_path("tls.peer_cert.serial"), Ok(FieldPath::TlsPeerCertSerial));
2545 }
2546
2547 #[test]
2548 fn peer_cert_present_true_when_cert_attached() {
2549 let cert = rcgen_cert_with_cn("client.internal");
2550 let conn = conn_with_peer_cert(&cert);
2551 let v = PredicateView::L4 { conn: &conn, peek: None };
2552 assert!(
2553 pred(FieldPath::TlsPeerCertPresent, CompiledOperator::Equals(CompiledValue::Bool(true)))
2554 .test(&v)
2555 );
2556 assert!(
2557 !pred(FieldPath::TlsPeerCertPresent, CompiledOperator::Equals(CompiledValue::Bool(false)))
2558 .test(&v)
2559 );
2560 }
2561
2562 #[test]
2563 fn peer_cert_present_false_when_cert_absent() {
2564 let conn = make_conn();
2567 let v = PredicateView::L4 { conn: &conn, peek: None };
2568 assert!(
2569 pred(FieldPath::TlsPeerCertPresent, CompiledOperator::Equals(CompiledValue::Bool(false)))
2570 .test(&v)
2571 );
2572 assert!(
2573 !pred(FieldPath::TlsPeerCertPresent, CompiledOperator::Equals(CompiledValue::Bool(true)))
2574 .test(&v)
2575 );
2576 }
2577
2578 #[test]
2579 fn peer_cert_san_dns_contains_matches_listed_element() {
2580 let cert = rcgen_cert_with_san_dns("svc-a", &["svc-a.internal", "svc-b.internal"]);
2581 let conn = conn_with_peer_cert(&cert);
2582 let v = PredicateView::L4 { conn: &conn, peek: None };
2583 assert!(
2584 pred(FieldPath::TlsPeerCertSanDns, CompiledOperator::Contains(b(b"svc-a.internal"))).test(&v)
2585 );
2586 assert!(
2587 !pred(FieldPath::TlsPeerCertSanDns, CompiledOperator::Contains(b(b"svc-c.internal")))
2588 .test(&v),
2589 );
2590 assert!(
2591 pred(FieldPath::TlsPeerCertSanDns, CompiledOperator::NotContains(b(b"svc-c.internal")))
2592 .test(&v),
2593 );
2594 }
2595
2596 #[test]
2597 fn peer_cert_san_dns_misses_when_cert_absent() {
2598 let conn = make_conn();
2599 let v = PredicateView::L4 { conn: &conn, peek: None };
2600 assert!(
2601 !pred(FieldPath::TlsPeerCertSanDns, CompiledOperator::Contains(b(b"anything"))).test(&v)
2602 );
2603 }
2604
2605 #[test]
2606 fn peer_cert_fingerprint_sha256_is_lowercase_hex_of_full_der() {
2607 use sha2::{Digest, Sha256};
2608 let cert = rcgen_cert_with_cn("fingerprinted");
2609 let mut h = Sha256::new();
2610 h.update(cert.as_ref());
2611 let want = h.finalize().iter().fold(String::new(), |mut s, b| {
2612 use std::fmt::Write as _;
2613 let _ = write!(s, "{b:02x}");
2614 s
2615 });
2616
2617 let conn = conn_with_peer_cert(&cert);
2618 let v = PredicateView::L4 { conn: &conn, peek: None };
2619 assert!(
2620 pred(FieldPath::TlsPeerCertFingerprintSha256, CompiledOperator::Equals(str_val(&want)),)
2621 .test(&v),
2622 );
2623 }
2624
2625 #[test]
2626 fn peer_cert_issuer_and_serial_present_for_self_signed_cert() {
2627 let cert = rcgen_cert_with_cn("issuer-test");
2630 let conn = conn_with_peer_cert(&cert);
2631 let v = PredicateView::L4 { conn: &conn, peek: None };
2632 assert!(
2634 pred(FieldPath::TlsPeerCertIssuerCn, CompiledOperator::Equals(str_val("issuer-test")))
2635 .test(&v)
2636 );
2637 let pc = conn.tls.lock().as_ref().unwrap().peer_cert.as_ref().unwrap().clone();
2641 assert!(!pc.serial.is_empty(), "serial extracted");
2642 assert!(pc.serial.chars().all(|c| c.is_ascii_hexdigit() && !c.is_ascii_uppercase()));
2643 }
2644
2645 #[test]
2646 fn peer_cert_present_value_type_is_bool() {
2647 assert_eq!(FieldPath::TlsPeerCertPresent.value_type(), FieldValueType::Bool);
2648 }
2649
2650 #[test]
2651 fn peer_cert_san_dns_value_type_is_vec_str() {
2652 assert_eq!(FieldPath::TlsPeerCertSanDns.value_type(), FieldValueType::VecStr);
2653 }
2654
2655 #[test]
2656 fn matrix_rejects_string_pref_suf_on_bool_field() {
2657 assert!(!OperatorFamily::StringPrefSuf.accepts(FieldValueType::Bool));
2660 assert!(!OperatorFamily::StringSubstr.accepts(FieldValueType::Bool));
2661 assert!(!OperatorFamily::RegexMatches.accepts(FieldValueType::Bool));
2662 assert!(OperatorFamily::Equality.accepts(FieldValueType::Bool));
2664 }
2665
2666 #[test]
2667 fn matrix_rejects_equals_on_vec_str_field() {
2668 assert!(!OperatorFamily::Equality.accepts(FieldValueType::VecStr));
2671 assert!(!OperatorFamily::InList.accepts(FieldValueType::VecStr));
2672 assert!(!OperatorFamily::StringPrefSuf.accepts(FieldValueType::VecStr));
2673 assert!(!OperatorFamily::RegexMatches.accepts(FieldValueType::VecStr));
2674 assert!(OperatorFamily::StringSubstr.accepts(FieldValueType::VecStr));
2675 }
2676
2677 fn req_with_body(body_bytes: &[u8]) -> Request {
2684 http::Request::builder()
2685 .method("POST")
2686 .uri("/upload")
2687 .body(Body::Static(Bytes::copy_from_slice(body_bytes)))
2688 .expect("build req with body")
2689 }
2690
2691 #[test]
2692 fn matrix_http_body_equality_happy_and_miss() {
2693 let conn = make_conn();
2694 let req = req_with_body(b"hello world");
2695 let v = PredicateView::L7Req { conn: &conn, req: &req };
2696 assert!(
2697 pred(FieldPath::HttpBody, CompiledOperator::Equals(bytes_val(b"hello world"))).test(&v)
2698 );
2699 assert!(!pred(FieldPath::HttpBody, CompiledOperator::Equals(bytes_val(b"wrong"))).test(&v));
2700 assert!(pred(FieldPath::HttpBody, CompiledOperator::NotEquals(bytes_val(b"wrong"))).test(&v));
2701 }
2702
2703 #[test]
2704 fn matrix_http_body_substring_happy_and_miss() {
2705 let conn = make_conn();
2706 let req = req_with_body(b"prelude payload trailer");
2707 let v = PredicateView::L7Req { conn: &conn, req: &req };
2708 assert!(pred(FieldPath::HttpBody, CompiledOperator::Contains(b(b"payload"))).test(&v));
2709 assert!(!pred(FieldPath::HttpBody, CompiledOperator::Contains(b(b"missing"))).test(&v));
2710 assert!(pred(FieldPath::HttpBody, CompiledOperator::NotContains(b(b"missing"))).test(&v));
2711 }
2712
2713 #[test]
2714 fn matrix_http_body_prefix_suffix_happy_and_miss() {
2715 let conn = make_conn();
2716 let req = req_with_body(b"START middle END");
2717 let v = PredicateView::L7Req { conn: &conn, req: &req };
2718 assert!(pred(FieldPath::HttpBody, CompiledOperator::Prefix(b(b"START"))).test(&v));
2719 assert!(!pred(FieldPath::HttpBody, CompiledOperator::Prefix(b(b"BEGIN"))).test(&v));
2720 assert!(pred(FieldPath::HttpBody, CompiledOperator::Suffix(b(b"END"))).test(&v));
2721 assert!(!pred(FieldPath::HttpBody, CompiledOperator::Suffix(b(b"FIN"))).test(&v));
2722 }
2723
2724 #[test]
2725 fn matrix_http_body_in_list_happy_and_miss() {
2726 let conn = make_conn();
2727 let req = req_with_body(b"one");
2728 let v = PredicateView::L7Req { conn: &conn, req: &req };
2729 let list = vec![bytes_val(b"two"), bytes_val(b"one")];
2730 assert!(pred(FieldPath::HttpBody, CompiledOperator::In(list)).test(&v));
2731 let miss = vec![bytes_val(b"two"), bytes_val(b"three")];
2732 assert!(!pred(FieldPath::HttpBody, CompiledOperator::In(miss.clone())).test(&v));
2733 assert!(pred(FieldPath::HttpBody, CompiledOperator::NotIn(miss)).test(&v));
2734 }
2735
2736 #[test]
2737 fn http_body_misses_on_l4_view() {
2738 let conn = make_conn();
2741 let v = PredicateView::L4 { conn: &conn, peek: None };
2742 assert!(!pred(FieldPath::HttpBody, CompiledOperator::Contains(b(b"x"))).test(&v));
2743 }
2744
2745 #[test]
2746 #[should_panic(expected = "lazy-buffer invariant")]
2747 fn http_body_panics_when_lazy_buffer_invariant_violated() {
2748 let conn = make_conn();
2756 let req = http::Request::builder().method("POST").uri("/").body(Body::Empty).unwrap();
2757 let v = PredicateView::L7Req { conn: &conn, req: &req };
2758 let _ = pred(FieldPath::HttpBody, CompiledOperator::Contains(b(b"x"))).test(&v);
2759 }
2760
2761 #[test]
2769 fn matrix_peek_substring_happy_and_miss() {
2770 let buf: &[u8] = &[0x16, 0x03, 0x01, 0x00, 0x40, 0x01];
2772 let conn = make_conn();
2773 let v = PredicateView::L4 { conn: &conn, peek: Some(buf) };
2774 assert!(pred(FieldPath::Peek, CompiledOperator::Prefix(b(b"\x16\x03"))).test(&v));
2775 assert!(!pred(FieldPath::Peek, CompiledOperator::Prefix(b(b"\x14\x03"))).test(&v));
2776 assert!(pred(FieldPath::Peek, CompiledOperator::Contains(b(b"\x03\x01"))).test(&v));
2777 assert!(!pred(FieldPath::Peek, CompiledOperator::Contains(b(b"\xff\xff"))).test(&v));
2778 }
2779
2780 #[test]
2781 fn matrix_peek_equality_happy_and_miss() {
2782 let buf: &[u8] = b"GET";
2783 let conn = make_conn();
2784 let v = PredicateView::L4 { conn: &conn, peek: Some(buf) };
2785 assert!(pred(FieldPath::Peek, CompiledOperator::Equals(bytes_val(b"GET"))).test(&v));
2786 assert!(!pred(FieldPath::Peek, CompiledOperator::Equals(bytes_val(b"PUT"))).test(&v));
2787 assert!(pred(FieldPath::Peek, CompiledOperator::NotEquals(bytes_val(b"PUT"))).test(&v));
2788 }
2789
2790 #[test]
2791 fn matrix_peek_in_list_happy_and_miss() {
2792 let buf: &[u8] = b"PRI ";
2793 let conn = make_conn();
2794 let v = PredicateView::L4 { conn: &conn, peek: Some(buf) };
2795 let list = vec![bytes_val(b"GET "), bytes_val(b"PRI ")];
2797 assert!(pred(FieldPath::Peek, CompiledOperator::In(list)).test(&v));
2798 let miss = vec![bytes_val(b"POST"), bytes_val(b"HEAD")];
2799 assert!(!pred(FieldPath::Peek, CompiledOperator::In(miss.clone())).test(&v));
2800 assert!(pred(FieldPath::Peek, CompiledOperator::NotIn(miss)).test(&v));
2801 }
2802
2803 #[test]
2804 fn peek_misses_when_buffer_absent_on_l4_view() {
2805 let conn = make_conn();
2808 let v = PredicateView::L4 { conn: &conn, peek: None };
2809 assert!(!pred(FieldPath::Peek, CompiledOperator::Prefix(b(b"\x16"))).test(&v));
2810 let req = http::Request::builder().method("GET").uri("/").body(Body::Empty).unwrap();
2812 let v7 = PredicateView::L7Req { conn: &conn, req: &req };
2813 assert!(!pred(FieldPath::Peek, CompiledOperator::Prefix(b(b"\x16"))).test(&v7));
2814 }
2815}