1use std::hash::{Hash, Hasher};
2use std::net::IpAddr;
3use std::sync::Arc;
4
5use bytes::Bytes;
6use ipnet::IpNet;
7
8use crate::body::Request;
9use crate::conn_context::ConnContext;
10
11#[derive(Clone, Eq, PartialEq, Hash, Debug, serde::Serialize, serde::Deserialize)]
12#[serde(rename_all = "snake_case")]
13pub enum FieldPath {
14 Transport,
15 RemoteIp,
16 RemotePort,
17 LocalIp,
18 LocalPort,
19 Peek,
20 TlsSni,
21 TlsAlpn,
22 TlsVersion,
23 TlsPeerCertSubjectCn,
24 HttpMethod,
25 HttpUriPath,
26 HttpUriQuery,
27 HttpHeader(Arc<str>),
28 HttpBody,
29}
30
31#[derive(Clone, Debug)]
32pub enum CompiledValue {
33 Str(Arc<str>),
34 Bytes(Bytes),
35 Int(i64),
36 Bool(bool),
37 Addr(IpAddr),
38}
39
40impl PartialEq for CompiledValue {
41 fn eq(&self, other: &Self) -> bool {
42 match (self, other) {
43 (Self::Str(a), Self::Str(b)) => a.as_ref() == b.as_ref(),
44 (Self::Bytes(a), Self::Bytes(b)) => a == b,
45 (Self::Int(a), Self::Int(b)) => a == b,
46 (Self::Bool(a), Self::Bool(b)) => a == b,
47 (Self::Addr(a), Self::Addr(b)) => a == b,
48 _ => false,
49 }
50 }
51}
52
53impl Eq for CompiledValue {}
54
55impl Hash for CompiledValue {
56 fn hash<H: Hasher>(&self, state: &mut H) {
57 std::mem::discriminant(self).hash(state);
58 match self {
59 Self::Str(s) => s.as_ref().hash(state),
60 Self::Bytes(b) => b.hash(state),
61 Self::Int(i) => i.hash(state),
62 Self::Bool(b) => b.hash(state),
63 Self::Addr(a) => a.hash(state),
64 }
65 }
66}
67
68#[derive(Clone, Debug)]
69pub enum CompiledOperator {
70 Equals(CompiledValue),
71 NotEquals(CompiledValue),
72 Contains(Bytes),
73 NotContains(Bytes),
74 Prefix(Bytes),
75 Suffix(Bytes),
76 Matches(fancy_regex::Regex),
77 In(Vec<CompiledValue>),
78 NotIn(Vec<CompiledValue>),
79 Gt(i64),
80 Gte(i64),
81 Lt(i64),
82 Lte(i64),
83 Cidr(IpNet),
84}
85
86impl PartialEq for CompiledOperator {
87 fn eq(&self, other: &Self) -> bool {
88 match (self, other) {
89 (Self::Equals(a), Self::Equals(b)) | (Self::NotEquals(a), Self::NotEquals(b)) => a == b,
90 (Self::Contains(a), Self::Contains(b))
91 | (Self::NotContains(a), Self::NotContains(b))
92 | (Self::Prefix(a), Self::Prefix(b))
93 | (Self::Suffix(a), Self::Suffix(b)) => a == b,
94 (Self::Matches(a), Self::Matches(b)) => a.as_str() == b.as_str(),
95 (Self::In(a), Self::In(b)) | (Self::NotIn(a), Self::NotIn(b)) => a == b,
96 (Self::Gt(a), Self::Gt(b))
97 | (Self::Gte(a), Self::Gte(b))
98 | (Self::Lt(a), Self::Lt(b))
99 | (Self::Lte(a), Self::Lte(b)) => a == b,
100 (Self::Cidr(a), Self::Cidr(b)) => a == b,
101 _ => false,
102 }
103 }
104}
105
106impl Eq for CompiledOperator {}
107
108impl Hash for CompiledOperator {
109 fn hash<H: Hasher>(&self, state: &mut H) {
110 std::mem::discriminant(self).hash(state);
111 match self {
112 Self::Equals(v) | Self::NotEquals(v) => v.hash(state),
113 Self::Contains(b) | Self::NotContains(b) | Self::Prefix(b) | Self::Suffix(b) => {
114 b.hash(state);
115 }
116 Self::Matches(r) => r.as_str().hash(state),
117 Self::In(v) | Self::NotIn(v) => v.hash(state),
118 Self::Gt(i) | Self::Gte(i) | Self::Lt(i) | Self::Lte(i) => i.hash(state),
119 Self::Cidr(n) => n.hash(state),
120 }
121 }
122}
123
124#[derive(Clone, Debug, Eq, PartialEq, Hash, serde::Serialize, serde::Deserialize)]
125pub struct PredicateInst {
126 pub path: FieldPath,
127 pub op: CompiledOperator,
128}
129
130pub enum PredicateView<'a> {
131 L4 { conn: &'a Arc<ConnContext>, peek: Option<&'a [u8]> },
132 L7Req { conn: &'a Arc<ConnContext>, req: &'a Request },
133}
134
135impl<'a> PredicateView<'a> {
136 #[must_use]
145 pub fn build(
146 conn: &'a Arc<ConnContext>,
147 req: Option<&'a Request>,
148 _l4: Option<&'a crate::l4::L4Conn>,
149 ) -> Self {
150 match req {
151 Some(r) => Self::L7Req { conn, req: r },
152 None => Self::L4 { conn, peek: None },
153 }
154 }
155}
156
157impl PredicateInst {
158 #[must_use]
165 pub fn test(&self, view: &PredicateView<'_>) -> bool {
166 match (&self.path, &self.op, view) {
167 (
168 FieldPath::RemoteIp,
169 CompiledOperator::Equals(CompiledValue::Addr(expected)),
170 PredicateView::L4 { conn, .. } | PredicateView::L7Req { conn, .. },
171 ) => conn.remote.ip() == *expected,
172
173 (
174 FieldPath::HttpMethod,
175 CompiledOperator::Equals(CompiledValue::Str(expected)),
176 PredicateView::L7Req { req, .. },
177 ) => req.method().as_str() == expected.as_ref(),
178
179 _ => false,
183 }
184 }
185}
186
187pub const REGEX_PATTERN_MAX_BYTES: usize = 4 * 1024;
188
189#[derive(Debug, Clone, serde::Serialize)]
190pub enum Predicate {
191 AnyOf(AnyOfP),
192 Not(NotP),
193 Check(CheckMap),
194}
195
196#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
197#[serde(deny_unknown_fields)]
198pub struct AnyOfP {
199 pub any_of: Vec<Predicate>,
200}
201
202#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
203#[serde(deny_unknown_fields)]
204pub struct NotP {
205 pub not: Box<Predicate>,
206}
207
208#[derive(Debug, Clone, serde::Serialize)]
209pub struct CheckMap {
210 pub path: FieldPath,
211 pub op: Operator,
212}
213
214#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
215#[serde(rename_all = "snake_case")]
216pub enum Operator {
217 Equals(Value),
218 NotEquals(Value),
219 Contains(Value),
220 NotContains(Value),
221 Prefix(Value),
222 Suffix(Value),
223 Matches(String),
224 In(Vec<Value>),
225 NotIn(Vec<Value>),
226 Gt(i64),
227 Gte(i64),
228 Lt(i64),
229 Lte(i64),
230 Cidr(String),
231}
232
233#[derive(Debug, Clone, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
234#[serde(untagged)]
235pub enum Value {
236 Bool(bool),
237 Int(i64),
238 Str(String),
239}
240
241impl<'de> serde::Deserialize<'de> for Predicate {
242 fn deserialize<D: serde::Deserializer<'de>>(de: D) -> Result<Self, D::Error> {
243 let v = serde_json::Value::deserialize(de)?;
244 let serde_json::Value::Object(ref map) = v else {
245 return Err(serde::de::Error::custom("predicate must be a JSON object"));
246 };
247 if map.len() == 1 {
248 let (key, _) = map.iter().next().expect("len == 1");
249 match key.as_str() {
250 "any_of" => {
251 return serde_json::from_value::<AnyOfP>(v)
252 .map(Predicate::AnyOf)
253 .map_err(serde::de::Error::custom);
254 }
255 "not" => {
256 return serde_json::from_value::<NotP>(v)
257 .map(Predicate::Not)
258 .map_err(serde::de::Error::custom);
259 }
260 _ => {}
261 }
262 }
263 serde_json::from_value::<CheckMap>(v).map(Predicate::Check).map_err(serde::de::Error::custom)
264 }
265}
266
267impl<'de> serde::Deserialize<'de> for CheckMap {
268 fn deserialize<D: serde::Deserializer<'de>>(de: D) -> Result<Self, D::Error> {
269 struct Visitor;
270
271 impl<'de> serde::de::Visitor<'de> for Visitor {
272 type Value = CheckMap;
273
274 fn expecting(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
275 f.write_str("a single-key object of the form {\"<field-path>\": {\"<operator>\": <value>}}")
276 }
277
278 fn visit_map<M: serde::de::MapAccess<'de>>(self, mut map: M) -> Result<CheckMap, M::Error> {
279 let Some(key) = map.next_key::<String>()? else {
280 return Err(serde::de::Error::invalid_length(0, &"exactly one key"));
281 };
282 let path = parse_field_path(&key).map_err(serde::de::Error::custom)?;
283 let op: Operator = map.next_value()?;
284 if map.next_key::<serde::de::IgnoredAny>()?.is_some() {
285 return Err(serde::de::Error::custom("check object must have exactly one key"));
286 }
287 validate_operator(&op).map_err(serde::de::Error::custom)?;
288 Ok(CheckMap { path, op })
289 }
290 }
291
292 de.deserialize_map(Visitor)
293 }
294}
295
296fn parse_field_path(s: &str) -> Result<FieldPath, String> {
297 if s.chars().any(|c| c.is_ascii_uppercase()) {
298 return Err(format!(
299 "field path must be lowercase: {:?} — did you mean {:?}?",
300 s,
301 s.to_ascii_lowercase(),
302 ));
303 }
304 match s {
305 "transport" => Ok(FieldPath::Transport),
306 "remote.ip" => Ok(FieldPath::RemoteIp),
307 "remote.port" => Ok(FieldPath::RemotePort),
308 "local.ip" => Ok(FieldPath::LocalIp),
309 "local.port" => Ok(FieldPath::LocalPort),
310 "peek" => Ok(FieldPath::Peek),
311 "tls.sni" => Ok(FieldPath::TlsSni),
312 "tls.alpn" => Ok(FieldPath::TlsAlpn),
313 "tls.version" => Ok(FieldPath::TlsVersion),
314 "tls.peer_cert.subject_cn" => Ok(FieldPath::TlsPeerCertSubjectCn),
315 "http.method" => Ok(FieldPath::HttpMethod),
316 "http.uri.path" => Ok(FieldPath::HttpUriPath),
317 "http.uri.query" => Ok(FieldPath::HttpUriQuery),
318 "http.body" => Ok(FieldPath::HttpBody),
319 other if other.starts_with("http.header.") => {
320 let name = &other["http.header.".len()..];
321 if name.is_empty() {
322 return Err(format!("http.header.* requires a header name: {other:?}"));
323 }
324 Ok(FieldPath::HttpHeader(Arc::from(name)))
325 }
326 other => Err(format!("unknown field path: {other:?}")),
327 }
328}
329
330fn validate_operator(op: &Operator) -> Result<(), String> {
331 if let Operator::Matches(pattern) = op
332 && pattern.len() > REGEX_PATTERN_MAX_BYTES
333 {
334 return Err(format!(
335 "regex pattern source exceeds {REGEX_PATTERN_MAX_BYTES}-byte limit: got {} bytes",
336 pattern.len(),
337 ));
338 }
339 Ok(())
340}
341
342mod serde_impls {
343 use base64::Engine as _;
344 use base64::engine::general_purpose::STANDARD as B64;
345 use bytes::Bytes;
346 use std::net::IpAddr;
347 use std::sync::Arc;
348
349 use super::{CompiledOperator, CompiledValue};
350
351 pub(super) fn ser_bytes<S: serde::Serializer>(b: &Bytes, s: S) -> Result<S::Ok, S::Error> {
352 s.serialize_str(&B64.encode(b))
353 }
354
355 pub(super) fn de_bytes<'de, D: serde::Deserializer<'de>>(d: D) -> Result<Bytes, D::Error> {
356 use serde::Deserialize as _;
357 let s = String::deserialize(d)?;
358 B64.decode(s.as_bytes()).map(Bytes::from).map_err(serde::de::Error::custom)
359 }
360
361 pub(super) fn ser_regex<S: serde::Serializer>(
362 r: &fancy_regex::Regex,
363 s: S,
364 ) -> Result<S::Ok, S::Error> {
365 s.serialize_str(r.as_str())
366 }
367
368 pub(super) fn de_regex<'de, D: serde::Deserializer<'de>>(
369 d: D,
370 ) -> Result<fancy_regex::Regex, D::Error> {
371 use serde::Deserialize as _;
372 let s = String::deserialize(d)?;
373 fancy_regex::Regex::new(&s)
374 .map_err(|e| serde::de::Error::custom(format!("invalid regex {s:?}: {e}")))
375 }
376
377 #[derive(serde::Serialize, serde::Deserialize)]
379 #[serde(rename_all = "snake_case")]
380 pub(super) enum ValueShadow {
381 Str(Arc<str>),
382 #[serde(serialize_with = "ser_bytes", deserialize_with = "de_bytes")]
383 Bytes(Bytes),
384 Int(i64),
385 Bool(bool),
386 Addr(IpAddr),
387 }
388
389 impl From<&CompiledValue> for ValueShadow {
390 fn from(v: &CompiledValue) -> Self {
391 match v {
392 CompiledValue::Str(s) => Self::Str(Arc::clone(s)),
393 CompiledValue::Bytes(b) => Self::Bytes(b.clone()),
394 CompiledValue::Int(i) => Self::Int(*i),
395 CompiledValue::Bool(b) => Self::Bool(*b),
396 CompiledValue::Addr(a) => Self::Addr(*a),
397 }
398 }
399 }
400
401 impl From<ValueShadow> for CompiledValue {
402 fn from(v: ValueShadow) -> Self {
403 match v {
404 ValueShadow::Str(s) => Self::Str(s),
405 ValueShadow::Bytes(b) => Self::Bytes(b),
406 ValueShadow::Int(i) => Self::Int(i),
407 ValueShadow::Bool(b) => Self::Bool(b),
408 ValueShadow::Addr(a) => Self::Addr(a),
409 }
410 }
411 }
412
413 #[derive(serde::Serialize, serde::Deserialize)]
416 #[serde(rename_all = "snake_case")]
417 pub(super) enum OperatorShadow {
418 Equals(CompiledValue),
419 NotEquals(CompiledValue),
420 #[serde(serialize_with = "ser_bytes", deserialize_with = "de_bytes")]
421 Contains(Bytes),
422 #[serde(serialize_with = "ser_bytes", deserialize_with = "de_bytes")]
423 NotContains(Bytes),
424 #[serde(serialize_with = "ser_bytes", deserialize_with = "de_bytes")]
425 Prefix(Bytes),
426 #[serde(serialize_with = "ser_bytes", deserialize_with = "de_bytes")]
427 Suffix(Bytes),
428 #[serde(serialize_with = "ser_regex", deserialize_with = "de_regex")]
429 Matches(fancy_regex::Regex),
430 In(Vec<CompiledValue>),
431 NotIn(Vec<CompiledValue>),
432 Gt(i64),
433 Gte(i64),
434 Lt(i64),
435 Lte(i64),
436 Cidr(ipnet::IpNet),
437 }
438
439 impl From<&CompiledOperator> for OperatorShadow {
440 fn from(op: &CompiledOperator) -> Self {
441 match op {
442 CompiledOperator::Equals(v) => Self::Equals(v.clone()),
443 CompiledOperator::NotEquals(v) => Self::NotEquals(v.clone()),
444 CompiledOperator::Contains(b) => Self::Contains(b.clone()),
445 CompiledOperator::NotContains(b) => Self::NotContains(b.clone()),
446 CompiledOperator::Prefix(b) => Self::Prefix(b.clone()),
447 CompiledOperator::Suffix(b) => Self::Suffix(b.clone()),
448 CompiledOperator::Matches(r) => {
449 Self::Matches(fancy_regex::Regex::new(r.as_str()).expect("round-trippable"))
450 }
451 CompiledOperator::In(vs) => Self::In(vs.clone()),
452 CompiledOperator::NotIn(vs) => Self::NotIn(vs.clone()),
453 CompiledOperator::Gt(i) => Self::Gt(*i),
454 CompiledOperator::Gte(i) => Self::Gte(*i),
455 CompiledOperator::Lt(i) => Self::Lt(*i),
456 CompiledOperator::Lte(i) => Self::Lte(*i),
457 CompiledOperator::Cidr(n) => Self::Cidr(*n),
458 }
459 }
460 }
461
462 impl From<OperatorShadow> for CompiledOperator {
463 fn from(op: OperatorShadow) -> Self {
464 match op {
465 OperatorShadow::Equals(v) => Self::Equals(v),
466 OperatorShadow::NotEquals(v) => Self::NotEquals(v),
467 OperatorShadow::Contains(b) => Self::Contains(b),
468 OperatorShadow::NotContains(b) => Self::NotContains(b),
469 OperatorShadow::Prefix(b) => Self::Prefix(b),
470 OperatorShadow::Suffix(b) => Self::Suffix(b),
471 OperatorShadow::Matches(r) => Self::Matches(r),
472 OperatorShadow::In(vs) => Self::In(vs),
473 OperatorShadow::NotIn(vs) => Self::NotIn(vs),
474 OperatorShadow::Gt(i) => Self::Gt(i),
475 OperatorShadow::Gte(i) => Self::Gte(i),
476 OperatorShadow::Lt(i) => Self::Lt(i),
477 OperatorShadow::Lte(i) => Self::Lte(i),
478 OperatorShadow::Cidr(n) => Self::Cidr(n),
479 }
480 }
481 }
482}
483
484impl serde::Serialize for CompiledValue {
485 fn serialize<S: serde::Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
486 serde_impls::ValueShadow::from(self).serialize(s)
487 }
488}
489
490impl<'de> serde::Deserialize<'de> for CompiledValue {
491 fn deserialize<D: serde::Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
492 serde_impls::ValueShadow::deserialize(d).map(Self::from)
493 }
494}
495
496impl serde::Serialize for CompiledOperator {
497 fn serialize<S: serde::Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
498 serde_impls::OperatorShadow::from(self).serialize(s)
499 }
500}
501
502impl<'de> serde::Deserialize<'de> for CompiledOperator {
503 fn deserialize<D: serde::Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
504 serde_impls::OperatorShadow::deserialize(d).map(Self::from)
505 }
506}
507
508#[cfg(test)]
509mod tests {
510 use std::collections::hash_map::DefaultHasher;
511 use std::hash::Hash;
512 use std::net::{Ipv4Addr, Ipv6Addr};
513 use std::str::FromStr;
514 use std::sync::OnceLock;
515 use std::time::Instant;
516
517 use bytes::Bytes;
518 use fancy_regex::Regex;
519 use ipnet::IpNet;
520 use parking_lot::Mutex;
521
522 use super::*;
523 use crate::body::{Body, Request};
524 use crate::conn_context::{ConnId, Transport};
525
526 fn hash_of<T: Hash>(v: &T) -> u64 {
530 let mut h = DefaultHasher::new();
531 v.hash(&mut h);
532 h.finish()
533 }
534
535 fn make_conn() -> Arc<ConnContext> {
536 Arc::new(ConnContext {
537 id: ConnId(1),
538 remote: "127.0.0.1:0".parse().expect("parse remote"),
539 local: "127.0.0.1:0".parse().expect("parse local"),
540 transport: Transport::Tcp,
541 entered_at: Instant::now(),
542 tls: Mutex::new(None),
543 http_version: OnceLock::new(),
544 user: Mutex::new(http::Extensions::new()),
545 })
546 }
547
548 #[test]
549 fn field_path_http_header_is_equal_by_string_content_not_arc_identity() {
550 let a = FieldPath::HttpHeader(Arc::from("host"));
551 let b = FieldPath::HttpHeader(Arc::from("host"));
552 assert_eq!(a, b);
553 assert_eq!(hash_of(&a), hash_of(&b));
554 let upper = FieldPath::HttpHeader(Arc::from("Host"));
559 assert_ne!(a, upper);
560 }
561
562 #[test]
563 fn field_path_simple_variants_are_self_equal_and_mutually_distinct() {
564 let paths = [
565 FieldPath::Transport,
566 FieldPath::RemoteIp,
567 FieldPath::RemotePort,
568 FieldPath::LocalIp,
569 FieldPath::LocalPort,
570 FieldPath::Peek,
571 FieldPath::TlsSni,
572 FieldPath::TlsAlpn,
573 FieldPath::TlsVersion,
574 FieldPath::TlsPeerCertSubjectCn,
575 FieldPath::HttpMethod,
576 FieldPath::HttpUriPath,
577 FieldPath::HttpUriQuery,
578 FieldPath::HttpBody,
579 ];
580 for (i, a) in paths.iter().enumerate() {
581 for (j, b) in paths.iter().enumerate() {
582 if i == j {
583 assert_eq!(a, b);
584 } else {
585 assert_ne!(a, b);
586 }
587 }
588 }
589 }
590
591 #[test]
592 fn compiled_value_str_is_equal_by_content_not_arc_identity() {
593 let a = CompiledValue::Str(Arc::<str>::from("x"));
594 let b = CompiledValue::Str(Arc::<str>::from("x"));
595 assert_eq!(a, b);
596 assert_eq!(hash_of(&a), hash_of(&b));
597 let c = CompiledValue::Str(Arc::<str>::from("y"));
598 assert_ne!(a, c);
599 }
600
601 #[test]
602 fn compiled_value_cross_variant_inequality() {
603 let s = CompiledValue::Str(Arc::<str>::from("42"));
604 let i = CompiledValue::Int(42);
605 assert_ne!(s, i);
606 }
607
608 #[test]
609 fn compiled_value_bytes_int_bool_addr_self_equal() {
610 assert_eq!(
611 CompiledValue::Bytes(Bytes::from_static(b"abc")),
612 CompiledValue::Bytes(Bytes::copy_from_slice(b"abc")),
613 );
614 assert_eq!(CompiledValue::Int(7), CompiledValue::Int(7));
615 assert_ne!(CompiledValue::Int(7), CompiledValue::Int(8));
616 assert_eq!(CompiledValue::Bool(true), CompiledValue::Bool(true));
617 assert_ne!(CompiledValue::Bool(true), CompiledValue::Bool(false));
618 assert_eq!(
619 CompiledValue::Addr(Ipv4Addr::new(10, 0, 0, 1).into()),
620 CompiledValue::Addr(Ipv4Addr::new(10, 0, 0, 1).into()),
621 );
622 assert_ne!(
623 CompiledValue::Addr(Ipv4Addr::new(10, 0, 0, 1).into()),
624 CompiledValue::Addr(Ipv6Addr::LOCALHOST.into()),
625 );
626 }
627
628 #[test]
629 fn compiled_operator_matches_equal_by_pattern_source() {
630 let a = CompiledOperator::Matches(Regex::new("^/api").expect("compile a"));
631 let b = CompiledOperator::Matches(Regex::new("^/api").expect("compile b"));
632 assert_eq!(a, b);
633 assert_eq!(hash_of(&a), hash_of(&b));
634 }
635
636 #[test]
637 fn compiled_operator_matches_distinct_patterns_unequal() {
638 let a = CompiledOperator::Matches(Regex::new("a|b").expect("compile a"));
641 let b = CompiledOperator::Matches(Regex::new("b|a").expect("compile b"));
642 assert_ne!(a, b);
643 }
644
645 #[test]
646 fn compiled_operator_cidr_equal_by_canonical_form() {
647 let a = CompiledOperator::Cidr(IpNet::from_str("10.0.0.0/8").expect("parse a"));
648 let b = CompiledOperator::Cidr(IpNet::from_str("10.0.0.0/8").expect("parse b"));
649 assert_eq!(a, b);
650 assert_eq!(hash_of(&a), hash_of(&b));
651 }
652
653 #[test]
654 fn compiled_operator_cidr_distinct_networks_unequal() {
655 let a = CompiledOperator::Cidr(IpNet::from_str("10.0.0.0/8").expect("parse a"));
656 let b = CompiledOperator::Cidr(IpNet::from_str("10.0.0.0/16").expect("parse b"));
657 assert_ne!(a, b);
658 }
659
660 #[test]
661 fn compiled_operator_in_is_order_sensitive() {
662 let xs =
663 vec![CompiledValue::Str(Arc::<str>::from("a")), CompiledValue::Str(Arc::<str>::from("b"))];
664 let ys =
665 vec![CompiledValue::Str(Arc::<str>::from("b")), CompiledValue::Str(Arc::<str>::from("a"))];
666 assert_ne!(CompiledOperator::In(xs.clone()), CompiledOperator::In(ys.clone()));
667 assert_ne!(CompiledOperator::NotIn(xs), CompiledOperator::NotIn(ys));
668 }
669
670 #[test]
671 fn compiled_operator_numeric_comparisons_distinct_per_variant() {
672 let ops = [
674 CompiledOperator::Gt(10),
675 CompiledOperator::Gte(10),
676 CompiledOperator::Lt(10),
677 CompiledOperator::Lte(10),
678 ];
679 for (i, a) in ops.iter().enumerate() {
680 for (j, b) in ops.iter().enumerate() {
681 if i == j {
682 assert_eq!(a, b);
683 } else {
684 assert_ne!(a, b);
685 }
686 }
687 }
688 }
689
690 #[test]
691 fn compiled_operator_bytes_variants_distinguished() {
692 let payload = Bytes::from_static(b"abc");
693 let ops = [
694 CompiledOperator::Contains(payload.clone()),
695 CompiledOperator::NotContains(payload.clone()),
696 CompiledOperator::Prefix(payload.clone()),
697 CompiledOperator::Suffix(payload),
698 ];
699 for (i, a) in ops.iter().enumerate() {
700 for (j, b) in ops.iter().enumerate() {
701 if i == j {
702 assert_eq!(a, b);
703 } else {
704 assert_ne!(a, b);
705 }
706 }
707 }
708 }
709
710 #[test]
711 fn predicate_inst_equal_across_independent_construction_paths() {
712 let lhs = PredicateInst {
713 path: FieldPath::HttpHeader(Arc::from("host")),
714 op: CompiledOperator::Equals(CompiledValue::Str(Arc::<str>::from("example.com"))),
715 };
716 let rhs = PredicateInst {
717 path: FieldPath::HttpHeader(Arc::from("host")),
718 op: CompiledOperator::Equals(CompiledValue::Str(Arc::<str>::from("example.com"))),
719 };
720 assert_eq!(lhs, rhs);
721 assert_eq!(hash_of(&lhs), hash_of(&rhs));
722 }
723
724 #[test]
725 fn predicate_inst_equal_with_regex_operator_from_separate_compiles() {
726 let lhs = PredicateInst {
727 path: FieldPath::HttpUriPath,
728 op: CompiledOperator::Matches(Regex::new("^/").expect("compile a")),
729 };
730 let rhs = PredicateInst {
731 path: FieldPath::HttpUriPath,
732 op: CompiledOperator::Matches(Regex::new("^/").expect("compile b")),
733 };
734 assert_eq!(lhs, rhs);
735 assert_eq!(hash_of(&lhs), hash_of(&rhs));
736 }
737
738 #[test]
739 fn predicate_inst_unequal_on_path_difference() {
740 let value = CompiledValue::Str(Arc::<str>::from("x"));
741 let a =
742 PredicateInst { path: FieldPath::HttpUriPath, op: CompiledOperator::Equals(value.clone()) };
743 let b = PredicateInst { path: FieldPath::HttpUriQuery, op: CompiledOperator::Equals(value) };
744 assert_ne!(a, b);
745 }
746
747 #[test]
748 fn predicate_view_variants_construct() {
749 let conn = make_conn();
750 let peek_bytes: &[u8] = b"\x16\x03\x01";
751 let l4 = PredicateView::L4 { conn: &conn, peek: Some(peek_bytes) };
752 match l4 {
753 PredicateView::L4 { peek, .. } => assert_eq!(peek.map(<[u8]>::len), Some(3)),
754 PredicateView::L7Req { .. } => panic!("wrong variant"),
755 }
756
757 let conn2 = make_conn();
758 let req: Request =
759 http::Request::builder().method("GET").uri("/").body(Body::Empty).expect("build request");
760 let l7 = PredicateView::L7Req { conn: &conn2, req: &req };
761 match l7 {
762 PredicateView::L7Req { .. } => {}
763 PredicateView::L4 { .. } => panic!("wrong variant"),
764 }
765 }
766
767 fn parse_predicate(v: serde_json::Value) -> Result<Predicate, serde_json::Error> {
771 serde_json::from_value(v)
772 }
773
774 fn expect_check(p: &Predicate) -> &CheckMap {
775 match p {
776 Predicate::Check(c) => c,
777 other => panic!("expected Predicate::Check, got {other:?}"),
778 }
779 }
780
781 #[test]
782 fn parse_any_of_happy_path() {
783 let raw = serde_json::json!({
784 "any_of": [
785 { "tls.sni": { "equals": "a" } },
786 { "tls.sni": { "equals": "b" } },
787 ],
788 });
789 let p = parse_predicate(raw).expect("parse any_of");
790 let Predicate::AnyOf(AnyOfP { any_of }) = p else {
791 panic!("expected AnyOf");
792 };
793 assert_eq!(any_of.len(), 2);
794 let c0 = expect_check(&any_of[0]);
795 let c1 = expect_check(&any_of[1]);
796 assert_eq!(c0.path, FieldPath::TlsSni);
797 assert_eq!(c1.path, FieldPath::TlsSni);
798 match (&c0.op, &c1.op) {
799 (Operator::Equals(Value::Str(a)), Operator::Equals(Value::Str(b))) => {
800 assert_eq!(a, "a");
801 assert_eq!(b, "b");
802 }
803 (a, b) => panic!("unexpected ops: {a:?} / {b:?}"),
804 }
805 }
806
807 #[test]
808 fn parse_not_happy_path() {
809 let raw = serde_json::json!({
810 "not": { "tls.sni": { "equals": "internal" } },
811 });
812 let p = parse_predicate(raw).expect("parse not");
813 let Predicate::Not(NotP { not }) = p else {
814 panic!("expected Not");
815 };
816 let inner = expect_check(¬);
817 assert_eq!(inner.path, FieldPath::TlsSni);
818 match &inner.op {
819 Operator::Equals(Value::Str(s)) => assert_eq!(s, "internal"),
820 other => panic!("unexpected op: {other:?}"),
821 }
822 }
823
824 #[test]
825 fn parse_check_across_representative_paths() {
826 let cases = [
827 (serde_json::json!({ "tls.sni": { "equals": "api.example.com" } }), FieldPath::TlsSni),
828 (serde_json::json!({ "remote.port": { "gt": 1024 } }), FieldPath::RemotePort),
829 (serde_json::json!({ "http.method": { "equals": "GET" } }), FieldPath::HttpMethod),
830 (serde_json::json!({ "http.uri.path": { "prefix": "/api" } }), FieldPath::HttpUriPath),
831 (
832 serde_json::json!({ "http.header.host": { "equals": "a.example.com" } }),
833 FieldPath::HttpHeader(Arc::from("host")),
834 ),
835 (serde_json::json!({ "http.body": { "contains": "hello" } }), FieldPath::HttpBody),
836 ];
837 for (raw, expected_path) in cases {
838 let p = parse_predicate(raw.clone()).unwrap_or_else(|e| panic!("parse {raw}: {e}"));
839 let c = expect_check(&p);
840 assert_eq!(c.path, expected_path, "input: {raw}");
841 }
842 }
843
844 #[test]
845 fn parse_any_of_with_extra_key_is_rejected() {
846 let raw = serde_json::json!({
849 "any_of": [ { "tls.sni": { "equals": "a" } } ],
850 "extra": true,
851 });
852 let err = parse_predicate(raw).expect_err("must reject extra key on any_of");
853 let _ = err.to_string();
854 }
855
856 #[test]
857 fn parse_http_header_any_of_is_a_check_not_combinator() {
858 let raw = serde_json::json!({ "http.header.any_of": { "equals": "x" } });
861 let p = parse_predicate(raw).expect("parse");
862 let c = expect_check(&p);
863 assert_eq!(c.path, FieldPath::HttpHeader(Arc::from("any_of")));
864 }
865
866 #[test]
867 fn parse_uppercase_field_path_suggests_lowercase() {
868 let raw = serde_json::json!({ "http.header.Host": { "equals": "x" } });
869 let err = parse_predicate(raw).expect_err("uppercase must fail");
870 let msg = err.to_string();
871 assert!(msg.contains("http.header.Host"), "error mentions offending input: {msg}");
872 assert!(msg.contains("did you mean"), "error includes suggestion phrase: {msg}");
873 assert!(msg.contains("http.header.host"), "error contains lowercased form: {msg}");
874 }
875
876 #[test]
877 fn parse_multi_key_check_is_rejected() {
878 let raw = serde_json::json!({
879 "http.uri.path": { "matches": "^/" },
880 "http.method": { "equals": "GET" },
881 });
882 let err = parse_predicate(raw).expect_err("multi-key check must fail");
883 let _ = err.to_string();
884 }
885
886 #[test]
887 fn parse_empty_http_header_name_is_rejected() {
888 let raw = serde_json::json!({ "http.header.": { "equals": "x" } });
889 let err = parse_predicate(raw).expect_err("empty header name must fail");
890 let _ = err.to_string();
891 }
892
893 #[test]
894 fn parse_unknown_field_path_is_rejected_with_name() {
895 let raw = serde_json::json!({ "http.nope": { "equals": "x" } });
896 let err = parse_predicate(raw).expect_err("unknown path must fail");
897 let msg = err.to_string();
898 assert!(msg.contains("http.nope"), "error mentions offending path: {msg}");
899 }
900
901 fn parse_op(v: serde_json::Value) -> Operator {
902 let mut map = serde_json::Map::new();
903 map.insert("tls.sni".to_string(), v);
904 let raw = serde_json::Value::Object(map);
905 match parse_predicate(raw).expect("parse check") {
906 Predicate::Check(c) => c.op,
907 other => panic!("expected Check, got {other:?}"),
908 }
909 }
910
911 #[test]
912 fn operator_equals_and_not_equals_on_string() {
913 let eq = parse_op(serde_json::json!({ "equals": "api" }));
914 match eq {
915 Operator::Equals(Value::Str(s)) => assert_eq!(s, "api"),
916 other => panic!("expected equals/str: {other:?}"),
917 }
918 let neq = parse_op(serde_json::json!({ "not_equals": "api" }));
919 match neq {
920 Operator::NotEquals(Value::Str(s)) => assert_eq!(s, "api"),
921 other => panic!("expected not_equals/str: {other:?}"),
922 }
923 }
924
925 #[test]
926 fn operator_contains_and_not_contains_on_string() {
927 let c = parse_op(serde_json::json!({ "contains": "foo" }));
928 match c {
929 Operator::Contains(Value::Str(s)) => assert_eq!(s, "foo"),
930 other => panic!("expected contains/str: {other:?}"),
931 }
932 let nc = parse_op(serde_json::json!({ "not_contains": "foo" }));
933 match nc {
934 Operator::NotContains(Value::Str(s)) => assert_eq!(s, "foo"),
935 other => panic!("expected not_contains/str: {other:?}"),
936 }
937 }
938
939 #[test]
940 fn operator_prefix_and_suffix_on_string() {
941 let p = parse_op(serde_json::json!({ "prefix": "/api" }));
942 match p {
943 Operator::Prefix(Value::Str(s)) => assert_eq!(s, "/api"),
944 other => panic!("expected prefix/str: {other:?}"),
945 }
946 let s = parse_op(serde_json::json!({ "suffix": ".json" }));
947 match s {
948 Operator::Suffix(Value::Str(v)) => assert_eq!(v, ".json"),
949 other => panic!("expected suffix/str: {other:?}"),
950 }
951 }
952
953 #[test]
954 fn operator_matches_carries_pattern_source() {
955 let op = parse_op(serde_json::json!({ "matches": "^/api/v\\d+" }));
956 match op {
957 Operator::Matches(pattern) => assert_eq!(pattern, "^/api/v\\d+"),
958 other => panic!("expected matches: {other:?}"),
959 }
960 }
961
962 #[test]
963 fn operator_in_and_not_in_accept_mixed_scalar_types() {
964 let op = parse_op(serde_json::json!({ "in": ["foo", 42] }));
965 let Operator::In(xs) = op else {
966 panic!("expected in");
967 };
968 assert_eq!(xs.len(), 2);
969 assert_eq!(xs[0], Value::Str("foo".into()));
970 assert_eq!(xs[1], Value::Int(42));
971 let op2 = parse_op(serde_json::json!({ "not_in": ["bar", 7] }));
972 let Operator::NotIn(ys) = op2 else {
973 panic!("expected not_in");
974 };
975 assert_eq!(ys.len(), 2);
976 assert_eq!(ys[0], Value::Str("bar".into()));
977 assert_eq!(ys[1], Value::Int(7));
978 }
979
980 #[test]
981 fn operator_numeric_comparisons() {
982 assert!(matches!(parse_op(serde_json::json!({ "gt": 10 })), Operator::Gt(10)));
983 assert!(matches!(parse_op(serde_json::json!({ "gte": 10 })), Operator::Gte(10)));
984 assert!(matches!(parse_op(serde_json::json!({ "lt": 10 })), Operator::Lt(10)));
985 assert!(matches!(parse_op(serde_json::json!({ "lte": 10 })), Operator::Lte(10)));
986 }
987
988 #[test]
989 fn operator_cidr_carries_source_string() {
990 let op = parse_op(serde_json::json!({ "cidr": "10.0.0.0/8" }));
991 match op {
992 Operator::Cidr(s) => assert_eq!(s, "10.0.0.0/8"),
993 other => panic!("expected cidr: {other:?}"),
994 }
995 }
996
997 #[test]
998 fn value_untagged_priority_bool_before_str() {
999 let op_t = parse_op(serde_json::json!({ "equals": true }));
1002 assert!(matches!(op_t, Operator::Equals(Value::Bool(true))));
1003 let op_f = parse_op(serde_json::json!({ "equals": false }));
1004 assert!(matches!(op_f, Operator::Equals(Value::Bool(false))));
1005 }
1006
1007 #[test]
1008 fn value_untagged_priority_int_before_str() {
1009 let op = parse_op(serde_json::json!({ "equals": 42 }));
1011 assert!(matches!(op, Operator::Equals(Value::Int(42))));
1012 }
1013
1014 #[test]
1015 fn value_untagged_json_string_stays_str() {
1016 let op = parse_op(serde_json::json!({ "equals": "42" }));
1019 match op {
1020 Operator::Equals(Value::Str(s)) => assert_eq!(s, "42"),
1021 other => panic!("expected equals/str(\"42\"): {other:?}"),
1022 }
1023 }
1024
1025 #[test]
1026 fn regex_pattern_exactly_at_limit_parses() {
1027 assert_eq!(REGEX_PATTERN_MAX_BYTES, 4 * 1024);
1029 let pattern = "a".repeat(REGEX_PATTERN_MAX_BYTES);
1030 let raw = serde_json::json!({ "http.uri.path": { "matches": pattern } });
1031 let p = parse_predicate(raw).expect("4 KiB pattern parses");
1032 let c = expect_check(&p);
1033 match &c.op {
1034 Operator::Matches(src) => assert_eq!(src.len(), REGEX_PATTERN_MAX_BYTES),
1035 other => panic!("expected matches: {other:?}"),
1036 }
1037 }
1038
1039 #[test]
1040 fn regex_pattern_over_limit_rejected_with_limit_in_message() {
1041 let pattern = "a".repeat(REGEX_PATTERN_MAX_BYTES + 1);
1042 let raw = serde_json::json!({ "http.uri.path": { "matches": pattern } });
1043 let err = parse_predicate(raw).expect_err("over-limit pattern must fail");
1044 let msg = err.to_string();
1045 assert!(
1046 msg.contains(®EX_PATTERN_MAX_BYTES.to_string()),
1047 "error mentions the limit ({REGEX_PATTERN_MAX_BYTES}): {msg}",
1048 );
1049 }
1050
1051 fn value_round_trip(v: &CompiledValue) -> CompiledValue {
1060 let encoded = serde_json::to_string(v).expect("serialize value");
1061 serde_json::from_str(&encoded).expect("deserialize value")
1062 }
1063
1064 #[test]
1065 fn compiled_value_str_round_trip_including_empty() {
1066 let non_empty = CompiledValue::Str(Arc::<str>::from("x"));
1067 assert_eq!(value_round_trip(&non_empty), non_empty);
1068 let empty = CompiledValue::Str(Arc::<str>::from(""));
1069 assert_eq!(value_round_trip(&empty), empty);
1070 }
1071
1072 #[test]
1073 fn compiled_value_bytes_round_trip_including_empty_and_binary() {
1074 let hello = CompiledValue::Bytes(Bytes::from_static(b"hello"));
1075 assert_eq!(value_round_trip(&hello), hello);
1076 let empty = CompiledValue::Bytes(Bytes::new());
1077 assert_eq!(value_round_trip(&empty), empty);
1078 let binary = CompiledValue::Bytes(Bytes::from_static(&[0xff, 0x00, 0x13]));
1079 assert_eq!(value_round_trip(&binary), binary);
1080 }
1081
1082 #[test]
1083 fn compiled_value_int_round_trip_including_extremes() {
1084 for i in [0_i64, i64::MIN, i64::MAX] {
1085 let v = CompiledValue::Int(i);
1086 assert_eq!(value_round_trip(&v), v);
1087 }
1088 }
1089
1090 #[test]
1091 fn compiled_value_bool_round_trip_both_variants() {
1092 for b in [true, false] {
1093 let v = CompiledValue::Bool(b);
1094 assert_eq!(value_round_trip(&v), v);
1095 }
1096 }
1097
1098 #[test]
1099 fn compiled_value_addr_round_trip_v4_and_v6() {
1100 let v4 = CompiledValue::Addr(Ipv4Addr::LOCALHOST.into());
1101 assert_eq!(value_round_trip(&v4), v4);
1102 let v6 = CompiledValue::Addr(Ipv6Addr::LOCALHOST.into());
1103 assert_eq!(value_round_trip(&v6), v6);
1104 }
1105
1106 #[test]
1107 fn compiled_value_bytes_emits_standard_base64_literal() {
1108 let v = CompiledValue::Bytes(Bytes::from_static(b"hello"));
1112 let encoded = serde_json::to_string(&v).expect("serialize");
1113 assert_eq!(encoded, r#"{"bytes":"aGVsbG8="}"#);
1114 }
1115
1116 fn op_round_trip(op: &CompiledOperator) -> CompiledOperator {
1117 let encoded = serde_json::to_string(op).expect("serialize op");
1118 serde_json::from_str(&encoded).expect("deserialize op")
1119 }
1120
1121 #[test]
1122 fn compiled_operator_equals_and_not_equals_round_trip() {
1123 let eq = CompiledOperator::Equals(CompiledValue::Str(Arc::<str>::from("x")));
1124 assert_eq!(op_round_trip(&eq), eq);
1125 let neq = CompiledOperator::NotEquals(CompiledValue::Str(Arc::<str>::from("x")));
1126 assert_eq!(op_round_trip(&neq), neq);
1127 }
1128
1129 #[test]
1130 fn compiled_operator_bytes_variants_round_trip() {
1131 let payload = Bytes::from_static(b"hello");
1132 let ops = [
1133 CompiledOperator::Contains(payload.clone()),
1134 CompiledOperator::NotContains(payload.clone()),
1135 CompiledOperator::Prefix(payload.clone()),
1136 CompiledOperator::Suffix(payload),
1137 ];
1138 for op in ops {
1139 assert_eq!(op_round_trip(&op), op);
1140 }
1141 }
1142
1143 #[test]
1144 fn compiled_operator_matches_round_trip_preserves_pattern_source() {
1145 let op = CompiledOperator::Matches(Regex::new("^/api/v[0-9]+").expect("compile"));
1146 let decoded = op_round_trip(&op);
1147 assert_eq!(decoded, op);
1149 match decoded {
1150 CompiledOperator::Matches(r) => assert_eq!(r.as_str(), "^/api/v[0-9]+"),
1151 other => panic!("expected matches, got {other:?}"),
1152 }
1153 }
1154
1155 #[test]
1156 fn compiled_operator_in_and_not_in_round_trip_mixed_values() {
1157 let xs = vec![CompiledValue::Str(Arc::<str>::from("a")), CompiledValue::Int(42)];
1158 let in_op = CompiledOperator::In(xs.clone());
1159 assert_eq!(op_round_trip(&in_op), in_op);
1160 let not_in_op = CompiledOperator::NotIn(xs);
1161 assert_eq!(op_round_trip(¬_in_op), not_in_op);
1162 }
1163
1164 #[test]
1165 fn compiled_operator_numeric_comparisons_round_trip() {
1166 let ops = [
1167 CompiledOperator::Gt(100),
1168 CompiledOperator::Gte(100),
1169 CompiledOperator::Lt(100),
1170 CompiledOperator::Lte(100),
1171 ];
1172 for op in ops {
1173 assert_eq!(op_round_trip(&op), op);
1174 }
1175 }
1176
1177 #[test]
1178 fn compiled_operator_cidr_round_trip_preserves_canonical_form() {
1179 let op = CompiledOperator::Cidr(IpNet::from_str("10.0.0.0/8").expect("parse"));
1180 assert_eq!(op_round_trip(&op), op);
1181 }
1182
1183 #[test]
1184 fn compiled_operator_matches_with_invalid_regex_is_rejected() {
1185 let raw = r#"{"matches":"["}"#;
1189 let err = serde_json::from_str::<CompiledOperator>(raw)
1190 .expect_err("invalid regex must fail to deserialize");
1191 let msg = err.to_string();
1192 assert!(msg.contains('['), "error mentions offending regex source: {msg}");
1193 }
1194
1195 #[test]
1196 fn predicate_inst_pins_exact_wire_shape_for_http_header_equals() {
1197 let inst = PredicateInst {
1198 path: FieldPath::HttpHeader(Arc::from("host")),
1199 op: CompiledOperator::Equals(CompiledValue::Str(Arc::<str>::from("example.com"))),
1200 };
1201 let encoded = serde_json::to_string(&inst).expect("serialize");
1202 assert_eq!(encoded, r#"{"path":{"http_header":"host"},"op":{"equals":{"str":"example.com"}}}"#,);
1203 let decoded: PredicateInst = serde_json::from_str(&encoded).expect("deserialize");
1204 assert_eq!(decoded, inst);
1205 }
1206
1207 #[test]
1208 fn predicate_inst_round_trip_with_regex_operator() {
1209 let inst = PredicateInst {
1210 path: FieldPath::HttpUriPath,
1211 op: CompiledOperator::Matches(Regex::new("^/api").expect("compile")),
1212 };
1213 let encoded = serde_json::to_string(&inst).expect("serialize");
1214 let decoded: PredicateInst = serde_json::from_str(&encoded).expect("deserialize");
1215 assert_eq!(decoded, inst);
1216 }
1217}