1use std::collections::{BTreeMap, HashMap};
2use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
3use std::path::PathBuf;
4use std::str::FromStr;
5use std::sync::Arc;
6use std::time::SystemTime;
7
8use base64::Engine as _;
9use base64::engine::general_purpose::STANDARD as B64;
10use sha2::{Digest, Sha256};
11
12use crate::compile::analyze::{AnalyzedRule, AnalyzedRuleSet, Posture};
13use crate::conn_context::Transport;
14use crate::error::Error;
15use crate::fetch::{FetchKind, FetchPhase, SymbolicFetchRef, Terminator};
16use crate::ir::{
17 BodySide, FetchId, FlowGraphMeta, ListenerKind, MiddlewareId, Node, NodeId, PredicateId,
18 SymbolicFlowGraph, TerminatorId,
19};
20use crate::metadata::{FetchMetadataProvider, MiddlewareMetadataProvider};
21use crate::middleware::{MiddlewareKind, SymbolicMiddlewareRef};
22use crate::predicate::{
23 CompiledOperator, CompiledValue, FieldPath, FieldValueType, Operator, Predicate, PredicateInst,
24 Value,
25};
26use crate::rule::SourceInfo;
27
28pub fn lower(
37 set: AnalyzedRuleSet,
38 mw_meta: &dyn MiddlewareMetadataProvider,
39 fetch_meta: &dyn FetchMetadataProvider,
40) -> Result<SymbolicFlowGraph, Error> {
41 let version_hash = hash_rules(&set.rules);
42 let mut builder = Builder::new();
43
44 let groups = group_by_listener(&set.rules)?;
45 for (transport, addrs, rules) in groups {
46 let resolved_tls = resolve_listener_tls(&addrs, &rules)?;
51 validate_zero_rtt_for_listener(&addrs, &rules, resolved_tls.as_ref())?;
57 let entry = builder.lower_port(&rules, mw_meta, fetch_meta)?;
58 for addr in &addrs {
59 builder.entries.insert(*addr, entry);
60 }
61 if let Some(spec) = resolved_tls {
62 for addr in &addrs {
63 builder.listener_tls.insert(*addr, spec.clone());
64 }
65 }
66 let kind = derive_listener_kind(&builder.nodes, &builder.fetches, entry);
67 validate_listener_fetches(&addrs, transport, &builder.nodes, &builder.fetches, entry)?;
73 for addr in addrs {
74 builder.listener_kinds.insert(addr, kind);
75 builder.listener_transports.insert(addr, transport);
76 }
77 }
78
79 warn_missing_plaintext_port_80_for_http01(&builder.listener_tls, &builder.listener_kinds);
87
88 let annotations = inject_acme_http01_routes(&mut builder)?;
93
94 validate_unique_body_reader_per_path(&builder.nodes, &builder.entries)?;
102
103 Ok(SymbolicFlowGraph {
104 nodes: builder.nodes,
105 predicates: builder.predicates,
106 middlewares: builder.middlewares,
107 fetches: builder.fetches,
108 terminators: builder.terminators,
109 entries: builder.entries,
110 meta: FlowGraphMeta {
111 version_hash,
112 compiled_at: SystemTime::now(),
113 source_files: set.source_files,
114 feature_set: &[],
115 short_circuit_response_entry: builder.short_circuit_response_entry,
116 listener_tls: builder.listener_tls,
117 listener_kinds: builder.listener_kinds,
118 listener_transports: builder.listener_transports,
119 annotations,
120 },
121 })
122}
123
124fn derive_listener_kind(
140 nodes: &[Node],
141 fetches: &[SymbolicFetchRef],
142 entry: NodeId,
143) -> ListenerKind {
144 let mut seen_l4 = false;
145 let mut seen_l7 = false;
146 let mut visited = std::collections::HashSet::new();
147 let mut queue = std::collections::VecDeque::from([entry]);
148 while let Some(id) = queue.pop_front() {
149 if !visited.insert(id) {
150 continue;
151 }
152 let Some(node) = nodes.get(id.get() as usize) else { continue };
153 match node {
154 Node::Check { on_match, on_miss, .. } => {
155 queue.push_back(*on_match);
156 queue.push_back(*on_miss);
157 }
158 Node::Middleware { next, on_error, .. } => {
159 queue.push_back(*next);
160 if let Some(e) = on_error {
161 queue.push_back(*e);
162 }
163 }
164 Node::Fetch { id, next_response, next_tunnel, .. } => {
165 match fetches[id.get() as usize].kind.phase() {
166 FetchPhase::L4 => seen_l4 = true,
167 FetchPhase::L7 => seen_l7 = true,
168 }
169 if let Some(n) = next_response {
170 queue.push_back(*n);
171 }
172 if let Some(n) = next_tunnel {
173 queue.push_back(*n);
174 }
175 }
176 Node::Upgrade { next } => queue.push_back(*next),
177 Node::Terminate(_) => {}
178 }
179 }
180 match (seen_l4, seen_l7) {
181 (true, true) => ListenerKind::Auto,
182 (false, true) => ListenerKind::Http,
183 (true | false, false) => ListenerKind::Raw,
188 }
189}
190
191fn validate_listener_fetches(
212 addrs: &[SocketAddr],
213 listener_transport: Transport,
214 nodes: &[Node],
215 fetches: &[SymbolicFetchRef],
216 entry: NodeId,
217) -> Result<(), Error> {
218 let mut visited = std::collections::HashSet::new();
219 let mut queue = std::collections::VecDeque::from([entry]);
220 while let Some(id) = queue.pop_front() {
221 if !visited.insert(id) {
222 continue;
223 }
224 let Some(node) = nodes.get(id.get() as usize) else { continue };
225 match node {
226 Node::Check { on_match, on_miss, .. } => {
227 queue.push_back(*on_match);
228 queue.push_back(*on_miss);
229 }
230 Node::Middleware { next, on_error, .. } => {
231 queue.push_back(*next);
232 if let Some(e) = on_error {
233 queue.push_back(*e);
234 }
235 }
236 Node::Fetch { id, next_response, next_tunnel, .. } => {
237 let fetch = &fetches[id.get() as usize];
238 if matches!(fetch.kind, FetchKind::L4Forward) {
239 let fetch_transport =
240 match fetch.args.get("transport").and_then(serde_json::Value::as_str) {
241 Some("udp") => Some(Transport::Udp),
242 Some("tcp") | None => Some(Transport::Tcp),
243 Some(other) => {
244 return Err(Error::compile(format!(
245 "listener {addrs:?}: L4Forward fetch carries unknown transport {other:?}",
246 )));
247 }
248 };
249 if let Some(ft) = fetch_transport
250 && ft != listener_transport
251 {
252 let upstream =
253 fetch.args.get("upstream").and_then(serde_json::Value::as_str).unwrap_or("<unknown>");
254 return Err(Error::compile(format!(
255 "listener {addrs:?} declared {listener_transport:?} but reachable L4Forward (upstream {upstream:?}) carries transport {ft:?} — listener prefix and fetch transport must agree",
256 )));
257 }
258 }
259 if let Some(n) = next_response {
260 queue.push_back(*n);
261 }
262 if let Some(n) = next_tunnel {
263 queue.push_back(*n);
264 }
265 }
266 Node::Upgrade { next } => queue.push_back(*next),
267 Node::Terminate(_) => {}
268 }
269 }
270 Ok(())
271}
272
273fn peek_retry_buffer_required(args: &serde_json::Value) -> bool {
284 let Some(retry) = args.get("retry") else {
285 return false;
286 };
287 let max_attempts = retry.get("max_attempts").and_then(serde_json::Value::as_u64).unwrap_or(1);
288 if max_attempts <= 1 {
289 return false;
290 }
291 let buffering =
292 retry.get("buffering").and_then(serde_json::Value::as_str).unwrap_or("opportunistic");
293 buffering == "force"
294}
295
296#[cfg(test)]
297pub(crate) mod test_only {
298 use std::net::SocketAddr;
299
300 use super::{
301 Error, ListenerKind, Node, NodeId, SymbolicFetchRef, Transport, derive_listener_kind,
302 parse_listen, validate_listener_fetches,
303 };
304
305 pub(crate) fn derive_listener_kind_for_test(
309 nodes: &[Node],
310 fetches: &[SymbolicFetchRef],
311 entry: NodeId,
312 ) -> ListenerKind {
313 derive_listener_kind(nodes, fetches, entry)
314 }
315
316 pub(crate) fn parse_listen_for_test(spec: &str) -> Result<(Transport, Vec<SocketAddr>), Error> {
317 parse_listen(spec)
318 }
319
320 pub(crate) fn validate_listener_fetches_for_test(
321 addrs: &[SocketAddr],
322 listener_transport: Transport,
323 nodes: &[Node],
324 fetches: &[SymbolicFetchRef],
325 entry: NodeId,
326 ) -> Result<(), Error> {
327 validate_listener_fetches(addrs, listener_transport, nodes, fetches, entry)
328 }
329}
330
331#[cfg(test)]
332mod listen_parse_tests {
333 use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
334
335 use super::test_only::parse_listen_for_test;
336 use crate::conn_context::Transport;
337
338 fn parse(s: &str) -> (Transport, Vec<SocketAddr>) {
339 parse_listen_for_test(s).expect("parse listen ok")
340 }
341
342 #[test]
343 fn bare_dual_stack_defaults_to_tcp() {
344 let (t, addrs) = parse(":443");
345 assert_eq!(t, Transport::Tcp);
346 assert_eq!(
347 addrs,
348 vec![
349 SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 443),
350 SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 443),
351 ]
352 );
353 }
354
355 #[test]
356 fn bare_specific_v4_defaults_to_tcp() {
357 let (t, addrs) = parse("0.0.0.0:443");
358 assert_eq!(t, Transport::Tcp);
359 assert_eq!(addrs, vec![SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 443)]);
360 }
361
362 #[test]
363 fn bare_specific_v6_defaults_to_tcp() {
364 let (t, addrs) = parse("[::]:443");
365 assert_eq!(t, Transport::Tcp);
366 assert_eq!(addrs, vec![SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 443)]);
367 }
368
369 #[test]
370 fn tcp_prefix_dual_stack_yields_tcp() {
371 let (t, addrs) = parse("tcp:443");
372 assert_eq!(t, Transport::Tcp);
373 assert_eq!(addrs.len(), 2, "dual-stack expansion preserved under prefix");
374 }
375
376 #[test]
377 fn udp_prefix_dual_stack_yields_udp() {
378 let (t, addrs) = parse("udp:443");
379 assert_eq!(t, Transport::Udp);
380 assert_eq!(addrs.len(), 2, "dual-stack expansion preserved under prefix");
381 }
382
383 #[test]
384 fn tcp_prefix_specific_v4_address() {
385 let (t, addrs) = parse("tcp:0.0.0.0:443");
386 assert_eq!(t, Transport::Tcp);
387 assert_eq!(addrs, vec![SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 443)]);
388 }
389
390 #[test]
391 fn udp_prefix_v6_unspecified() {
392 let (t, addrs) = parse("udp:[::]:443");
393 assert_eq!(t, Transport::Udp);
394 assert_eq!(addrs, vec![SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 443)]);
395 }
396
397 #[test]
398 fn tcp_prefix_v6_specific_loopback() {
399 let (t, addrs) = parse("tcp:[::1]:443");
400 assert_eq!(t, Transport::Tcp);
401 assert_eq!(addrs, vec![SocketAddr::new(IpAddr::V6(Ipv6Addr::LOCALHOST), 443)]);
402 }
403
404 #[test]
405 fn uppercase_prefix_rejected() {
406 let err = parse_listen_for_test("TCP:443").expect_err("uppercase prefix must reject");
409 assert!(err.to_string().contains("bad listen spec"), "{err}");
410 }
411
412 #[test]
413 fn unknown_prefix_rejected() {
414 let err = parse_listen_for_test("udpx:443").expect_err("unknown prefix must reject");
417 assert!(err.to_string().contains("bad listen spec"), "{err}");
418 }
419
420 #[test]
421 fn udp_prefix_with_zero_port_rejected() {
422 let err = parse_listen_for_test("udp::0").expect_err("port 0 must reject");
425 assert!(err.to_string().contains("wildcard port rejected"), "{err}");
426 }
427
428 #[test]
429 fn tcp_prefix_with_zero_port_rejected() {
430 let err = parse_listen_for_test("tcp::0").expect_err("port 0 must reject");
431 assert!(err.to_string().contains("wildcard port rejected"), "{err}");
432 }
433
434 #[test]
435 fn udp_double_colon_strips_one_prefix() {
436 let (t, addrs) = parse("udp::443");
439 assert_eq!(t, Transport::Udp);
440 assert_eq!(addrs.len(), 2);
441 assert_eq!(addrs[0].port(), 443);
442 }
443}
444
445#[cfg(test)]
446mod listener_fetch_validation_tests {
447 use std::net::SocketAddr;
448 use std::str::FromStr as _;
449
450 use super::test_only::validate_listener_fetches_for_test;
451 use crate::conn_context::Transport;
452 use crate::fetch::{FetchKind, SymbolicFetchRef, Terminator};
453 use crate::ir::{FetchId, Node, NodeId, TerminatorId};
454
455 fn fetch_node(id: u32, term: u32) -> Node {
456 Node::Fetch {
457 id: FetchId::new(id),
458 next_response: None,
459 next_tunnel: Some(NodeId::new(term)),
460 collect_body_before: None,
461 body_limit: 0,
462 }
463 }
464
465 fn l4_fetch(transport: &str) -> SymbolicFetchRef {
466 SymbolicFetchRef {
467 kind: FetchKind::L4Forward,
468 args: serde_json::json!({ "upstream": "127.0.0.1:9", "transport": transport }),
469 retry_buffer_required: false,
470 allow_zero_rtt: None,
471 }
472 }
473
474 fn l7_fetch() -> SymbolicFetchRef {
475 SymbolicFetchRef {
476 kind: FetchKind::HttpProxy,
477 args: serde_json::json!({ "upstream": "127.0.0.1:9" }),
478 retry_buffer_required: false,
479 allow_zero_rtt: None,
480 }
481 }
482
483 fn addr() -> Vec<SocketAddr> {
484 vec![SocketAddr::from_str("0.0.0.0:443").expect("addr")]
485 }
486
487 #[test]
488 fn udp_listener_with_udp_l4_forward_passes() {
489 let nodes = vec![fetch_node(0, 1), Node::Terminate(TerminatorId::new(0))];
490 let fetches = vec![l4_fetch("udp")];
491 validate_listener_fetches_for_test(&addr(), Transport::Udp, &nodes, &fetches, NodeId::new(0))
492 .expect("udp listener + udp L4Forward must pass");
493 }
494
495 #[test]
496 fn tcp_listener_with_tcp_l4_forward_passes() {
497 let nodes = vec![fetch_node(0, 1), Node::Terminate(TerminatorId::new(0))];
498 let fetches = vec![l4_fetch("tcp")];
499 validate_listener_fetches_for_test(&addr(), Transport::Tcp, &nodes, &fetches, NodeId::new(0))
500 .expect("tcp listener + tcp L4Forward must pass");
501 }
502
503 #[test]
504 fn tcp_listener_with_l4_forward_default_transport_passes() {
505 let nodes = vec![fetch_node(0, 1), Node::Terminate(TerminatorId::new(0))];
507 let fetches = vec![SymbolicFetchRef {
508 kind: FetchKind::L4Forward,
509 args: serde_json::json!({ "upstream": "127.0.0.1:9" }),
510 retry_buffer_required: false,
511 allow_zero_rtt: None,
512 }];
513 validate_listener_fetches_for_test(&addr(), Transport::Tcp, &nodes, &fetches, NodeId::new(0))
514 .expect("tcp listener + default-transport L4Forward must pass");
515 }
516
517 #[test]
518 fn tcp_listener_with_udp_l4_forward_compile_errors() {
519 let nodes = vec![fetch_node(0, 1), Node::Terminate(TerminatorId::new(0))];
520 let fetches = vec![l4_fetch("udp")];
521 let err =
522 validate_listener_fetches_for_test(&addr(), Transport::Tcp, &nodes, &fetches, NodeId::new(0))
523 .expect_err("tcp listener + udp L4Forward must reject");
524 let msg = err.to_string();
525 assert!(msg.contains("0.0.0.0:443"), "error names listener address: {msg}");
526 assert!(msg.contains("Tcp"), "error names listener transport: {msg}");
527 assert!(msg.contains("Udp"), "error names fetch transport: {msg}");
528 assert!(msg.contains("127.0.0.1:9"), "error names offending fetch: {msg}");
529 }
530
531 #[test]
532 fn udp_listener_with_tcp_l4_forward_compile_errors() {
533 let nodes = vec![fetch_node(0, 1), Node::Terminate(TerminatorId::new(0))];
534 let fetches = vec![l4_fetch("tcp")];
535 let err =
536 validate_listener_fetches_for_test(&addr(), Transport::Udp, &nodes, &fetches, NodeId::new(0))
537 .expect_err("udp listener + tcp L4Forward must reject");
538 let msg = err.to_string();
539 assert!(msg.contains("0.0.0.0:443"), "error names listener address: {msg}");
540 assert!(msg.contains("Udp"), "error names listener transport: {msg}");
541 assert!(msg.contains("Tcp"), "error names fetch transport: {msg}");
542 }
543
544 #[test]
545 fn udp_listener_with_l7_only_passes() {
546 let nodes = vec![
550 Node::Upgrade { next: NodeId::new(1) },
551 fetch_node(0, 2),
552 Node::Terminate(TerminatorId::new(0)),
553 ];
554 let fetches = vec![l7_fetch()];
555 let _ = Terminator::WriteHttpResponse;
556 validate_listener_fetches_for_test(&addr(), Transport::Udp, &nodes, &fetches, NodeId::new(0))
557 .expect("udp listener + l7-only must pass (kind derivation handles Http)");
558 }
559
560 #[test]
561 fn udp_listener_with_mixed_l4_branches_compile_errors() {
562 let nodes = vec![
565 Node::Check {
566 predicate: crate::ir::PredicateId::new(0),
567 on_match: NodeId::new(1),
568 on_miss: NodeId::new(3),
569 collect_body_before: None,
570 body_limit: 0,
571 },
572 fetch_node(0, 2),
573 Node::Terminate(TerminatorId::new(0)),
574 fetch_node(1, 2),
575 ];
576 let fetches = vec![l4_fetch("udp"), l4_fetch("tcp")];
577 let err =
578 validate_listener_fetches_for_test(&addr(), Transport::Udp, &nodes, &fetches, NodeId::new(0))
579 .expect_err("udp listener + mixed L4 branches must reject");
580 assert!(err.to_string().contains("must agree"), "{err}");
581 }
582}
583
584struct Builder {
585 nodes: Vec<Node>,
586 predicates: Vec<PredicateInst>,
587 pred_dedup: HashMap<PredicateInst, PredicateId>,
588 middlewares: Vec<SymbolicMiddlewareRef>,
589 mw_dedup: HashMap<(String, String), MiddlewareId>,
590 fetches: Vec<SymbolicFetchRef>,
591 terminators: Vec<Terminator>,
592 term_dedup: HashMap<Terminator, TerminatorId>,
593 entries: HashMap<SocketAddr, NodeId>,
594 short_circuit_response_entry: std::collections::BTreeMap<NodeId, NodeId>,
600 listener_tls: std::collections::BTreeMap<SocketAddr, crate::rule::ListenerTlsSpec>,
605 listener_kinds: std::collections::BTreeMap<SocketAddr, ListenerKind>,
609 listener_transports: std::collections::BTreeMap<SocketAddr, Transport>,
613}
614
615impl Builder {
616 fn new() -> Self {
617 Self {
618 nodes: Vec::new(),
619 predicates: Vec::new(),
620 pred_dedup: HashMap::new(),
621 middlewares: Vec::new(),
622 mw_dedup: HashMap::new(),
623 fetches: Vec::new(),
624 terminators: Vec::new(),
625 term_dedup: HashMap::new(),
626 entries: HashMap::new(),
627 short_circuit_response_entry: std::collections::BTreeMap::new(),
628 listener_tls: std::collections::BTreeMap::new(),
629 listener_kinds: std::collections::BTreeMap::new(),
630
631 listener_transports: std::collections::BTreeMap::new(),
632 }
633 }
634
635 fn intern_predicate(&mut self, p: PredicateInst) -> PredicateId {
636 if let Some(&id) = self.pred_dedup.get(&p) {
637 return id;
638 }
639 let id = PredicateId::new(u32::try_from(self.predicates.len()).expect("predicate id fits u32"));
640 self.predicates.push(p.clone());
641 self.pred_dedup.insert(p, id);
642 id
643 }
644
645 fn intern_middleware(&mut self, r: SymbolicMiddlewareRef) -> MiddlewareId {
646 if r.stateless {
647 let key = (r.name.to_string(), canonical_json(&r.args));
648 if let Some(&id) = self.mw_dedup.get(&key) {
649 return id;
650 }
651 let id =
652 MiddlewareId::new(u32::try_from(self.middlewares.len()).expect("middleware id fits u32"));
653 self.middlewares.push(r);
654 self.mw_dedup.insert(key, id);
655 id
656 } else {
657 let id =
658 MiddlewareId::new(u32::try_from(self.middlewares.len()).expect("middleware id fits u32"));
659 self.middlewares.push(r);
660 id
661 }
662 }
663
664 fn push_fetch(&mut self, r: SymbolicFetchRef) -> FetchId {
665 let id = FetchId::new(u32::try_from(self.fetches.len()).expect("fetch id fits u32"));
666 self.fetches.push(r);
667 id
668 }
669
670 fn intern_terminator(&mut self, t: Terminator) -> TerminatorId {
671 if let Some(&id) = self.term_dedup.get(&t) {
672 return id;
673 }
674 let id =
675 TerminatorId::new(u32::try_from(self.terminators.len()).expect("terminator id fits u32"));
676 self.terminators.push(t);
677 self.term_dedup.insert(t, id);
678 id
679 }
680
681 fn push_node(&mut self, n: Node) -> NodeId {
682 let id = NodeId::new(u32::try_from(self.nodes.len()).expect("node id fits u32"));
683 self.nodes.push(n);
684 id
685 }
686
687 fn lower_port(
688 &mut self,
689 rules: &[&AnalyzedRule],
690 mw_meta: &dyn MiddlewareMetadataProvider,
691 fetch_meta: &dyn FetchMetadataProvider,
692 ) -> Result<NodeId, Error> {
693 let posture = rules.first().map_or(Posture::L7, |r| r.posture);
694 if rules.iter().any(|r| r.posture != posture) {
695 return Err(Error::compile(
696 "mixed L4 and L7 rules on one listener require protocol_detect".to_string(),
697 ));
698 }
699
700 let mut ordered: Vec<&AnalyzedRule> = rules.to_vec();
702 ordered.sort_by(|a, b| {
703 b.inspection_level
704 .cmp(&a.inspection_level)
705 .then(b.specificity.cmp(&a.specificity))
706 .then(a.raw.name.cmp(&b.raw.name))
707 });
708
709 let needs_fallback = ordered.iter().any(|r| r.raw.match_predicate.is_some());
717 let fallback_miss =
718 if needs_fallback { self.synthesize_default_miss() } else { NodeId::new(0) };
719
720 let mut current_miss = fallback_miss;
728 for rule in ordered.iter().rev() {
729 let chain_entry = self.lower_rule(rule, current_miss, mw_meta, fetch_meta)?;
730 current_miss = chain_entry;
731 }
732 let inner_entry = current_miss;
733
734 match posture {
735 Posture::L7 => {
736 let synth_tid = self.intern_terminator(Terminator::WriteHttpResponse);
753 let synth_node = self.push_node(Node::Terminate(synth_tid));
754 let listener_entry = self.push_node(Node::Upgrade { next: inner_entry });
755 self.short_circuit_response_entry.insert(inner_entry, synth_node);
756 Ok(listener_entry)
757 }
758 Posture::L4 => Ok(inner_entry),
759 }
760 }
761
762 fn synthesize_default_miss(&mut self) -> NodeId {
763 let tid = self.intern_terminator(Terminator::Close);
768 self.push_node(Node::Terminate(tid))
769 }
770
771 fn lower_rule(
772 &mut self,
773 rule: &AnalyzedRule,
774 on_miss: NodeId,
775 mw_meta: &dyn MiddlewareMetadataProvider,
776 fetch_meta: &dyn FetchMetadataProvider,
777 ) -> Result<NodeId, Error> {
778 let fetch_kind = rule.raw.terminate.kind;
785 let retry_buffer_required = peek_retry_buffer_required(&rule.raw.terminate.args);
786 let fid = self.push_fetch(SymbolicFetchRef {
787 kind: fetch_kind,
788 args: rule.raw.terminate.args.clone(),
789 retry_buffer_required,
790 allow_zero_rtt: rule.raw.allow_zero_rtt,
797 });
798 let (next_response, next_tunnel) = match fetch_kind {
799 FetchKind::HttpProxy | FetchKind::HttpSynthesize | FetchKind::AcmeChallenge => {
800 let tid = self.intern_terminator(Terminator::WriteHttpResponse);
801 let term_node = self.push_node(Node::Terminate(tid));
802 (Some(term_node), None)
803 }
804 FetchKind::L4Forward => {
805 let tid = self.intern_terminator(Terminator::ByteTunnel);
806 let term_node = self.push_node(Node::Terminate(tid));
807 (None, Some(term_node))
808 }
809 FetchKind::WebSocketUpgrade => {
810 let resp_tid = self.intern_terminator(Terminator::WriteHttpResponse);
811 let resp_node = self.push_node(Node::Terminate(resp_tid));
812 let tun_tid = self.intern_terminator(Terminator::ByteTunnel);
813 let tun_node = self.push_node(Node::Terminate(tun_tid));
814 (Some(resp_node), Some(tun_node))
815 }
816 };
817 let _ = fetch_meta;
818 let fetch_node_idx = self.nodes.len();
819 let fetch_node_id = NodeId::new(u32::try_from(fetch_node_idx).expect("node id fits u32"));
820 let (fetch_collect, fetch_body_limit) = if retry_buffer_required {
828 (Some(BodySide::Request), rule.raw.max_body_bytes_request)
829 } else {
830 (None, 0)
831 };
832 self.nodes.push(Node::Fetch {
833 id: fid,
834 next_response,
835 next_tunnel,
836 collect_body_before: fetch_collect,
837 body_limit: fetch_body_limit,
838 });
839
840 let mut head = fetch_node_id;
843 let mut req_first_reader_seen = false;
844 let mut resp_first_reader_seen = false;
845 for mw_ref in rule.raw.middleware_chain.iter().rev() {
847 let meta = mw_meta
848 .get(&mw_ref.name)
849 .ok_or_else(|| Error::compile(format!("unknown middleware: {:?}", mw_ref.name)))?;
850 let sym = SymbolicMiddlewareRef {
851 name: Arc::from(mw_ref.name.as_str()),
852 args: mw_ref.args.clone(),
853 kind: meta.kind,
854 stateless: meta.stateless,
855 needs_body: meta.needs_body,
856 on_error: None,
857 };
858 let id = self.intern_middleware(sym);
859 let node = Node::Middleware {
860 id,
861 next: head,
862 on_error: None,
863 collect_body_before: None,
864 body_limit: 0,
865 };
866 head = self.push_node(node);
867 }
868
869 let chain_entry_before_upgrade = head;
873 let _ = (&mut req_first_reader_seen, &mut resp_first_reader_seen);
874 if rule.needs_request_body {
875 self.mark_request_reader(
876 chain_entry_before_upgrade,
877 mw_meta,
878 rule.raw.max_body_bytes_request,
879 )?;
880 }
881 if rule.needs_response_body {
882 self.mark_response_reader(
883 chain_entry_before_upgrade,
884 mw_meta,
885 rule.raw.max_body_bytes_response,
886 )?;
887 }
888
889 let _ = rule.raw.match_predicate.as_ref().map(predicate_uniform_level).transpose()?;
902
903 if let Some(pred) = &rule.raw.match_predicate {
904 head = self.lower_predicate(pred, head, on_miss, &rule.raw.source)?;
905 }
906
907 Ok(head)
908 }
909
910 fn lower_predicate(
911 &mut self,
912 pred: &Predicate,
913 on_match: NodeId,
914 on_miss: NodeId,
915 source: &SourceInfo,
916 ) -> Result<NodeId, Error> {
917 match pred {
918 Predicate::Check(c) => {
919 let inst =
920 PredicateInst { path: c.path.clone(), op: compile_operator(&c.op, &c.path, source)? };
921 let pid = self.intern_predicate(inst);
922 let collect_body_before =
923 if matches!(c.path, FieldPath::HttpBody) { Some(BodySide::Request) } else { None };
924 let node =
925 Node::Check { predicate: pid, on_match, on_miss, collect_body_before, body_limit: 0 };
926 Ok(self.push_node(node))
927 }
928 Predicate::AnyOf(any_of) => {
929 if any_of.any_of.is_empty() {
934 return Ok(on_miss);
936 }
937 let mut cur_miss = on_miss;
938 for child in any_of.any_of.iter().rev() {
939 cur_miss = self.lower_predicate(child, on_match, cur_miss, source)?;
940 }
941 Ok(cur_miss)
942 }
943 Predicate::AllOf(all_of) => {
944 if all_of.all_of.is_empty() {
949 return Ok(on_match);
951 }
952 let mut cur_match = on_match;
953 for child in all_of.all_of.iter().rev() {
954 cur_match = self.lower_predicate(child, cur_match, on_miss, source)?;
955 }
956 Ok(cur_match)
957 }
958 Predicate::Not(not) => {
959 self.lower_predicate(¬.not, on_miss, on_match, source)
961 }
962 }
963 }
964
965 fn mark_request_reader(
966 &mut self,
967 chain_head: NodeId,
968 _mw_meta: &dyn MiddlewareMetadataProvider,
969 body_limit: usize,
970 ) -> Result<(), Error> {
971 self.mark_first_body_reader_dfs(chain_head, BodySide::Request, body_limit);
972 Ok(())
973 }
974
975 fn mark_response_reader(
976 &mut self,
977 chain_head: NodeId,
978 _mw_meta: &dyn MiddlewareMetadataProvider,
979 body_limit: usize,
980 ) -> Result<(), Error> {
981 self.mark_first_body_reader_dfs(chain_head, BodySide::Response, body_limit);
982 Ok(())
983 }
984
985 fn mark_first_body_reader_dfs(&mut self, chain_head: NodeId, side: BodySide, body_limit: usize) {
1003 use std::collections::HashSet;
1004 let mut stack: Vec<(NodeId, bool)> = vec![(chain_head, false)];
1005 let mut visited: HashSet<(u32, bool)> = HashSet::new();
1006 while let Some((cur, already_marked)) = stack.pop() {
1007 if !visited.insert((cur.get(), already_marked)) {
1008 continue;
1009 }
1010 let idx = cur.get() as usize;
1011 match &self.nodes[idx] {
1012 Node::Middleware { id, next, on_error, .. } => {
1013 let sym = &self.middlewares[id.get() as usize];
1014 let is_reader = match side {
1015 BodySide::Request => sym.kind == MiddlewareKind::L7Request && sym.needs_body,
1016 BodySide::Response => sym.kind == MiddlewareKind::L7Response && sym.needs_body,
1017 };
1018 let next_id = *next;
1019 let on_error_id = *on_error;
1020 let now_marked = if is_reader && !already_marked {
1021 if let Node::Middleware { collect_body_before, body_limit: bl, .. } =
1022 &mut self.nodes[idx]
1023 {
1024 *collect_body_before = Some(side);
1025 *bl = body_limit;
1026 }
1027 true
1028 } else {
1029 already_marked
1030 };
1031 stack.push((next_id, now_marked));
1032 if let Some(eid) = on_error_id {
1033 stack.push((eid, now_marked));
1034 }
1035 }
1036 Node::Check { on_match, on_miss, .. } => {
1037 let m = *on_match;
1038 let s = *on_miss;
1039 stack.push((m, already_marked));
1040 stack.push((s, already_marked));
1041 }
1042 Node::Fetch { next_response, next_tunnel, .. } => {
1043 if matches!(side, BodySide::Response) {
1048 if let Some(n) = next_response {
1049 stack.push((*n, already_marked));
1050 }
1051 if let Some(t) = next_tunnel {
1052 stack.push((*t, already_marked));
1053 }
1054 }
1055 }
1056 Node::Upgrade { next } => {
1057 let n = *next;
1058 stack.push((n, already_marked));
1059 }
1060 Node::Terminate(_) => {}
1061 }
1062 }
1063 }
1064}
1065
1066#[derive(Copy, Clone, Eq, PartialEq, Debug)]
1067enum Level {
1068 L4Only,
1069 L4Peek,
1070 L7Header,
1071 L7Body,
1072}
1073
1074fn field_path_level(path: &FieldPath) -> Level {
1075 match path {
1076 FieldPath::Transport
1077 | FieldPath::RemoteIp
1078 | FieldPath::RemotePort
1079 | FieldPath::LocalIp
1080 | FieldPath::LocalPort => Level::L4Only,
1081 FieldPath::Peek
1082 | FieldPath::TlsSni
1083 | FieldPath::TlsAlpn
1084 | FieldPath::TlsVersion
1085 | FieldPath::TlsPeerCertPresent
1086 | FieldPath::TlsPeerCertSubjectCn
1087 | FieldPath::TlsPeerCertSanDns
1088 | FieldPath::TlsPeerCertFingerprintSha256
1089 | FieldPath::TlsPeerCertSpkiSha256
1090 | FieldPath::TlsPeerCertIssuerCn
1091 | FieldPath::TlsPeerCertSerial => Level::L4Peek,
1092 FieldPath::HttpMethod
1093 | FieldPath::HttpUriPath
1094 | FieldPath::HttpUriQuery
1095 | FieldPath::HttpHeader(_) => Level::L7Header,
1096 FieldPath::HttpBody => Level::L7Body,
1097 }
1098}
1099
1100const fn level_is_l4(l: Level) -> bool {
1101 matches!(l, Level::L4Only | Level::L4Peek)
1102}
1103
1104fn predicate_uniform_level(pred: &Predicate) -> Result<Level, Error> {
1110 let mut acc: Option<Level> = None;
1111 collect_levels(pred, &mut acc)?;
1112 Ok(acc.unwrap_or(Level::L4Only))
1113}
1114
1115fn collect_levels(pred: &Predicate, acc: &mut Option<Level>) -> Result<(), Error> {
1116 match pred {
1117 Predicate::Check(c) => {
1118 let leaf = field_path_level(&c.path);
1119 match *acc {
1120 None => *acc = Some(leaf),
1121 Some(existing) if level_is_l4(existing) == level_is_l4(leaf) => {
1122 if (leaf as u8) > (existing as u8) {
1123 *acc = Some(leaf);
1124 }
1125 }
1126 Some(existing) => {
1127 return Err(Error::compile(format!(
1128 "cross-level any_of / all_of / not not supported: predicate mixes {existing:?} and {leaf:?} leaves"
1129 )));
1130 }
1131 }
1132 Ok(())
1133 }
1134 Predicate::AnyOf(a) => {
1135 for child in &a.any_of {
1136 collect_levels(child, acc)?;
1137 }
1138 Ok(())
1139 }
1140 Predicate::AllOf(a) => {
1141 for child in &a.all_of {
1142 collect_levels(child, acc)?;
1143 }
1144 Ok(())
1145 }
1146 Predicate::Not(n) => collect_levels(&n.not, acc),
1147 }
1148}
1149
1150#[allow(dead_code)]
1151fn predicate_is_l4(pred: Option<&Predicate>) -> bool {
1152 let Some(Predicate::Check(c)) = pred else {
1153 return false;
1154 };
1155 matches!(
1156 c.path,
1157 FieldPath::Transport
1158 | FieldPath::RemoteIp
1159 | FieldPath::RemotePort
1160 | FieldPath::LocalIp
1161 | FieldPath::LocalPort
1162 | FieldPath::Peek
1163 | FieldPath::TlsSni
1164 | FieldPath::TlsAlpn
1165 | FieldPath::TlsVersion
1166 | FieldPath::TlsPeerCertPresent
1167 | FieldPath::TlsPeerCertSubjectCn
1168 | FieldPath::TlsPeerCertSanDns
1169 | FieldPath::TlsPeerCertFingerprintSha256
1170 | FieldPath::TlsPeerCertSpkiSha256
1171 | FieldPath::TlsPeerCertIssuerCn
1172 | FieldPath::TlsPeerCertSerial
1173 )
1174}
1175
1176type ListenerGroup<'a> = (Transport, Vec<SocketAddr>, Vec<&'a AnalyzedRule>);
1177
1178fn route_tls_config_into_spec(
1209 addrs: &[SocketAddr],
1210 tls: &crate::rule::TlsConfig,
1211 spec: &mut crate::rule::ListenerTlsSpec,
1212) -> Result<(), Error> {
1213 if let Some(managed) = tls.managed.as_ref() {
1214 let sni_key =
1215 tls.sni.as_deref().expect("managed validated requires tls.sni").to_ascii_lowercase();
1216 if spec.sni_certs.contains_key(&sni_key) {
1217 return Err(Error::compile(format!(
1218 "listener {addrs:?}: SNI {sni_key:?} declared as both static and managed — pick one source"
1219 )));
1220 }
1221 match spec.managed_snis.get(&sni_key) {
1222 None => {
1223 spec.managed_snis.insert(sni_key, managed.clone());
1224 }
1225 Some(existing) if existing == managed => {}
1226 Some(_) => {
1227 return Err(Error::compile(format!(
1228 "listener {addrs:?}: SNI {sni_key:?} mapped to two different `tls.managed` blocks"
1229 )));
1230 }
1231 }
1232 return Ok(());
1233 }
1234
1235 let normalised_sni = tls.sni.as_deref().map(str::to_ascii_lowercase);
1236 let normalised = crate::rule::TlsConfig {
1237 sni: normalised_sni.clone(),
1238 cert_file: tls.cert_file.clone(),
1239 key_file: tls.key_file.clone(),
1240 managed: None,
1241 enable_zero_rtt: tls.enable_zero_rtt,
1242 client_auth: tls.client_auth.clone(),
1243 ocsp_path: tls.ocsp_path.clone(),
1244 ocsp_fetch: tls.ocsp_fetch,
1245 };
1246 match normalised_sni {
1247 None => match &spec.default {
1248 None => spec.default = Some(normalised),
1249 Some(existing) if existing == &normalised => {}
1250 Some(existing) => {
1251 return Err(Error::compile(format!(
1252 "listener {addrs:?}: more than one default (sni-less) cert — {} vs {} — at most one cert may omit `sni`",
1253 display_cert_file(existing),
1254 display_cert_file(&normalised),
1255 )));
1256 }
1257 },
1258 Some(sni_key) => {
1259 if spec.managed_snis.contains_key(&sni_key) {
1260 return Err(Error::compile(format!(
1261 "listener {addrs:?}: SNI {sni_key:?} declared as both static and managed — pick one source"
1262 )));
1263 }
1264 match spec.sni_certs.get(&sni_key) {
1265 None => {
1266 spec.sni_certs.insert(sni_key, normalised);
1267 }
1268 Some(existing) if existing == &normalised => {}
1269 Some(existing) => {
1270 return Err(Error::compile(format!(
1271 "listener {addrs:?}: SNI {sni_key:?} mapped to two different certs — {} vs {}",
1272 display_cert_file(existing),
1273 display_cert_file(&normalised),
1274 )));
1275 }
1276 }
1277 }
1278 }
1279 Ok(())
1280}
1281
1282fn resolve_listener_tls(
1283 addrs: &[SocketAddr],
1284 rules: &[&AnalyzedRule],
1285) -> Result<Option<crate::rule::ListenerTlsSpec>, Error> {
1286 let any_l4 = rules.iter().any(|r| r.posture == Posture::L4);
1287 let any_tls = rules.iter().any(|r| r.raw.tls.is_some());
1288 if any_l4 && any_tls {
1289 return Err(Error::compile(format!(
1290 "listener {addrs:?}: TLS termination is L7-only — remove `tls` or change the terminator to an L7 type (http_proxy / static / websocket / redirect_https)"
1291 )));
1292 }
1293
1294 let mut spec = crate::rule::ListenerTlsSpec {
1295 default: None,
1296 sni_certs: BTreeMap::new(),
1297 managed_snis: BTreeMap::new(),
1298 client_auth: crate::rule::ClientAuthSpec::None,
1299 enable_zero_rtt: false,
1300 };
1301 for rule in rules {
1302 let Some(tls) = rule.raw.tls.as_ref() else { continue };
1303 route_tls_config_into_spec(addrs, tls, &mut spec)?;
1310 }
1311
1312 let mut resolved: Option<crate::rule::ClientAuthSpec> = None;
1323 let mut saw_any_tls_rule = false;
1324 for rule in rules {
1325 let Some(tls) = rule.raw.tls.as_ref() else { continue };
1326 saw_any_tls_rule = true;
1327 let candidate = match tls.client_auth.as_ref() {
1328 Some(ca) => compile_client_auth(addrs, ca)?,
1329 None => crate::rule::ClientAuthSpec::None,
1330 };
1331 match &resolved {
1332 None => resolved = Some(candidate),
1333 Some(existing) if existing == &candidate => {}
1334 Some(existing) => {
1335 return Err(Error::compile(format!(
1336 "listener {addrs:?}: rules disagree on `client_auth` posture — saw {existing:?} and {candidate:?}; mTLS is per-listener so every rule must declare the same `client_auth` (or all omit it)"
1337 )));
1338 }
1339 }
1340 }
1341 if saw_any_tls_rule {
1342 spec.client_auth = resolved.unwrap_or(crate::rule::ClientAuthSpec::None);
1343 }
1344
1345 let mut zero_rtt_resolved: Option<bool> = None;
1351 for rule in rules {
1352 let Some(tls) = rule.raw.tls.as_ref() else { continue };
1353 match zero_rtt_resolved {
1354 None => zero_rtt_resolved = Some(tls.enable_zero_rtt),
1355 Some(existing) if existing == tls.enable_zero_rtt => {}
1356 Some(_) => {
1357 return Err(Error::compile(format!(
1358 "listener {addrs:?}: rules disagree on `tls.enable_zero_rtt` — 0-RTT is a listener-level setting (the listener has one TLS config); every rule on the same address must agree"
1359 )));
1360 }
1361 }
1362 }
1363 if let Some(z) = zero_rtt_resolved {
1364 spec.enable_zero_rtt = z;
1365 }
1366
1367 if spec.is_empty() { Ok(None) } else { Ok(Some(spec)) }
1368}
1369
1370fn display_cert_file(tls: &crate::rule::TlsConfig) -> String {
1375 match &tls.cert_file {
1376 Some(p) => p.display().to_string(),
1377 None => "<managed>".to_owned(),
1378 }
1379}
1380
1381fn inject_acme_http01_routes(
1405 builder: &mut Builder,
1406) -> Result<Vec<crate::ir::DryRunAnnotation>, Error> {
1407 let mut annotations = Vec::new();
1408 let any_http01 = builder.listener_tls.values().any(|spec| {
1409 spec.managed_snis.values().any(|m| matches!(m.challenge, crate::rule::ChallengeKind::Http01))
1410 });
1411 if !any_http01 {
1412 return Ok(annotations);
1413 }
1414
1415 let targets: Vec<SocketAddr> = builder
1418 .listener_kinds
1419 .iter()
1420 .filter(|(addr, kind)| {
1421 addr.port() == 80
1422 && matches!(kind, ListenerKind::Http | ListenerKind::Auto)
1423 && !builder.listener_tls.contains_key(addr)
1424 })
1425 .map(|(addr, _)| *addr)
1426 .collect();
1427
1428 if targets.is_empty() {
1429 return Ok(annotations);
1430 }
1431
1432 let predicate = PredicateInst {
1437 path: crate::predicate::FieldPath::HttpUriPath,
1438 op: crate::predicate::CompiledOperator::Prefix(bytes::Bytes::from_static(
1439 b"/.well-known/acme-challenge/",
1440 )),
1441 };
1442 let pred_id = builder.intern_predicate(predicate);
1443 let acme_fetch_ref = SymbolicFetchRef {
1444 kind: FetchKind::AcmeChallenge,
1445 args: serde_json::Value::Null,
1446 retry_buffer_required: false,
1447 allow_zero_rtt: None,
1448 };
1449 let fetch_id = builder.push_fetch(acme_fetch_ref);
1450 let term_id = builder.intern_terminator(Terminator::WriteHttpResponse);
1451 let term_node = builder.push_node(Node::Terminate(term_id));
1452 let fetch_node = builder.push_node(Node::Fetch {
1453 id: fetch_id,
1454 next_response: Some(term_node),
1455 next_tunnel: None,
1456 collect_body_before: None,
1457 body_limit: 0,
1458 });
1459
1460 for addr in targets {
1461 let original_entry = *builder.entries.get(&addr).ok_or_else(|| {
1462 Error::internal(format!(
1463 "invariant: listener_kinds names {addr} but builder.entries has no matching listener-entry node; ACME http-01 injection cannot proceed",
1464 ))
1465 })?;
1466 let Some(original_l7_entry) = find_post_upgrade_node(&builder.nodes, original_entry) else {
1475 continue;
1476 };
1477 let check_node = builder.push_node(Node::Check {
1478 predicate: pred_id,
1479 on_match: fetch_node,
1480 on_miss: original_l7_entry,
1481 collect_body_before: None,
1482 body_limit: 0,
1483 });
1484 rewire_post_upgrade(&mut builder.nodes, original_entry, check_node);
1487 annotations.push(crate::ir::DryRunAnnotation {
1488 kind: "acme-injected".to_owned(),
1489 message: format!("acme http-01 challenge route injected on plaintext :80 listener {addr}"),
1490 });
1491 }
1492
1493 Ok(annotations)
1494}
1495
1496fn find_post_upgrade_node(nodes: &[Node], entry: NodeId) -> Option<NodeId> {
1500 match nodes.get(entry.get() as usize)? {
1501 Node::Upgrade { next } => Some(*next),
1502 _ => Some(entry),
1506 }
1507}
1508
1509fn rewire_post_upgrade(nodes: &mut [Node], entry: NodeId, target: NodeId) {
1514 if let Some(Node::Upgrade { next }) = nodes.get_mut(entry.get() as usize) {
1515 *next = target;
1516 }
1517}
1518
1519fn warn_missing_plaintext_port_80_for_http01(
1535 listener_tls: &std::collections::BTreeMap<SocketAddr, crate::rule::ListenerTlsSpec>,
1536 listener_kinds: &std::collections::BTreeMap<SocketAddr, ListenerKind>,
1537) {
1538 let any_http01 = listener_tls.values().any(|spec| {
1539 spec.managed_snis.values().any(|m| matches!(m.challenge, crate::rule::ChallengeKind::Http01))
1540 });
1541 if !any_http01 {
1542 return;
1543 }
1544 let has_plaintext_80 = listener_kinds.iter().any(|(addr, kind)| {
1545 addr.port() == 80
1546 && matches!(kind, ListenerKind::Http | ListenerKind::Auto)
1547 && !listener_tls.contains_key(addr)
1548 });
1549 if !has_plaintext_80 {
1550 tracing::warn!(
1551 target: "vane::compile::acme",
1552 "http-01 challenge declared but no plaintext :80 listener exists; \
1553 vaned will auto-bind :80 at runtime — the bind may fail without \
1554 CAP_NET_BIND_SERVICE or if the port is already in use",
1555 );
1556 }
1557}
1558
1559fn validate_zero_rtt_for_listener(
1572 addrs: &[SocketAddr],
1573 rules: &[&AnalyzedRule],
1574 resolved_tls: Option<&crate::rule::ListenerTlsSpec>,
1575) -> Result<(), Error> {
1576 let tls_l7 = resolved_tls.is_some();
1577 let listener_enable_zero_rtt = resolved_tls.is_some_and(|s| s.enable_zero_rtt);
1578
1579 for rule in rules {
1580 match (tls_l7, rule.raw.allow_zero_rtt) {
1581 (true, None) => {
1582 return Err(Error::compile(format!(
1583 "rule {:?} on TLS-terminating listener {addrs:?}: `allow_zero_rtt` is required (no implicit default) — set it to `true` or `false`",
1584 rule.raw.name
1585 )));
1586 }
1587 (false, Some(_)) => {
1588 return Err(Error::compile(format!(
1589 "rule {:?} on listener {addrs:?}: `allow_zero_rtt` is meaningful only on L7 rules whose listener is TLS-terminating — drop the field",
1590 rule.raw.name
1591 )));
1592 }
1593 (true, Some(true)) => {
1594 if !listener_enable_zero_rtt {
1595 return Err(Error::compile(format!(
1596 "allow_zero_rtt: true on rule {:?} but listener {addrs:?} has enable_zero_rtt: false",
1597 rule.raw.name
1598 )));
1599 }
1600 if !predicate_constrains_method_to_idempotent(rule.raw.match_predicate.as_ref()) {
1601 return Err(Error::compile(format!(
1602 "allow_zero_rtt: true on rule {:?} requires a method constraint restricted to GET / HEAD / OPTIONS",
1603 rule.raw.name
1604 )));
1605 }
1606 }
1607 (true, Some(false)) | (false, None) => {}
1608 }
1609 }
1610 Ok(())
1611}
1612
1613fn predicate_constrains_method_to_idempotent(pred: Option<&Predicate>) -> bool {
1633 let Some(pred) = pred else {
1634 return false;
1635 };
1636 match pred {
1637 Predicate::Check(c) => check_is_idempotent_method(c),
1638 Predicate::AllOf(a) => {
1639 a.all_of.iter().any(|child| predicate_constrains_method_to_idempotent(Some(child)))
1640 }
1641 Predicate::AnyOf(a) => {
1642 !a.any_of.is_empty()
1643 && a.any_of.iter().all(|child| predicate_constrains_method_to_idempotent(Some(child)))
1644 }
1645 Predicate::Not(_) => false,
1646 }
1647}
1648
1649fn check_is_idempotent_method(c: &crate::predicate::CheckMap) -> bool {
1650 use crate::predicate::{Operator, Value as PredValue};
1651 if !matches!(c.path, FieldPath::HttpMethod) {
1652 return false;
1653 }
1654 match &c.op {
1655 Operator::Equals(PredValue::Str(s)) => is_idempotent_method(s),
1656 Operator::In(values) => {
1657 !values.is_empty()
1658 && values.iter().all(|v| matches!(v, PredValue::Str(s) if is_idempotent_method(s)))
1659 }
1660 _ => false,
1661 }
1662}
1663
1664fn is_idempotent_method(method: &str) -> bool {
1665 matches!(method, "GET" | "HEAD" | "OPTIONS")
1666}
1667
1668fn compile_client_auth(
1672 addrs: &[SocketAddr],
1673 ca: &crate::rule::ClientAuthConfig,
1674) -> Result<crate::rule::ClientAuthSpec, Error> {
1675 use crate::rule::{ClientAuthMode, ClientAuthSpec};
1676 match ca.mode {
1677 ClientAuthMode::None => {
1678 if ca.trust_store.is_some() {
1679 return Err(Error::compile(format!(
1680 "listener {addrs:?}: `client_auth.mode = \"none\"` cannot carry a `trust_store` — drop the trust_store or change the mode"
1681 )));
1682 }
1683 Ok(ClientAuthSpec::None)
1684 }
1685 ClientAuthMode::Request | ClientAuthMode::Require => {
1686 let Some(ts) = ca.trust_store.clone() else {
1687 return Err(Error::compile(format!(
1688 "listener {addrs:?}: `client_auth.mode = \"{}\"` requires a `trust_store`",
1689 match ca.mode {
1690 ClientAuthMode::Request => "request",
1691 ClientAuthMode::Require => "require",
1692 ClientAuthMode::None => unreachable!(),
1693 }
1694 )));
1695 };
1696 if ts.ca_paths.is_empty() && ts.ca_dir.is_none() {
1697 return Err(Error::compile(format!(
1698 "listener {addrs:?}: `trust_store` requires at least one of `ca_paths` or `ca_dir`"
1699 )));
1700 }
1701 Ok(match ca.mode {
1702 ClientAuthMode::Request => ClientAuthSpec::Request { trust_store: ts },
1703 ClientAuthMode::Require => ClientAuthSpec::Require { trust_store: ts },
1704 ClientAuthMode::None => unreachable!(),
1705 })
1706 }
1707 }
1708}
1709
1710fn group_by_listener<'a>(rules: &'a [AnalyzedRule]) -> Result<Vec<ListenerGroup<'a>>, Error> {
1711 let mut groups: HashMap<(Transport, Vec<SocketAddr>), Vec<&'a AnalyzedRule>> = HashMap::new();
1720 for rule in rules {
1721 let mut transport: Option<Transport> = None;
1722 let mut addrs: Vec<SocketAddr> = Vec::new();
1723 for spec in &rule.raw.listen {
1724 let (t, more) = parse_listen(spec)?;
1725 match transport {
1726 None => transport = Some(t),
1727 Some(existing) if existing == t => {}
1728 Some(existing) => {
1729 return Err(Error::compile(format!(
1730 "rule {:?}: `listen` mixes transports {:?} and {:?} in one rule — split into separate rules",
1731 rule.raw.name, existing, t,
1732 )));
1733 }
1734 }
1735 addrs.extend(more);
1736 }
1737 addrs.sort();
1738 addrs.dedup();
1739 let transport = transport.unwrap_or(Transport::Tcp);
1744 groups.entry((transport, addrs)).or_default().push(rule);
1745 }
1746 let mut out: Vec<ListenerGroup<'_>> =
1747 groups.into_iter().map(|((transport, addrs), rules)| (transport, addrs, rules)).collect();
1748 out.sort_by(|a, b| a.1.cmp(&b.1).then(a.0.cmp(&b.0)));
1749 Ok(out)
1750}
1751
1752fn parse_listen(spec: &str) -> Result<(Transport, Vec<SocketAddr>), Error> {
1771 let s = spec.trim();
1772 let (transport, rest, prefix_seen) = if let Some(rest) = s.strip_prefix("tcp:") {
1773 (Transport::Tcp, rest, true)
1774 } else if let Some(rest) = s.strip_prefix("udp:") {
1775 (Transport::Udp, rest, true)
1776 } else {
1777 (Transport::Tcp, s, false)
1778 };
1779
1780 let owned: String;
1788 let parse_target: &str =
1789 if prefix_seen && !rest.is_empty() && rest.bytes().all(|b| b.is_ascii_digit()) {
1790 owned = format!(":{rest}");
1791 &owned
1792 } else {
1793 rest
1794 };
1795
1796 let addrs = parse_listen_address(parse_target, spec)?;
1797 Ok((transport, addrs))
1798}
1799
1800fn parse_listen_address(rest: &str, original: &str) -> Result<Vec<SocketAddr>, Error> {
1804 if rest == ":0" || rest == "*:0" {
1806 return Err(Error::compile(format!("wildcard port rejected: {original:?}")));
1807 }
1808 if let Some(port_str) = rest.strip_prefix(':').or_else(|| rest.strip_prefix("*:")) {
1810 let port = u16::from_str(port_str)
1811 .map_err(|e| Error::compile(format!("bad port in {original:?}: {e}")))?;
1812 return Ok(vec![
1813 SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), port),
1814 SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), port),
1815 ]);
1816 }
1817 SocketAddr::from_str(rest)
1818 .map(|a| vec![a])
1819 .map_err(|e| Error::compile(format!("bad listen spec {original:?}: {e}")))
1820}
1821
1822fn compile_operator(
1823 op: &Operator,
1824 path: &FieldPath,
1825 source: &SourceInfo,
1826) -> Result<CompiledOperator, Error> {
1827 let family = op.family();
1833 let vt = path.value_type();
1834 if !family.accepts(vt) {
1835 return Err(Error::compile(format!(
1836 "{}operator `{}` cannot apply to field `{}` (expected {}, got {})",
1837 source_prefix(source),
1838 op.name(),
1839 path.display_name(),
1840 family.family_expectation(),
1841 vt.name(),
1842 )));
1843 }
1844
1845 Ok(match op {
1846 Operator::Equals(v) => CompiledOperator::Equals(coerce_value(v, path, op.name(), source)?),
1847 Operator::NotEquals(v) => {
1848 CompiledOperator::NotEquals(coerce_value(v, path, op.name(), source)?)
1849 }
1850 Operator::Contains(v) => {
1851 CompiledOperator::Contains(value_to_bytes(v, path, op.name(), source)?)
1852 }
1853 Operator::NotContains(v) => {
1854 CompiledOperator::NotContains(value_to_bytes(v, path, op.name(), source)?)
1855 }
1856 Operator::Prefix(v) => CompiledOperator::Prefix(value_to_bytes(v, path, op.name(), source)?),
1857 Operator::Suffix(v) => CompiledOperator::Suffix(value_to_bytes(v, path, op.name(), source)?),
1858 Operator::Matches(pat) => CompiledOperator::Matches(compile_matches_regex(pat, path, source)?),
1859 Operator::In(vs) => {
1860 let mut out = Vec::with_capacity(vs.len());
1861 for v in vs {
1862 out.push(coerce_value(v, path, op.name(), source)?);
1863 }
1864 CompiledOperator::In(out)
1865 }
1866 Operator::NotIn(vs) => {
1867 let mut out = Vec::with_capacity(vs.len());
1868 for v in vs {
1869 out.push(coerce_value(v, path, op.name(), source)?);
1870 }
1871 CompiledOperator::NotIn(out)
1872 }
1873 Operator::Gt(n) => CompiledOperator::Gt(*n),
1874 Operator::Gte(n) => CompiledOperator::Gte(*n),
1875 Operator::Lt(n) => CompiledOperator::Lt(*n),
1876 Operator::Lte(n) => CompiledOperator::Lte(*n),
1877 Operator::Cidr(s) => CompiledOperator::Cidr(ipnet::IpNet::from_str(s).map_err(|e| {
1878 Error::compile(format!(
1879 "{}invalid cidr `{s}` on field `{}`: {e}",
1880 source_prefix(source),
1881 path.display_name(),
1882 ))
1883 })?),
1884 })
1885}
1886
1887fn coerce_value(
1888 v: &Value,
1889 path: &FieldPath,
1890 op_name: &'static str,
1891 source: &SourceInfo,
1892) -> Result<CompiledValue, Error> {
1893 let mismatch = || {
1894 Error::compile(format!(
1895 "{}field `{}` ({}) is not compatible with `{op_name}` value {}",
1896 source_prefix(source),
1897 path.display_name(),
1898 path.value_type().name(),
1899 value_kind(v),
1900 ))
1901 };
1902 match path.value_type() {
1903 FieldValueType::IpAddr => {
1904 let Value::Str(s) = v else {
1905 return Err(mismatch());
1906 };
1907 IpAddr::from_str(s).map(CompiledValue::Addr).map_err(|e| {
1908 Error::compile(format!(
1909 "{}field `{}` expects an ip-address string, got {s:?}: {e}",
1910 source_prefix(source),
1911 path.display_name(),
1912 ))
1913 })
1914 }
1915 FieldValueType::Int => {
1916 let Value::Int(n) = v else {
1917 return Err(mismatch());
1918 };
1919 Ok(CompiledValue::Int(*n))
1920 }
1921 FieldValueType::Bytes => {
1922 let Value::Str(s) = v else {
1927 return Err(mismatch());
1928 };
1929 let decoded = B64.decode(s.as_bytes()).map_err(|e| {
1930 Error::compile(format!(
1931 "{}operator `{op_name}` on field `{}` expected base64 string: {e}",
1932 source_prefix(source),
1933 path.display_name(),
1934 ))
1935 })?;
1936 Ok(CompiledValue::Bytes(bytes::Bytes::from(decoded)))
1937 }
1938 FieldValueType::Str => {
1939 let Value::Str(s) = v else {
1940 return Err(mismatch());
1941 };
1942 ensure_sni_ascii_lowercase(path, s, op_name, source)?;
1943 Ok(CompiledValue::Str(Arc::from(s.as_str())))
1944 }
1945 FieldValueType::Enum => {
1946 let Value::Str(s) = v else {
1947 return Err(mismatch());
1948 };
1949 coerce_enum_value(path, s, source)
1950 }
1951 FieldValueType::Bool => {
1952 let Value::Bool(b) = v else {
1953 return Err(mismatch());
1954 };
1955 Ok(CompiledValue::Bool(*b))
1956 }
1957 FieldValueType::VecStr => Err(Error::compile(format!(
1965 "{}field `{}` ({}) cannot be operand-coerced — only `contains` / `not_contains` apply to Vec<Str>",
1966 source_prefix(source),
1967 path.display_name(),
1968 path.value_type().name(),
1969 ))),
1970 }
1971}
1972
1973fn coerce_enum_value(
1974 path: &FieldPath,
1975 s: &str,
1976 source: &SourceInfo,
1977) -> Result<CompiledValue, Error> {
1978 let allowed: Option<&[&str]> = match path {
1979 FieldPath::Transport => Some(&["tcp", "udp"]),
1980 FieldPath::TlsVersion => Some(&["1.2", "1.3"]),
1981 FieldPath::HttpMethod => None,
1984 _ => unreachable!("non-enum path reached coerce_enum_value: {path:?}"),
1985 };
1986 if let Some(values) = allowed
1987 && !values.contains(&s)
1988 {
1989 return Err(Error::compile(format!(
1990 "{}field `{}` accepts {:?}, got {s:?}",
1991 source_prefix(source),
1992 path.display_name(),
1993 values,
1994 )));
1995 }
1996 Ok(CompiledValue::Str(Arc::from(s)))
1997}
1998
1999fn value_to_bytes(
2000 v: &Value,
2001 path: &FieldPath,
2002 op_name: &'static str,
2003 source: &SourceInfo,
2004) -> Result<bytes::Bytes, Error> {
2005 match v {
2014 Value::Str(s) => {
2015 if path.value_type() == FieldValueType::Bytes {
2016 B64.decode(s.as_bytes()).map(bytes::Bytes::from).map_err(|e| {
2017 Error::compile(format!(
2018 "{}operator `{op_name}` on field `{}` expected base64 string: {e}",
2019 source_prefix(source),
2020 path.display_name(),
2021 ))
2022 })
2023 } else {
2024 ensure_sni_ascii_lowercase(path, s, op_name, source)?;
2025 Ok(bytes::Bytes::copy_from_slice(s.as_bytes()))
2026 }
2027 }
2028 Value::Int(_) | Value::Bool(_) => Err(Error::compile(format!(
2029 "{}operator `{op_name}` on field `{}` expects a string value, got {}",
2030 source_prefix(source),
2031 path.display_name(),
2032 value_kind(v),
2033 ))),
2034 }
2035}
2036
2037fn validate_unique_body_reader_per_path(
2046 nodes: &[Node],
2047 entries: &std::collections::HashMap<SocketAddr, NodeId>,
2048) -> Result<(), Error> {
2049 use std::collections::HashSet;
2050 let mut visited: HashSet<(u32, u8, u8)> = HashSet::new();
2051 for &entry in entries.values() {
2052 let mut stack: Vec<(NodeId, u8, u8)> = vec![(entry, 0, 0)];
2053 while let Some((cur, req, resp)) = stack.pop() {
2054 if !visited.insert((cur.get(), req, resp)) {
2055 continue;
2056 }
2057 let idx = cur.get() as usize;
2058 let (own_req, own_resp) = match &nodes[idx] {
2059 Node::Middleware { collect_body_before, .. }
2060 | Node::Check { collect_body_before, .. }
2061 | Node::Fetch { collect_body_before, .. } => match collect_body_before {
2062 Some(BodySide::Request) => (1u8, 0u8),
2063 Some(BodySide::Response) => (0u8, 1u8),
2064 None => (0, 0),
2065 },
2066 Node::Upgrade { .. } | Node::Terminate(_) => (0, 0),
2067 };
2068 let new_req = req.saturating_add(own_req).min(2);
2069 let new_resp = resp.saturating_add(own_resp).min(2);
2070 if new_req > 1 {
2071 return Err(Error::compile(format!(
2072 "node {idx}: path through listener entry has more than one collect_body_before=Some(Request)",
2073 )));
2074 }
2075 if new_resp > 1 {
2076 return Err(Error::compile(format!(
2077 "node {idx}: path through listener entry has more than one collect_body_before=Some(Response)",
2078 )));
2079 }
2080 match &nodes[idx] {
2081 Node::Middleware { next, on_error, .. } => {
2082 stack.push((*next, new_req, new_resp));
2083 if let Some(eid) = on_error {
2084 stack.push((*eid, new_req, new_resp));
2085 }
2086 }
2087 Node::Check { on_match, on_miss, .. } => {
2088 stack.push((*on_match, new_req, new_resp));
2089 stack.push((*on_miss, new_req, new_resp));
2090 }
2091 Node::Fetch { next_response, next_tunnel, .. } => {
2092 if let Some(n) = next_response {
2093 stack.push((*n, new_req, new_resp));
2094 }
2095 if let Some(t) = next_tunnel {
2096 stack.push((*t, new_req, new_resp));
2097 }
2098 }
2099 Node::Upgrade { next } => stack.push((*next, new_req, new_resp)),
2100 Node::Terminate(_) => {}
2101 }
2102 }
2103 }
2104 Ok(())
2105}
2106
2107fn compile_matches_regex(
2121 pat: &str,
2122 path: &FieldPath,
2123 source: &SourceInfo,
2124) -> Result<fancy_regex::Regex, Error> {
2125 use crate::predicate::{
2126 REGEX_BACKTRACK_LIMIT, REGEX_DELEGATE_SIZE_LIMIT, REGEX_SMOKE_TEST_INPUT_LEN,
2127 };
2128 let re = fancy_regex::RegexBuilder::new(pat)
2129 .backtrack_limit(REGEX_BACKTRACK_LIMIT)
2130 .delegate_size_limit(REGEX_DELEGATE_SIZE_LIMIT)
2131 .build()
2132 .map_err(|e| {
2133 Error::compile(format!(
2134 "{}invalid regex in `matches` operator on field `{}`: {e}",
2135 source_prefix(source),
2136 path.display_name(),
2137 ))
2138 })?;
2139
2140 let probe: String = "a".repeat(REGEX_SMOKE_TEST_INPUT_LEN);
2145 match re.is_match(&probe) {
2146 Ok(_) => Ok(re),
2147 Err(fancy_regex::Error::RuntimeError(fancy_regex::RuntimeError::BacktrackLimitExceeded)) => {
2148 Err(Error::compile(format!(
2149 "{}regex in `matches` on field `{}` exceeded backtrack limit on smoke-test input; refusing to compile to avoid runtime ReDoS",
2150 source_prefix(source),
2151 path.display_name(),
2152 )))
2153 }
2154 Err(e) => Err(Error::compile(format!(
2155 "{}regex in `matches` on field `{}` errored on smoke test: {e}",
2156 source_prefix(source),
2157 path.display_name(),
2158 ))),
2159 }
2160}
2161
2162fn ensure_sni_ascii_lowercase(
2168 path: &FieldPath,
2169 s: &str,
2170 op_name: &'static str,
2171 source: &SourceInfo,
2172) -> Result<(), Error> {
2173 if matches!(path, FieldPath::TlsSni) && s.bytes().any(|b| b.is_ascii_uppercase()) {
2174 return Err(Error::compile(format!(
2175 "{}operator `{op_name}` on field `tls.sni`: operand {s:?} must be ASCII lowercase",
2176 source_prefix(source),
2177 )));
2178 }
2179 Ok(())
2180}
2181
2182fn value_kind(v: &Value) -> &'static str {
2183 match v {
2184 Value::Str(_) => "Str",
2185 Value::Int(_) => "Int",
2186 Value::Bool(_) => "Bool",
2187 }
2188}
2189
2190fn source_prefix(source: &SourceInfo) -> String {
2191 if source.file.as_os_str().is_empty() {
2192 String::new()
2193 } else {
2194 format!("{}:{}: ", source.file.display(), source.line)
2195 }
2196}
2197
2198fn canonical_json(v: &serde_json::Value) -> String {
2203 let mut out = String::new();
2204 crate::canonical::write_into_lossy(&mut out, v);
2205 out
2206}
2207
2208fn hash_rules(rules: &[AnalyzedRule]) -> [u8; 32] {
2219 let mut entries: Vec<serde_json::Value> = rules
2220 .iter()
2221 .map(|rule| {
2222 let mut v = serde_json::to_value(&rule.raw).unwrap_or(serde_json::Value::Null);
2223 if let serde_json::Value::Object(map) = &mut v {
2224 map.remove("source");
2225 }
2226 v
2227 })
2228 .collect();
2229 entries.sort_by(|a, b| {
2230 let an = a.get("name").and_then(serde_json::Value::as_str).unwrap_or("");
2231 let bn = b.get("name").and_then(serde_json::Value::as_str).unwrap_or("");
2232 an.cmp(bn)
2233 });
2234 let mut hasher = Sha256::new();
2235 hasher.update(canonical_json(&serde_json::Value::Array(entries)).as_bytes());
2236 let _ = PathBuf::new;
2237 hasher.finalize().into()
2238}
2239
2240#[cfg(test)]
2241mod compat_tests {
2242 use std::path::PathBuf;
2247 use std::sync::Arc;
2248
2249 use super::{SourceInfo, compile_operator};
2250 use crate::predicate::{FieldPath, Operator, Value};
2251
2252 fn src() -> SourceInfo {
2253 SourceInfo { file: PathBuf::from("rules/30-api.json"), line: 14 }
2254 }
2255
2256 fn assert_rejected_with_source(err: &crate::error::Error) {
2257 let msg = err.to_string();
2258 assert!(msg.contains("rules/30-api.json:14"), "error must carry rule source: {msg}");
2259 }
2260
2261 #[test]
2262 fn gt_on_bytes_field_rejected() {
2263 let err = compile_operator(&Operator::Gt(100), &FieldPath::HttpBody, &src())
2264 .expect_err("gt on http.body must reject");
2265 let msg = err.to_string();
2266 assert!(msg.contains("`gt`"), "{msg}");
2267 assert!(msg.contains("http.body"), "{msg}");
2268 assert!(msg.contains("expected numeric"), "{msg}");
2269 assert_rejected_with_source(&err);
2270 }
2271
2272 #[test]
2273 fn cidr_on_string_field_rejected() {
2274 let err =
2275 compile_operator(&Operator::Cidr("10.0.0.0/8".to_string()), &FieldPath::HttpUriPath, &src())
2276 .expect_err("cidr on http.uri.path must reject");
2277 let msg = err.to_string();
2278 assert!(msg.contains("`cidr`"), "{msg}");
2279 assert!(msg.contains("http.uri.path"), "{msg}");
2280 assert!(msg.contains("expected IpAddr"), "{msg}");
2281 assert_rejected_with_source(&err);
2282 }
2283
2284 #[test]
2285 fn matches_on_bytes_field_rejected() {
2286 let err = compile_operator(&Operator::Matches("^a".to_string()), &FieldPath::TlsAlpn, &src())
2287 .expect_err("matches on tls.alpn must reject");
2288 let msg = err.to_string();
2289 assert!(msg.contains("`matches`"), "{msg}");
2290 assert!(msg.contains("expected Str"), "{msg}");
2291 }
2292
2293 #[test]
2294 fn matches_on_int_field_rejected() {
2295 let err =
2296 compile_operator(&Operator::Matches("^1".to_string()), &FieldPath::RemotePort, &src())
2297 .expect_err("matches on remote.port must reject");
2298 let msg = err.to_string();
2299 assert!(msg.contains("`matches`"), "{msg}");
2300 assert!(msg.contains("expected Str"), "{msg}");
2301 }
2302
2303 #[test]
2304 fn contains_on_int_field_rejected() {
2305 let err = compile_operator(&Operator::Contains(Value::Int(1)), &FieldPath::RemotePort, &src())
2306 .expect_err("contains on remote.port must reject");
2307 let msg = err.to_string();
2308 assert!(msg.contains("`contains`"), "{msg}");
2309 assert!(msg.contains("Str, Bytes, or Vec<Str>"), "{msg}");
2310 }
2311
2312 #[test]
2313 fn prefix_on_ip_field_rejected() {
2314 let err = compile_operator(
2315 &Operator::Prefix(Value::Str("10.".to_string())),
2316 &FieldPath::RemoteIp,
2317 &src(),
2318 )
2319 .expect_err("prefix on remote.ip must reject");
2320 let msg = err.to_string();
2321 assert!(msg.contains("`prefix`"), "{msg}");
2322 assert!(msg.contains("Str or Bytes"), "{msg}");
2323 }
2324
2325 #[test]
2326 fn suffix_on_enum_field_rejected() {
2327 let err = compile_operator(
2328 &Operator::Suffix(Value::Str("p".to_string())),
2329 &FieldPath::Transport,
2330 &src(),
2331 )
2332 .expect_err("suffix on transport must reject");
2333 let msg = err.to_string();
2334 assert!(msg.contains("`suffix`"), "{msg}");
2335 }
2336
2337 #[test]
2338 fn gt_on_str_field_rejected() {
2339 let err = compile_operator(&Operator::Gt(0), &FieldPath::TlsSni, &src())
2340 .expect_err("gt on tls.sni must reject");
2341 assert!(err.to_string().contains("expected numeric"));
2342 }
2343
2344 #[test]
2345 fn cidr_on_int_field_rejected() {
2346 let err =
2347 compile_operator(&Operator::Cidr("10.0.0.0/8".to_string()), &FieldPath::RemotePort, &src())
2348 .expect_err("cidr on remote.port must reject");
2349 assert!(err.to_string().contains("expected IpAddr"));
2350 }
2351
2352 #[test]
2353 fn cidr_on_enum_field_rejected() {
2354 let err =
2355 compile_operator(&Operator::Cidr("0.0.0.0/0".to_string()), &FieldPath::HttpMethod, &src())
2356 .expect_err("cidr on http.method must reject");
2357 assert!(err.to_string().contains("expected IpAddr"));
2358 }
2359
2360 #[test]
2361 fn cidr_on_bytes_field_rejected() {
2362 let err =
2363 compile_operator(&Operator::Cidr("10.0.0.0/8".to_string()), &FieldPath::TlsAlpn, &src())
2364 .expect_err("cidr on tls.alpn must reject");
2365 assert!(err.to_string().contains("expected IpAddr"));
2366 }
2367
2368 #[test]
2369 fn substring_on_ip_field_rejected() {
2370 let err = compile_operator(
2371 &Operator::Contains(Value::Str("10.".to_string())),
2372 &FieldPath::RemoteIp,
2373 &src(),
2374 )
2375 .expect_err("contains on remote.ip must reject");
2376 assert!(err.to_string().contains("`contains`"));
2377 assert!(err.to_string().contains("Str, Bytes, or Vec<Str>"));
2378 }
2379
2380 #[test]
2381 fn substring_on_enum_field_rejected() {
2382 let err = compile_operator(
2383 &Operator::NotContains(Value::Str("p".to_string())),
2384 &FieldPath::Transport,
2385 &src(),
2386 )
2387 .expect_err("not_contains on transport must reject");
2388 assert!(err.to_string().contains("`not_contains`"));
2389 }
2390
2391 #[test]
2392 fn prefix_suffix_on_int_field_rejected() {
2393 let err = compile_operator(
2394 &Operator::Prefix(Value::Str("80".to_string())),
2395 &FieldPath::RemotePort,
2396 &src(),
2397 )
2398 .expect_err("prefix on remote.port must reject");
2399 assert!(err.to_string().contains("`prefix`"));
2400 assert!(err.to_string().contains("Str or Bytes"));
2401 }
2402
2403 #[test]
2404 fn matches_on_ip_field_rejected() {
2405 let err = compile_operator(&Operator::Matches("^10".to_string()), &FieldPath::RemoteIp, &src())
2406 .expect_err("matches on remote.ip must reject");
2407 let msg = err.to_string();
2408 assert!(msg.contains("`matches`"), "{msg}");
2409 assert!(msg.contains("expected Str"), "{msg}");
2410 }
2411
2412 #[test]
2413 fn matches_on_enum_field_rejected() {
2414 let err = compile_operator(&Operator::Matches("^t".to_string()), &FieldPath::Transport, &src())
2415 .expect_err("matches on transport must reject");
2416 assert!(err.to_string().contains("expected Str"));
2417 }
2418
2419 #[test]
2420 fn numeric_cmp_on_ip_field_rejected() {
2421 let err = compile_operator(&Operator::Lt(0), &FieldPath::RemoteIp, &src())
2422 .expect_err("lt on remote.ip must reject");
2423 assert!(err.to_string().contains("expected numeric"));
2424 }
2425
2426 #[test]
2427 fn numeric_cmp_on_enum_field_rejected() {
2428 let err = compile_operator(&Operator::Gte(0), &FieldPath::TlsVersion, &src())
2429 .expect_err("gte on tls.version must reject");
2430 assert!(err.to_string().contains("expected numeric"));
2431 }
2432
2433 #[test]
2434 fn invalid_regex_carries_source_and_field() {
2435 let err =
2436 compile_operator(&Operator::Matches("[".to_string()), &FieldPath::HttpUriPath, &src())
2437 .expect_err("unbalanced [ must reject");
2438 let msg = err.to_string();
2439 assert!(msg.contains("rules/30-api.json:14"), "{msg}");
2440 assert!(msg.contains("`matches`"), "{msg}");
2441 assert!(msg.contains("http.uri.path"), "{msg}");
2442 }
2443
2444 #[test]
2445 fn redos_pattern_rejected_at_compile_via_backtrack_smoke_test() {
2446 let err = compile_operator(
2452 &Operator::Matches("^(a+)*b\\1$".to_string()),
2453 &FieldPath::HttpUriPath,
2454 &src(),
2455 )
2456 .expect_err("redos pattern must reject");
2457 let msg = err.to_string();
2458 assert!(msg.contains("backtrack"), "error mentions backtrack limit: {msg}");
2459 assert!(msg.contains("http.uri.path"), "{msg}");
2460 }
2461
2462 #[test]
2463 fn well_behaved_regex_passes_smoke_test() {
2464 let op = compile_operator(
2466 &Operator::Matches("^(api|web|static)/[a-z0-9-]+$".to_string()),
2467 &FieldPath::HttpUriPath,
2468 &src(),
2469 )
2470 .expect("well-behaved regex compiles");
2471 match op {
2472 crate::predicate::CompiledOperator::Matches(_) => {}
2473 other => panic!("expected Matches, got {other:?}"),
2474 }
2475 }
2476
2477 #[test]
2478 fn transport_enum_rejects_unknown_literal() {
2479 let err = compile_operator(
2480 &Operator::Equals(Value::Str("ftp".to_string())),
2481 &FieldPath::Transport,
2482 &src(),
2483 )
2484 .expect_err("transport == \"ftp\" must reject");
2485 let msg = err.to_string();
2486 assert!(msg.contains("transport"), "{msg}");
2487 assert!(msg.contains("\"ftp\""), "{msg}");
2488 }
2489
2490 #[test]
2491 fn transport_enum_accepts_known_literals() {
2492 for v in ["tcp", "udp"] {
2493 compile_operator(&Operator::Equals(Value::Str(v.to_string())), &FieldPath::Transport, &src())
2494 .unwrap_or_else(|e| panic!("transport == {v:?} must compile: {e}"));
2495 }
2496 }
2497
2498 #[test]
2499 fn tls_version_enum_rejects_unknown_literal() {
2500 let err = compile_operator(
2501 &Operator::Equals(Value::Str("0.9".to_string())),
2502 &FieldPath::TlsVersion,
2503 &src(),
2504 )
2505 .expect_err("tls.version == \"0.9\" must reject");
2506 assert!(err.to_string().contains("tls.version"));
2507 }
2508
2509 #[test]
2510 fn http_method_enum_accepts_any_string() {
2511 compile_operator(
2514 &Operator::Equals(Value::Str("CONNECT".to_string())),
2515 &FieldPath::HttpMethod,
2516 &src(),
2517 )
2518 .expect("http.method == CONNECT must compile");
2519 }
2520
2521 #[test]
2522 fn equals_int_value_on_string_field_rejected() {
2523 let err = compile_operator(&Operator::Equals(Value::Int(1)), &FieldPath::TlsSni, &src())
2527 .expect_err("equals(int) on str field must reject");
2528 let msg = err.to_string();
2529 assert!(msg.contains("tls.sni"), "{msg}");
2530 assert!(msg.contains("Str"), "{msg}");
2531 }
2532
2533 #[test]
2534 fn in_list_with_mixed_types_rejected_on_int_field() {
2535 let err = compile_operator(
2536 &Operator::In(vec![Value::Int(1), Value::Str("x".to_string())]),
2537 &FieldPath::RemotePort,
2538 &src(),
2539 )
2540 .expect_err("in([int,str]) on int field must reject");
2541 assert!(err.to_string().contains("remote.port"));
2542 }
2543
2544 #[test]
2545 fn empty_source_info_omits_prefix() {
2546 let empty = SourceInfo::default();
2549 let err = compile_operator(&Operator::Gt(100), &FieldPath::HttpBody, &empty)
2550 .expect_err("gt on http.body must reject");
2551 let msg = err.to_string();
2552 assert!(!msg.contains(":0:"), "default source must not leak `:0:` prefix: {msg}");
2553 }
2554
2555 #[test]
2556 fn arc_compiled_for_str_field_uses_string_arc() {
2557 let op =
2560 compile_operator(&Operator::Equals(Value::Str("x".to_string())), &FieldPath::TlsSni, &src())
2561 .expect("legal equals/str compiles");
2562 match op {
2563 crate::predicate::CompiledOperator::Equals(crate::predicate::CompiledValue::Str(arc)) => {
2564 let _: Arc<str> = arc;
2565 }
2566 other => panic!("unexpected compiled op: {other:?}"),
2567 }
2568 }
2569
2570 #[test]
2576 fn bytes_literal_decoded_as_base64_for_contains_on_http_body() {
2577 let op = compile_operator(
2578 &Operator::Contains(Value::Str("aGVsbG8=".to_string())),
2579 &FieldPath::HttpBody,
2580 &src(),
2581 )
2582 .expect("base64 contains compiles");
2583 match op {
2584 crate::predicate::CompiledOperator::Contains(b) => {
2585 assert_eq!(b.as_ref(), b"hello", "base64 'aGVsbG8=' must decode to 'hello'");
2586 }
2587 other => panic!("expected Contains, got {other:?}"),
2588 }
2589 }
2590
2591 #[test]
2594 fn bytes_literal_decoded_as_base64_for_equals_on_tls_alpn() {
2595 let op = compile_operator(
2596 &Operator::Equals(Value::Str("aDI=".to_string())),
2597 &FieldPath::TlsAlpn,
2598 &src(),
2599 )
2600 .expect("base64 equals compiles");
2601 match op {
2602 crate::predicate::CompiledOperator::Equals(crate::predicate::CompiledValue::Bytes(b)) => {
2603 assert_eq!(b.as_ref(), b"h2", "base64 'aDI=' must decode to 'h2'");
2604 }
2605 other => panic!("expected Equals(Bytes(\"h2\")), got {other:?}"),
2606 }
2607 }
2608
2609 #[test]
2614 fn bytes_field_prefix_suffix_decodes_base64() {
2615 let prefix =
2617 compile_operator(&Operator::Prefix(Value::Str("FgM=".to_string())), &FieldPath::Peek, &src())
2618 .expect("peek prefix compiles");
2619 match prefix {
2620 crate::predicate::CompiledOperator::Prefix(b) => assert_eq!(b.as_ref(), &[0x16, 0x03]),
2621 other => panic!("expected Prefix, got {other:?}"),
2622 }
2623
2624 let suffix = compile_operator(
2626 &Operator::Suffix(Value::Str("RU5E".to_string())),
2627 &FieldPath::HttpBody,
2628 &src(),
2629 )
2630 .expect("body suffix compiles");
2631 match suffix {
2632 crate::predicate::CompiledOperator::Suffix(b) => assert_eq!(b.as_ref(), b"END"),
2633 other => panic!("expected Suffix, got {other:?}"),
2634 }
2635 }
2636
2637 #[test]
2641 fn str_field_prefix_suffix_keeps_raw_bytes() {
2642 let prefix = compile_operator(
2643 &Operator::Prefix(Value::Str("/api".to_string())),
2644 &FieldPath::HttpUriPath,
2645 &src(),
2646 )
2647 .expect("str-field prefix compiles verbatim");
2648 match prefix {
2649 crate::predicate::CompiledOperator::Prefix(b) => assert_eq!(b.as_ref(), b"/api"),
2650 other => panic!("expected Prefix, got {other:?}"),
2651 }
2652
2653 let suffix = compile_operator(
2654 &Operator::Suffix(Value::Str(".json".to_string())),
2655 &FieldPath::HttpUriPath,
2656 &src(),
2657 )
2658 .expect("str-field suffix compiles verbatim");
2659 match suffix {
2660 crate::predicate::CompiledOperator::Suffix(b) => assert_eq!(b.as_ref(), b".json"),
2661 other => panic!("expected Suffix, got {other:?}"),
2662 }
2663 }
2664
2665 #[test]
2668 fn bytes_literal_rejects_non_base64_with_source_prefix() {
2669 let err = compile_operator(
2670 &Operator::Contains(Value::Str("###".to_string())),
2671 &FieldPath::HttpBody,
2672 &src(),
2673 )
2674 .expect_err("non-base64 contains must reject");
2675 let msg = err.to_string();
2676 assert!(msg.contains("rules/30-api.json:14"), "error must carry source: {msg}");
2677 assert!(msg.contains("`contains`"), "{msg}");
2678 assert!(msg.contains("http.body"), "{msg}");
2679 assert!(msg.contains("expected base64 string"), "{msg}");
2680 }
2681
2682 #[test]
2685 fn bytes_literal_equals_rejects_non_base64() {
2686 let err = compile_operator(
2687 &Operator::Equals(Value::Str("not-valid-base64!".to_string())),
2688 &FieldPath::TlsAlpn,
2689 &src(),
2690 )
2691 .expect_err("non-base64 equals must reject");
2692 let msg = err.to_string();
2693 assert!(msg.contains("expected base64 string"), "{msg}");
2694 assert!(msg.contains("tls.alpn"), "{msg}");
2695 }
2696
2697 #[test]
2702 fn tls_sni_rejects_uppercase_ascii_in_equals() {
2703 let err = compile_operator(
2704 &Operator::Equals(Value::Str("Example.com".to_string())),
2705 &FieldPath::TlsSni,
2706 &src(),
2707 )
2708 .expect_err("uppercase tls.sni equals must reject");
2709 let msg = err.to_string();
2710 assert!(msg.contains("tls.sni"), "{msg}");
2711 assert!(msg.contains("ASCII lowercase"), "{msg}");
2712 }
2713
2714 #[test]
2715 fn tls_sni_rejects_uppercase_ascii_in_contains_prefix_suffix_in() {
2716 for op in [
2717 Operator::Contains(Value::Str("A".to_string())),
2718 Operator::NotContains(Value::Str("B".to_string())),
2719 Operator::Prefix(Value::Str("Api.".to_string())),
2720 Operator::Suffix(Value::Str(".CoM".to_string())),
2721 Operator::In(vec![Value::Str("ok.example.com".to_string()), Value::Str("X.com".to_string())]),
2722 Operator::NotIn(vec![Value::Str("Bad.com".to_string())]),
2723 ] {
2724 let err = compile_operator(&op, &FieldPath::TlsSni, &src())
2725 .expect_err("uppercase tls.sni operand must reject");
2726 let msg = err.to_string();
2727 assert!(msg.contains("tls.sni"), "{msg}");
2728 assert!(msg.contains("ASCII lowercase"), "{msg}");
2729 }
2730 }
2731
2732 #[test]
2733 fn tls_sni_accepts_lowercase_and_non_ascii_punycode() {
2734 compile_operator(
2736 &Operator::Equals(Value::Str("api.example.com".to_string())),
2737 &FieldPath::TlsSni,
2738 &src(),
2739 )
2740 .expect("lowercase tls.sni equals must compile");
2741 compile_operator(
2743 &Operator::Equals(Value::Str("xn--bcher-kva.example".to_string())),
2744 &FieldPath::TlsSni,
2745 &src(),
2746 )
2747 .expect("punycode tls.sni equals must compile");
2748 }
2749
2750 #[test]
2751 fn tls_sni_lowercase_invariant_on_compiled_values() {
2752 use crate::predicate::{CompiledOperator, CompiledValue};
2756
2757 fn check_bytes(b: &bytes::Bytes) {
2758 assert!(
2759 !b.iter().any(u8::is_ascii_uppercase),
2760 "tls.sni CompiledValue::Bytes must be ASCII lowercase, got {b:?}"
2761 );
2762 }
2763 fn check_value(v: &CompiledValue) {
2764 match v {
2765 CompiledValue::Str(s) => {
2766 assert!(
2767 !s.bytes().any(|b| b.is_ascii_uppercase()),
2768 "tls.sni CompiledValue::Str must be ASCII lowercase, got {s:?}"
2769 );
2770 }
2771 CompiledValue::Bytes(b) => check_bytes(b),
2772 other => panic!("tls.sni produced non-Str/Bytes CompiledValue: {other:?}"),
2773 }
2774 }
2775
2776 let legal = [
2777 Operator::Equals(Value::Str("a.example.com".to_string())),
2778 Operator::NotEquals(Value::Str("b.example.com".to_string())),
2779 Operator::Contains(Value::Str("api".to_string())),
2780 Operator::NotContains(Value::Str("internal".to_string())),
2781 Operator::Prefix(Value::Str("api.".to_string())),
2782 Operator::Suffix(Value::Str(".example.com".to_string())),
2783 Operator::In(vec![
2784 Value::Str("a.example.com".to_string()),
2785 Value::Str("b.example.com".to_string()),
2786 ]),
2787 Operator::NotIn(vec![Value::Str("c.example.com".to_string())]),
2788 ];
2789 for op in &legal {
2790 let compiled =
2791 compile_operator(op, &FieldPath::TlsSni, &src()).expect("legal tls.sni op compiles");
2792 match compiled {
2793 CompiledOperator::Equals(v) | CompiledOperator::NotEquals(v) => check_value(&v),
2794 CompiledOperator::Contains(b)
2795 | CompiledOperator::NotContains(b)
2796 | CompiledOperator::Prefix(b)
2797 | CompiledOperator::Suffix(b) => check_bytes(&b),
2798 CompiledOperator::In(vs) | CompiledOperator::NotIn(vs) => {
2799 for v in &vs {
2800 check_value(v);
2801 }
2802 }
2803 other => panic!("unexpected compiled op for tls.sni: {other:?}"),
2804 }
2805 }
2806 }
2807
2808 #[test]
2811 fn parse_and_lower_spec_example_decodes_base64_contains() {
2812 let raw = serde_json::json!({ "http.body": { "contains": "aGVsbG8=" } });
2816 let pred: crate::predicate::Predicate = serde_json::from_value(raw).expect("parse predicate");
2817 let check = match pred {
2818 crate::predicate::Predicate::Check(c) => c,
2819 other => panic!("expected Check, got {other:?}"),
2820 };
2821 let op = compile_operator(&check.op, &check.path, &src()).expect("lower");
2822 match op {
2823 crate::predicate::CompiledOperator::Contains(b) => assert_eq!(b.as_ref(), b"hello"),
2824 other => panic!("expected Contains, got {other:?}"),
2825 }
2826 }
2827}