1use std::hash::{Hash, Hasher};
2use std::net::IpAddr;
3use std::sync::Arc;
4
5use bytes::Bytes;
6use ipnet::IpNet;
7
8use crate::body::Request;
9use crate::conn_context::ConnContext;
10
11#[derive(Clone, Eq, PartialEq, Hash, Debug, serde::Serialize, serde::Deserialize)]
12#[serde(rename_all = "snake_case")]
13pub enum FieldPath {
14 Transport,
15 RemoteIp,
16 RemotePort,
17 LocalIp,
18 LocalPort,
19 Peek,
20 TlsSni,
21 TlsAlpn,
22 TlsVersion,
23 TlsPeerCertSubjectCn,
24 HttpMethod,
25 HttpUriPath,
26 HttpUriQuery,
27 HttpHeader(Arc<str>),
28 HttpBody,
29}
30
31#[derive(Copy, Clone, Eq, PartialEq, Debug)]
36pub enum FieldValueType {
37 Str,
38 Bytes,
39 Int,
40 IpAddr,
41 Enum,
42}
43
44impl FieldValueType {
45 #[must_use]
46 pub fn name(self) -> &'static str {
47 match self {
48 Self::Str => "Str",
49 Self::Bytes => "Bytes",
50 Self::Int => "Int",
51 Self::IpAddr => "IpAddr",
52 Self::Enum => "enum",
53 }
54 }
55}
56
57impl FieldPath {
58 #[must_use]
62 pub fn value_type(&self) -> FieldValueType {
63 match self {
64 Self::Transport | Self::TlsVersion | Self::HttpMethod => FieldValueType::Enum,
65 Self::RemoteIp | Self::LocalIp => FieldValueType::IpAddr,
66 Self::RemotePort | Self::LocalPort => FieldValueType::Int,
67 Self::Peek | Self::TlsAlpn | Self::HttpBody => FieldValueType::Bytes,
68 Self::TlsSni
69 | Self::TlsPeerCertSubjectCn
70 | Self::HttpUriPath
71 | Self::HttpUriQuery
72 | Self::HttpHeader(_) => FieldValueType::Str,
73 }
74 }
75
76 #[must_use]
78 pub fn display_name(&self) -> String {
79 match self {
80 Self::Transport => "transport".to_string(),
81 Self::RemoteIp => "remote.ip".to_string(),
82 Self::RemotePort => "remote.port".to_string(),
83 Self::LocalIp => "local.ip".to_string(),
84 Self::LocalPort => "local.port".to_string(),
85 Self::Peek => "peek".to_string(),
86 Self::TlsSni => "tls.sni".to_string(),
87 Self::TlsAlpn => "tls.alpn".to_string(),
88 Self::TlsVersion => "tls.version".to_string(),
89 Self::TlsPeerCertSubjectCn => "tls.peer_cert.subject_cn".to_string(),
90 Self::HttpMethod => "http.method".to_string(),
91 Self::HttpUriPath => "http.uri.path".to_string(),
92 Self::HttpUriQuery => "http.uri.query".to_string(),
93 Self::HttpHeader(name) => format!("http.header.{name}"),
94 Self::HttpBody => "http.body".to_string(),
95 }
96 }
97}
98
99#[derive(Copy, Clone, Eq, PartialEq, Debug)]
104pub enum OperatorFamily {
105 Equality,
106 StringSubstr,
107 StringPrefSuf,
108 RegexMatches,
109 InList,
110 NumericCmp,
111 CidrMatch,
112}
113
114impl Operator {
115 #[must_use]
116 pub fn family(&self) -> OperatorFamily {
117 match self {
118 Self::Equals(_) | Self::NotEquals(_) => OperatorFamily::Equality,
119 Self::Contains(_) | Self::NotContains(_) => OperatorFamily::StringSubstr,
120 Self::Prefix(_) | Self::Suffix(_) => OperatorFamily::StringPrefSuf,
121 Self::Matches(_) => OperatorFamily::RegexMatches,
122 Self::In(_) | Self::NotIn(_) => OperatorFamily::InList,
123 Self::Gt(_) | Self::Gte(_) | Self::Lt(_) | Self::Lte(_) => OperatorFamily::NumericCmp,
124 Self::Cidr(_) => OperatorFamily::CidrMatch,
125 }
126 }
127
128 #[must_use]
129 pub fn name(&self) -> &'static str {
130 match self {
131 Self::Equals(_) => "equals",
132 Self::NotEquals(_) => "not_equals",
133 Self::Contains(_) => "contains",
134 Self::NotContains(_) => "not_contains",
135 Self::Prefix(_) => "prefix",
136 Self::Suffix(_) => "suffix",
137 Self::Matches(_) => "matches",
138 Self::In(_) => "in",
139 Self::NotIn(_) => "not_in",
140 Self::Gt(_) => "gt",
141 Self::Gte(_) => "gte",
142 Self::Lt(_) => "lt",
143 Self::Lte(_) => "lte",
144 Self::Cidr(_) => "cidr",
145 }
146 }
147}
148
149impl OperatorFamily {
150 #[must_use]
155 pub fn accepts(self, vt: FieldValueType) -> bool {
156 use FieldValueType as V;
157 use OperatorFamily as F;
158 matches!(
159 (self, vt),
160 (F::Equality | F::InList, _)
161 | (F::StringSubstr | F::StringPrefSuf, V::Str | V::Bytes)
162 | (F::RegexMatches, V::Str)
163 | (F::NumericCmp, V::Int)
164 | (F::CidrMatch, V::IpAddr),
165 )
166 }
167
168 #[must_use]
170 pub fn family_expectation(self) -> &'static str {
171 match self {
172 Self::Equality | Self::InList => "any of Str/Bytes/Int/IpAddr/enum",
173 Self::StringSubstr | Self::StringPrefSuf => "Str or Bytes",
174 Self::RegexMatches => "Str",
175 Self::NumericCmp => "numeric",
176 Self::CidrMatch => "IpAddr",
177 }
178 }
179}
180
181#[derive(Clone, Debug)]
182pub enum CompiledValue {
183 Str(Arc<str>),
184 Bytes(Bytes),
185 Int(i64),
186 Bool(bool),
187 Addr(IpAddr),
188}
189
190impl PartialEq for CompiledValue {
191 fn eq(&self, other: &Self) -> bool {
192 match (self, other) {
193 (Self::Str(a), Self::Str(b)) => a.as_ref() == b.as_ref(),
194 (Self::Bytes(a), Self::Bytes(b)) => a == b,
195 (Self::Int(a), Self::Int(b)) => a == b,
196 (Self::Bool(a), Self::Bool(b)) => a == b,
197 (Self::Addr(a), Self::Addr(b)) => a == b,
198 _ => false,
199 }
200 }
201}
202
203impl Eq for CompiledValue {}
204
205impl Hash for CompiledValue {
206 fn hash<H: Hasher>(&self, state: &mut H) {
207 std::mem::discriminant(self).hash(state);
208 match self {
209 Self::Str(s) => s.as_ref().hash(state),
210 Self::Bytes(b) => b.hash(state),
211 Self::Int(i) => i.hash(state),
212 Self::Bool(b) => b.hash(state),
213 Self::Addr(a) => a.hash(state),
214 }
215 }
216}
217
218#[derive(Clone, Debug)]
219pub enum CompiledOperator {
220 Equals(CompiledValue),
221 NotEquals(CompiledValue),
222 Contains(Bytes),
223 NotContains(Bytes),
224 Prefix(Bytes),
225 Suffix(Bytes),
226 Matches(fancy_regex::Regex),
227 In(Vec<CompiledValue>),
228 NotIn(Vec<CompiledValue>),
229 Gt(i64),
230 Gte(i64),
231 Lt(i64),
232 Lte(i64),
233 Cidr(IpNet),
234}
235
236impl PartialEq for CompiledOperator {
237 fn eq(&self, other: &Self) -> bool {
238 match (self, other) {
239 (Self::Equals(a), Self::Equals(b)) | (Self::NotEquals(a), Self::NotEquals(b)) => a == b,
240 (Self::Contains(a), Self::Contains(b))
241 | (Self::NotContains(a), Self::NotContains(b))
242 | (Self::Prefix(a), Self::Prefix(b))
243 | (Self::Suffix(a), Self::Suffix(b)) => a == b,
244 (Self::Matches(a), Self::Matches(b)) => a.as_str() == b.as_str(),
245 (Self::In(a), Self::In(b)) | (Self::NotIn(a), Self::NotIn(b)) => a == b,
246 (Self::Gt(a), Self::Gt(b))
247 | (Self::Gte(a), Self::Gte(b))
248 | (Self::Lt(a), Self::Lt(b))
249 | (Self::Lte(a), Self::Lte(b)) => a == b,
250 (Self::Cidr(a), Self::Cidr(b)) => a == b,
251 _ => false,
252 }
253 }
254}
255
256impl Eq for CompiledOperator {}
257
258impl Hash for CompiledOperator {
259 fn hash<H: Hasher>(&self, state: &mut H) {
260 std::mem::discriminant(self).hash(state);
261 match self {
262 Self::Equals(v) | Self::NotEquals(v) => v.hash(state),
263 Self::Contains(b) | Self::NotContains(b) | Self::Prefix(b) | Self::Suffix(b) => {
264 b.hash(state);
265 }
266 Self::Matches(r) => r.as_str().hash(state),
267 Self::In(v) | Self::NotIn(v) => v.hash(state),
268 Self::Gt(i) | Self::Gte(i) | Self::Lt(i) | Self::Lte(i) => i.hash(state),
269 Self::Cidr(n) => n.hash(state),
270 }
271 }
272}
273
274#[derive(Clone, Debug, Eq, PartialEq, Hash, serde::Serialize, serde::Deserialize)]
275pub struct PredicateInst {
276 pub path: FieldPath,
277 pub op: CompiledOperator,
278}
279
280pub enum PredicateView<'a> {
281 L4 { conn: &'a Arc<ConnContext>, peek: Option<&'a [u8]> },
282 L7Req { conn: &'a Arc<ConnContext>, req: &'a Request },
283}
284
285impl<'a> PredicateView<'a> {
286 #[must_use]
296 pub fn build(
297 conn: &'a Arc<ConnContext>,
298 req: Option<&'a Request>,
299 _l4: Option<&'a crate::l4::L4Conn>,
300 peek: Option<&'a [u8]>,
301 ) -> Self {
302 match req {
303 Some(r) => Self::L7Req { conn, req: r },
304 None => Self::L4 { conn, peek },
305 }
306 }
307
308 fn conn(&self) -> &Arc<ConnContext> {
309 match self {
310 Self::L4 { conn, .. } | Self::L7Req { conn, .. } => conn,
311 }
312 }
313
314 fn request(&self) -> Option<&Request> {
315 match self {
316 Self::L7Req { req, .. } => Some(req),
317 Self::L4 { .. } => None,
318 }
319 }
320
321 fn peek_buffer(&self) -> Option<&[u8]> {
322 match self {
323 Self::L4 { peek, .. } => *peek,
324 Self::L7Req { .. } => None,
325 }
326 }
327}
328
329impl PredicateInst {
330 #[must_use]
353 pub fn test(&self, view: &PredicateView<'_>) -> bool {
354 match &self.path {
355 FieldPath::Transport => {
356 let s = match view.conn().transport {
357 crate::conn_context::Transport::Tcp => "tcp",
358 crate::conn_context::Transport::Udp => "udp",
359 };
360 test_str(&self.op, s)
361 }
362 FieldPath::RemoteIp => test_addr(&self.op, view.conn().remote.ip()),
363 FieldPath::RemotePort => test_int(&self.op, i64::from(view.conn().remote.port())),
364 FieldPath::LocalIp => test_addr(&self.op, view.conn().local.ip()),
365 FieldPath::LocalPort => test_int(&self.op, i64::from(view.conn().local.port())),
366 FieldPath::Peek => view.peek_buffer().is_some_and(|b| test_bytes(&self.op, b)),
367 FieldPath::TlsSni => view
368 .conn()
369 .tls
370 .lock()
371 .as_ref()
372 .and_then(|t| t.sni.clone())
373 .is_some_and(|got| test_str(&self.op, got.as_str())),
374 FieldPath::TlsAlpn => view
375 .conn()
376 .tls
377 .lock()
378 .as_ref()
379 .and_then(|t| t.alpn.clone())
380 .is_some_and(|got| test_bytes(&self.op, got.as_slice())),
381 FieldPath::TlsVersion => view
382 .conn()
383 .tls
384 .lock()
385 .as_ref()
386 .and_then(|t| t.version)
387 .is_some_and(|v| test_str(&self.op, tls_version_str(v))),
388 FieldPath::TlsPeerCertSubjectCn => {
396 let cert_der = view.conn().tls.lock().as_ref().and_then(|t| t.peer_cert.clone());
397 let Some(cert) = cert_der else { return false };
398 match peer_cert_subject_cn(cert.as_ref()) {
399 Some(cn) => test_str(&self.op, cn.as_str()),
400 None => false,
401 }
402 }
403 FieldPath::HttpMethod => {
404 let Some(req) = view.request() else { return false };
405 test_str(&self.op, req.method().as_str())
406 }
407 FieldPath::HttpUriPath => {
408 let Some(req) = view.request() else { return false };
409 test_str(&self.op, req.uri().path())
410 }
411 FieldPath::HttpUriQuery => {
412 let Some(req) = view.request() else { return false };
413 test_str(&self.op, req.uri().query().unwrap_or(""))
414 }
415 FieldPath::HttpHeader(name) => {
423 let Some(req) = view.request() else { return false };
424 let Some(value) = req.headers().get(name.as_ref()) else { return false };
425 let Ok(s) = value.to_str() else {
426 return false;
431 };
432 test_str(&self.op, s)
433 }
434 FieldPath::HttpBody => {
443 let Some(req) = view.request() else { return false };
444 let bytes = req.body().as_static().expect("lazy-buffer invariant");
445 test_bytes(&self.op, bytes.as_ref())
446 }
447 }
448 }
449}
450
451fn tls_version_str(v: crate::conn_context::TlsVersion) -> &'static str {
452 match v {
453 crate::conn_context::TlsVersion::Tls12 => "1.2",
454 crate::conn_context::TlsVersion::Tls13 => "1.3",
455 }
456}
457
458fn peer_cert_subject_cn(der: &[u8]) -> Option<String> {
470 use x509_parser::prelude::*;
471 let (_, cert) = X509Certificate::from_der(der).ok()?;
472 let subject = cert.tbs_certificate.subject();
473 let cn_attr = subject.iter_common_name().next()?;
474 cn_attr.as_str().ok().map(ToString::to_string)
475}
476
477fn test_str(op: &CompiledOperator, value: &str) -> bool {
482 match op {
483 CompiledOperator::Equals(CompiledValue::Str(expected)) => value == expected.as_ref(),
484 CompiledOperator::NotEquals(CompiledValue::Str(expected)) => value != expected.as_ref(),
485 CompiledOperator::Contains(b) => contains_bytes(value.as_bytes(), b),
486 CompiledOperator::NotContains(b) => !contains_bytes(value.as_bytes(), b),
487 CompiledOperator::Prefix(b) => value.as_bytes().starts_with(b.as_ref()),
488 CompiledOperator::Suffix(b) => value.as_bytes().ends_with(b.as_ref()),
489 CompiledOperator::Matches(re) => re.is_match(value).unwrap_or(false),
490 CompiledOperator::In(values) => {
491 values.iter().any(|v| matches!(v, CompiledValue::Str(s) if value == s.as_ref()))
492 }
493 CompiledOperator::NotIn(values) => {
494 !values.iter().any(|v| matches!(v, CompiledValue::Str(s) if value == s.as_ref()))
495 }
496 _ => false,
497 }
498}
499
500fn test_bytes(op: &CompiledOperator, value: &[u8]) -> bool {
504 match op {
505 CompiledOperator::Equals(CompiledValue::Bytes(expected)) => value == expected.as_ref(),
506 CompiledOperator::NotEquals(CompiledValue::Bytes(expected)) => value != expected.as_ref(),
507 CompiledOperator::Contains(b) => contains_bytes(value, b),
508 CompiledOperator::NotContains(b) => !contains_bytes(value, b),
509 CompiledOperator::Prefix(b) => value.starts_with(b.as_ref()),
510 CompiledOperator::Suffix(b) => value.ends_with(b.as_ref()),
511 CompiledOperator::In(values) => {
512 values.iter().any(|v| matches!(v, CompiledValue::Bytes(b) if value == b.as_ref()))
513 }
514 CompiledOperator::NotIn(values) => {
515 !values.iter().any(|v| matches!(v, CompiledValue::Bytes(b) if value == b.as_ref()))
516 }
517 _ => false,
518 }
519}
520
521fn test_int(op: &CompiledOperator, value: i64) -> bool {
524 match op {
525 CompiledOperator::Equals(CompiledValue::Int(expected)) => value == *expected,
526 CompiledOperator::NotEquals(CompiledValue::Int(expected)) => value != *expected,
527 CompiledOperator::Gt(n) => value > *n,
528 CompiledOperator::Gte(n) => value >= *n,
529 CompiledOperator::Lt(n) => value < *n,
530 CompiledOperator::Lte(n) => value <= *n,
531 CompiledOperator::In(values) => {
532 values.iter().any(|v| matches!(v, CompiledValue::Int(i) if value == *i))
533 }
534 CompiledOperator::NotIn(values) => {
535 !values.iter().any(|v| matches!(v, CompiledValue::Int(i) if value == *i))
536 }
537 _ => false,
538 }
539}
540
541fn test_addr(op: &CompiledOperator, value: std::net::IpAddr) -> bool {
545 match op {
546 CompiledOperator::Equals(CompiledValue::Addr(expected)) => value == *expected,
547 CompiledOperator::NotEquals(CompiledValue::Addr(expected)) => value != *expected,
548 CompiledOperator::Cidr(net) => net.contains(&value),
549 CompiledOperator::In(values) => {
550 values.iter().any(|v| matches!(v, CompiledValue::Addr(a) if value == *a))
551 }
552 CompiledOperator::NotIn(values) => {
553 !values.iter().any(|v| matches!(v, CompiledValue::Addr(a) if value == *a))
554 }
555 _ => false,
556 }
557}
558
559fn contains_bytes(haystack: &[u8], needle: &[u8]) -> bool {
560 if needle.is_empty() {
561 return true;
562 }
563 if needle.len() > haystack.len() {
564 return false;
565 }
566 haystack.windows(needle.len()).any(|w| w == needle)
567}
568
569pub const REGEX_PATTERN_MAX_BYTES: usize = 4 * 1024;
570
571#[derive(Debug, Clone, serde::Serialize)]
572pub enum Predicate {
573 AnyOf(AnyOfP),
574 AllOf(AllOfP),
575 Not(NotP),
576 Check(CheckMap),
577}
578
579#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
580#[serde(deny_unknown_fields)]
581pub struct AnyOfP {
582 pub any_of: Vec<Predicate>,
583}
584
585#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
586#[serde(deny_unknown_fields)]
587pub struct AllOfP {
588 pub all_of: Vec<Predicate>,
589}
590
591#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
592#[serde(deny_unknown_fields)]
593pub struct NotP {
594 pub not: Box<Predicate>,
595}
596
597#[derive(Debug, Clone, serde::Serialize)]
598pub struct CheckMap {
599 pub path: FieldPath,
600 pub op: Operator,
601}
602
603#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
604#[serde(rename_all = "snake_case")]
605pub enum Operator {
606 Equals(Value),
607 NotEquals(Value),
608 Contains(Value),
609 NotContains(Value),
610 Prefix(Value),
611 Suffix(Value),
612 Matches(String),
613 In(Vec<Value>),
614 NotIn(Vec<Value>),
615 Gt(i64),
616 Gte(i64),
617 Lt(i64),
618 Lte(i64),
619 Cidr(String),
620}
621
622#[derive(Debug, Clone, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
623#[serde(untagged)]
624pub enum Value {
625 Bool(bool),
626 Int(i64),
627 Str(String),
628}
629
630impl<'de> serde::Deserialize<'de> for Predicate {
631 fn deserialize<D: serde::Deserializer<'de>>(de: D) -> Result<Self, D::Error> {
632 let v = serde_json::Value::deserialize(de)?;
633 let serde_json::Value::Object(ref map) = v else {
634 return Err(serde::de::Error::custom("predicate must be a JSON object"));
635 };
636 if map.len() == 1 {
637 let (key, _) = map.iter().next().expect("len == 1");
638 match key.as_str() {
639 "any_of" => {
640 return serde_json::from_value::<AnyOfP>(v)
641 .map(Predicate::AnyOf)
642 .map_err(serde::de::Error::custom);
643 }
644 "all_of" => {
645 return serde_json::from_value::<AllOfP>(v)
646 .map(Predicate::AllOf)
647 .map_err(serde::de::Error::custom);
648 }
649 "not" => {
650 return serde_json::from_value::<NotP>(v)
651 .map(Predicate::Not)
652 .map_err(serde::de::Error::custom);
653 }
654 _ => {}
655 }
656 }
657 serde_json::from_value::<CheckMap>(v).map(Predicate::Check).map_err(serde::de::Error::custom)
658 }
659}
660
661impl<'de> serde::Deserialize<'de> for CheckMap {
662 fn deserialize<D: serde::Deserializer<'de>>(de: D) -> Result<Self, D::Error> {
663 struct Visitor;
664
665 impl<'de> serde::de::Visitor<'de> for Visitor {
666 type Value = CheckMap;
667
668 fn expecting(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
669 f.write_str("a single-key object of the form {\"<field-path>\": {\"<operator>\": <value>}}")
670 }
671
672 fn visit_map<M: serde::de::MapAccess<'de>>(self, mut map: M) -> Result<CheckMap, M::Error> {
673 let Some(key) = map.next_key::<String>()? else {
674 return Err(serde::de::Error::invalid_length(0, &"exactly one key"));
675 };
676 let path = parse_field_path(&key).map_err(serde::de::Error::custom)?;
677 let op: Operator = map.next_value()?;
678 if map.next_key::<serde::de::IgnoredAny>()?.is_some() {
679 return Err(serde::de::Error::custom("check object must have exactly one key"));
680 }
681 validate_operator(&op).map_err(serde::de::Error::custom)?;
682 Ok(CheckMap { path, op })
683 }
684 }
685
686 de.deserialize_map(Visitor)
687 }
688}
689
690fn parse_field_path(s: &str) -> Result<FieldPath, String> {
691 if s.chars().any(|c| c.is_ascii_uppercase()) {
692 return Err(format!(
693 "field path must be lowercase: {:?} — did you mean {:?}?",
694 s,
695 s.to_ascii_lowercase(),
696 ));
697 }
698 match s {
699 "transport" => Ok(FieldPath::Transport),
700 "remote.ip" => Ok(FieldPath::RemoteIp),
701 "remote.port" => Ok(FieldPath::RemotePort),
702 "local.ip" => Ok(FieldPath::LocalIp),
703 "local.port" => Ok(FieldPath::LocalPort),
704 "peek" => Ok(FieldPath::Peek),
705 "tls.sni" => Ok(FieldPath::TlsSni),
706 "tls.alpn" => Ok(FieldPath::TlsAlpn),
707 "tls.version" => Ok(FieldPath::TlsVersion),
708 "tls.peer_cert.subject_cn" => Ok(FieldPath::TlsPeerCertSubjectCn),
709 "http.method" => Ok(FieldPath::HttpMethod),
710 "http.uri.path" => Ok(FieldPath::HttpUriPath),
711 "http.uri.query" => Ok(FieldPath::HttpUriQuery),
712 "http.body" => Ok(FieldPath::HttpBody),
713 other if other.starts_with("http.header.") => {
714 let name = &other["http.header.".len()..];
715 if name.is_empty() {
716 return Err(format!("http.header.* requires a header name: {other:?}"));
717 }
718 Ok(FieldPath::HttpHeader(Arc::from(name)))
719 }
720 other => Err(format!("unknown field path: {other:?}")),
721 }
722}
723
724fn validate_operator(op: &Operator) -> Result<(), String> {
725 if let Operator::Matches(pattern) = op
726 && pattern.len() > REGEX_PATTERN_MAX_BYTES
727 {
728 return Err(format!(
729 "regex pattern source exceeds {REGEX_PATTERN_MAX_BYTES}-byte limit: got {} bytes",
730 pattern.len(),
731 ));
732 }
733 Ok(())
734}
735
736mod serde_impls {
737 use base64::Engine as _;
738 use base64::engine::general_purpose::STANDARD as B64;
739 use bytes::Bytes;
740 use std::net::IpAddr;
741 use std::sync::Arc;
742
743 use super::{CompiledOperator, CompiledValue};
744
745 pub(super) fn ser_bytes<S: serde::Serializer>(b: &Bytes, s: S) -> Result<S::Ok, S::Error> {
746 s.serialize_str(&B64.encode(b))
747 }
748
749 pub(super) fn de_bytes<'de, D: serde::Deserializer<'de>>(d: D) -> Result<Bytes, D::Error> {
750 use serde::Deserialize as _;
751 let s = String::deserialize(d)?;
752 B64.decode(s.as_bytes()).map(Bytes::from).map_err(serde::de::Error::custom)
753 }
754
755 pub(super) fn ser_regex<S: serde::Serializer>(
756 r: &fancy_regex::Regex,
757 s: S,
758 ) -> Result<S::Ok, S::Error> {
759 s.serialize_str(r.as_str())
760 }
761
762 pub(super) fn de_regex<'de, D: serde::Deserializer<'de>>(
763 d: D,
764 ) -> Result<fancy_regex::Regex, D::Error> {
765 use serde::Deserialize as _;
766 let s = String::deserialize(d)?;
767 fancy_regex::Regex::new(&s)
768 .map_err(|e| serde::de::Error::custom(format!("invalid regex {s:?}: {e}")))
769 }
770
771 #[derive(serde::Serialize, serde::Deserialize)]
773 #[serde(rename_all = "snake_case")]
774 pub(super) enum ValueShadow {
775 Str(Arc<str>),
776 #[serde(serialize_with = "ser_bytes", deserialize_with = "de_bytes")]
777 Bytes(Bytes),
778 Int(i64),
779 Bool(bool),
780 Addr(IpAddr),
781 }
782
783 impl From<&CompiledValue> for ValueShadow {
784 fn from(v: &CompiledValue) -> Self {
785 match v {
786 CompiledValue::Str(s) => Self::Str(Arc::clone(s)),
787 CompiledValue::Bytes(b) => Self::Bytes(b.clone()),
788 CompiledValue::Int(i) => Self::Int(*i),
789 CompiledValue::Bool(b) => Self::Bool(*b),
790 CompiledValue::Addr(a) => Self::Addr(*a),
791 }
792 }
793 }
794
795 impl From<ValueShadow> for CompiledValue {
796 fn from(v: ValueShadow) -> Self {
797 match v {
798 ValueShadow::Str(s) => Self::Str(s),
799 ValueShadow::Bytes(b) => Self::Bytes(b),
800 ValueShadow::Int(i) => Self::Int(i),
801 ValueShadow::Bool(b) => Self::Bool(b),
802 ValueShadow::Addr(a) => Self::Addr(a),
803 }
804 }
805 }
806
807 #[derive(serde::Serialize, serde::Deserialize)]
810 #[serde(rename_all = "snake_case")]
811 pub(super) enum OperatorShadow {
812 Equals(CompiledValue),
813 NotEquals(CompiledValue),
814 #[serde(serialize_with = "ser_bytes", deserialize_with = "de_bytes")]
815 Contains(Bytes),
816 #[serde(serialize_with = "ser_bytes", deserialize_with = "de_bytes")]
817 NotContains(Bytes),
818 #[serde(serialize_with = "ser_bytes", deserialize_with = "de_bytes")]
819 Prefix(Bytes),
820 #[serde(serialize_with = "ser_bytes", deserialize_with = "de_bytes")]
821 Suffix(Bytes),
822 #[serde(serialize_with = "ser_regex", deserialize_with = "de_regex")]
823 Matches(fancy_regex::Regex),
824 In(Vec<CompiledValue>),
825 NotIn(Vec<CompiledValue>),
826 Gt(i64),
827 Gte(i64),
828 Lt(i64),
829 Lte(i64),
830 Cidr(ipnet::IpNet),
831 }
832
833 impl From<&CompiledOperator> for OperatorShadow {
834 fn from(op: &CompiledOperator) -> Self {
835 match op {
836 CompiledOperator::Equals(v) => Self::Equals(v.clone()),
837 CompiledOperator::NotEquals(v) => Self::NotEquals(v.clone()),
838 CompiledOperator::Contains(b) => Self::Contains(b.clone()),
839 CompiledOperator::NotContains(b) => Self::NotContains(b.clone()),
840 CompiledOperator::Prefix(b) => Self::Prefix(b.clone()),
841 CompiledOperator::Suffix(b) => Self::Suffix(b.clone()),
842 CompiledOperator::Matches(r) => {
843 Self::Matches(fancy_regex::Regex::new(r.as_str()).expect("round-trippable"))
844 }
845 CompiledOperator::In(vs) => Self::In(vs.clone()),
846 CompiledOperator::NotIn(vs) => Self::NotIn(vs.clone()),
847 CompiledOperator::Gt(i) => Self::Gt(*i),
848 CompiledOperator::Gte(i) => Self::Gte(*i),
849 CompiledOperator::Lt(i) => Self::Lt(*i),
850 CompiledOperator::Lte(i) => Self::Lte(*i),
851 CompiledOperator::Cidr(n) => Self::Cidr(*n),
852 }
853 }
854 }
855
856 impl From<OperatorShadow> for CompiledOperator {
857 fn from(op: OperatorShadow) -> Self {
858 match op {
859 OperatorShadow::Equals(v) => Self::Equals(v),
860 OperatorShadow::NotEquals(v) => Self::NotEquals(v),
861 OperatorShadow::Contains(b) => Self::Contains(b),
862 OperatorShadow::NotContains(b) => Self::NotContains(b),
863 OperatorShadow::Prefix(b) => Self::Prefix(b),
864 OperatorShadow::Suffix(b) => Self::Suffix(b),
865 OperatorShadow::Matches(r) => Self::Matches(r),
866 OperatorShadow::In(vs) => Self::In(vs),
867 OperatorShadow::NotIn(vs) => Self::NotIn(vs),
868 OperatorShadow::Gt(i) => Self::Gt(i),
869 OperatorShadow::Gte(i) => Self::Gte(i),
870 OperatorShadow::Lt(i) => Self::Lt(i),
871 OperatorShadow::Lte(i) => Self::Lte(i),
872 OperatorShadow::Cidr(n) => Self::Cidr(n),
873 }
874 }
875 }
876}
877
878impl serde::Serialize for CompiledValue {
879 fn serialize<S: serde::Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
880 serde_impls::ValueShadow::from(self).serialize(s)
881 }
882}
883
884impl<'de> serde::Deserialize<'de> for CompiledValue {
885 fn deserialize<D: serde::Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
886 serde_impls::ValueShadow::deserialize(d).map(Self::from)
887 }
888}
889
890impl serde::Serialize for CompiledOperator {
891 fn serialize<S: serde::Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
892 serde_impls::OperatorShadow::from(self).serialize(s)
893 }
894}
895
896impl<'de> serde::Deserialize<'de> for CompiledOperator {
897 fn deserialize<D: serde::Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
898 serde_impls::OperatorShadow::deserialize(d).map(Self::from)
899 }
900}
901
902#[cfg(test)]
903mod tests {
904 use std::collections::hash_map::DefaultHasher;
905 use std::hash::Hash;
906 use std::net::{Ipv4Addr, Ipv6Addr};
907 use std::str::FromStr;
908 use std::sync::OnceLock;
909 use std::time::Instant;
910
911 use bytes::Bytes;
912 use fancy_regex::Regex;
913 use ipnet::IpNet;
914 use parking_lot::Mutex;
915
916 use super::*;
917 use crate::body::{Body, Request};
918 use crate::conn_context::{ConnId, Transport};
919
920 fn hash_of<T: Hash>(v: &T) -> u64 {
924 let mut h = DefaultHasher::new();
925 v.hash(&mut h);
926 h.finish()
927 }
928
929 fn make_conn() -> Arc<ConnContext> {
930 Arc::new(ConnContext {
931 id: ConnId(1),
932 remote: "127.0.0.1:0".parse().expect("parse remote"),
933 local: "127.0.0.1:0".parse().expect("parse local"),
934 transport: Transport::Tcp,
935 entered_at: Instant::now(),
936 tls: Mutex::new(None),
937 http_version: OnceLock::new(),
938 user: Mutex::new(http::Extensions::new()),
939 })
940 }
941
942 #[test]
943 fn field_path_http_header_is_equal_by_string_content_not_arc_identity() {
944 let a = FieldPath::HttpHeader(Arc::from("host"));
945 let b = FieldPath::HttpHeader(Arc::from("host"));
946 assert_eq!(a, b);
947 assert_eq!(hash_of(&a), hash_of(&b));
948 let upper = FieldPath::HttpHeader(Arc::from("Host"));
953 assert_ne!(a, upper);
954 }
955
956 #[test]
957 fn field_path_simple_variants_are_self_equal_and_mutually_distinct() {
958 let paths = [
959 FieldPath::Transport,
960 FieldPath::RemoteIp,
961 FieldPath::RemotePort,
962 FieldPath::LocalIp,
963 FieldPath::LocalPort,
964 FieldPath::Peek,
965 FieldPath::TlsSni,
966 FieldPath::TlsAlpn,
967 FieldPath::TlsVersion,
968 FieldPath::TlsPeerCertSubjectCn,
969 FieldPath::HttpMethod,
970 FieldPath::HttpUriPath,
971 FieldPath::HttpUriQuery,
972 FieldPath::HttpBody,
973 ];
974 for (i, a) in paths.iter().enumerate() {
975 for (j, b) in paths.iter().enumerate() {
976 if i == j {
977 assert_eq!(a, b);
978 } else {
979 assert_ne!(a, b);
980 }
981 }
982 }
983 }
984
985 #[test]
986 fn compiled_value_str_is_equal_by_content_not_arc_identity() {
987 let a = CompiledValue::Str(Arc::<str>::from("x"));
988 let b = CompiledValue::Str(Arc::<str>::from("x"));
989 assert_eq!(a, b);
990 assert_eq!(hash_of(&a), hash_of(&b));
991 let c = CompiledValue::Str(Arc::<str>::from("y"));
992 assert_ne!(a, c);
993 }
994
995 #[test]
996 fn compiled_value_cross_variant_inequality() {
997 let s = CompiledValue::Str(Arc::<str>::from("42"));
998 let i = CompiledValue::Int(42);
999 assert_ne!(s, i);
1000 }
1001
1002 #[test]
1003 fn compiled_value_bytes_int_bool_addr_self_equal() {
1004 assert_eq!(
1005 CompiledValue::Bytes(Bytes::from_static(b"abc")),
1006 CompiledValue::Bytes(Bytes::copy_from_slice(b"abc")),
1007 );
1008 assert_eq!(CompiledValue::Int(7), CompiledValue::Int(7));
1009 assert_ne!(CompiledValue::Int(7), CompiledValue::Int(8));
1010 assert_eq!(CompiledValue::Bool(true), CompiledValue::Bool(true));
1011 assert_ne!(CompiledValue::Bool(true), CompiledValue::Bool(false));
1012 assert_eq!(
1013 CompiledValue::Addr(Ipv4Addr::new(10, 0, 0, 1).into()),
1014 CompiledValue::Addr(Ipv4Addr::new(10, 0, 0, 1).into()),
1015 );
1016 assert_ne!(
1017 CompiledValue::Addr(Ipv4Addr::new(10, 0, 0, 1).into()),
1018 CompiledValue::Addr(Ipv6Addr::LOCALHOST.into()),
1019 );
1020 }
1021
1022 #[test]
1023 fn compiled_operator_matches_equal_by_pattern_source() {
1024 let a = CompiledOperator::Matches(Regex::new("^/api").expect("compile a"));
1025 let b = CompiledOperator::Matches(Regex::new("^/api").expect("compile b"));
1026 assert_eq!(a, b);
1027 assert_eq!(hash_of(&a), hash_of(&b));
1028 }
1029
1030 #[test]
1031 fn compiled_operator_matches_distinct_patterns_unequal() {
1032 let a = CompiledOperator::Matches(Regex::new("a|b").expect("compile a"));
1035 let b = CompiledOperator::Matches(Regex::new("b|a").expect("compile b"));
1036 assert_ne!(a, b);
1037 }
1038
1039 #[test]
1040 fn compiled_operator_cidr_equal_by_canonical_form() {
1041 let a = CompiledOperator::Cidr(IpNet::from_str("10.0.0.0/8").expect("parse a"));
1042 let b = CompiledOperator::Cidr(IpNet::from_str("10.0.0.0/8").expect("parse b"));
1043 assert_eq!(a, b);
1044 assert_eq!(hash_of(&a), hash_of(&b));
1045 }
1046
1047 #[test]
1048 fn compiled_operator_cidr_distinct_networks_unequal() {
1049 let a = CompiledOperator::Cidr(IpNet::from_str("10.0.0.0/8").expect("parse a"));
1050 let b = CompiledOperator::Cidr(IpNet::from_str("10.0.0.0/16").expect("parse b"));
1051 assert_ne!(a, b);
1052 }
1053
1054 #[test]
1055 fn compiled_operator_in_is_order_sensitive() {
1056 let xs =
1057 vec![CompiledValue::Str(Arc::<str>::from("a")), CompiledValue::Str(Arc::<str>::from("b"))];
1058 let ys =
1059 vec![CompiledValue::Str(Arc::<str>::from("b")), CompiledValue::Str(Arc::<str>::from("a"))];
1060 assert_ne!(CompiledOperator::In(xs.clone()), CompiledOperator::In(ys.clone()));
1061 assert_ne!(CompiledOperator::NotIn(xs), CompiledOperator::NotIn(ys));
1062 }
1063
1064 #[test]
1065 fn compiled_operator_numeric_comparisons_distinct_per_variant() {
1066 let ops = [
1068 CompiledOperator::Gt(10),
1069 CompiledOperator::Gte(10),
1070 CompiledOperator::Lt(10),
1071 CompiledOperator::Lte(10),
1072 ];
1073 for (i, a) in ops.iter().enumerate() {
1074 for (j, b) in ops.iter().enumerate() {
1075 if i == j {
1076 assert_eq!(a, b);
1077 } else {
1078 assert_ne!(a, b);
1079 }
1080 }
1081 }
1082 }
1083
1084 #[test]
1085 fn compiled_operator_bytes_variants_distinguished() {
1086 let payload = Bytes::from_static(b"abc");
1087 let ops = [
1088 CompiledOperator::Contains(payload.clone()),
1089 CompiledOperator::NotContains(payload.clone()),
1090 CompiledOperator::Prefix(payload.clone()),
1091 CompiledOperator::Suffix(payload),
1092 ];
1093 for (i, a) in ops.iter().enumerate() {
1094 for (j, b) in ops.iter().enumerate() {
1095 if i == j {
1096 assert_eq!(a, b);
1097 } else {
1098 assert_ne!(a, b);
1099 }
1100 }
1101 }
1102 }
1103
1104 #[test]
1105 fn predicate_inst_equal_across_independent_construction_paths() {
1106 let lhs = PredicateInst {
1107 path: FieldPath::HttpHeader(Arc::from("host")),
1108 op: CompiledOperator::Equals(CompiledValue::Str(Arc::<str>::from("example.com"))),
1109 };
1110 let rhs = PredicateInst {
1111 path: FieldPath::HttpHeader(Arc::from("host")),
1112 op: CompiledOperator::Equals(CompiledValue::Str(Arc::<str>::from("example.com"))),
1113 };
1114 assert_eq!(lhs, rhs);
1115 assert_eq!(hash_of(&lhs), hash_of(&rhs));
1116 }
1117
1118 #[test]
1119 fn predicate_inst_equal_with_regex_operator_from_separate_compiles() {
1120 let lhs = PredicateInst {
1121 path: FieldPath::HttpUriPath,
1122 op: CompiledOperator::Matches(Regex::new("^/").expect("compile a")),
1123 };
1124 let rhs = PredicateInst {
1125 path: FieldPath::HttpUriPath,
1126 op: CompiledOperator::Matches(Regex::new("^/").expect("compile b")),
1127 };
1128 assert_eq!(lhs, rhs);
1129 assert_eq!(hash_of(&lhs), hash_of(&rhs));
1130 }
1131
1132 #[test]
1133 fn predicate_inst_unequal_on_path_difference() {
1134 let value = CompiledValue::Str(Arc::<str>::from("x"));
1135 let a =
1136 PredicateInst { path: FieldPath::HttpUriPath, op: CompiledOperator::Equals(value.clone()) };
1137 let b = PredicateInst { path: FieldPath::HttpUriQuery, op: CompiledOperator::Equals(value) };
1138 assert_ne!(a, b);
1139 }
1140
1141 #[test]
1142 fn predicate_view_variants_construct() {
1143 let conn = make_conn();
1144 let peek_bytes: &[u8] = b"\x16\x03\x01";
1145 let l4 = PredicateView::L4 { conn: &conn, peek: Some(peek_bytes) };
1146 match l4 {
1147 PredicateView::L4 { peek, .. } => assert_eq!(peek.map(<[u8]>::len), Some(3)),
1148 PredicateView::L7Req { .. } => panic!("wrong variant"),
1149 }
1150
1151 let conn2 = make_conn();
1152 let req: Request =
1153 http::Request::builder().method("GET").uri("/").body(Body::Empty).expect("build request");
1154 let l7 = PredicateView::L7Req { conn: &conn2, req: &req };
1155 match l7 {
1156 PredicateView::L7Req { .. } => {}
1157 PredicateView::L4 { .. } => panic!("wrong variant"),
1158 }
1159 }
1160
1161 fn parse_predicate(v: serde_json::Value) -> Result<Predicate, serde_json::Error> {
1165 serde_json::from_value(v)
1166 }
1167
1168 fn expect_check(p: &Predicate) -> &CheckMap {
1169 match p {
1170 Predicate::Check(c) => c,
1171 other => panic!("expected Predicate::Check, got {other:?}"),
1172 }
1173 }
1174
1175 #[test]
1176 fn parse_any_of_happy_path() {
1177 let raw = serde_json::json!({
1178 "any_of": [
1179 { "tls.sni": { "equals": "a" } },
1180 { "tls.sni": { "equals": "b" } },
1181 ],
1182 });
1183 let p = parse_predicate(raw).expect("parse any_of");
1184 let Predicate::AnyOf(AnyOfP { any_of }) = p else {
1185 panic!("expected AnyOf");
1186 };
1187 assert_eq!(any_of.len(), 2);
1188 let c0 = expect_check(&any_of[0]);
1189 let c1 = expect_check(&any_of[1]);
1190 assert_eq!(c0.path, FieldPath::TlsSni);
1191 assert_eq!(c1.path, FieldPath::TlsSni);
1192 match (&c0.op, &c1.op) {
1193 (Operator::Equals(Value::Str(a)), Operator::Equals(Value::Str(b))) => {
1194 assert_eq!(a, "a");
1195 assert_eq!(b, "b");
1196 }
1197 (a, b) => panic!("unexpected ops: {a:?} / {b:?}"),
1198 }
1199 }
1200
1201 #[test]
1202 fn parse_not_happy_path() {
1203 let raw = serde_json::json!({
1204 "not": { "tls.sni": { "equals": "internal" } },
1205 });
1206 let p = parse_predicate(raw).expect("parse not");
1207 let Predicate::Not(NotP { not }) = p else {
1208 panic!("expected Not");
1209 };
1210 let inner = expect_check(¬);
1211 assert_eq!(inner.path, FieldPath::TlsSni);
1212 match &inner.op {
1213 Operator::Equals(Value::Str(s)) => assert_eq!(s, "internal"),
1214 other => panic!("unexpected op: {other:?}"),
1215 }
1216 }
1217
1218 #[test]
1219 fn parse_all_of_happy_path() {
1220 let raw = serde_json::json!({
1221 "all_of": [
1222 { "http.header.upgrade": { "equals": "websocket" } },
1223 { "http.uri.path": { "prefix": "/ws" } },
1224 ],
1225 });
1226 let p = parse_predicate(raw).expect("parse all_of");
1227 let Predicate::AllOf(AllOfP { all_of }) = p else {
1228 panic!("expected AllOf");
1229 };
1230 assert_eq!(all_of.len(), 2);
1231 let c0 = expect_check(&all_of[0]);
1232 let c1 = expect_check(&all_of[1]);
1233 assert_eq!(c0.path, FieldPath::HttpHeader(Arc::from("upgrade")));
1234 assert_eq!(c1.path, FieldPath::HttpUriPath);
1235 }
1236
1237 #[test]
1238 fn parse_all_of_empty_array_parses() {
1239 let raw = serde_json::json!({ "all_of": [] });
1242 let p = parse_predicate(raw).expect("empty all_of parses");
1243 let Predicate::AllOf(AllOfP { all_of }) = p else {
1244 panic!("expected AllOf");
1245 };
1246 assert!(all_of.is_empty());
1247 }
1248
1249 #[test]
1250 fn parse_all_of_nested_with_check_and_any_of() {
1251 let raw = serde_json::json!({
1252 "all_of": [
1253 { "tls.sni": { "equals": "api.example.com" } },
1254 { "any_of": [
1255 { "remote.ip": { "cidr": "10.0.0.0/8" } },
1256 { "remote.ip": { "cidr": "192.168.0.0/16" } },
1257 ]},
1258 ],
1259 });
1260 let p = parse_predicate(raw).expect("parse nested all_of/any_of");
1261 let Predicate::AllOf(AllOfP { all_of }) = p else {
1262 panic!("expected AllOf");
1263 };
1264 assert_eq!(all_of.len(), 2);
1265 assert!(matches!(all_of[0], Predicate::Check(_)));
1266 assert!(matches!(all_of[1], Predicate::AnyOf(_)));
1267 }
1268
1269 #[test]
1270 fn parse_all_of_with_extra_key_is_rejected() {
1271 let raw = serde_json::json!({
1273 "all_of": [ { "tls.sni": { "equals": "a" } } ],
1274 "extra": "unwanted",
1275 });
1276 let err = parse_predicate(raw).expect_err("must reject extra key on all_of");
1277 let _ = err.to_string();
1278 }
1279
1280 #[test]
1281 fn parse_http_header_all_of_is_a_check_not_combinator() {
1282 let raw = serde_json::json!({ "http.header.all_of": { "equals": "x" } });
1285 let p = parse_predicate(raw).expect("parse http.header.all_of");
1286 let c = expect_check(&p);
1287 assert_eq!(c.path, FieldPath::HttpHeader(Arc::from("all_of")));
1288 }
1289
1290 #[test]
1291 fn parse_check_across_representative_paths() {
1292 let cases = [
1293 (serde_json::json!({ "tls.sni": { "equals": "api.example.com" } }), FieldPath::TlsSni),
1294 (serde_json::json!({ "remote.port": { "gt": 1024 } }), FieldPath::RemotePort),
1295 (serde_json::json!({ "http.method": { "equals": "GET" } }), FieldPath::HttpMethod),
1296 (serde_json::json!({ "http.uri.path": { "prefix": "/api" } }), FieldPath::HttpUriPath),
1297 (
1298 serde_json::json!({ "http.header.host": { "equals": "a.example.com" } }),
1299 FieldPath::HttpHeader(Arc::from("host")),
1300 ),
1301 (serde_json::json!({ "http.body": { "contains": "hello" } }), FieldPath::HttpBody),
1302 ];
1303 for (raw, expected_path) in cases {
1304 let p = parse_predicate(raw.clone()).unwrap_or_else(|e| panic!("parse {raw}: {e}"));
1305 let c = expect_check(&p);
1306 assert_eq!(c.path, expected_path, "input: {raw}");
1307 }
1308 }
1309
1310 #[test]
1311 fn parse_any_of_with_extra_key_is_rejected() {
1312 let raw = serde_json::json!({
1315 "any_of": [ { "tls.sni": { "equals": "a" } } ],
1316 "extra": true,
1317 });
1318 let err = parse_predicate(raw).expect_err("must reject extra key on any_of");
1319 let _ = err.to_string();
1320 }
1321
1322 #[test]
1323 fn parse_http_header_any_of_is_a_check_not_combinator() {
1324 let raw = serde_json::json!({ "http.header.any_of": { "equals": "x" } });
1327 let p = parse_predicate(raw).expect("parse");
1328 let c = expect_check(&p);
1329 assert_eq!(c.path, FieldPath::HttpHeader(Arc::from("any_of")));
1330 }
1331
1332 #[test]
1333 fn parse_uppercase_field_path_suggests_lowercase() {
1334 let raw = serde_json::json!({ "http.header.Host": { "equals": "x" } });
1335 let err = parse_predicate(raw).expect_err("uppercase must fail");
1336 let msg = err.to_string();
1337 assert!(msg.contains("http.header.Host"), "error mentions offending input: {msg}");
1338 assert!(msg.contains("did you mean"), "error includes suggestion phrase: {msg}");
1339 assert!(msg.contains("http.header.host"), "error contains lowercased form: {msg}");
1340 }
1341
1342 #[test]
1343 fn parse_multi_key_check_is_rejected() {
1344 let raw = serde_json::json!({
1345 "http.uri.path": { "matches": "^/" },
1346 "http.method": { "equals": "GET" },
1347 });
1348 let err = parse_predicate(raw).expect_err("multi-key check must fail");
1349 let _ = err.to_string();
1350 }
1351
1352 #[test]
1353 fn parse_empty_http_header_name_is_rejected() {
1354 let raw = serde_json::json!({ "http.header.": { "equals": "x" } });
1355 let err = parse_predicate(raw).expect_err("empty header name must fail");
1356 let _ = err.to_string();
1357 }
1358
1359 #[test]
1360 fn parse_unknown_field_path_is_rejected_with_name() {
1361 let raw = serde_json::json!({ "http.nope": { "equals": "x" } });
1362 let err = parse_predicate(raw).expect_err("unknown path must fail");
1363 let msg = err.to_string();
1364 assert!(msg.contains("http.nope"), "error mentions offending path: {msg}");
1365 }
1366
1367 fn parse_op(v: serde_json::Value) -> Operator {
1368 let mut map = serde_json::Map::new();
1369 map.insert("tls.sni".to_string(), v);
1370 let raw = serde_json::Value::Object(map);
1371 match parse_predicate(raw).expect("parse check") {
1372 Predicate::Check(c) => c.op,
1373 other => panic!("expected Check, got {other:?}"),
1374 }
1375 }
1376
1377 #[test]
1378 fn operator_equals_and_not_equals_on_string() {
1379 let eq = parse_op(serde_json::json!({ "equals": "api" }));
1380 match eq {
1381 Operator::Equals(Value::Str(s)) => assert_eq!(s, "api"),
1382 other => panic!("expected equals/str: {other:?}"),
1383 }
1384 let neq = parse_op(serde_json::json!({ "not_equals": "api" }));
1385 match neq {
1386 Operator::NotEquals(Value::Str(s)) => assert_eq!(s, "api"),
1387 other => panic!("expected not_equals/str: {other:?}"),
1388 }
1389 }
1390
1391 #[test]
1392 fn operator_contains_and_not_contains_on_string() {
1393 let c = parse_op(serde_json::json!({ "contains": "foo" }));
1394 match c {
1395 Operator::Contains(Value::Str(s)) => assert_eq!(s, "foo"),
1396 other => panic!("expected contains/str: {other:?}"),
1397 }
1398 let nc = parse_op(serde_json::json!({ "not_contains": "foo" }));
1399 match nc {
1400 Operator::NotContains(Value::Str(s)) => assert_eq!(s, "foo"),
1401 other => panic!("expected not_contains/str: {other:?}"),
1402 }
1403 }
1404
1405 #[test]
1406 fn operator_prefix_and_suffix_on_string() {
1407 let p = parse_op(serde_json::json!({ "prefix": "/api" }));
1408 match p {
1409 Operator::Prefix(Value::Str(s)) => assert_eq!(s, "/api"),
1410 other => panic!("expected prefix/str: {other:?}"),
1411 }
1412 let s = parse_op(serde_json::json!({ "suffix": ".json" }));
1413 match s {
1414 Operator::Suffix(Value::Str(v)) => assert_eq!(v, ".json"),
1415 other => panic!("expected suffix/str: {other:?}"),
1416 }
1417 }
1418
1419 #[test]
1420 fn operator_matches_carries_pattern_source() {
1421 let op = parse_op(serde_json::json!({ "matches": "^/api/v\\d+" }));
1422 match op {
1423 Operator::Matches(pattern) => assert_eq!(pattern, "^/api/v\\d+"),
1424 other => panic!("expected matches: {other:?}"),
1425 }
1426 }
1427
1428 #[test]
1429 fn operator_in_and_not_in_accept_mixed_scalar_types() {
1430 let op = parse_op(serde_json::json!({ "in": ["foo", 42] }));
1431 let Operator::In(xs) = op else {
1432 panic!("expected in");
1433 };
1434 assert_eq!(xs.len(), 2);
1435 assert_eq!(xs[0], Value::Str("foo".into()));
1436 assert_eq!(xs[1], Value::Int(42));
1437 let op2 = parse_op(serde_json::json!({ "not_in": ["bar", 7] }));
1438 let Operator::NotIn(ys) = op2 else {
1439 panic!("expected not_in");
1440 };
1441 assert_eq!(ys.len(), 2);
1442 assert_eq!(ys[0], Value::Str("bar".into()));
1443 assert_eq!(ys[1], Value::Int(7));
1444 }
1445
1446 #[test]
1447 fn operator_numeric_comparisons() {
1448 assert!(matches!(parse_op(serde_json::json!({ "gt": 10 })), Operator::Gt(10)));
1449 assert!(matches!(parse_op(serde_json::json!({ "gte": 10 })), Operator::Gte(10)));
1450 assert!(matches!(parse_op(serde_json::json!({ "lt": 10 })), Operator::Lt(10)));
1451 assert!(matches!(parse_op(serde_json::json!({ "lte": 10 })), Operator::Lte(10)));
1452 }
1453
1454 #[test]
1455 fn operator_cidr_carries_source_string() {
1456 let op = parse_op(serde_json::json!({ "cidr": "10.0.0.0/8" }));
1457 match op {
1458 Operator::Cidr(s) => assert_eq!(s, "10.0.0.0/8"),
1459 other => panic!("expected cidr: {other:?}"),
1460 }
1461 }
1462
1463 #[test]
1464 fn value_untagged_priority_bool_before_str() {
1465 let op_t = parse_op(serde_json::json!({ "equals": true }));
1468 assert!(matches!(op_t, Operator::Equals(Value::Bool(true))));
1469 let op_f = parse_op(serde_json::json!({ "equals": false }));
1470 assert!(matches!(op_f, Operator::Equals(Value::Bool(false))));
1471 }
1472
1473 #[test]
1474 fn value_untagged_priority_int_before_str() {
1475 let op = parse_op(serde_json::json!({ "equals": 42 }));
1477 assert!(matches!(op, Operator::Equals(Value::Int(42))));
1478 }
1479
1480 #[test]
1481 fn value_untagged_json_string_stays_str() {
1482 let op = parse_op(serde_json::json!({ "equals": "42" }));
1485 match op {
1486 Operator::Equals(Value::Str(s)) => assert_eq!(s, "42"),
1487 other => panic!("expected equals/str(\"42\"): {other:?}"),
1488 }
1489 }
1490
1491 #[test]
1492 fn regex_pattern_exactly_at_limit_parses() {
1493 assert_eq!(REGEX_PATTERN_MAX_BYTES, 4 * 1024);
1495 let pattern = "a".repeat(REGEX_PATTERN_MAX_BYTES);
1496 let raw = serde_json::json!({ "http.uri.path": { "matches": pattern } });
1497 let p = parse_predicate(raw).expect("4 KiB pattern parses");
1498 let c = expect_check(&p);
1499 match &c.op {
1500 Operator::Matches(src) => assert_eq!(src.len(), REGEX_PATTERN_MAX_BYTES),
1501 other => panic!("expected matches: {other:?}"),
1502 }
1503 }
1504
1505 #[test]
1506 fn regex_pattern_over_limit_rejected_with_limit_in_message() {
1507 let pattern = "a".repeat(REGEX_PATTERN_MAX_BYTES + 1);
1508 let raw = serde_json::json!({ "http.uri.path": { "matches": pattern } });
1509 let err = parse_predicate(raw).expect_err("over-limit pattern must fail");
1510 let msg = err.to_string();
1511 assert!(
1512 msg.contains(®EX_PATTERN_MAX_BYTES.to_string()),
1513 "error mentions the limit ({REGEX_PATTERN_MAX_BYTES}): {msg}",
1514 );
1515 }
1516
1517 fn value_round_trip(v: &CompiledValue) -> CompiledValue {
1526 let encoded = serde_json::to_string(v).expect("serialize value");
1527 serde_json::from_str(&encoded).expect("deserialize value")
1528 }
1529
1530 #[test]
1531 fn compiled_value_str_round_trip_including_empty() {
1532 let non_empty = CompiledValue::Str(Arc::<str>::from("x"));
1533 assert_eq!(value_round_trip(&non_empty), non_empty);
1534 let empty = CompiledValue::Str(Arc::<str>::from(""));
1535 assert_eq!(value_round_trip(&empty), empty);
1536 }
1537
1538 #[test]
1539 fn compiled_value_bytes_round_trip_including_empty_and_binary() {
1540 let hello = CompiledValue::Bytes(Bytes::from_static(b"hello"));
1541 assert_eq!(value_round_trip(&hello), hello);
1542 let empty = CompiledValue::Bytes(Bytes::new());
1543 assert_eq!(value_round_trip(&empty), empty);
1544 let binary = CompiledValue::Bytes(Bytes::from_static(&[0xff, 0x00, 0x13]));
1545 assert_eq!(value_round_trip(&binary), binary);
1546 }
1547
1548 #[test]
1549 fn compiled_value_int_round_trip_including_extremes() {
1550 for i in [0_i64, i64::MIN, i64::MAX] {
1551 let v = CompiledValue::Int(i);
1552 assert_eq!(value_round_trip(&v), v);
1553 }
1554 }
1555
1556 #[test]
1557 fn compiled_value_bool_round_trip_both_variants() {
1558 for b in [true, false] {
1559 let v = CompiledValue::Bool(b);
1560 assert_eq!(value_round_trip(&v), v);
1561 }
1562 }
1563
1564 #[test]
1565 fn compiled_value_addr_round_trip_v4_and_v6() {
1566 let v4 = CompiledValue::Addr(Ipv4Addr::LOCALHOST.into());
1567 assert_eq!(value_round_trip(&v4), v4);
1568 let v6 = CompiledValue::Addr(Ipv6Addr::LOCALHOST.into());
1569 assert_eq!(value_round_trip(&v6), v6);
1570 }
1571
1572 #[test]
1573 fn compiled_value_bytes_emits_standard_base64_literal() {
1574 let v = CompiledValue::Bytes(Bytes::from_static(b"hello"));
1578 let encoded = serde_json::to_string(&v).expect("serialize");
1579 assert_eq!(encoded, r#"{"bytes":"aGVsbG8="}"#);
1580 }
1581
1582 fn op_round_trip(op: &CompiledOperator) -> CompiledOperator {
1583 let encoded = serde_json::to_string(op).expect("serialize op");
1584 serde_json::from_str(&encoded).expect("deserialize op")
1585 }
1586
1587 #[test]
1588 fn compiled_operator_equals_and_not_equals_round_trip() {
1589 let eq = CompiledOperator::Equals(CompiledValue::Str(Arc::<str>::from("x")));
1590 assert_eq!(op_round_trip(&eq), eq);
1591 let neq = CompiledOperator::NotEquals(CompiledValue::Str(Arc::<str>::from("x")));
1592 assert_eq!(op_round_trip(&neq), neq);
1593 }
1594
1595 #[test]
1596 fn compiled_operator_bytes_variants_round_trip() {
1597 let payload = Bytes::from_static(b"hello");
1598 let ops = [
1599 CompiledOperator::Contains(payload.clone()),
1600 CompiledOperator::NotContains(payload.clone()),
1601 CompiledOperator::Prefix(payload.clone()),
1602 CompiledOperator::Suffix(payload),
1603 ];
1604 for op in ops {
1605 assert_eq!(op_round_trip(&op), op);
1606 }
1607 }
1608
1609 #[test]
1610 fn compiled_operator_matches_round_trip_preserves_pattern_source() {
1611 let op = CompiledOperator::Matches(Regex::new("^/api/v[0-9]+").expect("compile"));
1612 let decoded = op_round_trip(&op);
1613 assert_eq!(decoded, op);
1615 match decoded {
1616 CompiledOperator::Matches(r) => assert_eq!(r.as_str(), "^/api/v[0-9]+"),
1617 other => panic!("expected matches, got {other:?}"),
1618 }
1619 }
1620
1621 #[test]
1622 fn compiled_operator_in_and_not_in_round_trip_mixed_values() {
1623 let xs = vec![CompiledValue::Str(Arc::<str>::from("a")), CompiledValue::Int(42)];
1624 let in_op = CompiledOperator::In(xs.clone());
1625 assert_eq!(op_round_trip(&in_op), in_op);
1626 let not_in_op = CompiledOperator::NotIn(xs);
1627 assert_eq!(op_round_trip(¬_in_op), not_in_op);
1628 }
1629
1630 #[test]
1631 fn compiled_operator_numeric_comparisons_round_trip() {
1632 let ops = [
1633 CompiledOperator::Gt(100),
1634 CompiledOperator::Gte(100),
1635 CompiledOperator::Lt(100),
1636 CompiledOperator::Lte(100),
1637 ];
1638 for op in ops {
1639 assert_eq!(op_round_trip(&op), op);
1640 }
1641 }
1642
1643 #[test]
1644 fn compiled_operator_cidr_round_trip_preserves_canonical_form() {
1645 let op = CompiledOperator::Cidr(IpNet::from_str("10.0.0.0/8").expect("parse"));
1646 assert_eq!(op_round_trip(&op), op);
1647 }
1648
1649 #[test]
1650 fn compiled_operator_matches_with_invalid_regex_is_rejected() {
1651 let raw = r#"{"matches":"["}"#;
1655 let err = serde_json::from_str::<CompiledOperator>(raw)
1656 .expect_err("invalid regex must fail to deserialize");
1657 let msg = err.to_string();
1658 assert!(msg.contains('['), "error mentions offending regex source: {msg}");
1659 }
1660
1661 #[test]
1662 fn predicate_inst_pins_exact_wire_shape_for_http_header_equals() {
1663 let inst = PredicateInst {
1664 path: FieldPath::HttpHeader(Arc::from("host")),
1665 op: CompiledOperator::Equals(CompiledValue::Str(Arc::<str>::from("example.com"))),
1666 };
1667 let encoded = serde_json::to_string(&inst).expect("serialize");
1668 assert_eq!(encoded, r#"{"path":{"http_header":"host"},"op":{"equals":{"str":"example.com"}}}"#,);
1669 let decoded: PredicateInst = serde_json::from_str(&encoded).expect("deserialize");
1670 assert_eq!(decoded, inst);
1671 }
1672
1673 #[test]
1674 fn predicate_inst_round_trip_with_regex_operator() {
1675 let inst = PredicateInst {
1676 path: FieldPath::HttpUriPath,
1677 op: CompiledOperator::Matches(Regex::new("^/api").expect("compile")),
1678 };
1679 let encoded = serde_json::to_string(&inst).expect("serialize");
1680 let decoded: PredicateInst = serde_json::from_str(&encoded).expect("deserialize");
1681 assert_eq!(decoded, inst);
1682 }
1683
1684 fn http_header_equals(name: &str, value: &str) -> PredicateInst {
1692 PredicateInst {
1693 path: FieldPath::HttpHeader(Arc::from(name)),
1694 op: CompiledOperator::Equals(CompiledValue::Str(Arc::<str>::from(value))),
1695 }
1696 }
1697
1698 fn http_uri_path_equals(value: &str) -> PredicateInst {
1699 PredicateInst {
1700 path: FieldPath::HttpUriPath,
1701 op: CompiledOperator::Equals(CompiledValue::Str(Arc::<str>::from(value))),
1702 }
1703 }
1704
1705 fn http_uri_path_prefix(value: &str) -> PredicateInst {
1706 PredicateInst {
1707 path: FieldPath::HttpUriPath,
1708 op: CompiledOperator::Prefix(Bytes::copy_from_slice(value.as_bytes())),
1709 }
1710 }
1711
1712 fn tls_sni_equals(value: &str) -> PredicateInst {
1713 PredicateInst {
1714 path: FieldPath::TlsSni,
1715 op: CompiledOperator::Equals(CompiledValue::Str(Arc::<str>::from(value))),
1716 }
1717 }
1718
1719 fn conn_with_sni(sni: &str) -> Arc<ConnContext> {
1720 let conn = make_conn();
1721 *conn.tls.lock() = Some(crate::conn_context::TlsInfo {
1722 sni: Some(sni.to_string()),
1723 alpn: None,
1724 version: None,
1725 peer_cert: None,
1726 });
1727 conn
1728 }
1729
1730 fn req_with_header(name: &str, value: &str) -> Request {
1731 http::Request::builder()
1732 .method("GET")
1733 .uri("/")
1734 .header(name, value)
1735 .body(Body::Empty)
1736 .expect("build req")
1737 }
1738
1739 fn req_with_uri(uri: &str) -> Request {
1740 http::Request::builder().method("GET").uri(uri).body(Body::Empty).expect("build req")
1741 }
1742
1743 #[test]
1744 fn predicate_test_http_header_equals_matches_when_present_and_equal() {
1745 let conn = make_conn();
1746 let req = req_with_header("upgrade", "websocket");
1747 let view = PredicateView::L7Req { conn: &conn, req: &req };
1748 assert!(http_header_equals("upgrade", "websocket").test(&view));
1749 }
1750
1751 #[test]
1752 fn predicate_test_http_header_equals_misses_when_header_absent() {
1753 let conn = make_conn();
1754 let req = req_with_header("host", "example.com");
1755 let view = PredicateView::L7Req { conn: &conn, req: &req };
1756 assert!(!http_header_equals("upgrade", "websocket").test(&view));
1757 }
1758
1759 #[test]
1760 fn predicate_test_http_header_equals_value_is_case_sensitive() {
1761 let conn = make_conn();
1766 let req = req_with_header("upgrade", "WebSocket");
1767 let view = PredicateView::L7Req { conn: &conn, req: &req };
1768 assert!(!http_header_equals("upgrade", "websocket").test(&view));
1769 }
1770
1771 #[test]
1772 fn predicate_test_http_header_equals_name_lookup_is_case_insensitive() {
1773 let conn = make_conn();
1779 let req = req_with_header("Upgrade", "websocket");
1780 let view = PredicateView::L7Req { conn: &conn, req: &req };
1781 assert!(http_header_equals("upgrade", "websocket").test(&view));
1782 }
1783
1784 #[test]
1785 fn predicate_test_http_header_equals_misses_on_l4_view() {
1786 let conn = make_conn();
1790 let view = PredicateView::L4 { conn: &conn, peek: None };
1791 assert!(!http_header_equals("upgrade", "websocket").test(&view));
1792 }
1793
1794 #[test]
1795 fn predicate_test_http_uri_path_equals_matches_exact() {
1796 let conn = make_conn();
1797 let req = req_with_uri("/api/v1/users");
1798 let view = PredicateView::L7Req { conn: &conn, req: &req };
1799 assert!(http_uri_path_equals("/api/v1/users").test(&view));
1800 }
1801
1802 #[test]
1803 fn predicate_test_http_uri_path_equals_misses_on_substring() {
1804 let conn = make_conn();
1808 let req = req_with_uri("/api/v1/users");
1809 let view = PredicateView::L7Req { conn: &conn, req: &req };
1810 assert!(!http_uri_path_equals("/api").test(&view));
1811 }
1812
1813 #[test]
1814 fn predicate_test_http_uri_path_prefix_matches_when_path_starts_with() {
1815 let conn = make_conn();
1816 let req = req_with_uri("/api/v1/users");
1817 let view = PredicateView::L7Req { conn: &conn, req: &req };
1818 assert!(http_uri_path_prefix("/api").test(&view));
1819 }
1820
1821 #[test]
1822 fn predicate_test_http_uri_path_prefix_misses_when_no_prefix() {
1823 let conn = make_conn();
1824 let req = req_with_uri("/admin");
1825 let view = PredicateView::L7Req { conn: &conn, req: &req };
1826 assert!(!http_uri_path_prefix("/api").test(&view));
1827 }
1828
1829 #[test]
1830 fn predicate_test_tls_sni_equals_matches_when_set() {
1831 let conn = conn_with_sni("api.example.com");
1835 let req = req_with_uri("/");
1836 let view = PredicateView::L7Req { conn: &conn, req: &req };
1837 assert!(tls_sni_equals("api.example.com").test(&view));
1838 }
1839
1840 #[test]
1841 fn predicate_test_tls_sni_equals_misses_when_unset() {
1842 let conn = make_conn();
1845 let req = req_with_uri("/");
1846 let view = PredicateView::L7Req { conn: &conn, req: &req };
1847 assert!(!tls_sni_equals("api.example.com").test(&view));
1848 }
1849
1850 #[test]
1851 fn predicate_test_tls_sni_equals_works_in_l4_view_too() {
1852 let conn = conn_with_sni("api.example.com");
1858 let view = PredicateView::L4 { conn: &conn, peek: None };
1859 assert!(tls_sni_equals("api.example.com").test(&view));
1860 }
1861
1862 fn pred(path: FieldPath, op: CompiledOperator) -> PredicateInst {
1872 PredicateInst { path, op }
1873 }
1874
1875 fn str_val(s: &str) -> CompiledValue {
1876 CompiledValue::Str(Arc::<str>::from(s))
1877 }
1878
1879 fn bytes_val(b: &[u8]) -> CompiledValue {
1880 CompiledValue::Bytes(Bytes::copy_from_slice(b))
1881 }
1882
1883 fn b(b: &[u8]) -> Bytes {
1884 Bytes::copy_from_slice(b)
1885 }
1886
1887 fn make_conn_with(remote: &str, local: &str) -> Arc<ConnContext> {
1888 Arc::new(ConnContext {
1889 id: ConnId(1),
1890 remote: remote.parse().expect("parse remote"),
1891 local: local.parse().expect("parse local"),
1892 transport: Transport::Tcp,
1893 entered_at: Instant::now(),
1894 tls: Mutex::new(None),
1895 http_version: OnceLock::new(),
1896 user: Mutex::new(http::Extensions::new()),
1897 })
1898 }
1899
1900 fn make_conn_with_transport(t: Transport) -> Arc<ConnContext> {
1901 Arc::new(ConnContext {
1902 id: ConnId(1),
1903 remote: "127.0.0.1:0".parse().expect("remote"),
1904 local: "127.0.0.1:0".parse().expect("local"),
1905 transport: t,
1906 entered_at: Instant::now(),
1907 tls: Mutex::new(None),
1908 http_version: OnceLock::new(),
1909 user: Mutex::new(http::Extensions::new()),
1910 })
1911 }
1912
1913 fn conn_with_tls_alpn(alpn: &[u8]) -> Arc<ConnContext> {
1914 let conn = make_conn();
1915 *conn.tls.lock() = Some(crate::conn_context::TlsInfo {
1916 sni: None,
1917 alpn: Some(alpn.to_vec()),
1918 version: None,
1919 peer_cert: None,
1920 });
1921 conn
1922 }
1923
1924 fn conn_with_tls_version(v: crate::conn_context::TlsVersion) -> Arc<ConnContext> {
1925 let conn = make_conn();
1926 *conn.tls.lock() = Some(crate::conn_context::TlsInfo {
1927 sni: None,
1928 alpn: None,
1929 version: Some(v),
1930 peer_cert: None,
1931 });
1932 conn
1933 }
1934
1935 #[test]
1938 fn matrix_equality_str_happy_and_miss() {
1939 let conn = conn_with_sni("api.example.com");
1941 let v = PredicateView::L4 { conn: &conn, peek: None };
1942 assert!(pred(FieldPath::TlsSni, CompiledOperator::Equals(str_val("api.example.com"))).test(&v));
1943 assert!(
1944 !pred(FieldPath::TlsSni, CompiledOperator::Equals(str_val("other.example.com"))).test(&v)
1945 );
1946 assert!(
1947 pred(FieldPath::TlsSni, CompiledOperator::NotEquals(str_val("other.example.com"))).test(&v)
1948 );
1949 assert!(
1950 !pred(FieldPath::TlsSni, CompiledOperator::NotEquals(str_val("api.example.com"))).test(&v)
1951 );
1952 }
1953
1954 #[test]
1955 fn matrix_equality_bytes_happy_and_miss() {
1956 let conn = conn_with_tls_alpn(b"h2");
1958 let v = PredicateView::L4 { conn: &conn, peek: None };
1959 assert!(pred(FieldPath::TlsAlpn, CompiledOperator::Equals(bytes_val(b"h2"))).test(&v));
1960 assert!(!pred(FieldPath::TlsAlpn, CompiledOperator::Equals(bytes_val(b"http/1.1"))).test(&v));
1961 assert!(pred(FieldPath::TlsAlpn, CompiledOperator::NotEquals(bytes_val(b"http/1.1"))).test(&v));
1962 assert!(!pred(FieldPath::TlsAlpn, CompiledOperator::NotEquals(bytes_val(b"h2"))).test(&v));
1963 }
1964
1965 #[test]
1966 fn matrix_equality_int_happy_and_miss() {
1967 let conn = make_conn_with("127.0.0.1:9090", "127.0.0.1:80");
1968 let v = PredicateView::L4 { conn: &conn, peek: None };
1969 assert!(
1970 pred(FieldPath::RemotePort, CompiledOperator::Equals(CompiledValue::Int(9090))).test(&v)
1971 );
1972 assert!(
1973 !pred(FieldPath::RemotePort, CompiledOperator::Equals(CompiledValue::Int(81))).test(&v)
1974 );
1975 assert!(
1976 pred(FieldPath::RemotePort, CompiledOperator::NotEquals(CompiledValue::Int(81))).test(&v)
1977 );
1978 assert!(
1979 !pred(FieldPath::RemotePort, CompiledOperator::NotEquals(CompiledValue::Int(9090))).test(&v)
1980 );
1981 }
1982
1983 #[test]
1984 fn matrix_equality_addr_happy_and_miss() {
1985 let conn = make_conn_with("10.0.0.5:55555", "127.0.0.1:80");
1986 let v = PredicateView::L4 { conn: &conn, peek: None };
1987 let ten: std::net::IpAddr = "10.0.0.5".parse().unwrap();
1988 let other: std::net::IpAddr = "10.0.0.6".parse().unwrap();
1989 assert!(pred(FieldPath::RemoteIp, CompiledOperator::Equals(CompiledValue::Addr(ten))).test(&v));
1990 assert!(
1991 !pred(FieldPath::RemoteIp, CompiledOperator::Equals(CompiledValue::Addr(other))).test(&v)
1992 );
1993 assert!(
1994 pred(FieldPath::RemoteIp, CompiledOperator::NotEquals(CompiledValue::Addr(other))).test(&v)
1995 );
1996 assert!(
1997 !pred(FieldPath::RemoteIp, CompiledOperator::NotEquals(CompiledValue::Addr(ten))).test(&v)
1998 );
1999 }
2000
2001 #[test]
2002 fn matrix_equality_enum_transport_happy_and_miss() {
2003 let tcp = make_conn_with_transport(Transport::Tcp);
2004 let udp = make_conn_with_transport(Transport::Udp);
2005 let v_tcp = PredicateView::L4 { conn: &tcp, peek: None };
2006 let v_udp = PredicateView::L4 { conn: &udp, peek: None };
2007 assert!(pred(FieldPath::Transport, CompiledOperator::Equals(str_val("tcp"))).test(&v_tcp));
2008 assert!(!pred(FieldPath::Transport, CompiledOperator::Equals(str_val("udp"))).test(&v_tcp));
2009 assert!(pred(FieldPath::Transport, CompiledOperator::Equals(str_val("udp"))).test(&v_udp));
2010 }
2011
2012 #[test]
2013 fn matrix_equality_enum_tls_version_happy_and_miss() {
2014 let conn = conn_with_tls_version(crate::conn_context::TlsVersion::Tls13);
2015 let v = PredicateView::L4 { conn: &conn, peek: None };
2016 assert!(pred(FieldPath::TlsVersion, CompiledOperator::Equals(str_val("1.3"))).test(&v));
2017 assert!(!pred(FieldPath::TlsVersion, CompiledOperator::Equals(str_val("1.2"))).test(&v));
2018 assert!(pred(FieldPath::TlsVersion, CompiledOperator::NotEquals(str_val("1.2"))).test(&v));
2019 }
2020
2021 #[test]
2022 fn matrix_equality_enum_tls_version_misses_when_absent() {
2023 let conn = make_conn();
2025 let v = PredicateView::L4 { conn: &conn, peek: None };
2026 assert!(!pred(FieldPath::TlsVersion, CompiledOperator::Equals(str_val("1.3"))).test(&v));
2027 assert!(!pred(FieldPath::TlsVersion, CompiledOperator::NotEquals(str_val("1.3"))).test(&v));
2029 }
2030
2031 #[test]
2032 fn matrix_equality_enum_http_method_happy_and_miss() {
2033 let conn = make_conn();
2034 let req = http::Request::builder().method("POST").uri("/").body(Body::Empty).unwrap();
2035 let v = PredicateView::L7Req { conn: &conn, req: &req };
2036 assert!(pred(FieldPath::HttpMethod, CompiledOperator::Equals(str_val("POST"))).test(&v));
2037 assert!(!pred(FieldPath::HttpMethod, CompiledOperator::Equals(str_val("GET"))).test(&v));
2038 assert!(pred(FieldPath::HttpMethod, CompiledOperator::NotEquals(str_val("GET"))).test(&v));
2039 }
2040
2041 #[test]
2044 fn matrix_in_list_str_happy_and_miss() {
2045 let conn = conn_with_sni("api.example.com");
2046 let v = PredicateView::L4 { conn: &conn, peek: None };
2047 let list = vec![str_val("a.example.com"), str_val("api.example.com")];
2048 assert!(pred(FieldPath::TlsSni, CompiledOperator::In(list.clone())).test(&v));
2049 let list_miss = vec![str_val("a.example.com"), str_val("b.example.com")];
2050 assert!(!pred(FieldPath::TlsSni, CompiledOperator::In(list_miss.clone())).test(&v));
2051 assert!(pred(FieldPath::TlsSni, CompiledOperator::NotIn(list_miss)).test(&v));
2052 assert!(!pred(FieldPath::TlsSni, CompiledOperator::NotIn(list)).test(&v));
2053 }
2054
2055 #[test]
2056 fn matrix_in_list_bytes_happy_and_miss() {
2057 let conn = conn_with_tls_alpn(b"h2");
2058 let v = PredicateView::L4 { conn: &conn, peek: None };
2059 let list = vec![bytes_val(b"http/1.1"), bytes_val(b"h2")];
2060 assert!(pred(FieldPath::TlsAlpn, CompiledOperator::In(list.clone())).test(&v));
2061 let list_miss = vec![bytes_val(b"http/1.0"), bytes_val(b"http/1.1")];
2062 assert!(!pred(FieldPath::TlsAlpn, CompiledOperator::In(list_miss.clone())).test(&v));
2063 assert!(pred(FieldPath::TlsAlpn, CompiledOperator::NotIn(list_miss)).test(&v));
2064 }
2065
2066 #[test]
2067 fn matrix_in_list_int_happy_and_miss() {
2068 let conn = make_conn_with("127.0.0.1:443", "127.0.0.1:80");
2069 let v = PredicateView::L4 { conn: &conn, peek: None };
2070 let in_list = vec![CompiledValue::Int(80), CompiledValue::Int(443)];
2071 assert!(pred(FieldPath::RemotePort, CompiledOperator::In(in_list.clone())).test(&v));
2072 let miss_list = vec![CompiledValue::Int(80), CompiledValue::Int(81)];
2073 assert!(!pred(FieldPath::RemotePort, CompiledOperator::In(miss_list.clone())).test(&v));
2074 assert!(pred(FieldPath::RemotePort, CompiledOperator::NotIn(miss_list)).test(&v));
2075 }
2076
2077 #[test]
2078 fn matrix_in_list_addr_happy_and_miss_mixed_family() {
2079 let conn = make_conn_with("10.0.0.5:55555", "127.0.0.1:80");
2080 let v = PredicateView::L4 { conn: &conn, peek: None };
2081 let v4: std::net::IpAddr = "10.0.0.5".parse().unwrap();
2082 let v6: std::net::IpAddr = "::1".parse().unwrap();
2083 let list = vec![CompiledValue::Addr(v6), CompiledValue::Addr(v4)];
2084 assert!(pred(FieldPath::RemoteIp, CompiledOperator::In(list.clone())).test(&v));
2085 let miss = vec![CompiledValue::Addr(v6)];
2086 assert!(!pred(FieldPath::RemoteIp, CompiledOperator::In(miss.clone())).test(&v));
2087 assert!(pred(FieldPath::RemoteIp, CompiledOperator::NotIn(miss)).test(&v));
2088 }
2089
2090 #[test]
2091 fn matrix_in_list_enum_transport_happy_and_miss() {
2092 let conn = make_conn_with_transport(Transport::Udp);
2093 let v = PredicateView::L4 { conn: &conn, peek: None };
2094 let list = vec![str_val("tcp"), str_val("udp")];
2095 assert!(pred(FieldPath::Transport, CompiledOperator::In(list)).test(&v));
2096 let miss = vec![str_val("tcp")];
2097 assert!(!pred(FieldPath::Transport, CompiledOperator::In(miss.clone())).test(&v));
2098 assert!(pred(FieldPath::Transport, CompiledOperator::NotIn(miss)).test(&v));
2099 }
2100
2101 #[test]
2104 fn matrix_substring_on_str_happy_and_miss() {
2105 let conn = make_conn();
2106 let req =
2107 http::Request::builder().method("GET").uri("/api/v1/users").body(Body::Empty).unwrap();
2108 let v = PredicateView::L7Req { conn: &conn, req: &req };
2109 assert!(pred(FieldPath::HttpUriPath, CompiledOperator::Contains(b(b"/v1/"))).test(&v));
2110 assert!(!pred(FieldPath::HttpUriPath, CompiledOperator::Contains(b(b"/v2/"))).test(&v));
2111 assert!(pred(FieldPath::HttpUriPath, CompiledOperator::NotContains(b(b"/v2/"))).test(&v));
2112 assert!(!pred(FieldPath::HttpUriPath, CompiledOperator::NotContains(b(b"/v1/"))).test(&v));
2113 }
2114
2115 #[test]
2116 fn matrix_substring_on_bytes_happy_and_miss() {
2117 let conn = conn_with_tls_alpn(b"http/1.1");
2118 let v = PredicateView::L4 { conn: &conn, peek: None };
2119 assert!(pred(FieldPath::TlsAlpn, CompiledOperator::Contains(b(b"/1."))).test(&v));
2120 assert!(!pred(FieldPath::TlsAlpn, CompiledOperator::Contains(b(b"/2."))).test(&v));
2121 assert!(pred(FieldPath::TlsAlpn, CompiledOperator::NotContains(b(b"/2."))).test(&v));
2122 }
2123
2124 #[test]
2127 fn matrix_prefix_suffix_on_str_happy_and_miss() {
2128 let conn = make_conn();
2129 let req =
2130 http::Request::builder().method("GET").uri("/api/file.json?q=1").body(Body::Empty).unwrap();
2131 let v = PredicateView::L7Req { conn: &conn, req: &req };
2132 assert!(pred(FieldPath::HttpUriPath, CompiledOperator::Prefix(b(b"/api"))).test(&v));
2133 assert!(!pred(FieldPath::HttpUriPath, CompiledOperator::Prefix(b(b"/admin"))).test(&v));
2134 assert!(pred(FieldPath::HttpUriPath, CompiledOperator::Suffix(b(b".json"))).test(&v));
2135 assert!(!pred(FieldPath::HttpUriPath, CompiledOperator::Suffix(b(b".html"))).test(&v));
2136 }
2137
2138 #[test]
2139 fn matrix_prefix_suffix_on_bytes_happy_and_miss() {
2140 let conn = conn_with_tls_alpn(b"http/1.1");
2141 let v = PredicateView::L4 { conn: &conn, peek: None };
2142 assert!(pred(FieldPath::TlsAlpn, CompiledOperator::Prefix(b(b"http"))).test(&v));
2143 assert!(!pred(FieldPath::TlsAlpn, CompiledOperator::Prefix(b(b"h2"))).test(&v));
2144 assert!(pred(FieldPath::TlsAlpn, CompiledOperator::Suffix(b(b"1.1"))).test(&v));
2145 assert!(!pred(FieldPath::TlsAlpn, CompiledOperator::Suffix(b(b"2.0"))).test(&v));
2146 }
2147
2148 #[test]
2151 fn matrix_regex_matches_on_str_happy_and_miss() {
2152 let conn = make_conn();
2153 let req =
2154 http::Request::builder().method("GET").uri("/api/v3/orders").body(Body::Empty).unwrap();
2155 let v = PredicateView::L7Req { conn: &conn, req: &req };
2156 let re = Regex::new(r"^/api/v\d+/orders").expect("compile regex");
2157 assert!(pred(FieldPath::HttpUriPath, CompiledOperator::Matches(re)).test(&v));
2158 let re_miss = Regex::new(r"^/admin").expect("compile regex");
2159 assert!(!pred(FieldPath::HttpUriPath, CompiledOperator::Matches(re_miss)).test(&v));
2160 }
2161
2162 #[test]
2163 fn matrix_regex_matches_on_header_happy_and_miss() {
2164 let conn = make_conn();
2165 let req = http::Request::builder()
2166 .method("GET")
2167 .uri("/")
2168 .header("user-agent", "Mozilla/5.0 (Macintosh; Intel)")
2169 .body(Body::Empty)
2170 .unwrap();
2171 let v = PredicateView::L7Req { conn: &conn, req: &req };
2172 let re = Regex::new(r"(?i)mozilla").expect("compile");
2173 assert!(
2174 pred(FieldPath::HttpHeader(Arc::from("user-agent")), CompiledOperator::Matches(re)).test(&v)
2175 );
2176 let re_miss = Regex::new(r"^curl").expect("compile");
2177 assert!(
2178 !pred(FieldPath::HttpHeader(Arc::from("user-agent")), CompiledOperator::Matches(re_miss))
2179 .test(&v)
2180 );
2181 }
2182
2183 #[test]
2186 fn matrix_numeric_cmp_gt_gte_lt_lte_happy_and_miss() {
2187 let conn = make_conn_with("127.0.0.1:1024", "127.0.0.1:443");
2188 let v = PredicateView::L4 { conn: &conn, peek: None };
2189 assert!(pred(FieldPath::RemotePort, CompiledOperator::Gt(1023)).test(&v));
2191 assert!(!pred(FieldPath::RemotePort, CompiledOperator::Gt(1024)).test(&v));
2192 assert!(pred(FieldPath::RemotePort, CompiledOperator::Gte(1024)).test(&v));
2194 assert!(!pred(FieldPath::RemotePort, CompiledOperator::Gte(1025)).test(&v));
2195 assert!(pred(FieldPath::RemotePort, CompiledOperator::Lt(1025)).test(&v));
2197 assert!(!pred(FieldPath::RemotePort, CompiledOperator::Lt(1024)).test(&v));
2198 assert!(pred(FieldPath::RemotePort, CompiledOperator::Lte(1024)).test(&v));
2200 assert!(!pred(FieldPath::RemotePort, CompiledOperator::Lte(1023)).test(&v));
2201 }
2202
2203 #[test]
2204 fn matrix_numeric_cmp_local_port_too() {
2205 let conn = make_conn_with("127.0.0.1:0", "127.0.0.1:8443");
2207 let v = PredicateView::L4 { conn: &conn, peek: None };
2208 assert!(pred(FieldPath::LocalPort, CompiledOperator::Gt(8000)).test(&v));
2209 assert!(!pred(FieldPath::LocalPort, CompiledOperator::Gt(9000)).test(&v));
2210 }
2211
2212 #[test]
2215 fn matrix_cidr_v4_happy_and_miss() {
2216 let conn = make_conn_with("10.0.5.7:0", "127.0.0.1:0");
2217 let v = PredicateView::L4 { conn: &conn, peek: None };
2218 let ten = IpNet::from_str("10.0.0.0/8").unwrap();
2219 let nineteen2 = IpNet::from_str("192.168.0.0/16").unwrap();
2220 assert!(pred(FieldPath::RemoteIp, CompiledOperator::Cidr(ten)).test(&v));
2221 assert!(!pred(FieldPath::RemoteIp, CompiledOperator::Cidr(nineteen2)).test(&v));
2222 }
2223
2224 #[test]
2225 fn matrix_cidr_v6_happy_and_miss() {
2226 let conn = make_conn_with("[2001:db8::5]:0", "127.0.0.1:0");
2227 let v = PredicateView::L4 { conn: &conn, peek: None };
2228 let net = IpNet::from_str("2001:db8::/32").unwrap();
2229 let other = IpNet::from_str("2001:dead::/32").unwrap();
2230 assert!(pred(FieldPath::RemoteIp, CompiledOperator::Cidr(net)).test(&v));
2231 assert!(!pred(FieldPath::RemoteIp, CompiledOperator::Cidr(other)).test(&v));
2232 }
2233
2234 #[test]
2235 fn matrix_cidr_v4_against_v6_addr_misses() {
2236 let conn = make_conn_with("[2001:db8::5]:0", "127.0.0.1:0");
2238 let v = PredicateView::L4 { conn: &conn, peek: None };
2239 let v4 = IpNet::from_str("0.0.0.0/0").unwrap();
2240 assert!(!pred(FieldPath::RemoteIp, CompiledOperator::Cidr(v4)).test(&v));
2241 }
2242
2243 #[test]
2247 fn http_uri_query_reader_returns_empty_when_query_absent() {
2248 let conn = make_conn();
2251 let req = http::Request::builder().method("GET").uri("/no-q").body(Body::Empty).unwrap();
2252 let v = PredicateView::L7Req { conn: &conn, req: &req };
2253 assert!(pred(FieldPath::HttpUriQuery, CompiledOperator::Equals(str_val(""))).test(&v));
2254 assert!(!pred(FieldPath::HttpUriQuery, CompiledOperator::Equals(str_val("q=1"))).test(&v));
2255 }
2256
2257 #[test]
2258 fn http_uri_query_reader_matches_present_query() {
2259 let conn = make_conn();
2260 let req = http::Request::builder().method("GET").uri("/x?a=1&b=2").body(Body::Empty).unwrap();
2261 let v = PredicateView::L7Req { conn: &conn, req: &req };
2262 assert!(pred(FieldPath::HttpUriQuery, CompiledOperator::Equals(str_val("a=1&b=2"))).test(&v));
2263 assert!(pred(FieldPath::HttpUriQuery, CompiledOperator::Contains(b(b"b=2"))).test(&v));
2264 }
2265
2266 #[test]
2267 fn local_ip_reader_uses_local_socket() {
2268 let conn = make_conn_with("10.0.0.5:0", "127.0.0.1:8443");
2269 let v = PredicateView::L4 { conn: &conn, peek: None };
2270 let local: std::net::IpAddr = "127.0.0.1".parse().unwrap();
2271 assert!(
2272 pred(FieldPath::LocalIp, CompiledOperator::Equals(CompiledValue::Addr(local))).test(&v)
2273 );
2274 }
2275
2276 #[test]
2277 fn http_header_lookup_misses_for_non_utf8_value() {
2278 let conn = make_conn();
2281 let bad =
2282 http::HeaderValue::from_bytes(&[0xff, 0xfe, 0xfd]).expect("non-utf8 header value parses");
2283 let mut builder = http::Request::builder().method("GET").uri("/");
2284 builder.headers_mut().expect("headers").insert("x-bad", bad);
2285 let req: Request = builder.body(Body::Empty).expect("build request");
2286 let v = PredicateView::L7Req { conn: &conn, req: &req };
2287 assert!(
2288 !pred(
2289 FieldPath::HttpHeader(Arc::from("x-bad")),
2290 CompiledOperator::Equals(str_val("anything")),
2291 )
2292 .test(&v)
2293 );
2294 }
2295
2296 fn rcgen_cert_with_cn(cn: &str) -> rustls_pki_types::CertificateDer<'static> {
2299 let mut params = rcgen::CertificateParams::default();
2300 params.distinguished_name = rcgen::DistinguishedName::new();
2301 params.distinguished_name.push(rcgen::DnType::CommonName, cn);
2302 let key = rcgen::KeyPair::generate().expect("rcgen keypair");
2303 let cert = params.self_signed(&key).expect("self-sign cert");
2304 cert.der().clone()
2305 }
2306
2307 fn rcgen_cert_no_cn() -> rustls_pki_types::CertificateDer<'static> {
2308 let params = rcgen::CertificateParams::default();
2311 let mut params = params;
2314 params.distinguished_name = rcgen::DistinguishedName::new();
2315 let key = rcgen::KeyPair::generate().expect("rcgen keypair");
2316 let cert = params.self_signed(&key).expect("self-sign cert");
2317 cert.der().clone()
2318 }
2319
2320 fn conn_with_peer_cert(cert: rustls_pki_types::CertificateDer<'static>) -> Arc<ConnContext> {
2321 let conn = make_conn();
2322 *conn.tls.lock() = Some(crate::conn_context::TlsInfo {
2323 sni: None,
2324 alpn: None,
2325 version: None,
2326 peer_cert: Some(cert),
2327 });
2328 conn
2329 }
2330
2331 #[test]
2332 fn peer_cert_subject_cn_extraction_returns_cn_string() {
2333 let cert = rcgen_cert_with_cn("client.internal");
2335 assert_eq!(peer_cert_subject_cn(cert.as_ref()), Some("client.internal".to_string()));
2336 }
2337
2338 #[test]
2339 fn peer_cert_subject_cn_returns_none_for_malformed_der() {
2340 assert_eq!(peer_cert_subject_cn(&[0x30, 0x80, 0x00, 0x00]), None);
2342 assert_eq!(peer_cert_subject_cn(b"not a cert at all"), None);
2343 }
2344
2345 #[test]
2346 fn peer_cert_subject_cn_returns_none_when_dn_has_no_cn() {
2347 let cert = rcgen_cert_no_cn();
2348 assert_eq!(peer_cert_subject_cn(cert.as_ref()), None);
2349 }
2350
2351 #[test]
2352 fn matrix_peer_cert_subject_cn_equals_happy_and_miss() {
2353 let cert = rcgen_cert_with_cn("ops-bot");
2354 let conn = conn_with_peer_cert(cert);
2355 let v = PredicateView::L4 { conn: &conn, peek: None };
2356 assert!(
2357 pred(FieldPath::TlsPeerCertSubjectCn, CompiledOperator::Equals(str_val("ops-bot"))).test(&v)
2358 );
2359 assert!(
2360 !pred(FieldPath::TlsPeerCertSubjectCn, CompiledOperator::Equals(str_val("attacker")))
2361 .test(&v)
2362 );
2363 }
2364
2365 #[test]
2366 fn matrix_peer_cert_subject_cn_string_ops_happy_and_miss() {
2367 let cert = rcgen_cert_with_cn("svc-payments-prod");
2368 let conn = conn_with_peer_cert(cert);
2369 let v = PredicateView::L4 { conn: &conn, peek: None };
2370 assert!(pred(FieldPath::TlsPeerCertSubjectCn, CompiledOperator::Prefix(b(b"svc-"))).test(&v));
2372 assert!(
2373 !pred(FieldPath::TlsPeerCertSubjectCn, CompiledOperator::Prefix(b(b"client-"))).test(&v)
2374 );
2375 assert!(pred(FieldPath::TlsPeerCertSubjectCn, CompiledOperator::Suffix(b(b"-prod"))).test(&v));
2377 assert!(
2379 pred(FieldPath::TlsPeerCertSubjectCn, CompiledOperator::Contains(b(b"payments"))).test(&v)
2380 );
2381 let re = Regex::new(r"^svc-[a-z]+-(prod|stg)$").expect("regex");
2383 assert!(pred(FieldPath::TlsPeerCertSubjectCn, CompiledOperator::Matches(re)).test(&v));
2384 let list = vec![str_val("svc-other-prod"), str_val("svc-payments-prod")];
2386 assert!(pred(FieldPath::TlsPeerCertSubjectCn, CompiledOperator::In(list)).test(&v));
2387 }
2388
2389 #[test]
2390 fn peer_cert_subject_cn_misses_when_cert_absent() {
2391 let conn = make_conn();
2394 let v = PredicateView::L4 { conn: &conn, peek: None };
2395 assert!(
2396 !pred(FieldPath::TlsPeerCertSubjectCn, CompiledOperator::Equals(str_val("anything")))
2397 .test(&v)
2398 );
2399 }
2400
2401 #[test]
2402 fn peer_cert_subject_cn_misses_when_cert_has_no_cn() {
2403 let cert = rcgen_cert_no_cn();
2406 let conn = conn_with_peer_cert(cert);
2407 let v = PredicateView::L4 { conn: &conn, peek: None };
2408 assert!(
2409 !pred(FieldPath::TlsPeerCertSubjectCn, CompiledOperator::Equals(str_val("ops-bot"))).test(&v)
2410 );
2411 }
2412
2413 fn req_with_body(body_bytes: &[u8]) -> Request {
2421 http::Request::builder()
2422 .method("POST")
2423 .uri("/upload")
2424 .body(Body::Static(Bytes::copy_from_slice(body_bytes)))
2425 .expect("build req with body")
2426 }
2427
2428 #[test]
2429 fn matrix_http_body_equality_happy_and_miss() {
2430 let conn = make_conn();
2431 let req = req_with_body(b"hello world");
2432 let v = PredicateView::L7Req { conn: &conn, req: &req };
2433 assert!(
2434 pred(FieldPath::HttpBody, CompiledOperator::Equals(bytes_val(b"hello world"))).test(&v)
2435 );
2436 assert!(!pred(FieldPath::HttpBody, CompiledOperator::Equals(bytes_val(b"wrong"))).test(&v));
2437 assert!(pred(FieldPath::HttpBody, CompiledOperator::NotEquals(bytes_val(b"wrong"))).test(&v));
2438 }
2439
2440 #[test]
2441 fn matrix_http_body_substring_happy_and_miss() {
2442 let conn = make_conn();
2443 let req = req_with_body(b"prelude payload trailer");
2444 let v = PredicateView::L7Req { conn: &conn, req: &req };
2445 assert!(pred(FieldPath::HttpBody, CompiledOperator::Contains(b(b"payload"))).test(&v));
2446 assert!(!pred(FieldPath::HttpBody, CompiledOperator::Contains(b(b"missing"))).test(&v));
2447 assert!(pred(FieldPath::HttpBody, CompiledOperator::NotContains(b(b"missing"))).test(&v));
2448 }
2449
2450 #[test]
2451 fn matrix_http_body_prefix_suffix_happy_and_miss() {
2452 let conn = make_conn();
2453 let req = req_with_body(b"START middle END");
2454 let v = PredicateView::L7Req { conn: &conn, req: &req };
2455 assert!(pred(FieldPath::HttpBody, CompiledOperator::Prefix(b(b"START"))).test(&v));
2456 assert!(!pred(FieldPath::HttpBody, CompiledOperator::Prefix(b(b"BEGIN"))).test(&v));
2457 assert!(pred(FieldPath::HttpBody, CompiledOperator::Suffix(b(b"END"))).test(&v));
2458 assert!(!pred(FieldPath::HttpBody, CompiledOperator::Suffix(b(b"FIN"))).test(&v));
2459 }
2460
2461 #[test]
2462 fn matrix_http_body_in_list_happy_and_miss() {
2463 let conn = make_conn();
2464 let req = req_with_body(b"one");
2465 let v = PredicateView::L7Req { conn: &conn, req: &req };
2466 let list = vec![bytes_val(b"two"), bytes_val(b"one")];
2467 assert!(pred(FieldPath::HttpBody, CompiledOperator::In(list)).test(&v));
2468 let miss = vec![bytes_val(b"two"), bytes_val(b"three")];
2469 assert!(!pred(FieldPath::HttpBody, CompiledOperator::In(miss.clone())).test(&v));
2470 assert!(pred(FieldPath::HttpBody, CompiledOperator::NotIn(miss)).test(&v));
2471 }
2472
2473 #[test]
2474 fn http_body_misses_on_l4_view() {
2475 let conn = make_conn();
2478 let v = PredicateView::L4 { conn: &conn, peek: None };
2479 assert!(!pred(FieldPath::HttpBody, CompiledOperator::Contains(b(b"x"))).test(&v));
2480 }
2481
2482 #[test]
2483 #[should_panic(expected = "lazy-buffer invariant")]
2484 fn http_body_panics_when_lazy_buffer_invariant_violated() {
2485 let conn = make_conn();
2493 let req = http::Request::builder().method("POST").uri("/").body(Body::Empty).unwrap();
2494 let v = PredicateView::L7Req { conn: &conn, req: &req };
2495 let _ = pred(FieldPath::HttpBody, CompiledOperator::Contains(b(b"x"))).test(&v);
2496 }
2497
2498 #[test]
2507 fn matrix_peek_substring_happy_and_miss() {
2508 let buf: &[u8] = &[0x16, 0x03, 0x01, 0x00, 0x40, 0x01];
2510 let conn = make_conn();
2511 let v = PredicateView::L4 { conn: &conn, peek: Some(buf) };
2512 assert!(pred(FieldPath::Peek, CompiledOperator::Prefix(b(b"\x16\x03"))).test(&v));
2513 assert!(!pred(FieldPath::Peek, CompiledOperator::Prefix(b(b"\x14\x03"))).test(&v));
2514 assert!(pred(FieldPath::Peek, CompiledOperator::Contains(b(b"\x03\x01"))).test(&v));
2515 assert!(!pred(FieldPath::Peek, CompiledOperator::Contains(b(b"\xff\xff"))).test(&v));
2516 }
2517
2518 #[test]
2519 fn matrix_peek_equality_happy_and_miss() {
2520 let buf: &[u8] = b"GET";
2521 let conn = make_conn();
2522 let v = PredicateView::L4 { conn: &conn, peek: Some(buf) };
2523 assert!(pred(FieldPath::Peek, CompiledOperator::Equals(bytes_val(b"GET"))).test(&v));
2524 assert!(!pred(FieldPath::Peek, CompiledOperator::Equals(bytes_val(b"PUT"))).test(&v));
2525 assert!(pred(FieldPath::Peek, CompiledOperator::NotEquals(bytes_val(b"PUT"))).test(&v));
2526 }
2527
2528 #[test]
2529 fn matrix_peek_in_list_happy_and_miss() {
2530 let buf: &[u8] = b"PRI ";
2531 let conn = make_conn();
2532 let v = PredicateView::L4 { conn: &conn, peek: Some(buf) };
2533 let list = vec![bytes_val(b"GET "), bytes_val(b"PRI ")];
2535 assert!(pred(FieldPath::Peek, CompiledOperator::In(list)).test(&v));
2536 let miss = vec![bytes_val(b"POST"), bytes_val(b"HEAD")];
2537 assert!(!pred(FieldPath::Peek, CompiledOperator::In(miss.clone())).test(&v));
2538 assert!(pred(FieldPath::Peek, CompiledOperator::NotIn(miss)).test(&v));
2539 }
2540
2541 #[test]
2542 fn peek_misses_when_buffer_absent_on_l4_view() {
2543 let conn = make_conn();
2546 let v = PredicateView::L4 { conn: &conn, peek: None };
2547 assert!(!pred(FieldPath::Peek, CompiledOperator::Prefix(b(b"\x16"))).test(&v));
2548 let req = http::Request::builder().method("GET").uri("/").body(Body::Empty).unwrap();
2550 let v7 = PredicateView::L7Req { conn: &conn, req: &req };
2551 assert!(!pred(FieldPath::Peek, CompiledOperator::Prefix(b(b"\x16"))).test(&v7));
2552 }
2553}