1use std::hash::{Hash, Hasher};
2use std::net::IpAddr;
3use std::sync::Arc;
4
5use bytes::Bytes;
6use ipnet::IpNet;
7
8use crate::body::Request;
9use crate::conn_context::ConnContext;
10
11#[derive(Clone, Eq, PartialEq, Hash, Debug, serde::Serialize, serde::Deserialize)]
12#[serde(rename_all = "snake_case")]
13pub enum FieldPath {
14 Transport,
15 RemoteIp,
16 RemotePort,
17 LocalIp,
18 LocalPort,
19 Peek,
20 TlsSni,
21 TlsAlpn,
22 TlsVersion,
23 TlsPeerCertPresent,
26 TlsPeerCertSubjectCn,
27 TlsPeerCertSanDns,
32 TlsPeerCertFingerprintSha256,
34 TlsPeerCertSpkiSha256,
37 TlsPeerCertIssuerCn,
40 TlsPeerCertSerial,
43 HttpMethod,
44 HttpUriPath,
45 HttpUriQuery,
46 HttpHeader(Arc<str>),
47 HttpBody,
48}
49
50#[derive(Copy, Clone, Eq, PartialEq, Debug)]
55pub enum FieldValueType {
56 Str,
57 Bytes,
58 Int,
59 IpAddr,
60 Enum,
61 Bool,
62 VecStr,
63}
64
65impl FieldValueType {
66 #[must_use]
67 pub fn name(self) -> &'static str {
68 match self {
69 Self::Str => "Str",
70 Self::Bytes => "Bytes",
71 Self::Int => "Int",
72 Self::IpAddr => "IpAddr",
73 Self::Enum => "enum",
74 Self::Bool => "Bool",
75 Self::VecStr => "Vec<Str>",
76 }
77 }
78}
79
80impl FieldPath {
81 #[must_use]
85 pub fn value_type(&self) -> FieldValueType {
86 match self {
87 Self::Transport | Self::TlsVersion | Self::HttpMethod => FieldValueType::Enum,
88 Self::RemoteIp | Self::LocalIp => FieldValueType::IpAddr,
89 Self::RemotePort | Self::LocalPort => FieldValueType::Int,
90 Self::Peek | Self::TlsAlpn | Self::HttpBody => FieldValueType::Bytes,
91 Self::TlsPeerCertPresent => FieldValueType::Bool,
92 Self::TlsPeerCertSanDns => FieldValueType::VecStr,
93 Self::TlsSni
94 | Self::TlsPeerCertSubjectCn
95 | Self::TlsPeerCertFingerprintSha256
96 | Self::TlsPeerCertSpkiSha256
97 | Self::TlsPeerCertIssuerCn
98 | Self::TlsPeerCertSerial
99 | Self::HttpUriPath
100 | Self::HttpUriQuery
101 | Self::HttpHeader(_) => FieldValueType::Str,
102 }
103 }
104
105 #[must_use]
107 pub fn display_name(&self) -> String {
108 match self {
109 Self::Transport => "transport".to_string(),
110 Self::RemoteIp => "remote.ip".to_string(),
111 Self::RemotePort => "remote.port".to_string(),
112 Self::LocalIp => "local.ip".to_string(),
113 Self::LocalPort => "local.port".to_string(),
114 Self::Peek => "peek".to_string(),
115 Self::TlsSni => "tls.sni".to_string(),
116 Self::TlsAlpn => "tls.alpn".to_string(),
117 Self::TlsVersion => "tls.version".to_string(),
118 Self::TlsPeerCertPresent => "tls.peer_cert.present".to_string(),
119 Self::TlsPeerCertSubjectCn => "tls.peer_cert.subject_cn".to_string(),
120 Self::TlsPeerCertSanDns => "tls.peer_cert.san_dns".to_string(),
121 Self::TlsPeerCertFingerprintSha256 => "tls.peer_cert.fingerprint_sha256".to_string(),
122 Self::TlsPeerCertSpkiSha256 => "tls.peer_cert.spki_sha256".to_string(),
123 Self::TlsPeerCertIssuerCn => "tls.peer_cert.issuer_cn".to_string(),
124 Self::TlsPeerCertSerial => "tls.peer_cert.serial".to_string(),
125 Self::HttpMethod => "http.method".to_string(),
126 Self::HttpUriPath => "http.uri.path".to_string(),
127 Self::HttpUriQuery => "http.uri.query".to_string(),
128 Self::HttpHeader(name) => format!("http.header.{name}"),
129 Self::HttpBody => "http.body".to_string(),
130 }
131 }
132}
133
134#[derive(Copy, Clone, Eq, PartialEq, Debug)]
139pub enum OperatorFamily {
140 Equality,
141 StringSubstr,
142 StringPrefSuf,
143 RegexMatches,
144 InList,
145 NumericCmp,
146 CidrMatch,
147}
148
149impl Operator {
150 #[must_use]
151 pub fn family(&self) -> OperatorFamily {
152 match self {
153 Self::Equals(_) | Self::NotEquals(_) => OperatorFamily::Equality,
154 Self::Contains(_) | Self::NotContains(_) => OperatorFamily::StringSubstr,
155 Self::Prefix(_) | Self::Suffix(_) => OperatorFamily::StringPrefSuf,
156 Self::Matches(_) => OperatorFamily::RegexMatches,
157 Self::In(_) | Self::NotIn(_) => OperatorFamily::InList,
158 Self::Gt(_) | Self::Gte(_) | Self::Lt(_) | Self::Lte(_) => OperatorFamily::NumericCmp,
159 Self::Cidr(_) => OperatorFamily::CidrMatch,
160 }
161 }
162
163 #[must_use]
164 pub fn name(&self) -> &'static str {
165 match self {
166 Self::Equals(_) => "equals",
167 Self::NotEquals(_) => "not_equals",
168 Self::Contains(_) => "contains",
169 Self::NotContains(_) => "not_contains",
170 Self::Prefix(_) => "prefix",
171 Self::Suffix(_) => "suffix",
172 Self::Matches(_) => "matches",
173 Self::In(_) => "in",
174 Self::NotIn(_) => "not_in",
175 Self::Gt(_) => "gt",
176 Self::Gte(_) => "gte",
177 Self::Lt(_) => "lt",
178 Self::Lte(_) => "lte",
179 Self::Cidr(_) => "cidr",
180 }
181 }
182}
183
184impl OperatorFamily {
185 #[must_use]
190 pub fn accepts(self, vt: FieldValueType) -> bool {
191 use FieldValueType as V;
192 use OperatorFamily as F;
193 matches!(
194 (self, vt),
195 (F::Equality, V::Str | V::Bytes | V::Int | V::IpAddr | V::Enum | V::Bool)
200 | (F::InList, V::Str | V::Bytes | V::Int | V::IpAddr | V::Enum)
203 | (F::StringSubstr, V::Str | V::Bytes | V::VecStr)
204 | (F::StringPrefSuf, V::Str | V::Bytes)
205 | (F::RegexMatches, V::Str)
206 | (F::NumericCmp, V::Int)
207 | (F::CidrMatch, V::IpAddr),
208 )
209 }
210
211 #[must_use]
213 pub fn family_expectation(self) -> &'static str {
214 match self {
215 Self::Equality => "any of Str/Bytes/Int/IpAddr/enum/Bool",
216 Self::InList => "any of Str/Bytes/Int/IpAddr/enum",
217 Self::StringSubstr => "Str, Bytes, or Vec<Str>",
218 Self::StringPrefSuf => "Str or Bytes",
219 Self::RegexMatches => "Str",
220 Self::NumericCmp => "numeric",
221 Self::CidrMatch => "IpAddr",
222 }
223 }
224}
225
226#[derive(Clone, Debug)]
227pub enum CompiledValue {
228 Str(Arc<str>),
229 Bytes(Bytes),
230 Int(i64),
231 Bool(bool),
232 Addr(IpAddr),
233}
234
235impl PartialEq for CompiledValue {
236 fn eq(&self, other: &Self) -> bool {
237 match (self, other) {
238 (Self::Str(a), Self::Str(b)) => a.as_ref() == b.as_ref(),
239 (Self::Bytes(a), Self::Bytes(b)) => a == b,
240 (Self::Int(a), Self::Int(b)) => a == b,
241 (Self::Bool(a), Self::Bool(b)) => a == b,
242 (Self::Addr(a), Self::Addr(b)) => a == b,
243 _ => false,
244 }
245 }
246}
247
248impl Eq for CompiledValue {}
249
250impl Hash for CompiledValue {
251 fn hash<H: Hasher>(&self, state: &mut H) {
252 std::mem::discriminant(self).hash(state);
253 match self {
254 Self::Str(s) => s.as_ref().hash(state),
255 Self::Bytes(b) => b.hash(state),
256 Self::Int(i) => i.hash(state),
257 Self::Bool(b) => b.hash(state),
258 Self::Addr(a) => a.hash(state),
259 }
260 }
261}
262
263#[derive(Clone, Debug)]
264pub enum CompiledOperator {
265 Equals(CompiledValue),
266 NotEquals(CompiledValue),
267 Contains(Bytes),
268 NotContains(Bytes),
269 Prefix(Bytes),
270 Suffix(Bytes),
271 Matches(fancy_regex::Regex),
272 In(Vec<CompiledValue>),
273 NotIn(Vec<CompiledValue>),
274 Gt(i64),
275 Gte(i64),
276 Lt(i64),
277 Lte(i64),
278 Cidr(IpNet),
279}
280
281impl PartialEq for CompiledOperator {
282 fn eq(&self, other: &Self) -> bool {
283 match (self, other) {
284 (Self::Equals(a), Self::Equals(b)) | (Self::NotEquals(a), Self::NotEquals(b)) => a == b,
285 (Self::Contains(a), Self::Contains(b))
286 | (Self::NotContains(a), Self::NotContains(b))
287 | (Self::Prefix(a), Self::Prefix(b))
288 | (Self::Suffix(a), Self::Suffix(b)) => a == b,
289 (Self::Matches(a), Self::Matches(b)) => a.as_str() == b.as_str(),
290 (Self::In(a), Self::In(b)) | (Self::NotIn(a), Self::NotIn(b)) => a == b,
291 (Self::Gt(a), Self::Gt(b))
292 | (Self::Gte(a), Self::Gte(b))
293 | (Self::Lt(a), Self::Lt(b))
294 | (Self::Lte(a), Self::Lte(b)) => a == b,
295 (Self::Cidr(a), Self::Cidr(b)) => a == b,
296 _ => false,
297 }
298 }
299}
300
301impl Eq for CompiledOperator {}
302
303impl Hash for CompiledOperator {
304 fn hash<H: Hasher>(&self, state: &mut H) {
305 std::mem::discriminant(self).hash(state);
306 match self {
307 Self::Equals(v) | Self::NotEquals(v) => v.hash(state),
308 Self::Contains(b) | Self::NotContains(b) | Self::Prefix(b) | Self::Suffix(b) => {
309 b.hash(state);
310 }
311 Self::Matches(r) => r.as_str().hash(state),
312 Self::In(v) | Self::NotIn(v) => v.hash(state),
313 Self::Gt(i) | Self::Gte(i) | Self::Lt(i) | Self::Lte(i) => i.hash(state),
314 Self::Cidr(n) => n.hash(state),
315 }
316 }
317}
318
319#[derive(Clone, Debug, Eq, PartialEq, Hash, serde::Serialize, serde::Deserialize)]
320pub struct PredicateInst {
321 pub path: FieldPath,
322 pub op: CompiledOperator,
323}
324
325pub enum PredicateView<'a> {
326 L4 { conn: &'a Arc<ConnContext>, peek: Option<&'a [u8]> },
327 L7Req { conn: &'a Arc<ConnContext>, req: &'a Request },
328}
329
330impl<'a> PredicateView<'a> {
331 #[must_use]
341 pub fn build(
342 conn: &'a Arc<ConnContext>,
343 req: Option<&'a Request>,
344 _l4: Option<&'a crate::l4::L4Conn>,
345 peek: Option<&'a [u8]>,
346 ) -> Self {
347 match req {
348 Some(r) => Self::L7Req { conn, req: r },
349 None => Self::L4 { conn, peek },
350 }
351 }
352
353 fn conn(&self) -> &Arc<ConnContext> {
354 match self {
355 Self::L4 { conn, .. } | Self::L7Req { conn, .. } => conn,
356 }
357 }
358
359 fn request(&self) -> Option<&Request> {
360 match self {
361 Self::L7Req { req, .. } => Some(req),
362 Self::L4 { .. } => None,
363 }
364 }
365
366 fn peek_buffer(&self) -> Option<&[u8]> {
367 match self {
368 Self::L4 { peek, .. } => *peek,
369 Self::L7Req { .. } => None,
370 }
371 }
372}
373
374impl PredicateInst {
375 #[must_use]
398 #[allow(clippy::too_many_lines)]
399 pub fn test(&self, view: &PredicateView<'_>) -> bool {
400 match &self.path {
401 FieldPath::Transport => {
402 let s = match view.conn().transport {
403 crate::conn_context::Transport::Tcp => "tcp",
404 crate::conn_context::Transport::Udp => "udp",
405 };
406 test_str(&self.op, s)
407 }
408 FieldPath::RemoteIp => test_addr(&self.op, view.conn().remote.ip()),
409 FieldPath::RemotePort => test_int(&self.op, i64::from(view.conn().remote.port())),
410 FieldPath::LocalIp => test_addr(&self.op, view.conn().local.ip()),
411 FieldPath::LocalPort => test_int(&self.op, i64::from(view.conn().local.port())),
412 FieldPath::Peek => view.peek_buffer().is_some_and(|b| test_bytes(&self.op, b)),
413 FieldPath::TlsSni => view
414 .conn()
415 .tls
416 .lock()
417 .as_ref()
418 .and_then(|t| t.sni.clone())
419 .is_some_and(|got| test_str(&self.op, got.as_str())),
420 FieldPath::TlsAlpn => view
421 .conn()
422 .tls
423 .lock()
424 .as_ref()
425 .and_then(|t| t.alpn.clone())
426 .is_some_and(|got| test_bytes(&self.op, got.as_slice())),
427 FieldPath::TlsVersion => view
428 .conn()
429 .tls
430 .lock()
431 .as_ref()
432 .and_then(|t| t.version)
433 .is_some_and(|v| test_str(&self.op, tls_version_str(v))),
434 FieldPath::TlsPeerCertPresent => {
444 let present = view.conn().tls.lock().as_ref().is_some_and(|t| t.peer_cert.is_some());
445 test_bool(&self.op, present)
446 }
447 FieldPath::TlsPeerCertSubjectCn => view
448 .conn()
449 .tls
450 .lock()
451 .as_ref()
452 .and_then(|t| t.peer_cert.as_ref().and_then(|p| p.subject_cn.clone()))
453 .is_some_and(|cn| test_str(&self.op, cn.as_str())),
454 FieldPath::TlsPeerCertSanDns => {
455 let dns_list: Vec<String> = view
456 .conn()
457 .tls
458 .lock()
459 .as_ref()
460 .and_then(|t| t.peer_cert.as_ref().map(|p| p.san_dns.clone()))
461 .unwrap_or_default();
462 test_vec_str(&self.op, &dns_list)
463 }
464 FieldPath::TlsPeerCertFingerprintSha256 => view
465 .conn()
466 .tls
467 .lock()
468 .as_ref()
469 .and_then(|t| t.peer_cert.as_ref().map(|p| p.fingerprint_sha256.clone()))
470 .is_some_and(|s| test_str(&self.op, s.as_str())),
471 FieldPath::TlsPeerCertSpkiSha256 => view
472 .conn()
473 .tls
474 .lock()
475 .as_ref()
476 .and_then(|t| t.peer_cert.as_ref().map(|p| p.spki_sha256.clone()))
477 .is_some_and(|s| test_str(&self.op, s.as_str())),
478 FieldPath::TlsPeerCertIssuerCn => view
479 .conn()
480 .tls
481 .lock()
482 .as_ref()
483 .and_then(|t| t.peer_cert.as_ref().and_then(|p| p.issuer_cn.clone()))
484 .is_some_and(|s| test_str(&self.op, s.as_str())),
485 FieldPath::TlsPeerCertSerial => view
486 .conn()
487 .tls
488 .lock()
489 .as_ref()
490 .and_then(|t| t.peer_cert.as_ref().map(|p| p.serial.clone()))
491 .is_some_and(|s| test_str(&self.op, s.as_str())),
492 FieldPath::HttpMethod => {
493 let Some(req) = view.request() else { return false };
494 test_str(&self.op, req.method().as_str())
495 }
496 FieldPath::HttpUriPath => {
497 let Some(req) = view.request() else { return false };
498 test_str(&self.op, req.uri().path())
499 }
500 FieldPath::HttpUriQuery => {
501 let Some(req) = view.request() else { return false };
502 test_str(&self.op, req.uri().query().unwrap_or(""))
503 }
504 FieldPath::HttpHeader(name) => {
512 let Some(req) = view.request() else { return false };
513 let Some(value) = req.headers().get(name.as_ref()) else { return false };
514 let Ok(s) = value.to_str() else {
515 return false;
520 };
521 test_str(&self.op, s)
522 }
523 FieldPath::HttpBody => {
532 let Some(req) = view.request() else { return false };
533 let bytes = req.body().as_static().expect("lazy-buffer invariant");
534 test_bytes(&self.op, bytes.as_ref())
535 }
536 }
537 }
538}
539
540fn tls_version_str(v: crate::conn_context::TlsVersion) -> &'static str {
541 match v {
542 crate::conn_context::TlsVersion::Tls12 => "1.2",
543 crate::conn_context::TlsVersion::Tls13 => "1.3",
544 }
545}
546
547fn test_bool(op: &CompiledOperator, value: bool) -> bool {
559 match op {
560 CompiledOperator::Equals(CompiledValue::Bool(expected)) => value == *expected,
561 CompiledOperator::NotEquals(CompiledValue::Bool(expected)) => value != *expected,
562 _ => false,
563 }
564}
565
566fn test_vec_str(op: &CompiledOperator, values: &[String]) -> bool {
572 match op {
573 CompiledOperator::Contains(needle) => values.iter().any(|v| v.as_bytes() == needle.as_ref()),
574 CompiledOperator::NotContains(needle) => {
575 !values.iter().any(|v| v.as_bytes() == needle.as_ref())
576 }
577 _ => false,
578 }
579}
580
581fn test_str(op: &CompiledOperator, value: &str) -> bool {
586 match op {
587 CompiledOperator::Equals(CompiledValue::Str(expected)) => value == expected.as_ref(),
588 CompiledOperator::NotEquals(CompiledValue::Str(expected)) => value != expected.as_ref(),
589 CompiledOperator::Contains(b) => contains_bytes(value.as_bytes(), b),
590 CompiledOperator::NotContains(b) => !contains_bytes(value.as_bytes(), b),
591 CompiledOperator::Prefix(b) => value.as_bytes().starts_with(b.as_ref()),
592 CompiledOperator::Suffix(b) => value.as_bytes().ends_with(b.as_ref()),
593 CompiledOperator::Matches(re) => re.is_match(value).unwrap_or(false),
594 CompiledOperator::In(values) => {
595 values.iter().any(|v| matches!(v, CompiledValue::Str(s) if value == s.as_ref()))
596 }
597 CompiledOperator::NotIn(values) => {
598 !values.iter().any(|v| matches!(v, CompiledValue::Str(s) if value == s.as_ref()))
599 }
600 _ => false,
601 }
602}
603
604fn test_bytes(op: &CompiledOperator, value: &[u8]) -> bool {
608 match op {
609 CompiledOperator::Equals(CompiledValue::Bytes(expected)) => value == expected.as_ref(),
610 CompiledOperator::NotEquals(CompiledValue::Bytes(expected)) => value != expected.as_ref(),
611 CompiledOperator::Contains(b) => contains_bytes(value, b),
612 CompiledOperator::NotContains(b) => !contains_bytes(value, b),
613 CompiledOperator::Prefix(b) => value.starts_with(b.as_ref()),
614 CompiledOperator::Suffix(b) => value.ends_with(b.as_ref()),
615 CompiledOperator::In(values) => {
616 values.iter().any(|v| matches!(v, CompiledValue::Bytes(b) if value == b.as_ref()))
617 }
618 CompiledOperator::NotIn(values) => {
619 !values.iter().any(|v| matches!(v, CompiledValue::Bytes(b) if value == b.as_ref()))
620 }
621 _ => false,
622 }
623}
624
625fn test_int(op: &CompiledOperator, value: i64) -> bool {
628 match op {
629 CompiledOperator::Equals(CompiledValue::Int(expected)) => value == *expected,
630 CompiledOperator::NotEquals(CompiledValue::Int(expected)) => value != *expected,
631 CompiledOperator::Gt(n) => value > *n,
632 CompiledOperator::Gte(n) => value >= *n,
633 CompiledOperator::Lt(n) => value < *n,
634 CompiledOperator::Lte(n) => value <= *n,
635 CompiledOperator::In(values) => {
636 values.iter().any(|v| matches!(v, CompiledValue::Int(i) if value == *i))
637 }
638 CompiledOperator::NotIn(values) => {
639 !values.iter().any(|v| matches!(v, CompiledValue::Int(i) if value == *i))
640 }
641 _ => false,
642 }
643}
644
645fn test_addr(op: &CompiledOperator, value: std::net::IpAddr) -> bool {
649 match op {
650 CompiledOperator::Equals(CompiledValue::Addr(expected)) => value == *expected,
651 CompiledOperator::NotEquals(CompiledValue::Addr(expected)) => value != *expected,
652 CompiledOperator::Cidr(net) => net.contains(&value),
653 CompiledOperator::In(values) => {
654 values.iter().any(|v| matches!(v, CompiledValue::Addr(a) if value == *a))
655 }
656 CompiledOperator::NotIn(values) => {
657 !values.iter().any(|v| matches!(v, CompiledValue::Addr(a) if value == *a))
658 }
659 _ => false,
660 }
661}
662
663fn contains_bytes(haystack: &[u8], needle: &[u8]) -> bool {
664 if needle.is_empty() {
665 return true;
666 }
667 if needle.len() > haystack.len() {
668 return false;
669 }
670 haystack.windows(needle.len()).any(|w| w == needle)
671}
672
673pub const REGEX_PATTERN_MAX_BYTES: usize = 4 * 1024;
674
675#[derive(Debug, Clone, serde::Serialize)]
676pub enum Predicate {
677 AnyOf(AnyOfP),
678 AllOf(AllOfP),
679 Not(NotP),
680 Check(CheckMap),
681}
682
683#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
684#[serde(deny_unknown_fields)]
685pub struct AnyOfP {
686 pub any_of: Vec<Predicate>,
687}
688
689#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
690#[serde(deny_unknown_fields)]
691pub struct AllOfP {
692 pub all_of: Vec<Predicate>,
693}
694
695#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
696#[serde(deny_unknown_fields)]
697pub struct NotP {
698 pub not: Box<Predicate>,
699}
700
701#[derive(Debug, Clone, serde::Serialize)]
702pub struct CheckMap {
703 pub path: FieldPath,
704 pub op: Operator,
705}
706
707#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
708#[serde(rename_all = "snake_case")]
709pub enum Operator {
710 Equals(Value),
711 NotEquals(Value),
712 Contains(Value),
713 NotContains(Value),
714 Prefix(Value),
715 Suffix(Value),
716 Matches(String),
717 In(Vec<Value>),
718 NotIn(Vec<Value>),
719 Gt(i64),
720 Gte(i64),
721 Lt(i64),
722 Lte(i64),
723 Cidr(String),
724}
725
726#[derive(Debug, Clone, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
727#[serde(untagged)]
728pub enum Value {
729 Bool(bool),
730 Int(i64),
731 Str(String),
732}
733
734impl<'de> serde::Deserialize<'de> for Predicate {
735 fn deserialize<D: serde::Deserializer<'de>>(de: D) -> Result<Self, D::Error> {
736 let v = serde_json::Value::deserialize(de)?;
737 let serde_json::Value::Object(ref map) = v else {
738 return Err(serde::de::Error::custom("predicate must be a JSON object"));
739 };
740 if map.len() == 1 {
741 let (key, _) = map.iter().next().expect("len == 1");
742 match key.as_str() {
743 "any_of" => {
744 return serde_json::from_value::<AnyOfP>(v)
745 .map(Predicate::AnyOf)
746 .map_err(serde::de::Error::custom);
747 }
748 "all_of" => {
749 return serde_json::from_value::<AllOfP>(v)
750 .map(Predicate::AllOf)
751 .map_err(serde::de::Error::custom);
752 }
753 "not" => {
754 return serde_json::from_value::<NotP>(v)
755 .map(Predicate::Not)
756 .map_err(serde::de::Error::custom);
757 }
758 _ => {}
759 }
760 }
761 serde_json::from_value::<CheckMap>(v).map(Predicate::Check).map_err(serde::de::Error::custom)
762 }
763}
764
765impl<'de> serde::Deserialize<'de> for CheckMap {
766 fn deserialize<D: serde::Deserializer<'de>>(de: D) -> Result<Self, D::Error> {
767 struct Visitor;
768
769 impl<'de> serde::de::Visitor<'de> for Visitor {
770 type Value = CheckMap;
771
772 fn expecting(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
773 f.write_str("a single-key object of the form {\"<field-path>\": {\"<operator>\": <value>}}")
774 }
775
776 fn visit_map<M: serde::de::MapAccess<'de>>(self, mut map: M) -> Result<CheckMap, M::Error> {
777 let Some(key) = map.next_key::<String>()? else {
778 return Err(serde::de::Error::invalid_length(0, &"exactly one key"));
779 };
780 let path = parse_field_path(&key).map_err(serde::de::Error::custom)?;
781 let op: Operator = map.next_value()?;
782 if map.next_key::<serde::de::IgnoredAny>()?.is_some() {
783 return Err(serde::de::Error::custom("check object must have exactly one key"));
784 }
785 validate_operator(&op).map_err(serde::de::Error::custom)?;
786 Ok(CheckMap { path, op })
787 }
788 }
789
790 de.deserialize_map(Visitor)
791 }
792}
793
794fn parse_field_path(s: &str) -> Result<FieldPath, String> {
795 if s.chars().any(|c| c.is_ascii_uppercase()) {
796 return Err(format!(
797 "field path must be lowercase: {:?} — did you mean {:?}?",
798 s,
799 s.to_ascii_lowercase(),
800 ));
801 }
802 match s {
803 "transport" => Ok(FieldPath::Transport),
804 "remote.ip" => Ok(FieldPath::RemoteIp),
805 "remote.port" => Ok(FieldPath::RemotePort),
806 "local.ip" => Ok(FieldPath::LocalIp),
807 "local.port" => Ok(FieldPath::LocalPort),
808 "peek" => Ok(FieldPath::Peek),
809 "tls.sni" => Ok(FieldPath::TlsSni),
810 "tls.alpn" => Ok(FieldPath::TlsAlpn),
811 "tls.version" => Ok(FieldPath::TlsVersion),
812 "tls.peer_cert.present" => Ok(FieldPath::TlsPeerCertPresent),
813 "tls.peer_cert.subject_cn" => Ok(FieldPath::TlsPeerCertSubjectCn),
814 "tls.peer_cert.san_dns" => Ok(FieldPath::TlsPeerCertSanDns),
815 "tls.peer_cert.fingerprint_sha256" => Ok(FieldPath::TlsPeerCertFingerprintSha256),
816 "tls.peer_cert.spki_sha256" => Ok(FieldPath::TlsPeerCertSpkiSha256),
817 "tls.peer_cert.issuer_cn" => Ok(FieldPath::TlsPeerCertIssuerCn),
818 "tls.peer_cert.serial" => Ok(FieldPath::TlsPeerCertSerial),
819 "http.method" => Ok(FieldPath::HttpMethod),
820 "http.uri.path" => Ok(FieldPath::HttpUriPath),
821 "http.uri.query" => Ok(FieldPath::HttpUriQuery),
822 "http.body" => Ok(FieldPath::HttpBody),
823 other if other.starts_with("http.header.") => {
824 let name = &other["http.header.".len()..];
825 if name.is_empty() {
826 return Err(format!("http.header.* requires a header name: {other:?}"));
827 }
828 Ok(FieldPath::HttpHeader(Arc::from(name)))
829 }
830 other => Err(format!("unknown field path: {other:?}")),
831 }
832}
833
834fn validate_operator(op: &Operator) -> Result<(), String> {
835 if let Operator::Matches(pattern) = op
836 && pattern.len() > REGEX_PATTERN_MAX_BYTES
837 {
838 return Err(format!(
839 "regex pattern source exceeds {REGEX_PATTERN_MAX_BYTES}-byte limit: got {} bytes",
840 pattern.len(),
841 ));
842 }
843 Ok(())
844}
845
846mod serde_impls {
847 use base64::Engine as _;
848 use base64::engine::general_purpose::STANDARD as B64;
849 use bytes::Bytes;
850 use std::net::IpAddr;
851 use std::sync::Arc;
852
853 use super::{CompiledOperator, CompiledValue};
854
855 pub(super) fn ser_bytes<S: serde::Serializer>(b: &Bytes, s: S) -> Result<S::Ok, S::Error> {
856 s.serialize_str(&B64.encode(b))
857 }
858
859 pub(super) fn de_bytes<'de, D: serde::Deserializer<'de>>(d: D) -> Result<Bytes, D::Error> {
860 use serde::Deserialize as _;
861 let s = String::deserialize(d)?;
862 B64.decode(s.as_bytes()).map(Bytes::from).map_err(serde::de::Error::custom)
863 }
864
865 pub(super) fn ser_regex<S: serde::Serializer>(
866 r: &fancy_regex::Regex,
867 s: S,
868 ) -> Result<S::Ok, S::Error> {
869 s.serialize_str(r.as_str())
870 }
871
872 pub(super) fn de_regex<'de, D: serde::Deserializer<'de>>(
873 d: D,
874 ) -> Result<fancy_regex::Regex, D::Error> {
875 use serde::Deserialize as _;
876 let s = String::deserialize(d)?;
877 fancy_regex::Regex::new(&s)
878 .map_err(|e| serde::de::Error::custom(format!("invalid regex {s:?}: {e}")))
879 }
880
881 #[derive(serde::Serialize, serde::Deserialize)]
883 #[serde(rename_all = "snake_case")]
884 pub(super) enum ValueShadow {
885 Str(Arc<str>),
886 #[serde(serialize_with = "ser_bytes", deserialize_with = "de_bytes")]
887 Bytes(Bytes),
888 Int(i64),
889 Bool(bool),
890 Addr(IpAddr),
891 }
892
893 impl From<&CompiledValue> for ValueShadow {
894 fn from(v: &CompiledValue) -> Self {
895 match v {
896 CompiledValue::Str(s) => Self::Str(Arc::clone(s)),
897 CompiledValue::Bytes(b) => Self::Bytes(b.clone()),
898 CompiledValue::Int(i) => Self::Int(*i),
899 CompiledValue::Bool(b) => Self::Bool(*b),
900 CompiledValue::Addr(a) => Self::Addr(*a),
901 }
902 }
903 }
904
905 impl From<ValueShadow> for CompiledValue {
906 fn from(v: ValueShadow) -> Self {
907 match v {
908 ValueShadow::Str(s) => Self::Str(s),
909 ValueShadow::Bytes(b) => Self::Bytes(b),
910 ValueShadow::Int(i) => Self::Int(i),
911 ValueShadow::Bool(b) => Self::Bool(b),
912 ValueShadow::Addr(a) => Self::Addr(a),
913 }
914 }
915 }
916
917 #[derive(serde::Serialize, serde::Deserialize)]
920 #[serde(rename_all = "snake_case")]
921 pub(super) enum OperatorShadow {
922 Equals(CompiledValue),
923 NotEquals(CompiledValue),
924 #[serde(serialize_with = "ser_bytes", deserialize_with = "de_bytes")]
925 Contains(Bytes),
926 #[serde(serialize_with = "ser_bytes", deserialize_with = "de_bytes")]
927 NotContains(Bytes),
928 #[serde(serialize_with = "ser_bytes", deserialize_with = "de_bytes")]
929 Prefix(Bytes),
930 #[serde(serialize_with = "ser_bytes", deserialize_with = "de_bytes")]
931 Suffix(Bytes),
932 #[serde(serialize_with = "ser_regex", deserialize_with = "de_regex")]
933 Matches(fancy_regex::Regex),
934 In(Vec<CompiledValue>),
935 NotIn(Vec<CompiledValue>),
936 Gt(i64),
937 Gte(i64),
938 Lt(i64),
939 Lte(i64),
940 Cidr(ipnet::IpNet),
941 }
942
943 impl From<&CompiledOperator> for OperatorShadow {
944 fn from(op: &CompiledOperator) -> Self {
945 match op {
946 CompiledOperator::Equals(v) => Self::Equals(v.clone()),
947 CompiledOperator::NotEquals(v) => Self::NotEquals(v.clone()),
948 CompiledOperator::Contains(b) => Self::Contains(b.clone()),
949 CompiledOperator::NotContains(b) => Self::NotContains(b.clone()),
950 CompiledOperator::Prefix(b) => Self::Prefix(b.clone()),
951 CompiledOperator::Suffix(b) => Self::Suffix(b.clone()),
952 CompiledOperator::Matches(r) => {
953 Self::Matches(fancy_regex::Regex::new(r.as_str()).expect("round-trippable"))
954 }
955 CompiledOperator::In(vs) => Self::In(vs.clone()),
956 CompiledOperator::NotIn(vs) => Self::NotIn(vs.clone()),
957 CompiledOperator::Gt(i) => Self::Gt(*i),
958 CompiledOperator::Gte(i) => Self::Gte(*i),
959 CompiledOperator::Lt(i) => Self::Lt(*i),
960 CompiledOperator::Lte(i) => Self::Lte(*i),
961 CompiledOperator::Cidr(n) => Self::Cidr(*n),
962 }
963 }
964 }
965
966 impl From<OperatorShadow> for CompiledOperator {
967 fn from(op: OperatorShadow) -> Self {
968 match op {
969 OperatorShadow::Equals(v) => Self::Equals(v),
970 OperatorShadow::NotEquals(v) => Self::NotEquals(v),
971 OperatorShadow::Contains(b) => Self::Contains(b),
972 OperatorShadow::NotContains(b) => Self::NotContains(b),
973 OperatorShadow::Prefix(b) => Self::Prefix(b),
974 OperatorShadow::Suffix(b) => Self::Suffix(b),
975 OperatorShadow::Matches(r) => Self::Matches(r),
976 OperatorShadow::In(vs) => Self::In(vs),
977 OperatorShadow::NotIn(vs) => Self::NotIn(vs),
978 OperatorShadow::Gt(i) => Self::Gt(i),
979 OperatorShadow::Gte(i) => Self::Gte(i),
980 OperatorShadow::Lt(i) => Self::Lt(i),
981 OperatorShadow::Lte(i) => Self::Lte(i),
982 OperatorShadow::Cidr(n) => Self::Cidr(n),
983 }
984 }
985 }
986}
987
988impl serde::Serialize for CompiledValue {
989 fn serialize<S: serde::Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
990 serde_impls::ValueShadow::from(self).serialize(s)
991 }
992}
993
994impl<'de> serde::Deserialize<'de> for CompiledValue {
995 fn deserialize<D: serde::Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
996 serde_impls::ValueShadow::deserialize(d).map(Self::from)
997 }
998}
999
1000impl serde::Serialize for CompiledOperator {
1001 fn serialize<S: serde::Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
1002 serde_impls::OperatorShadow::from(self).serialize(s)
1003 }
1004}
1005
1006impl<'de> serde::Deserialize<'de> for CompiledOperator {
1007 fn deserialize<D: serde::Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
1008 serde_impls::OperatorShadow::deserialize(d).map(Self::from)
1009 }
1010}
1011
1012#[cfg(test)]
1013mod tests {
1014 use std::collections::hash_map::DefaultHasher;
1015 use std::hash::Hash;
1016 use std::net::{Ipv4Addr, Ipv6Addr};
1017 use std::str::FromStr;
1018 use std::sync::OnceLock;
1019 use std::time::Instant;
1020
1021 use bytes::Bytes;
1022 use fancy_regex::Regex;
1023 use ipnet::IpNet;
1024 use parking_lot::Mutex;
1025
1026 use super::*;
1027 use crate::body::{Body, Request};
1028 use crate::conn_context::{ConnId, Transport};
1029
1030 fn hash_of<T: Hash>(v: &T) -> u64 {
1035 let mut h = DefaultHasher::new();
1036 v.hash(&mut h);
1037 h.finish()
1038 }
1039
1040 fn make_conn() -> Arc<ConnContext> {
1041 Arc::new(ConnContext {
1042 id: ConnId(1),
1043 remote: "127.0.0.1:0".parse().expect("parse remote"),
1044 local: "127.0.0.1:0".parse().expect("parse local"),
1045 transport: Transport::Tcp,
1046 entered_at: Instant::now(),
1047 tls: Mutex::new(None),
1048 http_version: OnceLock::new(),
1049 user: Mutex::new(http::Extensions::new()),
1050 })
1051 }
1052
1053 #[test]
1054 fn field_path_http_header_is_equal_by_string_content_not_arc_identity() {
1055 let a = FieldPath::HttpHeader(Arc::from("host"));
1056 let b = FieldPath::HttpHeader(Arc::from("host"));
1057 assert_eq!(a, b);
1058 assert_eq!(hash_of(&a), hash_of(&b));
1059 let upper = FieldPath::HttpHeader(Arc::from("Host"));
1064 assert_ne!(a, upper);
1065 }
1066
1067 #[test]
1068 fn field_path_simple_variants_are_self_equal_and_mutually_distinct() {
1069 let paths = [
1070 FieldPath::Transport,
1071 FieldPath::RemoteIp,
1072 FieldPath::RemotePort,
1073 FieldPath::LocalIp,
1074 FieldPath::LocalPort,
1075 FieldPath::Peek,
1076 FieldPath::TlsSni,
1077 FieldPath::TlsAlpn,
1078 FieldPath::TlsVersion,
1079 FieldPath::TlsPeerCertSubjectCn,
1080 FieldPath::HttpMethod,
1081 FieldPath::HttpUriPath,
1082 FieldPath::HttpUriQuery,
1083 FieldPath::HttpBody,
1084 ];
1085 for (i, a) in paths.iter().enumerate() {
1086 for (j, b) in paths.iter().enumerate() {
1087 if i == j {
1088 assert_eq!(a, b);
1089 } else {
1090 assert_ne!(a, b);
1091 }
1092 }
1093 }
1094 }
1095
1096 #[test]
1097 fn compiled_value_str_is_equal_by_content_not_arc_identity() {
1098 let a = CompiledValue::Str(Arc::<str>::from("x"));
1099 let b = CompiledValue::Str(Arc::<str>::from("x"));
1100 assert_eq!(a, b);
1101 assert_eq!(hash_of(&a), hash_of(&b));
1102 let c = CompiledValue::Str(Arc::<str>::from("y"));
1103 assert_ne!(a, c);
1104 }
1105
1106 #[test]
1107 fn compiled_value_cross_variant_inequality() {
1108 let s = CompiledValue::Str(Arc::<str>::from("42"));
1109 let i = CompiledValue::Int(42);
1110 assert_ne!(s, i);
1111 }
1112
1113 #[test]
1114 fn compiled_value_bytes_int_bool_addr_self_equal() {
1115 assert_eq!(
1116 CompiledValue::Bytes(Bytes::from_static(b"abc")),
1117 CompiledValue::Bytes(Bytes::copy_from_slice(b"abc")),
1118 );
1119 assert_eq!(CompiledValue::Int(7), CompiledValue::Int(7));
1120 assert_ne!(CompiledValue::Int(7), CompiledValue::Int(8));
1121 assert_eq!(CompiledValue::Bool(true), CompiledValue::Bool(true));
1122 assert_ne!(CompiledValue::Bool(true), CompiledValue::Bool(false));
1123 assert_eq!(
1124 CompiledValue::Addr(Ipv4Addr::new(10, 0, 0, 1).into()),
1125 CompiledValue::Addr(Ipv4Addr::new(10, 0, 0, 1).into()),
1126 );
1127 assert_ne!(
1128 CompiledValue::Addr(Ipv4Addr::new(10, 0, 0, 1).into()),
1129 CompiledValue::Addr(Ipv6Addr::LOCALHOST.into()),
1130 );
1131 }
1132
1133 #[test]
1134 fn compiled_operator_matches_equal_by_pattern_source() {
1135 let a = CompiledOperator::Matches(Regex::new("^/api").expect("compile a"));
1136 let b = CompiledOperator::Matches(Regex::new("^/api").expect("compile b"));
1137 assert_eq!(a, b);
1138 assert_eq!(hash_of(&a), hash_of(&b));
1139 }
1140
1141 #[test]
1142 fn compiled_operator_matches_distinct_patterns_unequal() {
1143 let a = CompiledOperator::Matches(Regex::new("a|b").expect("compile a"));
1146 let b = CompiledOperator::Matches(Regex::new("b|a").expect("compile b"));
1147 assert_ne!(a, b);
1148 }
1149
1150 #[test]
1151 fn compiled_operator_cidr_equal_by_canonical_form() {
1152 let a = CompiledOperator::Cidr(IpNet::from_str("10.0.0.0/8").expect("parse a"));
1153 let b = CompiledOperator::Cidr(IpNet::from_str("10.0.0.0/8").expect("parse b"));
1154 assert_eq!(a, b);
1155 assert_eq!(hash_of(&a), hash_of(&b));
1156 }
1157
1158 #[test]
1159 fn compiled_operator_cidr_distinct_networks_unequal() {
1160 let a = CompiledOperator::Cidr(IpNet::from_str("10.0.0.0/8").expect("parse a"));
1161 let b = CompiledOperator::Cidr(IpNet::from_str("10.0.0.0/16").expect("parse b"));
1162 assert_ne!(a, b);
1163 }
1164
1165 #[test]
1166 fn compiled_operator_in_is_order_sensitive() {
1167 let xs =
1168 vec![CompiledValue::Str(Arc::<str>::from("a")), CompiledValue::Str(Arc::<str>::from("b"))];
1169 let ys =
1170 vec![CompiledValue::Str(Arc::<str>::from("b")), CompiledValue::Str(Arc::<str>::from("a"))];
1171 assert_ne!(CompiledOperator::In(xs.clone()), CompiledOperator::In(ys.clone()));
1172 assert_ne!(CompiledOperator::NotIn(xs), CompiledOperator::NotIn(ys));
1173 }
1174
1175 #[test]
1176 fn compiled_operator_numeric_comparisons_distinct_per_variant() {
1177 let ops = [
1179 CompiledOperator::Gt(10),
1180 CompiledOperator::Gte(10),
1181 CompiledOperator::Lt(10),
1182 CompiledOperator::Lte(10),
1183 ];
1184 for (i, a) in ops.iter().enumerate() {
1185 for (j, b) in ops.iter().enumerate() {
1186 if i == j {
1187 assert_eq!(a, b);
1188 } else {
1189 assert_ne!(a, b);
1190 }
1191 }
1192 }
1193 }
1194
1195 #[test]
1196 fn compiled_operator_bytes_variants_distinguished() {
1197 let payload = Bytes::from_static(b"abc");
1198 let ops = [
1199 CompiledOperator::Contains(payload.clone()),
1200 CompiledOperator::NotContains(payload.clone()),
1201 CompiledOperator::Prefix(payload.clone()),
1202 CompiledOperator::Suffix(payload),
1203 ];
1204 for (i, a) in ops.iter().enumerate() {
1205 for (j, b) in ops.iter().enumerate() {
1206 if i == j {
1207 assert_eq!(a, b);
1208 } else {
1209 assert_ne!(a, b);
1210 }
1211 }
1212 }
1213 }
1214
1215 #[test]
1216 fn predicate_inst_equal_across_independent_construction_paths() {
1217 let lhs = PredicateInst {
1218 path: FieldPath::HttpHeader(Arc::from("host")),
1219 op: CompiledOperator::Equals(CompiledValue::Str(Arc::<str>::from("example.com"))),
1220 };
1221 let rhs = PredicateInst {
1222 path: FieldPath::HttpHeader(Arc::from("host")),
1223 op: CompiledOperator::Equals(CompiledValue::Str(Arc::<str>::from("example.com"))),
1224 };
1225 assert_eq!(lhs, rhs);
1226 assert_eq!(hash_of(&lhs), hash_of(&rhs));
1227 }
1228
1229 #[test]
1230 fn predicate_inst_equal_with_regex_operator_from_separate_compiles() {
1231 let lhs = PredicateInst {
1232 path: FieldPath::HttpUriPath,
1233 op: CompiledOperator::Matches(Regex::new("^/").expect("compile a")),
1234 };
1235 let rhs = PredicateInst {
1236 path: FieldPath::HttpUriPath,
1237 op: CompiledOperator::Matches(Regex::new("^/").expect("compile b")),
1238 };
1239 assert_eq!(lhs, rhs);
1240 assert_eq!(hash_of(&lhs), hash_of(&rhs));
1241 }
1242
1243 #[test]
1244 fn predicate_inst_unequal_on_path_difference() {
1245 let value = CompiledValue::Str(Arc::<str>::from("x"));
1246 let a =
1247 PredicateInst { path: FieldPath::HttpUriPath, op: CompiledOperator::Equals(value.clone()) };
1248 let b = PredicateInst { path: FieldPath::HttpUriQuery, op: CompiledOperator::Equals(value) };
1249 assert_ne!(a, b);
1250 }
1251
1252 #[test]
1253 fn predicate_view_variants_construct() {
1254 let conn = make_conn();
1255 let peek_bytes: &[u8] = b"\x16\x03\x01";
1256 let l4 = PredicateView::L4 { conn: &conn, peek: Some(peek_bytes) };
1257 match l4 {
1258 PredicateView::L4 { peek, .. } => assert_eq!(peek.map(<[u8]>::len), Some(3)),
1259 PredicateView::L7Req { .. } => panic!("wrong variant"),
1260 }
1261
1262 let conn2 = make_conn();
1263 let req: Request =
1264 http::Request::builder().method("GET").uri("/").body(Body::Empty).expect("build request");
1265 let l7 = PredicateView::L7Req { conn: &conn2, req: &req };
1266 match l7 {
1267 PredicateView::L7Req { .. } => {}
1268 PredicateView::L4 { .. } => panic!("wrong variant"),
1269 }
1270 }
1271
1272 fn parse_predicate(v: serde_json::Value) -> Result<Predicate, serde_json::Error> {
1276 serde_json::from_value(v)
1277 }
1278
1279 fn expect_check(p: &Predicate) -> &CheckMap {
1280 match p {
1281 Predicate::Check(c) => c,
1282 other => panic!("expected Predicate::Check, got {other:?}"),
1283 }
1284 }
1285
1286 #[test]
1287 fn parse_any_of_happy_path() {
1288 let raw = serde_json::json!({
1289 "any_of": [
1290 { "tls.sni": { "equals": "a" } },
1291 { "tls.sni": { "equals": "b" } },
1292 ],
1293 });
1294 let p = parse_predicate(raw).expect("parse any_of");
1295 let Predicate::AnyOf(AnyOfP { any_of }) = p else {
1296 panic!("expected AnyOf");
1297 };
1298 assert_eq!(any_of.len(), 2);
1299 let c0 = expect_check(&any_of[0]);
1300 let c1 = expect_check(&any_of[1]);
1301 assert_eq!(c0.path, FieldPath::TlsSni);
1302 assert_eq!(c1.path, FieldPath::TlsSni);
1303 match (&c0.op, &c1.op) {
1304 (Operator::Equals(Value::Str(a)), Operator::Equals(Value::Str(b))) => {
1305 assert_eq!(a, "a");
1306 assert_eq!(b, "b");
1307 }
1308 (a, b) => panic!("unexpected ops: {a:?} / {b:?}"),
1309 }
1310 }
1311
1312 #[test]
1313 fn parse_not_happy_path() {
1314 let raw = serde_json::json!({
1315 "not": { "tls.sni": { "equals": "internal" } },
1316 });
1317 let p = parse_predicate(raw).expect("parse not");
1318 let Predicate::Not(NotP { not }) = p else {
1319 panic!("expected Not");
1320 };
1321 let inner = expect_check(¬);
1322 assert_eq!(inner.path, FieldPath::TlsSni);
1323 match &inner.op {
1324 Operator::Equals(Value::Str(s)) => assert_eq!(s, "internal"),
1325 other => panic!("unexpected op: {other:?}"),
1326 }
1327 }
1328
1329 #[test]
1330 fn parse_all_of_happy_path() {
1331 let raw = serde_json::json!({
1332 "all_of": [
1333 { "http.header.upgrade": { "equals": "websocket" } },
1334 { "http.uri.path": { "prefix": "/ws" } },
1335 ],
1336 });
1337 let p = parse_predicate(raw).expect("parse all_of");
1338 let Predicate::AllOf(AllOfP { all_of }) = p else {
1339 panic!("expected AllOf");
1340 };
1341 assert_eq!(all_of.len(), 2);
1342 let c0 = expect_check(&all_of[0]);
1343 let c1 = expect_check(&all_of[1]);
1344 assert_eq!(c0.path, FieldPath::HttpHeader(Arc::from("upgrade")));
1345 assert_eq!(c1.path, FieldPath::HttpUriPath);
1346 }
1347
1348 #[test]
1349 fn parse_all_of_empty_array_parses() {
1350 let raw = serde_json::json!({ "all_of": [] });
1353 let p = parse_predicate(raw).expect("empty all_of parses");
1354 let Predicate::AllOf(AllOfP { all_of }) = p else {
1355 panic!("expected AllOf");
1356 };
1357 assert!(all_of.is_empty());
1358 }
1359
1360 #[test]
1361 fn parse_all_of_nested_with_check_and_any_of() {
1362 let raw = serde_json::json!({
1363 "all_of": [
1364 { "tls.sni": { "equals": "api.example.com" } },
1365 { "any_of": [
1366 { "remote.ip": { "cidr": "10.0.0.0/8" } },
1367 { "remote.ip": { "cidr": "192.168.0.0/16" } },
1368 ]},
1369 ],
1370 });
1371 let p = parse_predicate(raw).expect("parse nested all_of/any_of");
1372 let Predicate::AllOf(AllOfP { all_of }) = p else {
1373 panic!("expected AllOf");
1374 };
1375 assert_eq!(all_of.len(), 2);
1376 assert!(matches!(all_of[0], Predicate::Check(_)));
1377 assert!(matches!(all_of[1], Predicate::AnyOf(_)));
1378 }
1379
1380 #[test]
1381 fn parse_all_of_with_extra_key_is_rejected() {
1382 let raw = serde_json::json!({
1384 "all_of": [ { "tls.sni": { "equals": "a" } } ],
1385 "extra": "unwanted",
1386 });
1387 let err = parse_predicate(raw).expect_err("must reject extra key on all_of");
1388 let _ = err.to_string();
1389 }
1390
1391 #[test]
1392 fn parse_http_header_all_of_is_a_check_not_combinator() {
1393 let raw = serde_json::json!({ "http.header.all_of": { "equals": "x" } });
1396 let p = parse_predicate(raw).expect("parse http.header.all_of");
1397 let c = expect_check(&p);
1398 assert_eq!(c.path, FieldPath::HttpHeader(Arc::from("all_of")));
1399 }
1400
1401 #[test]
1402 fn parse_check_across_representative_paths() {
1403 let cases = [
1404 (serde_json::json!({ "tls.sni": { "equals": "api.example.com" } }), FieldPath::TlsSni),
1405 (serde_json::json!({ "remote.port": { "gt": 1024 } }), FieldPath::RemotePort),
1406 (serde_json::json!({ "http.method": { "equals": "GET" } }), FieldPath::HttpMethod),
1407 (serde_json::json!({ "http.uri.path": { "prefix": "/api" } }), FieldPath::HttpUriPath),
1408 (
1409 serde_json::json!({ "http.header.host": { "equals": "a.example.com" } }),
1410 FieldPath::HttpHeader(Arc::from("host")),
1411 ),
1412 (serde_json::json!({ "http.body": { "contains": "hello" } }), FieldPath::HttpBody),
1413 ];
1414 for (raw, expected_path) in cases {
1415 let p = parse_predicate(raw.clone()).unwrap_or_else(|e| panic!("parse {raw}: {e}"));
1416 let c = expect_check(&p);
1417 assert_eq!(c.path, expected_path, "input: {raw}");
1418 }
1419 }
1420
1421 #[test]
1422 fn parse_any_of_with_extra_key_is_rejected() {
1423 let raw = serde_json::json!({
1426 "any_of": [ { "tls.sni": { "equals": "a" } } ],
1427 "extra": true,
1428 });
1429 let err = parse_predicate(raw).expect_err("must reject extra key on any_of");
1430 let _ = err.to_string();
1431 }
1432
1433 #[test]
1434 fn parse_http_header_any_of_is_a_check_not_combinator() {
1435 let raw = serde_json::json!({ "http.header.any_of": { "equals": "x" } });
1438 let p = parse_predicate(raw).expect("parse");
1439 let c = expect_check(&p);
1440 assert_eq!(c.path, FieldPath::HttpHeader(Arc::from("any_of")));
1441 }
1442
1443 #[test]
1444 fn parse_uppercase_field_path_suggests_lowercase() {
1445 let raw = serde_json::json!({ "http.header.Host": { "equals": "x" } });
1446 let err = parse_predicate(raw).expect_err("uppercase must fail");
1447 let msg = err.to_string();
1448 assert!(msg.contains("http.header.Host"), "error mentions offending input: {msg}");
1449 assert!(msg.contains("did you mean"), "error includes suggestion phrase: {msg}");
1450 assert!(msg.contains("http.header.host"), "error contains lowercased form: {msg}");
1451 }
1452
1453 #[test]
1454 fn parse_multi_key_check_is_rejected() {
1455 let raw = serde_json::json!({
1456 "http.uri.path": { "matches": "^/" },
1457 "http.method": { "equals": "GET" },
1458 });
1459 let err = parse_predicate(raw).expect_err("multi-key check must fail");
1460 let _ = err.to_string();
1461 }
1462
1463 #[test]
1464 fn parse_empty_http_header_name_is_rejected() {
1465 let raw = serde_json::json!({ "http.header.": { "equals": "x" } });
1466 let err = parse_predicate(raw).expect_err("empty header name must fail");
1467 let _ = err.to_string();
1468 }
1469
1470 #[test]
1471 fn parse_unknown_field_path_is_rejected_with_name() {
1472 let raw = serde_json::json!({ "http.nope": { "equals": "x" } });
1473 let err = parse_predicate(raw).expect_err("unknown path must fail");
1474 let msg = err.to_string();
1475 assert!(msg.contains("http.nope"), "error mentions offending path: {msg}");
1476 }
1477
1478 fn parse_op(v: serde_json::Value) -> Operator {
1479 let mut map = serde_json::Map::new();
1480 map.insert("tls.sni".to_string(), v);
1481 let raw = serde_json::Value::Object(map);
1482 match parse_predicate(raw).expect("parse check") {
1483 Predicate::Check(c) => c.op,
1484 other => panic!("expected Check, got {other:?}"),
1485 }
1486 }
1487
1488 #[test]
1489 fn operator_equals_and_not_equals_on_string() {
1490 let eq = parse_op(serde_json::json!({ "equals": "api" }));
1491 match eq {
1492 Operator::Equals(Value::Str(s)) => assert_eq!(s, "api"),
1493 other => panic!("expected equals/str: {other:?}"),
1494 }
1495 let neq = parse_op(serde_json::json!({ "not_equals": "api" }));
1496 match neq {
1497 Operator::NotEquals(Value::Str(s)) => assert_eq!(s, "api"),
1498 other => panic!("expected not_equals/str: {other:?}"),
1499 }
1500 }
1501
1502 #[test]
1503 fn operator_contains_and_not_contains_on_string() {
1504 let c = parse_op(serde_json::json!({ "contains": "foo" }));
1505 match c {
1506 Operator::Contains(Value::Str(s)) => assert_eq!(s, "foo"),
1507 other => panic!("expected contains/str: {other:?}"),
1508 }
1509 let nc = parse_op(serde_json::json!({ "not_contains": "foo" }));
1510 match nc {
1511 Operator::NotContains(Value::Str(s)) => assert_eq!(s, "foo"),
1512 other => panic!("expected not_contains/str: {other:?}"),
1513 }
1514 }
1515
1516 #[test]
1517 fn operator_prefix_and_suffix_on_string() {
1518 let p = parse_op(serde_json::json!({ "prefix": "/api" }));
1519 match p {
1520 Operator::Prefix(Value::Str(s)) => assert_eq!(s, "/api"),
1521 other => panic!("expected prefix/str: {other:?}"),
1522 }
1523 let s = parse_op(serde_json::json!({ "suffix": ".json" }));
1524 match s {
1525 Operator::Suffix(Value::Str(v)) => assert_eq!(v, ".json"),
1526 other => panic!("expected suffix/str: {other:?}"),
1527 }
1528 }
1529
1530 #[test]
1531 fn operator_matches_carries_pattern_source() {
1532 let op = parse_op(serde_json::json!({ "matches": "^/api/v\\d+" }));
1533 match op {
1534 Operator::Matches(pattern) => assert_eq!(pattern, "^/api/v\\d+"),
1535 other => panic!("expected matches: {other:?}"),
1536 }
1537 }
1538
1539 #[test]
1540 fn operator_in_and_not_in_accept_mixed_scalar_types() {
1541 let op = parse_op(serde_json::json!({ "in": ["foo", 42] }));
1542 let Operator::In(xs) = op else {
1543 panic!("expected in");
1544 };
1545 assert_eq!(xs.len(), 2);
1546 assert_eq!(xs[0], Value::Str("foo".into()));
1547 assert_eq!(xs[1], Value::Int(42));
1548 let op2 = parse_op(serde_json::json!({ "not_in": ["bar", 7] }));
1549 let Operator::NotIn(ys) = op2 else {
1550 panic!("expected not_in");
1551 };
1552 assert_eq!(ys.len(), 2);
1553 assert_eq!(ys[0], Value::Str("bar".into()));
1554 assert_eq!(ys[1], Value::Int(7));
1555 }
1556
1557 #[test]
1558 fn operator_numeric_comparisons() {
1559 assert!(matches!(parse_op(serde_json::json!({ "gt": 10 })), Operator::Gt(10)));
1560 assert!(matches!(parse_op(serde_json::json!({ "gte": 10 })), Operator::Gte(10)));
1561 assert!(matches!(parse_op(serde_json::json!({ "lt": 10 })), Operator::Lt(10)));
1562 assert!(matches!(parse_op(serde_json::json!({ "lte": 10 })), Operator::Lte(10)));
1563 }
1564
1565 #[test]
1566 fn operator_cidr_carries_source_string() {
1567 let op = parse_op(serde_json::json!({ "cidr": "10.0.0.0/8" }));
1568 match op {
1569 Operator::Cidr(s) => assert_eq!(s, "10.0.0.0/8"),
1570 other => panic!("expected cidr: {other:?}"),
1571 }
1572 }
1573
1574 #[test]
1575 fn value_untagged_priority_bool_before_str() {
1576 let op_t = parse_op(serde_json::json!({ "equals": true }));
1579 assert!(matches!(op_t, Operator::Equals(Value::Bool(true))));
1580 let op_f = parse_op(serde_json::json!({ "equals": false }));
1581 assert!(matches!(op_f, Operator::Equals(Value::Bool(false))));
1582 }
1583
1584 #[test]
1585 fn value_untagged_priority_int_before_str() {
1586 let op = parse_op(serde_json::json!({ "equals": 42 }));
1588 assert!(matches!(op, Operator::Equals(Value::Int(42))));
1589 }
1590
1591 #[test]
1592 fn value_untagged_json_string_stays_str() {
1593 let op = parse_op(serde_json::json!({ "equals": "42" }));
1596 match op {
1597 Operator::Equals(Value::Str(s)) => assert_eq!(s, "42"),
1598 other => panic!("expected equals/str(\"42\"): {other:?}"),
1599 }
1600 }
1601
1602 #[test]
1603 fn regex_pattern_exactly_at_limit_parses() {
1604 assert_eq!(REGEX_PATTERN_MAX_BYTES, 4 * 1024);
1606 let pattern = "a".repeat(REGEX_PATTERN_MAX_BYTES);
1607 let raw = serde_json::json!({ "http.uri.path": { "matches": pattern } });
1608 let p = parse_predicate(raw).expect("4 KiB pattern parses");
1609 let c = expect_check(&p);
1610 match &c.op {
1611 Operator::Matches(src) => assert_eq!(src.len(), REGEX_PATTERN_MAX_BYTES),
1612 other => panic!("expected matches: {other:?}"),
1613 }
1614 }
1615
1616 #[test]
1617 fn regex_pattern_over_limit_rejected_with_limit_in_message() {
1618 let pattern = "a".repeat(REGEX_PATTERN_MAX_BYTES + 1);
1619 let raw = serde_json::json!({ "http.uri.path": { "matches": pattern } });
1620 let err = parse_predicate(raw).expect_err("over-limit pattern must fail");
1621 let msg = err.to_string();
1622 assert!(
1623 msg.contains(®EX_PATTERN_MAX_BYTES.to_string()),
1624 "error mentions the limit ({REGEX_PATTERN_MAX_BYTES}): {msg}",
1625 );
1626 }
1627
1628 fn value_round_trip(v: &CompiledValue) -> CompiledValue {
1635 let encoded = serde_json::to_string(v).expect("serialize value");
1636 serde_json::from_str(&encoded).expect("deserialize value")
1637 }
1638
1639 #[test]
1640 fn compiled_value_str_round_trip_including_empty() {
1641 let non_empty = CompiledValue::Str(Arc::<str>::from("x"));
1642 assert_eq!(value_round_trip(&non_empty), non_empty);
1643 let empty = CompiledValue::Str(Arc::<str>::from(""));
1644 assert_eq!(value_round_trip(&empty), empty);
1645 }
1646
1647 #[test]
1648 fn compiled_value_bytes_round_trip_including_empty_and_binary() {
1649 let hello = CompiledValue::Bytes(Bytes::from_static(b"hello"));
1650 assert_eq!(value_round_trip(&hello), hello);
1651 let empty = CompiledValue::Bytes(Bytes::new());
1652 assert_eq!(value_round_trip(&empty), empty);
1653 let binary = CompiledValue::Bytes(Bytes::from_static(&[0xff, 0x00, 0x13]));
1654 assert_eq!(value_round_trip(&binary), binary);
1655 }
1656
1657 #[test]
1658 fn compiled_value_int_round_trip_including_extremes() {
1659 for i in [0_i64, i64::MIN, i64::MAX] {
1660 let v = CompiledValue::Int(i);
1661 assert_eq!(value_round_trip(&v), v);
1662 }
1663 }
1664
1665 #[test]
1666 fn compiled_value_bool_round_trip_both_variants() {
1667 for b in [true, false] {
1668 let v = CompiledValue::Bool(b);
1669 assert_eq!(value_round_trip(&v), v);
1670 }
1671 }
1672
1673 #[test]
1674 fn compiled_value_addr_round_trip_v4_and_v6() {
1675 let v4 = CompiledValue::Addr(Ipv4Addr::LOCALHOST.into());
1676 assert_eq!(value_round_trip(&v4), v4);
1677 let v6 = CompiledValue::Addr(Ipv6Addr::LOCALHOST.into());
1678 assert_eq!(value_round_trip(&v6), v6);
1679 }
1680
1681 #[test]
1682 fn compiled_value_bytes_emits_standard_base64_literal() {
1683 let v = CompiledValue::Bytes(Bytes::from_static(b"hello"));
1687 let encoded = serde_json::to_string(&v).expect("serialize");
1688 assert_eq!(encoded, r#"{"bytes":"aGVsbG8="}"#);
1689 }
1690
1691 fn op_round_trip(op: &CompiledOperator) -> CompiledOperator {
1692 let encoded = serde_json::to_string(op).expect("serialize op");
1693 serde_json::from_str(&encoded).expect("deserialize op")
1694 }
1695
1696 #[test]
1697 fn compiled_operator_equals_and_not_equals_round_trip() {
1698 let eq = CompiledOperator::Equals(CompiledValue::Str(Arc::<str>::from("x")));
1699 assert_eq!(op_round_trip(&eq), eq);
1700 let neq = CompiledOperator::NotEquals(CompiledValue::Str(Arc::<str>::from("x")));
1701 assert_eq!(op_round_trip(&neq), neq);
1702 }
1703
1704 #[test]
1705 fn compiled_operator_bytes_variants_round_trip() {
1706 let payload = Bytes::from_static(b"hello");
1707 let ops = [
1708 CompiledOperator::Contains(payload.clone()),
1709 CompiledOperator::NotContains(payload.clone()),
1710 CompiledOperator::Prefix(payload.clone()),
1711 CompiledOperator::Suffix(payload),
1712 ];
1713 for op in ops {
1714 assert_eq!(op_round_trip(&op), op);
1715 }
1716 }
1717
1718 #[test]
1719 fn compiled_operator_matches_round_trip_preserves_pattern_source() {
1720 let op = CompiledOperator::Matches(Regex::new("^/api/v[0-9]+").expect("compile"));
1721 let decoded = op_round_trip(&op);
1722 assert_eq!(decoded, op);
1724 match decoded {
1725 CompiledOperator::Matches(r) => assert_eq!(r.as_str(), "^/api/v[0-9]+"),
1726 other => panic!("expected matches, got {other:?}"),
1727 }
1728 }
1729
1730 #[test]
1731 fn compiled_operator_in_and_not_in_round_trip_mixed_values() {
1732 let xs = vec![CompiledValue::Str(Arc::<str>::from("a")), CompiledValue::Int(42)];
1733 let in_op = CompiledOperator::In(xs.clone());
1734 assert_eq!(op_round_trip(&in_op), in_op);
1735 let not_in_op = CompiledOperator::NotIn(xs);
1736 assert_eq!(op_round_trip(¬_in_op), not_in_op);
1737 }
1738
1739 #[test]
1740 fn compiled_operator_numeric_comparisons_round_trip() {
1741 let ops = [
1742 CompiledOperator::Gt(100),
1743 CompiledOperator::Gte(100),
1744 CompiledOperator::Lt(100),
1745 CompiledOperator::Lte(100),
1746 ];
1747 for op in ops {
1748 assert_eq!(op_round_trip(&op), op);
1749 }
1750 }
1751
1752 #[test]
1753 fn compiled_operator_cidr_round_trip_preserves_canonical_form() {
1754 let op = CompiledOperator::Cidr(IpNet::from_str("10.0.0.0/8").expect("parse"));
1755 assert_eq!(op_round_trip(&op), op);
1756 }
1757
1758 #[test]
1759 fn compiled_operator_matches_with_invalid_regex_is_rejected() {
1760 let raw = r#"{"matches":"["}"#;
1764 let err = serde_json::from_str::<CompiledOperator>(raw)
1765 .expect_err("invalid regex must fail to deserialize");
1766 let msg = err.to_string();
1767 assert!(msg.contains('['), "error mentions offending regex source: {msg}");
1768 }
1769
1770 #[test]
1771 fn predicate_inst_pins_exact_wire_shape_for_http_header_equals() {
1772 let inst = PredicateInst {
1773 path: FieldPath::HttpHeader(Arc::from("host")),
1774 op: CompiledOperator::Equals(CompiledValue::Str(Arc::<str>::from("example.com"))),
1775 };
1776 let encoded = serde_json::to_string(&inst).expect("serialize");
1777 assert_eq!(encoded, r#"{"path":{"http_header":"host"},"op":{"equals":{"str":"example.com"}}}"#,);
1778 let decoded: PredicateInst = serde_json::from_str(&encoded).expect("deserialize");
1779 assert_eq!(decoded, inst);
1780 }
1781
1782 #[test]
1783 fn predicate_inst_round_trip_with_regex_operator() {
1784 let inst = PredicateInst {
1785 path: FieldPath::HttpUriPath,
1786 op: CompiledOperator::Matches(Regex::new("^/api").expect("compile")),
1787 };
1788 let encoded = serde_json::to_string(&inst).expect("serialize");
1789 let decoded: PredicateInst = serde_json::from_str(&encoded).expect("deserialize");
1790 assert_eq!(decoded, inst);
1791 }
1792
1793 fn http_header_equals(name: &str, value: &str) -> PredicateInst {
1800 PredicateInst {
1801 path: FieldPath::HttpHeader(Arc::from(name)),
1802 op: CompiledOperator::Equals(CompiledValue::Str(Arc::<str>::from(value))),
1803 }
1804 }
1805
1806 fn http_uri_path_equals(value: &str) -> PredicateInst {
1807 PredicateInst {
1808 path: FieldPath::HttpUriPath,
1809 op: CompiledOperator::Equals(CompiledValue::Str(Arc::<str>::from(value))),
1810 }
1811 }
1812
1813 fn http_uri_path_prefix(value: &str) -> PredicateInst {
1814 PredicateInst {
1815 path: FieldPath::HttpUriPath,
1816 op: CompiledOperator::Prefix(Bytes::copy_from_slice(value.as_bytes())),
1817 }
1818 }
1819
1820 fn tls_sni_equals(value: &str) -> PredicateInst {
1821 PredicateInst {
1822 path: FieldPath::TlsSni,
1823 op: CompiledOperator::Equals(CompiledValue::Str(Arc::<str>::from(value))),
1824 }
1825 }
1826
1827 fn conn_with_sni(sni: &str) -> Arc<ConnContext> {
1828 let conn = make_conn();
1829 *conn.tls.lock() = Some(crate::conn_context::TlsInfo {
1830 sni: Some(sni.to_string()),
1831 alpn: None,
1832 version: None,
1833 peer_cert: None,
1834 zero_rtt_used: false,
1835 });
1836 conn
1837 }
1838
1839 fn req_with_header(name: &str, value: &str) -> Request {
1840 http::Request::builder()
1841 .method("GET")
1842 .uri("/")
1843 .header(name, value)
1844 .body(Body::Empty)
1845 .expect("build req")
1846 }
1847
1848 fn req_with_uri(uri: &str) -> Request {
1849 http::Request::builder().method("GET").uri(uri).body(Body::Empty).expect("build req")
1850 }
1851
1852 #[test]
1853 fn predicate_test_http_header_equals_matches_when_present_and_equal() {
1854 let conn = make_conn();
1855 let req = req_with_header("upgrade", "websocket");
1856 let view = PredicateView::L7Req { conn: &conn, req: &req };
1857 assert!(http_header_equals("upgrade", "websocket").test(&view));
1858 }
1859
1860 #[test]
1861 fn predicate_test_http_header_equals_misses_when_header_absent() {
1862 let conn = make_conn();
1863 let req = req_with_header("host", "example.com");
1864 let view = PredicateView::L7Req { conn: &conn, req: &req };
1865 assert!(!http_header_equals("upgrade", "websocket").test(&view));
1866 }
1867
1868 #[test]
1869 fn predicate_test_http_header_equals_value_is_case_sensitive() {
1870 let conn = make_conn();
1875 let req = req_with_header("upgrade", "WebSocket");
1876 let view = PredicateView::L7Req { conn: &conn, req: &req };
1877 assert!(!http_header_equals("upgrade", "websocket").test(&view));
1878 }
1879
1880 #[test]
1881 fn predicate_test_http_header_equals_name_lookup_is_case_insensitive() {
1882 let conn = make_conn();
1888 let req = req_with_header("Upgrade", "websocket");
1889 let view = PredicateView::L7Req { conn: &conn, req: &req };
1890 assert!(http_header_equals("upgrade", "websocket").test(&view));
1891 }
1892
1893 #[test]
1894 fn predicate_test_http_header_equals_misses_on_l4_view() {
1895 let conn = make_conn();
1899 let view = PredicateView::L4 { conn: &conn, peek: None };
1900 assert!(!http_header_equals("upgrade", "websocket").test(&view));
1901 }
1902
1903 #[test]
1904 fn predicate_test_http_uri_path_equals_matches_exact() {
1905 let conn = make_conn();
1906 let req = req_with_uri("/api/v1/users");
1907 let view = PredicateView::L7Req { conn: &conn, req: &req };
1908 assert!(http_uri_path_equals("/api/v1/users").test(&view));
1909 }
1910
1911 #[test]
1912 fn predicate_test_http_uri_path_equals_misses_on_substring() {
1913 let conn = make_conn();
1917 let req = req_with_uri("/api/v1/users");
1918 let view = PredicateView::L7Req { conn: &conn, req: &req };
1919 assert!(!http_uri_path_equals("/api").test(&view));
1920 }
1921
1922 #[test]
1923 fn predicate_test_http_uri_path_prefix_matches_when_path_starts_with() {
1924 let conn = make_conn();
1925 let req = req_with_uri("/api/v1/users");
1926 let view = PredicateView::L7Req { conn: &conn, req: &req };
1927 assert!(http_uri_path_prefix("/api").test(&view));
1928 }
1929
1930 #[test]
1931 fn predicate_test_http_uri_path_prefix_misses_when_no_prefix() {
1932 let conn = make_conn();
1933 let req = req_with_uri("/admin");
1934 let view = PredicateView::L7Req { conn: &conn, req: &req };
1935 assert!(!http_uri_path_prefix("/api").test(&view));
1936 }
1937
1938 #[test]
1939 fn predicate_test_tls_sni_equals_matches_when_set() {
1940 let conn = conn_with_sni("api.example.com");
1944 let req = req_with_uri("/");
1945 let view = PredicateView::L7Req { conn: &conn, req: &req };
1946 assert!(tls_sni_equals("api.example.com").test(&view));
1947 }
1948
1949 #[test]
1950 fn predicate_test_tls_sni_equals_misses_when_unset() {
1951 let conn = make_conn();
1954 let req = req_with_uri("/");
1955 let view = PredicateView::L7Req { conn: &conn, req: &req };
1956 assert!(!tls_sni_equals("api.example.com").test(&view));
1957 }
1958
1959 #[test]
1960 fn predicate_test_tls_sni_equals_works_in_l4_view_too() {
1961 let conn = conn_with_sni("api.example.com");
1967 let view = PredicateView::L4 { conn: &conn, peek: None };
1968 assert!(tls_sni_equals("api.example.com").test(&view));
1969 }
1970
1971 fn pred(path: FieldPath, op: CompiledOperator) -> PredicateInst {
1978 PredicateInst { path, op }
1979 }
1980
1981 fn str_val(s: &str) -> CompiledValue {
1982 CompiledValue::Str(Arc::<str>::from(s))
1983 }
1984
1985 fn bytes_val(b: &[u8]) -> CompiledValue {
1986 CompiledValue::Bytes(Bytes::copy_from_slice(b))
1987 }
1988
1989 fn b(b: &[u8]) -> Bytes {
1990 Bytes::copy_from_slice(b)
1991 }
1992
1993 fn make_conn_with(remote: &str, local: &str) -> Arc<ConnContext> {
1994 Arc::new(ConnContext {
1995 id: ConnId(1),
1996 remote: remote.parse().expect("parse remote"),
1997 local: local.parse().expect("parse local"),
1998 transport: Transport::Tcp,
1999 entered_at: Instant::now(),
2000 tls: Mutex::new(None),
2001 http_version: OnceLock::new(),
2002 user: Mutex::new(http::Extensions::new()),
2003 })
2004 }
2005
2006 fn make_conn_with_transport(t: Transport) -> Arc<ConnContext> {
2007 Arc::new(ConnContext {
2008 id: ConnId(1),
2009 remote: "127.0.0.1:0".parse().expect("remote"),
2010 local: "127.0.0.1:0".parse().expect("local"),
2011 transport: t,
2012 entered_at: Instant::now(),
2013 tls: Mutex::new(None),
2014 http_version: OnceLock::new(),
2015 user: Mutex::new(http::Extensions::new()),
2016 })
2017 }
2018
2019 fn conn_with_tls_alpn(alpn: &[u8]) -> Arc<ConnContext> {
2020 let conn = make_conn();
2021 *conn.tls.lock() = Some(crate::conn_context::TlsInfo {
2022 sni: None,
2023 alpn: Some(alpn.to_vec()),
2024 version: None,
2025 peer_cert: None,
2026 zero_rtt_used: false,
2027 });
2028 conn
2029 }
2030
2031 fn conn_with_tls_version(v: crate::conn_context::TlsVersion) -> Arc<ConnContext> {
2032 let conn = make_conn();
2033 *conn.tls.lock() = Some(crate::conn_context::TlsInfo {
2034 sni: None,
2035 alpn: None,
2036 version: Some(v),
2037 peer_cert: None,
2038 zero_rtt_used: false,
2039 });
2040 conn
2041 }
2042
2043 #[test]
2045 fn matrix_equality_str_happy_and_miss() {
2046 let conn = conn_with_sni("api.example.com");
2048 let v = PredicateView::L4 { conn: &conn, peek: None };
2049 assert!(pred(FieldPath::TlsSni, CompiledOperator::Equals(str_val("api.example.com"))).test(&v));
2050 assert!(
2051 !pred(FieldPath::TlsSni, CompiledOperator::Equals(str_val("other.example.com"))).test(&v)
2052 );
2053 assert!(
2054 pred(FieldPath::TlsSni, CompiledOperator::NotEquals(str_val("other.example.com"))).test(&v)
2055 );
2056 assert!(
2057 !pred(FieldPath::TlsSni, CompiledOperator::NotEquals(str_val("api.example.com"))).test(&v)
2058 );
2059 }
2060
2061 #[test]
2062 fn matrix_equality_bytes_happy_and_miss() {
2063 let conn = conn_with_tls_alpn(b"h2");
2065 let v = PredicateView::L4 { conn: &conn, peek: None };
2066 assert!(pred(FieldPath::TlsAlpn, CompiledOperator::Equals(bytes_val(b"h2"))).test(&v));
2067 assert!(!pred(FieldPath::TlsAlpn, CompiledOperator::Equals(bytes_val(b"http/1.1"))).test(&v));
2068 assert!(pred(FieldPath::TlsAlpn, CompiledOperator::NotEquals(bytes_val(b"http/1.1"))).test(&v));
2069 assert!(!pred(FieldPath::TlsAlpn, CompiledOperator::NotEquals(bytes_val(b"h2"))).test(&v));
2070 }
2071
2072 #[test]
2073 fn matrix_equality_int_happy_and_miss() {
2074 let conn = make_conn_with("127.0.0.1:9090", "127.0.0.1:80");
2075 let v = PredicateView::L4 { conn: &conn, peek: None };
2076 assert!(
2077 pred(FieldPath::RemotePort, CompiledOperator::Equals(CompiledValue::Int(9090))).test(&v)
2078 );
2079 assert!(
2080 !pred(FieldPath::RemotePort, CompiledOperator::Equals(CompiledValue::Int(81))).test(&v)
2081 );
2082 assert!(
2083 pred(FieldPath::RemotePort, CompiledOperator::NotEquals(CompiledValue::Int(81))).test(&v)
2084 );
2085 assert!(
2086 !pred(FieldPath::RemotePort, CompiledOperator::NotEquals(CompiledValue::Int(9090))).test(&v)
2087 );
2088 }
2089
2090 #[test]
2091 fn matrix_equality_addr_happy_and_miss() {
2092 let conn = make_conn_with("10.0.0.5:55555", "127.0.0.1:80");
2093 let v = PredicateView::L4 { conn: &conn, peek: None };
2094 let ten: std::net::IpAddr = "10.0.0.5".parse().unwrap();
2095 let other: std::net::IpAddr = "10.0.0.6".parse().unwrap();
2096 assert!(pred(FieldPath::RemoteIp, CompiledOperator::Equals(CompiledValue::Addr(ten))).test(&v));
2097 assert!(
2098 !pred(FieldPath::RemoteIp, CompiledOperator::Equals(CompiledValue::Addr(other))).test(&v)
2099 );
2100 assert!(
2101 pred(FieldPath::RemoteIp, CompiledOperator::NotEquals(CompiledValue::Addr(other))).test(&v)
2102 );
2103 assert!(
2104 !pred(FieldPath::RemoteIp, CompiledOperator::NotEquals(CompiledValue::Addr(ten))).test(&v)
2105 );
2106 }
2107
2108 #[test]
2109 fn matrix_equality_enum_transport_happy_and_miss() {
2110 let tcp = make_conn_with_transport(Transport::Tcp);
2111 let udp = make_conn_with_transport(Transport::Udp);
2112 let v_tcp = PredicateView::L4 { conn: &tcp, peek: None };
2113 let v_udp = PredicateView::L4 { conn: &udp, peek: None };
2114 assert!(pred(FieldPath::Transport, CompiledOperator::Equals(str_val("tcp"))).test(&v_tcp));
2115 assert!(!pred(FieldPath::Transport, CompiledOperator::Equals(str_val("udp"))).test(&v_tcp));
2116 assert!(pred(FieldPath::Transport, CompiledOperator::Equals(str_val("udp"))).test(&v_udp));
2117 }
2118
2119 #[test]
2120 fn matrix_equality_enum_tls_version_happy_and_miss() {
2121 let conn = conn_with_tls_version(crate::conn_context::TlsVersion::Tls13);
2122 let v = PredicateView::L4 { conn: &conn, peek: None };
2123 assert!(pred(FieldPath::TlsVersion, CompiledOperator::Equals(str_val("1.3"))).test(&v));
2124 assert!(!pred(FieldPath::TlsVersion, CompiledOperator::Equals(str_val("1.2"))).test(&v));
2125 assert!(pred(FieldPath::TlsVersion, CompiledOperator::NotEquals(str_val("1.2"))).test(&v));
2126 }
2127
2128 #[test]
2129 fn matrix_equality_enum_tls_version_misses_when_absent() {
2130 let conn = make_conn();
2132 let v = PredicateView::L4 { conn: &conn, peek: None };
2133 assert!(!pred(FieldPath::TlsVersion, CompiledOperator::Equals(str_val("1.3"))).test(&v));
2134 assert!(!pred(FieldPath::TlsVersion, CompiledOperator::NotEquals(str_val("1.3"))).test(&v));
2136 }
2137
2138 #[test]
2139 fn matrix_equality_enum_http_method_happy_and_miss() {
2140 let conn = make_conn();
2141 let req = http::Request::builder().method("POST").uri("/").body(Body::Empty).unwrap();
2142 let v = PredicateView::L7Req { conn: &conn, req: &req };
2143 assert!(pred(FieldPath::HttpMethod, CompiledOperator::Equals(str_val("POST"))).test(&v));
2144 assert!(!pred(FieldPath::HttpMethod, CompiledOperator::Equals(str_val("GET"))).test(&v));
2145 assert!(pred(FieldPath::HttpMethod, CompiledOperator::NotEquals(str_val("GET"))).test(&v));
2146 }
2147
2148 #[test]
2150 fn matrix_in_list_str_happy_and_miss() {
2151 let conn = conn_with_sni("api.example.com");
2152 let v = PredicateView::L4 { conn: &conn, peek: None };
2153 let list = vec![str_val("a.example.com"), str_val("api.example.com")];
2154 assert!(pred(FieldPath::TlsSni, CompiledOperator::In(list.clone())).test(&v));
2155 let list_miss = vec![str_val("a.example.com"), str_val("b.example.com")];
2156 assert!(!pred(FieldPath::TlsSni, CompiledOperator::In(list_miss.clone())).test(&v));
2157 assert!(pred(FieldPath::TlsSni, CompiledOperator::NotIn(list_miss)).test(&v));
2158 assert!(!pred(FieldPath::TlsSni, CompiledOperator::NotIn(list)).test(&v));
2159 }
2160
2161 #[test]
2162 fn matrix_in_list_bytes_happy_and_miss() {
2163 let conn = conn_with_tls_alpn(b"h2");
2164 let v = PredicateView::L4 { conn: &conn, peek: None };
2165 let list = vec![bytes_val(b"http/1.1"), bytes_val(b"h2")];
2166 assert!(pred(FieldPath::TlsAlpn, CompiledOperator::In(list.clone())).test(&v));
2167 let list_miss = vec![bytes_val(b"http/1.0"), bytes_val(b"http/1.1")];
2168 assert!(!pred(FieldPath::TlsAlpn, CompiledOperator::In(list_miss.clone())).test(&v));
2169 assert!(pred(FieldPath::TlsAlpn, CompiledOperator::NotIn(list_miss)).test(&v));
2170 }
2171
2172 #[test]
2173 fn matrix_in_list_int_happy_and_miss() {
2174 let conn = make_conn_with("127.0.0.1:443", "127.0.0.1:80");
2175 let v = PredicateView::L4 { conn: &conn, peek: None };
2176 let in_list = vec![CompiledValue::Int(80), CompiledValue::Int(443)];
2177 assert!(pred(FieldPath::RemotePort, CompiledOperator::In(in_list.clone())).test(&v));
2178 let miss_list = vec![CompiledValue::Int(80), CompiledValue::Int(81)];
2179 assert!(!pred(FieldPath::RemotePort, CompiledOperator::In(miss_list.clone())).test(&v));
2180 assert!(pred(FieldPath::RemotePort, CompiledOperator::NotIn(miss_list)).test(&v));
2181 }
2182
2183 #[test]
2184 fn matrix_in_list_addr_happy_and_miss_mixed_family() {
2185 let conn = make_conn_with("10.0.0.5:55555", "127.0.0.1:80");
2186 let v = PredicateView::L4 { conn: &conn, peek: None };
2187 let v4: std::net::IpAddr = "10.0.0.5".parse().unwrap();
2188 let v6: std::net::IpAddr = "::1".parse().unwrap();
2189 let list = vec![CompiledValue::Addr(v6), CompiledValue::Addr(v4)];
2190 assert!(pred(FieldPath::RemoteIp, CompiledOperator::In(list.clone())).test(&v));
2191 let miss = vec![CompiledValue::Addr(v6)];
2192 assert!(!pred(FieldPath::RemoteIp, CompiledOperator::In(miss.clone())).test(&v));
2193 assert!(pred(FieldPath::RemoteIp, CompiledOperator::NotIn(miss)).test(&v));
2194 }
2195
2196 #[test]
2197 fn matrix_in_list_enum_transport_happy_and_miss() {
2198 let conn = make_conn_with_transport(Transport::Udp);
2199 let v = PredicateView::L4 { conn: &conn, peek: None };
2200 let list = vec![str_val("tcp"), str_val("udp")];
2201 assert!(pred(FieldPath::Transport, CompiledOperator::In(list)).test(&v));
2202 let miss = vec![str_val("tcp")];
2203 assert!(!pred(FieldPath::Transport, CompiledOperator::In(miss.clone())).test(&v));
2204 assert!(pred(FieldPath::Transport, CompiledOperator::NotIn(miss)).test(&v));
2205 }
2206
2207 #[test]
2209 fn matrix_substring_on_str_happy_and_miss() {
2210 let conn = make_conn();
2211 let req =
2212 http::Request::builder().method("GET").uri("/api/v1/users").body(Body::Empty).unwrap();
2213 let v = PredicateView::L7Req { conn: &conn, req: &req };
2214 assert!(pred(FieldPath::HttpUriPath, CompiledOperator::Contains(b(b"/v1/"))).test(&v));
2215 assert!(!pred(FieldPath::HttpUriPath, CompiledOperator::Contains(b(b"/v2/"))).test(&v));
2216 assert!(pred(FieldPath::HttpUriPath, CompiledOperator::NotContains(b(b"/v2/"))).test(&v));
2217 assert!(!pred(FieldPath::HttpUriPath, CompiledOperator::NotContains(b(b"/v1/"))).test(&v));
2218 }
2219
2220 #[test]
2221 fn matrix_substring_on_bytes_happy_and_miss() {
2222 let conn = conn_with_tls_alpn(b"http/1.1");
2223 let v = PredicateView::L4 { conn: &conn, peek: None };
2224 assert!(pred(FieldPath::TlsAlpn, CompiledOperator::Contains(b(b"/1."))).test(&v));
2225 assert!(!pred(FieldPath::TlsAlpn, CompiledOperator::Contains(b(b"/2."))).test(&v));
2226 assert!(pred(FieldPath::TlsAlpn, CompiledOperator::NotContains(b(b"/2."))).test(&v));
2227 }
2228
2229 #[test]
2231 fn matrix_prefix_suffix_on_str_happy_and_miss() {
2232 let conn = make_conn();
2233 let req =
2234 http::Request::builder().method("GET").uri("/api/file.json?q=1").body(Body::Empty).unwrap();
2235 let v = PredicateView::L7Req { conn: &conn, req: &req };
2236 assert!(pred(FieldPath::HttpUriPath, CompiledOperator::Prefix(b(b"/api"))).test(&v));
2237 assert!(!pred(FieldPath::HttpUriPath, CompiledOperator::Prefix(b(b"/admin"))).test(&v));
2238 assert!(pred(FieldPath::HttpUriPath, CompiledOperator::Suffix(b(b".json"))).test(&v));
2239 assert!(!pred(FieldPath::HttpUriPath, CompiledOperator::Suffix(b(b".html"))).test(&v));
2240 }
2241
2242 #[test]
2243 fn matrix_prefix_suffix_on_bytes_happy_and_miss() {
2244 let conn = conn_with_tls_alpn(b"http/1.1");
2245 let v = PredicateView::L4 { conn: &conn, peek: None };
2246 assert!(pred(FieldPath::TlsAlpn, CompiledOperator::Prefix(b(b"http"))).test(&v));
2247 assert!(!pred(FieldPath::TlsAlpn, CompiledOperator::Prefix(b(b"h2"))).test(&v));
2248 assert!(pred(FieldPath::TlsAlpn, CompiledOperator::Suffix(b(b"1.1"))).test(&v));
2249 assert!(!pred(FieldPath::TlsAlpn, CompiledOperator::Suffix(b(b"2.0"))).test(&v));
2250 }
2251
2252 #[test]
2254 fn matrix_regex_matches_on_str_happy_and_miss() {
2255 let conn = make_conn();
2256 let req =
2257 http::Request::builder().method("GET").uri("/api/v3/orders").body(Body::Empty).unwrap();
2258 let v = PredicateView::L7Req { conn: &conn, req: &req };
2259 let re = Regex::new(r"^/api/v\d+/orders").expect("compile regex");
2260 assert!(pred(FieldPath::HttpUriPath, CompiledOperator::Matches(re)).test(&v));
2261 let re_miss = Regex::new(r"^/admin").expect("compile regex");
2262 assert!(!pred(FieldPath::HttpUriPath, CompiledOperator::Matches(re_miss)).test(&v));
2263 }
2264
2265 #[test]
2266 fn matrix_regex_matches_on_header_happy_and_miss() {
2267 let conn = make_conn();
2268 let req = http::Request::builder()
2269 .method("GET")
2270 .uri("/")
2271 .header("user-agent", "Mozilla/5.0 (Macintosh; Intel)")
2272 .body(Body::Empty)
2273 .unwrap();
2274 let v = PredicateView::L7Req { conn: &conn, req: &req };
2275 let re = Regex::new(r"(?i)mozilla").expect("compile");
2276 assert!(
2277 pred(FieldPath::HttpHeader(Arc::from("user-agent")), CompiledOperator::Matches(re)).test(&v)
2278 );
2279 let re_miss = Regex::new(r"^curl").expect("compile");
2280 assert!(
2281 !pred(FieldPath::HttpHeader(Arc::from("user-agent")), CompiledOperator::Matches(re_miss))
2282 .test(&v)
2283 );
2284 }
2285
2286 #[test]
2288 fn matrix_numeric_cmp_gt_gte_lt_lte_happy_and_miss() {
2289 let conn = make_conn_with("127.0.0.1:1024", "127.0.0.1:443");
2290 let v = PredicateView::L4 { conn: &conn, peek: None };
2291 assert!(pred(FieldPath::RemotePort, CompiledOperator::Gt(1023)).test(&v));
2293 assert!(!pred(FieldPath::RemotePort, CompiledOperator::Gt(1024)).test(&v));
2294 assert!(pred(FieldPath::RemotePort, CompiledOperator::Gte(1024)).test(&v));
2296 assert!(!pred(FieldPath::RemotePort, CompiledOperator::Gte(1025)).test(&v));
2297 assert!(pred(FieldPath::RemotePort, CompiledOperator::Lt(1025)).test(&v));
2299 assert!(!pred(FieldPath::RemotePort, CompiledOperator::Lt(1024)).test(&v));
2300 assert!(pred(FieldPath::RemotePort, CompiledOperator::Lte(1024)).test(&v));
2302 assert!(!pred(FieldPath::RemotePort, CompiledOperator::Lte(1023)).test(&v));
2303 }
2304
2305 #[test]
2306 fn matrix_numeric_cmp_local_port_too() {
2307 let conn = make_conn_with("127.0.0.1:0", "127.0.0.1:8443");
2309 let v = PredicateView::L4 { conn: &conn, peek: None };
2310 assert!(pred(FieldPath::LocalPort, CompiledOperator::Gt(8000)).test(&v));
2311 assert!(!pred(FieldPath::LocalPort, CompiledOperator::Gt(9000)).test(&v));
2312 }
2313
2314 #[test]
2316 fn matrix_cidr_v4_happy_and_miss() {
2317 let conn = make_conn_with("10.0.5.7:0", "127.0.0.1:0");
2318 let v = PredicateView::L4 { conn: &conn, peek: None };
2319 let ten = IpNet::from_str("10.0.0.0/8").unwrap();
2320 let nineteen2 = IpNet::from_str("192.168.0.0/16").unwrap();
2321 assert!(pred(FieldPath::RemoteIp, CompiledOperator::Cidr(ten)).test(&v));
2322 assert!(!pred(FieldPath::RemoteIp, CompiledOperator::Cidr(nineteen2)).test(&v));
2323 }
2324
2325 #[test]
2326 fn matrix_cidr_v6_happy_and_miss() {
2327 let conn = make_conn_with("[2001:db8::5]:0", "127.0.0.1:0");
2328 let v = PredicateView::L4 { conn: &conn, peek: None };
2329 let net = IpNet::from_str("2001:db8::/32").unwrap();
2330 let other = IpNet::from_str("2001:dead::/32").unwrap();
2331 assert!(pred(FieldPath::RemoteIp, CompiledOperator::Cidr(net)).test(&v));
2332 assert!(!pred(FieldPath::RemoteIp, CompiledOperator::Cidr(other)).test(&v));
2333 }
2334
2335 #[test]
2336 fn matrix_cidr_v4_against_v6_addr_misses() {
2337 let conn = make_conn_with("[2001:db8::5]:0", "127.0.0.1:0");
2339 let v = PredicateView::L4 { conn: &conn, peek: None };
2340 let v4 = IpNet::from_str("0.0.0.0/0").unwrap();
2341 assert!(!pred(FieldPath::RemoteIp, CompiledOperator::Cidr(v4)).test(&v));
2342 }
2343
2344 #[test]
2348 fn http_uri_query_reader_returns_empty_when_query_absent() {
2349 let conn = make_conn();
2352 let req = http::Request::builder().method("GET").uri("/no-q").body(Body::Empty).unwrap();
2353 let v = PredicateView::L7Req { conn: &conn, req: &req };
2354 assert!(pred(FieldPath::HttpUriQuery, CompiledOperator::Equals(str_val(""))).test(&v));
2355 assert!(!pred(FieldPath::HttpUriQuery, CompiledOperator::Equals(str_val("q=1"))).test(&v));
2356 }
2357
2358 #[test]
2359 fn http_uri_query_reader_matches_present_query() {
2360 let conn = make_conn();
2361 let req = http::Request::builder().method("GET").uri("/x?a=1&b=2").body(Body::Empty).unwrap();
2362 let v = PredicateView::L7Req { conn: &conn, req: &req };
2363 assert!(pred(FieldPath::HttpUriQuery, CompiledOperator::Equals(str_val("a=1&b=2"))).test(&v));
2364 assert!(pred(FieldPath::HttpUriQuery, CompiledOperator::Contains(b(b"b=2"))).test(&v));
2365 }
2366
2367 #[test]
2368 fn local_ip_reader_uses_local_socket() {
2369 let conn = make_conn_with("10.0.0.5:0", "127.0.0.1:8443");
2370 let v = PredicateView::L4 { conn: &conn, peek: None };
2371 let local: std::net::IpAddr = "127.0.0.1".parse().unwrap();
2372 assert!(
2373 pred(FieldPath::LocalIp, CompiledOperator::Equals(CompiledValue::Addr(local))).test(&v)
2374 );
2375 }
2376
2377 #[test]
2378 fn http_header_lookup_misses_for_non_utf8_value() {
2379 let conn = make_conn();
2382 let bad =
2383 http::HeaderValue::from_bytes(&[0xff, 0xfe, 0xfd]).expect("non-utf8 header value parses");
2384 let mut builder = http::Request::builder().method("GET").uri("/");
2385 builder.headers_mut().expect("headers").insert("x-bad", bad);
2386 let req: Request = builder.body(Body::Empty).expect("build request");
2387 let v = PredicateView::L7Req { conn: &conn, req: &req };
2388 assert!(
2389 !pred(
2390 FieldPath::HttpHeader(Arc::from("x-bad")),
2391 CompiledOperator::Equals(str_val("anything")),
2392 )
2393 .test(&v)
2394 );
2395 }
2396
2397 fn rcgen_cert_with_cn(cn: &str) -> rustls_pki_types::CertificateDer<'static> {
2399 let mut params = rcgen::CertificateParams::default();
2400 params.distinguished_name = rcgen::DistinguishedName::new();
2401 params.distinguished_name.push(rcgen::DnType::CommonName, cn);
2402 let key = rcgen::KeyPair::generate().expect("rcgen keypair");
2403 let cert = params.self_signed(&key).expect("self-sign cert");
2404 cert.der().clone()
2405 }
2406
2407 fn rcgen_cert_no_cn() -> rustls_pki_types::CertificateDer<'static> {
2408 let params = rcgen::CertificateParams::default();
2411 let mut params = params;
2414 params.distinguished_name = rcgen::DistinguishedName::new();
2415 let key = rcgen::KeyPair::generate().expect("rcgen keypair");
2416 let cert = params.self_signed(&key).expect("self-sign cert");
2417 cert.der().clone()
2418 }
2419
2420 fn conn_with_peer_cert(cert: &rustls_pki_types::CertificateDer<'static>) -> Arc<ConnContext> {
2421 let pc = crate::conn_context::PeerCertificate::from_der(cert)
2422 .expect("rcgen-issued cert must parse via PeerCertificate::from_der");
2423 let conn = make_conn();
2424 *conn.tls.lock() = Some(crate::conn_context::TlsInfo {
2425 sni: None,
2426 alpn: None,
2427 version: None,
2428 peer_cert: Some(Arc::new(pc)),
2429 zero_rtt_used: false,
2430 });
2431 conn
2432 }
2433
2434 #[test]
2435 fn peer_cert_from_der_extracts_cn() {
2436 let cert = rcgen_cert_with_cn("client.internal");
2437 let pc = crate::conn_context::PeerCertificate::from_der(&cert).expect("parse");
2438 assert_eq!(pc.subject_cn.as_deref(), Some("client.internal"));
2439 }
2440
2441 #[test]
2442 fn peer_cert_from_der_returns_none_for_malformed_der() {
2443 let raw = rustls_pki_types::CertificateDer::from(vec![0x30, 0x80, 0x00, 0x00]);
2444 assert!(crate::conn_context::PeerCertificate::from_der(&raw).is_none());
2445 let raw = rustls_pki_types::CertificateDer::from(b"not a cert at all".to_vec());
2446 assert!(crate::conn_context::PeerCertificate::from_der(&raw).is_none());
2447 }
2448
2449 #[test]
2450 fn peer_cert_from_der_returns_some_with_no_cn_when_dn_has_no_cn() {
2451 let cert = rcgen_cert_no_cn();
2453 let pc = crate::conn_context::PeerCertificate::from_der(&cert).expect("parse");
2454 assert!(pc.subject_cn.is_none());
2455 }
2456
2457 #[test]
2458 fn matrix_peer_cert_subject_cn_equals_happy_and_miss() {
2459 let cert = rcgen_cert_with_cn("ops-bot");
2460 let conn = conn_with_peer_cert(&cert);
2461 let v = PredicateView::L4 { conn: &conn, peek: None };
2462 assert!(
2463 pred(FieldPath::TlsPeerCertSubjectCn, CompiledOperator::Equals(str_val("ops-bot"))).test(&v)
2464 );
2465 assert!(
2466 !pred(FieldPath::TlsPeerCertSubjectCn, CompiledOperator::Equals(str_val("attacker")))
2467 .test(&v)
2468 );
2469 }
2470
2471 #[test]
2472 fn matrix_peer_cert_subject_cn_string_ops_happy_and_miss() {
2473 let cert = rcgen_cert_with_cn("svc-payments-prod");
2474 let conn = conn_with_peer_cert(&cert);
2475 let v = PredicateView::L4 { conn: &conn, peek: None };
2476 assert!(pred(FieldPath::TlsPeerCertSubjectCn, CompiledOperator::Prefix(b(b"svc-"))).test(&v));
2478 assert!(
2479 !pred(FieldPath::TlsPeerCertSubjectCn, CompiledOperator::Prefix(b(b"client-"))).test(&v)
2480 );
2481 assert!(pred(FieldPath::TlsPeerCertSubjectCn, CompiledOperator::Suffix(b(b"-prod"))).test(&v));
2483 assert!(
2485 pred(FieldPath::TlsPeerCertSubjectCn, CompiledOperator::Contains(b(b"payments"))).test(&v)
2486 );
2487 let re = Regex::new(r"^svc-[a-z]+-(prod|stg)$").expect("regex");
2489 assert!(pred(FieldPath::TlsPeerCertSubjectCn, CompiledOperator::Matches(re)).test(&v));
2490 let list = vec![str_val("svc-other-prod"), str_val("svc-payments-prod")];
2492 assert!(pred(FieldPath::TlsPeerCertSubjectCn, CompiledOperator::In(list)).test(&v));
2493 }
2494
2495 #[test]
2496 fn peer_cert_subject_cn_misses_when_cert_absent() {
2497 let conn = make_conn();
2500 let v = PredicateView::L4 { conn: &conn, peek: None };
2501 assert!(
2502 !pred(FieldPath::TlsPeerCertSubjectCn, CompiledOperator::Equals(str_val("anything")))
2503 .test(&v)
2504 );
2505 }
2506
2507 #[test]
2508 fn peer_cert_subject_cn_misses_when_cert_has_no_cn() {
2509 let cert = rcgen_cert_no_cn();
2512 let conn = conn_with_peer_cert(&cert);
2513 let v = PredicateView::L4 { conn: &conn, peek: None };
2514 assert!(
2515 !pred(FieldPath::TlsPeerCertSubjectCn, CompiledOperator::Equals(str_val("ops-bot"))).test(&v)
2516 );
2517 }
2518
2519 fn rcgen_cert_with_san_dns(cn: &str, dns: &[&str]) -> rustls_pki_types::CertificateDer<'static> {
2521 let san: Vec<String> = dns.iter().map(|s| (*s).to_owned()).collect();
2522 let mut params = rcgen::CertificateParams::new(san).expect("rcgen params");
2523 params.distinguished_name = rcgen::DistinguishedName::new();
2524 params.distinguished_name.push(rcgen::DnType::CommonName, cn);
2525 let key = rcgen::KeyPair::generate().expect("rcgen keypair");
2526 let cert = params.self_signed(&key).expect("self-sign cert");
2527 cert.der().clone()
2528 }
2529
2530 #[test]
2531 fn each_new_field_path_parses_from_string_form() {
2532 use super::parse_field_path;
2533 assert_eq!(parse_field_path("tls.peer_cert.present"), Ok(FieldPath::TlsPeerCertPresent));
2534 assert_eq!(parse_field_path("tls.peer_cert.san_dns"), Ok(FieldPath::TlsPeerCertSanDns));
2535 assert_eq!(
2536 parse_field_path("tls.peer_cert.fingerprint_sha256"),
2537 Ok(FieldPath::TlsPeerCertFingerprintSha256),
2538 );
2539 assert_eq!(parse_field_path("tls.peer_cert.spki_sha256"), Ok(FieldPath::TlsPeerCertSpkiSha256),);
2540 assert_eq!(parse_field_path("tls.peer_cert.issuer_cn"), Ok(FieldPath::TlsPeerCertIssuerCn));
2541 assert_eq!(parse_field_path("tls.peer_cert.serial"), Ok(FieldPath::TlsPeerCertSerial));
2542 }
2543
2544 #[test]
2545 fn peer_cert_present_true_when_cert_attached() {
2546 let cert = rcgen_cert_with_cn("client.internal");
2547 let conn = conn_with_peer_cert(&cert);
2548 let v = PredicateView::L4 { conn: &conn, peek: None };
2549 assert!(
2550 pred(FieldPath::TlsPeerCertPresent, CompiledOperator::Equals(CompiledValue::Bool(true)))
2551 .test(&v)
2552 );
2553 assert!(
2554 !pred(FieldPath::TlsPeerCertPresent, CompiledOperator::Equals(CompiledValue::Bool(false)))
2555 .test(&v)
2556 );
2557 }
2558
2559 #[test]
2560 fn peer_cert_present_false_when_cert_absent() {
2561 let conn = make_conn();
2564 let v = PredicateView::L4 { conn: &conn, peek: None };
2565 assert!(
2566 pred(FieldPath::TlsPeerCertPresent, CompiledOperator::Equals(CompiledValue::Bool(false)))
2567 .test(&v)
2568 );
2569 assert!(
2570 !pred(FieldPath::TlsPeerCertPresent, CompiledOperator::Equals(CompiledValue::Bool(true)))
2571 .test(&v)
2572 );
2573 }
2574
2575 #[test]
2576 fn peer_cert_san_dns_contains_matches_listed_element() {
2577 let cert = rcgen_cert_with_san_dns("svc-a", &["svc-a.internal", "svc-b.internal"]);
2578 let conn = conn_with_peer_cert(&cert);
2579 let v = PredicateView::L4 { conn: &conn, peek: None };
2580 assert!(
2581 pred(FieldPath::TlsPeerCertSanDns, CompiledOperator::Contains(b(b"svc-a.internal"))).test(&v)
2582 );
2583 assert!(
2584 !pred(FieldPath::TlsPeerCertSanDns, CompiledOperator::Contains(b(b"svc-c.internal")))
2585 .test(&v),
2586 );
2587 assert!(
2588 pred(FieldPath::TlsPeerCertSanDns, CompiledOperator::NotContains(b(b"svc-c.internal")))
2589 .test(&v),
2590 );
2591 }
2592
2593 #[test]
2594 fn peer_cert_san_dns_misses_when_cert_absent() {
2595 let conn = make_conn();
2596 let v = PredicateView::L4 { conn: &conn, peek: None };
2597 assert!(
2598 !pred(FieldPath::TlsPeerCertSanDns, CompiledOperator::Contains(b(b"anything"))).test(&v)
2599 );
2600 }
2601
2602 #[test]
2603 fn peer_cert_fingerprint_sha256_is_lowercase_hex_of_full_der() {
2604 use sha2::{Digest, Sha256};
2605 let cert = rcgen_cert_with_cn("fingerprinted");
2606 let mut h = Sha256::new();
2607 h.update(cert.as_ref());
2608 let want = h.finalize().iter().fold(String::new(), |mut s, b| {
2609 use std::fmt::Write as _;
2610 let _ = write!(s, "{b:02x}");
2611 s
2612 });
2613
2614 let conn = conn_with_peer_cert(&cert);
2615 let v = PredicateView::L4 { conn: &conn, peek: None };
2616 assert!(
2617 pred(FieldPath::TlsPeerCertFingerprintSha256, CompiledOperator::Equals(str_val(&want)),)
2618 .test(&v),
2619 );
2620 }
2621
2622 #[test]
2623 fn peer_cert_issuer_and_serial_present_for_self_signed_cert() {
2624 let cert = rcgen_cert_with_cn("issuer-test");
2627 let conn = conn_with_peer_cert(&cert);
2628 let v = PredicateView::L4 { conn: &conn, peek: None };
2629 assert!(
2631 pred(FieldPath::TlsPeerCertIssuerCn, CompiledOperator::Equals(str_val("issuer-test")))
2632 .test(&v)
2633 );
2634 let pc = conn.tls.lock().as_ref().unwrap().peer_cert.as_ref().unwrap().clone();
2638 assert!(!pc.serial.is_empty(), "serial extracted");
2639 assert!(pc.serial.chars().all(|c| c.is_ascii_hexdigit() && !c.is_ascii_uppercase()));
2640 }
2641
2642 #[test]
2643 fn peer_cert_present_value_type_is_bool() {
2644 assert_eq!(FieldPath::TlsPeerCertPresent.value_type(), FieldValueType::Bool);
2645 }
2646
2647 #[test]
2648 fn peer_cert_san_dns_value_type_is_vec_str() {
2649 assert_eq!(FieldPath::TlsPeerCertSanDns.value_type(), FieldValueType::VecStr);
2650 }
2651
2652 #[test]
2653 fn matrix_rejects_string_pref_suf_on_bool_field() {
2654 assert!(!OperatorFamily::StringPrefSuf.accepts(FieldValueType::Bool));
2657 assert!(!OperatorFamily::StringSubstr.accepts(FieldValueType::Bool));
2658 assert!(!OperatorFamily::RegexMatches.accepts(FieldValueType::Bool));
2659 assert!(OperatorFamily::Equality.accepts(FieldValueType::Bool));
2661 }
2662
2663 #[test]
2664 fn matrix_rejects_equals_on_vec_str_field() {
2665 assert!(!OperatorFamily::Equality.accepts(FieldValueType::VecStr));
2668 assert!(!OperatorFamily::InList.accepts(FieldValueType::VecStr));
2669 assert!(!OperatorFamily::StringPrefSuf.accepts(FieldValueType::VecStr));
2670 assert!(!OperatorFamily::RegexMatches.accepts(FieldValueType::VecStr));
2671 assert!(OperatorFamily::StringSubstr.accepts(FieldValueType::VecStr));
2672 }
2673
2674 fn req_with_body(body_bytes: &[u8]) -> Request {
2681 http::Request::builder()
2682 .method("POST")
2683 .uri("/upload")
2684 .body(Body::Static(Bytes::copy_from_slice(body_bytes)))
2685 .expect("build req with body")
2686 }
2687
2688 #[test]
2689 fn matrix_http_body_equality_happy_and_miss() {
2690 let conn = make_conn();
2691 let req = req_with_body(b"hello world");
2692 let v = PredicateView::L7Req { conn: &conn, req: &req };
2693 assert!(
2694 pred(FieldPath::HttpBody, CompiledOperator::Equals(bytes_val(b"hello world"))).test(&v)
2695 );
2696 assert!(!pred(FieldPath::HttpBody, CompiledOperator::Equals(bytes_val(b"wrong"))).test(&v));
2697 assert!(pred(FieldPath::HttpBody, CompiledOperator::NotEquals(bytes_val(b"wrong"))).test(&v));
2698 }
2699
2700 #[test]
2701 fn matrix_http_body_substring_happy_and_miss() {
2702 let conn = make_conn();
2703 let req = req_with_body(b"prelude payload trailer");
2704 let v = PredicateView::L7Req { conn: &conn, req: &req };
2705 assert!(pred(FieldPath::HttpBody, CompiledOperator::Contains(b(b"payload"))).test(&v));
2706 assert!(!pred(FieldPath::HttpBody, CompiledOperator::Contains(b(b"missing"))).test(&v));
2707 assert!(pred(FieldPath::HttpBody, CompiledOperator::NotContains(b(b"missing"))).test(&v));
2708 }
2709
2710 #[test]
2711 fn matrix_http_body_prefix_suffix_happy_and_miss() {
2712 let conn = make_conn();
2713 let req = req_with_body(b"START middle END");
2714 let v = PredicateView::L7Req { conn: &conn, req: &req };
2715 assert!(pred(FieldPath::HttpBody, CompiledOperator::Prefix(b(b"START"))).test(&v));
2716 assert!(!pred(FieldPath::HttpBody, CompiledOperator::Prefix(b(b"BEGIN"))).test(&v));
2717 assert!(pred(FieldPath::HttpBody, CompiledOperator::Suffix(b(b"END"))).test(&v));
2718 assert!(!pred(FieldPath::HttpBody, CompiledOperator::Suffix(b(b"FIN"))).test(&v));
2719 }
2720
2721 #[test]
2722 fn matrix_http_body_in_list_happy_and_miss() {
2723 let conn = make_conn();
2724 let req = req_with_body(b"one");
2725 let v = PredicateView::L7Req { conn: &conn, req: &req };
2726 let list = vec![bytes_val(b"two"), bytes_val(b"one")];
2727 assert!(pred(FieldPath::HttpBody, CompiledOperator::In(list)).test(&v));
2728 let miss = vec![bytes_val(b"two"), bytes_val(b"three")];
2729 assert!(!pred(FieldPath::HttpBody, CompiledOperator::In(miss.clone())).test(&v));
2730 assert!(pred(FieldPath::HttpBody, CompiledOperator::NotIn(miss)).test(&v));
2731 }
2732
2733 #[test]
2734 fn http_body_misses_on_l4_view() {
2735 let conn = make_conn();
2738 let v = PredicateView::L4 { conn: &conn, peek: None };
2739 assert!(!pred(FieldPath::HttpBody, CompiledOperator::Contains(b(b"x"))).test(&v));
2740 }
2741
2742 #[test]
2743 #[should_panic(expected = "lazy-buffer invariant")]
2744 fn http_body_panics_when_lazy_buffer_invariant_violated() {
2745 let conn = make_conn();
2753 let req = http::Request::builder().method("POST").uri("/").body(Body::Empty).unwrap();
2754 let v = PredicateView::L7Req { conn: &conn, req: &req };
2755 let _ = pred(FieldPath::HttpBody, CompiledOperator::Contains(b(b"x"))).test(&v);
2756 }
2757
2758 #[test]
2766 fn matrix_peek_substring_happy_and_miss() {
2767 let buf: &[u8] = &[0x16, 0x03, 0x01, 0x00, 0x40, 0x01];
2769 let conn = make_conn();
2770 let v = PredicateView::L4 { conn: &conn, peek: Some(buf) };
2771 assert!(pred(FieldPath::Peek, CompiledOperator::Prefix(b(b"\x16\x03"))).test(&v));
2772 assert!(!pred(FieldPath::Peek, CompiledOperator::Prefix(b(b"\x14\x03"))).test(&v));
2773 assert!(pred(FieldPath::Peek, CompiledOperator::Contains(b(b"\x03\x01"))).test(&v));
2774 assert!(!pred(FieldPath::Peek, CompiledOperator::Contains(b(b"\xff\xff"))).test(&v));
2775 }
2776
2777 #[test]
2778 fn matrix_peek_equality_happy_and_miss() {
2779 let buf: &[u8] = b"GET";
2780 let conn = make_conn();
2781 let v = PredicateView::L4 { conn: &conn, peek: Some(buf) };
2782 assert!(pred(FieldPath::Peek, CompiledOperator::Equals(bytes_val(b"GET"))).test(&v));
2783 assert!(!pred(FieldPath::Peek, CompiledOperator::Equals(bytes_val(b"PUT"))).test(&v));
2784 assert!(pred(FieldPath::Peek, CompiledOperator::NotEquals(bytes_val(b"PUT"))).test(&v));
2785 }
2786
2787 #[test]
2788 fn matrix_peek_in_list_happy_and_miss() {
2789 let buf: &[u8] = b"PRI ";
2790 let conn = make_conn();
2791 let v = PredicateView::L4 { conn: &conn, peek: Some(buf) };
2792 let list = vec![bytes_val(b"GET "), bytes_val(b"PRI ")];
2794 assert!(pred(FieldPath::Peek, CompiledOperator::In(list)).test(&v));
2795 let miss = vec![bytes_val(b"POST"), bytes_val(b"HEAD")];
2796 assert!(!pred(FieldPath::Peek, CompiledOperator::In(miss.clone())).test(&v));
2797 assert!(pred(FieldPath::Peek, CompiledOperator::NotIn(miss)).test(&v));
2798 }
2799
2800 #[test]
2801 fn peek_misses_when_buffer_absent_on_l4_view() {
2802 let conn = make_conn();
2805 let v = PredicateView::L4 { conn: &conn, peek: None };
2806 assert!(!pred(FieldPath::Peek, CompiledOperator::Prefix(b(b"\x16"))).test(&v));
2807 let req = http::Request::builder().method("GET").uri("/").body(Body::Empty).unwrap();
2809 let v7 = PredicateView::L7Req { conn: &conn, req: &req };
2810 assert!(!pred(FieldPath::Peek, CompiledOperator::Prefix(b(b"\x16"))).test(&v7));
2811 }
2812}