1use std::collections::HashMap;
2use std::net::SocketAddr;
3use std::ops::Index;
4use std::path::PathBuf;
5use std::time::SystemTime;
6
7use crate::conn_context::Transport;
8use crate::fetch::{SymbolicFetchRef, Terminator};
9use crate::middleware::SymbolicMiddlewareRef;
10use crate::predicate::PredicateInst;
11
12macro_rules! id_newtype {
13 ($name:ident) => {
14 #[derive(
15 Copy, Clone, Eq, PartialEq, Hash, Debug, PartialOrd, Ord, serde::Serialize, serde::Deserialize,
16 )]
17 pub struct $name(u32);
18
19 impl $name {
20 #[must_use]
27 pub(crate) const fn new(raw: u32) -> Self {
28 Self(raw)
29 }
30
31 #[cfg(any(test, feature = "test-support"))]
39 #[must_use]
40 pub const fn for_testing(raw: u32) -> Self {
41 Self(raw)
42 }
43
44 #[must_use]
45 pub const fn get(self) -> u32 {
46 self.0
47 }
48 }
49 };
50}
51
52id_newtype!(NodeId);
53id_newtype!(PredicateId);
54id_newtype!(MiddlewareId);
55id_newtype!(FetchId);
56id_newtype!(TerminatorId);
57
58#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug, serde::Serialize, serde::Deserialize)]
59pub enum BodySide {
60 Request,
61 Response,
62}
63
64#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug, serde::Serialize, serde::Deserialize)]
70pub enum ListenerKind {
71 Raw,
75 Http,
79 Auto,
83}
84
85#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
86pub enum Node {
87 Check {
88 predicate: PredicateId,
89 on_match: NodeId,
90 on_miss: NodeId,
91 collect_body_before: Option<BodySide>,
92 #[serde(default)]
93 body_limit: usize,
94 },
95 Middleware {
96 id: MiddlewareId,
97 next: NodeId,
98 on_error: Option<NodeId>,
99 collect_body_before: Option<BodySide>,
100 #[serde(default)]
101 body_limit: usize,
102 },
103 Fetch {
104 id: FetchId,
105 next_response: Option<NodeId>,
106 next_tunnel: Option<NodeId>,
107 collect_body_before: Option<BodySide>,
108 #[serde(default)]
109 body_limit: usize,
110 },
111 Upgrade {
112 next: NodeId,
113 },
114 Terminate(TerminatorId),
115}
116
117impl Node {
118 #[must_use]
119 pub const fn collect_body_before(&self) -> Option<BodySide> {
120 match self {
121 Self::Check { collect_body_before, .. }
122 | Self::Middleware { collect_body_before, .. }
123 | Self::Fetch { collect_body_before, .. } => *collect_body_before,
124 Self::Upgrade { .. } | Self::Terminate(_) => None,
125 }
126 }
127
128 #[must_use]
129 pub const fn body_limit(&self) -> usize {
130 match self {
131 Self::Check { body_limit, .. }
132 | Self::Middleware { body_limit, .. }
133 | Self::Fetch { body_limit, .. } => *body_limit,
134 _ => 0,
135 }
136 }
137}
138
139#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
140pub struct FlowGraphMeta {
141 pub version_hash: [u8; 32],
142 pub compiled_at: SystemTime,
143 pub source_files: Vec<PathBuf>,
144 #[serde(skip, default = "empty_feature_set")]
148 pub feature_set: &'static [&'static str],
149
150 #[serde(default)]
165 pub short_circuit_response_entry: std::collections::BTreeMap<NodeId, NodeId>,
166
167 #[serde(default)]
178 pub listener_tls: std::collections::BTreeMap<SocketAddr, crate::rule::ListenerTlsSpec>,
179
180 #[serde(default)]
188 pub listener_kinds: std::collections::BTreeMap<SocketAddr, ListenerKind>,
189
190 #[serde(default)]
200 pub listener_transports: std::collections::BTreeMap<SocketAddr, Transport>,
201
202 #[serde(default)]
211 pub annotations: Vec<DryRunAnnotation>,
212}
213
214#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
225pub struct DryRunAnnotation {
226 pub kind: String,
230 pub message: String,
233}
234
235const fn empty_feature_set() -> &'static [&'static str] {
236 &[]
237}
238
239#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
240pub struct SymbolicFlowGraph {
241 pub nodes: Vec<Node>,
242 pub predicates: Vec<PredicateInst>,
243 pub middlewares: Vec<SymbolicMiddlewareRef>,
244 pub fetches: Vec<SymbolicFetchRef>,
245 pub terminators: Vec<Terminator>,
246 pub entries: HashMap<SocketAddr, NodeId>,
247 pub meta: FlowGraphMeta,
248}
249
250impl Index<NodeId> for SymbolicFlowGraph {
251 type Output = Node;
252 fn index(&self, id: NodeId) -> &Node {
253 &self.nodes[id.get() as usize]
254 }
255}
256
257impl Index<PredicateId> for SymbolicFlowGraph {
258 type Output = PredicateInst;
259 fn index(&self, id: PredicateId) -> &PredicateInst {
260 &self.predicates[id.get() as usize]
261 }
262}
263
264impl Index<MiddlewareId> for SymbolicFlowGraph {
265 type Output = SymbolicMiddlewareRef;
266 fn index(&self, id: MiddlewareId) -> &SymbolicMiddlewareRef {
267 &self.middlewares[id.get() as usize]
268 }
269}
270
271impl Index<FetchId> for SymbolicFlowGraph {
272 type Output = SymbolicFetchRef;
273 fn index(&self, id: FetchId) -> &SymbolicFetchRef {
274 &self.fetches[id.get() as usize]
275 }
276}
277
278impl Index<TerminatorId> for SymbolicFlowGraph {
279 type Output = Terminator;
280 fn index(&self, id: TerminatorId) -> &Terminator {
281 &self.terminators[id.get() as usize]
282 }
283}
284
285#[cfg(test)]
286mod tests {
287 use std::collections::hash_map::DefaultHasher;
288 use std::hash::{Hash, Hasher};
289 use std::sync::Arc;
290
291 use serde_json::Value;
292
293 use super::*;
294 use crate::fetch::{FetchKind, SymbolicFetchRef, Terminator};
295 use crate::middleware::{MiddlewareKind, SymbolicMiddlewareRef};
296 use crate::predicate::{CompiledOperator, CompiledValue, FieldPath, PredicateInst};
297
298 #[test]
299 fn new_then_get_round_trips_raw_u32() {
300 for raw in [0_u32, 1, 42, u32::MAX] {
301 assert_eq!(NodeId::new(raw).get(), raw);
302 }
303 }
304
305 #[test]
306 fn node_id_equality_is_structural() {
307 assert_eq!(NodeId::new(7), NodeId::new(7));
308 assert_ne!(NodeId::new(7), NodeId::new(8));
309 }
310
311 #[test]
312 fn node_id_ordering_follows_raw_u32() {
313 assert!(NodeId::new(1) < NodeId::new(2));
314 assert!(NodeId::new(u32::MAX) > NodeId::new(0));
315 }
316
317 #[test]
318 fn node_id_serde_round_trip() {
319 let id = NodeId::new(0x0bad_f00d);
320 let encoded = serde_json::to_string(&id).expect("serialize");
321 let decoded: NodeId = serde_json::from_str(&encoded).expect("deserialize");
322 assert_eq!(decoded, id);
323 }
324
325 #[test]
326 fn body_side_serde_round_trip_per_variant() {
327 for s in [BodySide::Request, BodySide::Response] {
328 let encoded = serde_json::to_string(&s).expect("serialize");
329 let decoded: BodySide = serde_json::from_str(&encoded).expect("deserialize");
330 assert_eq!(decoded, s);
331 }
332 }
333
334 fn hash_of<T: Hash>(t: &T) -> u64 {
335 let mut h = DefaultHasher::new();
336 t.hash(&mut h);
337 h.finish()
338 }
339
340 #[test]
341 fn predicate_id_new_get_round_trip_and_hash_eq() {
342 for raw in [0_u32, 1, 42, u32::MAX] {
343 let a = PredicateId::new(raw);
344 let b = PredicateId::new(raw);
345 assert_eq!(a.get(), raw);
346 assert_eq!(a, b);
347 assert_eq!(hash_of(&a), hash_of(&b));
348 let encoded = serde_json::to_string(&a).expect("serialize");
349 let decoded: PredicateId = serde_json::from_str(&encoded).expect("deserialize");
350 assert_eq!(decoded, a);
351 }
352 }
353
354 #[test]
355 fn middleware_id_new_get_round_trip_and_hash_eq() {
356 for raw in [0_u32, 1, 42, u32::MAX] {
357 let a = MiddlewareId::new(raw);
358 let b = MiddlewareId::new(raw);
359 assert_eq!(a.get(), raw);
360 assert_eq!(a, b);
361 assert_eq!(hash_of(&a), hash_of(&b));
362 let encoded = serde_json::to_string(&a).expect("serialize");
363 let decoded: MiddlewareId = serde_json::from_str(&encoded).expect("deserialize");
364 assert_eq!(decoded, a);
365 }
366 }
367
368 #[test]
369 fn fetch_id_new_get_round_trip_and_hash_eq() {
370 for raw in [0_u32, 1, 42, u32::MAX] {
371 let a = FetchId::new(raw);
372 let b = FetchId::new(raw);
373 assert_eq!(a.get(), raw);
374 assert_eq!(a, b);
375 assert_eq!(hash_of(&a), hash_of(&b));
376 let encoded = serde_json::to_string(&a).expect("serialize");
377 let decoded: FetchId = serde_json::from_str(&encoded).expect("deserialize");
378 assert_eq!(decoded, a);
379 }
380 }
381
382 #[test]
383 fn terminator_id_new_get_round_trip_and_hash_eq() {
384 for raw in [0_u32, 1, 42, u32::MAX] {
385 let a = TerminatorId::new(raw);
386 let b = TerminatorId::new(raw);
387 assert_eq!(a.get(), raw);
388 assert_eq!(a, b);
389 assert_eq!(hash_of(&a), hash_of(&b));
390 let encoded = serde_json::to_string(&a).expect("serialize");
391 let decoded: TerminatorId = serde_json::from_str(&encoded).expect("deserialize");
392 assert_eq!(decoded, a);
393 }
394 }
395
396 fn _id_types_are_distinct(
401 _n: NodeId,
402 _p: PredicateId,
403 _m: MiddlewareId,
404 _f: FetchId,
405 _t: TerminatorId,
406 ) {
407 }
408
409 #[test]
410 fn node_check_collect_body_before_returns_stored_flag() {
411 let some = Node::Check {
412 predicate: PredicateId::new(0),
413 on_match: NodeId::new(0),
414 on_miss: NodeId::new(0),
415 collect_body_before: Some(BodySide::Request),
416 body_limit: 0,
417 };
418 assert_eq!(some.collect_body_before(), Some(BodySide::Request));
419
420 let none = Node::Check {
421 predicate: PredicateId::new(0),
422 on_match: NodeId::new(0),
423 on_miss: NodeId::new(0),
424 collect_body_before: None,
425 body_limit: 0,
426 };
427 assert_eq!(none.collect_body_before(), None);
428 }
429
430 #[test]
431 fn node_middleware_collect_body_before_returns_stored_flag() {
432 let some = Node::Middleware {
433 id: MiddlewareId::new(0),
434 next: NodeId::new(0),
435 on_error: None,
436 collect_body_before: Some(BodySide::Response),
437 body_limit: 0,
438 };
439 assert_eq!(some.collect_body_before(), Some(BodySide::Response));
440
441 let none = Node::Middleware {
442 id: MiddlewareId::new(0),
443 next: NodeId::new(0),
444 on_error: None,
445 collect_body_before: None,
446 body_limit: 0,
447 };
448 assert_eq!(none.collect_body_before(), None);
449 }
450
451 #[test]
452 fn node_fetch_collect_body_before_returns_stored_flag() {
453 let some = Node::Fetch {
454 id: FetchId::new(0),
455 next_response: None,
456 next_tunnel: None,
457 collect_body_before: Some(BodySide::Request),
458 body_limit: 0,
459 };
460 assert_eq!(some.collect_body_before(), Some(BodySide::Request));
461
462 let none = Node::Fetch {
463 id: FetchId::new(0),
464 next_response: None,
465 next_tunnel: None,
466 collect_body_before: None,
467 body_limit: 0,
468 };
469 assert_eq!(none.collect_body_before(), None);
470 }
471
472 #[test]
473 fn node_upgrade_collect_body_before_is_always_none() {
474 let n = Node::Upgrade { next: NodeId::new(0) };
475 assert_eq!(n.collect_body_before(), None);
476 }
477
478 #[test]
479 fn node_terminate_collect_body_before_is_always_none() {
480 let n = Node::Terminate(TerminatorId::new(0));
481 assert_eq!(n.collect_body_before(), None);
482 }
483
484 fn sample_predicate() -> PredicateInst {
485 PredicateInst {
486 path: FieldPath::TlsSni,
487 op: CompiledOperator::Equals(CompiledValue::Str(Arc::from("a"))),
488 }
489 }
490
491 fn sample_middleware() -> SymbolicMiddlewareRef {
492 SymbolicMiddlewareRef {
493 name: Arc::from("noop"),
494 args: Value::Null,
495 kind: MiddlewareKind::L7Request,
496 stateless: true,
497 needs_body: false,
498 on_error: None,
499 }
500 }
501
502 fn sample_fetch() -> SymbolicFetchRef {
503 SymbolicFetchRef {
504 kind: FetchKind::HttpProxy,
505 args: Value::Null,
506 retry_buffer_required: false,
507 allow_zero_rtt: None,
508 }
509 }
510
511 fn sample_meta() -> FlowGraphMeta {
512 FlowGraphMeta {
513 version_hash: [0; 32],
514 compiled_at: SystemTime::UNIX_EPOCH,
515 source_files: vec![],
516 feature_set: &[],
517 short_circuit_response_entry: std::collections::BTreeMap::new(),
518 listener_tls: std::collections::BTreeMap::new(),
519 listener_kinds: std::collections::BTreeMap::new(),
520 listener_transports: std::collections::BTreeMap::new(),
521 annotations: Vec::new(),
522 }
523 }
524
525 fn one_of_each_graph() -> SymbolicFlowGraph {
526 SymbolicFlowGraph {
527 nodes: vec![Node::Terminate(TerminatorId::new(0))],
528 predicates: vec![sample_predicate()],
529 middlewares: vec![sample_middleware()],
530 fetches: vec![sample_fetch()],
531 terminators: vec![Terminator::WriteHttpResponse],
532 entries: HashMap::new(),
533 meta: sample_meta(),
534 }
535 }
536
537 #[test]
538 fn index_by_node_id_returns_matching_node() {
539 let g = one_of_each_graph();
540 match &g[NodeId::new(0)] {
541 Node::Terminate(t) => assert_eq!(*t, TerminatorId::new(0)),
542 other => panic!("expected Terminate, got {other:?}"),
543 }
544 }
545
546 #[test]
547 fn index_by_predicate_id_returns_matching_predicate() {
548 let g = one_of_each_graph();
549 assert_eq!(g[PredicateId::new(0)], sample_predicate());
550 }
551
552 #[test]
553 fn index_by_middleware_id_returns_matching_middleware() {
554 let g = one_of_each_graph();
555 assert_eq!(g[MiddlewareId::new(0)], sample_middleware());
556 }
557
558 #[test]
559 fn index_by_fetch_id_returns_matching_fetch() {
560 let g = one_of_each_graph();
561 assert_eq!(g[FetchId::new(0)].kind, FetchKind::HttpProxy);
562 }
563
564 #[test]
565 fn index_by_terminator_id_returns_matching_terminator() {
566 let g = one_of_each_graph();
567 assert_eq!(g[TerminatorId::new(0)], Terminator::WriteHttpResponse);
568 }
569
570 fn node_round_trip(n: &Node) -> Node {
571 let encoded = serde_json::to_string(n).expect("serialize node");
572 serde_json::from_str(&encoded).expect("deserialize node")
573 }
574
575 #[test]
576 fn node_check_serde_round_trip_with_and_without_collect_flag() {
577 let with = Node::Check {
578 predicate: PredicateId::new(3),
579 on_match: NodeId::new(4),
580 on_miss: NodeId::new(5),
581 collect_body_before: Some(BodySide::Request),
582 body_limit: 0,
583 };
584 match node_round_trip(&with) {
585 Node::Check { predicate, on_match, on_miss, collect_body_before, .. } => {
586 assert_eq!(predicate, PredicateId::new(3));
587 assert_eq!(on_match, NodeId::new(4));
588 assert_eq!(on_miss, NodeId::new(5));
589 assert_eq!(collect_body_before, Some(BodySide::Request));
590 }
591 other => panic!("expected Check, got {other:?}"),
592 }
593
594 let without = Node::Check {
595 predicate: PredicateId::new(0),
596 on_match: NodeId::new(0),
597 on_miss: NodeId::new(0),
598 collect_body_before: None,
599 body_limit: 0,
600 };
601 match node_round_trip(&without) {
602 Node::Check { collect_body_before, .. } => assert_eq!(collect_body_before, None),
603 other => panic!("expected Check, got {other:?}"),
604 }
605 }
606
607 #[test]
608 fn node_middleware_serde_round_trip_with_and_without_collect_flag() {
609 let with = Node::Middleware {
610 id: MiddlewareId::new(1),
611 next: NodeId::new(2),
612 on_error: Some(NodeId::new(9)),
613 collect_body_before: Some(BodySide::Response),
614 body_limit: 0,
615 };
616 match node_round_trip(&with) {
617 Node::Middleware { id, next, on_error, collect_body_before, .. } => {
618 assert_eq!(id, MiddlewareId::new(1));
619 assert_eq!(next, NodeId::new(2));
620 assert_eq!(on_error, Some(NodeId::new(9)));
621 assert_eq!(collect_body_before, Some(BodySide::Response));
622 }
623 other => panic!("expected Middleware, got {other:?}"),
624 }
625
626 let without = Node::Middleware {
627 id: MiddlewareId::new(0),
628 next: NodeId::new(0),
629 on_error: None,
630 collect_body_before: None,
631 body_limit: 0,
632 };
633 match node_round_trip(&without) {
634 Node::Middleware { on_error, collect_body_before, .. } => {
635 assert_eq!(on_error, None);
636 assert_eq!(collect_body_before, None);
637 }
638 other => panic!("expected Middleware, got {other:?}"),
639 }
640 }
641
642 #[test]
643 fn node_fetch_serde_round_trip_with_and_without_collect_flag() {
644 let with = Node::Fetch {
645 id: FetchId::new(7),
646 next_response: Some(NodeId::new(8)),
647 next_tunnel: Some(NodeId::new(9)),
648 collect_body_before: Some(BodySide::Request),
649 body_limit: 0,
650 };
651 match node_round_trip(&with) {
652 Node::Fetch { id, next_response, next_tunnel, collect_body_before, .. } => {
653 assert_eq!(id, FetchId::new(7));
654 assert_eq!(next_response, Some(NodeId::new(8)));
655 assert_eq!(next_tunnel, Some(NodeId::new(9)));
656 assert_eq!(collect_body_before, Some(BodySide::Request));
657 }
658 other => panic!("expected Fetch, got {other:?}"),
659 }
660
661 let without = Node::Fetch {
662 id: FetchId::new(0),
663 next_response: None,
664 next_tunnel: None,
665 collect_body_before: None,
666 body_limit: 0,
667 };
668 match node_round_trip(&without) {
669 Node::Fetch { next_response, next_tunnel, collect_body_before, .. } => {
670 assert_eq!(next_response, None);
671 assert_eq!(next_tunnel, None);
672 assert_eq!(collect_body_before, None);
673 }
674 other => panic!("expected Fetch, got {other:?}"),
675 }
676 }
677
678 #[test]
679 fn node_upgrade_serde_round_trip() {
680 let n = Node::Upgrade { next: NodeId::new(11) };
681 match node_round_trip(&n) {
682 Node::Upgrade { next } => assert_eq!(next, NodeId::new(11)),
683 other => panic!("expected Upgrade, got {other:?}"),
684 }
685 }
686
687 #[test]
688 fn node_terminate_serde_round_trip() {
689 let n = Node::Terminate(TerminatorId::new(13));
690 match node_round_trip(&n) {
691 Node::Terminate(t) => assert_eq!(t, TerminatorId::new(13)),
692 other => panic!("expected Terminate, got {other:?}"),
693 }
694 }
695
696 #[test]
700 fn flow_graph_meta_serializes_and_emits_version_hash_field() {
701 let meta = sample_meta();
702 let encoded = serde_json::to_string(&meta).expect("serialize meta");
703 assert!(encoded.contains("version_hash"), "expected version_hash field in {encoded}");
704 }
705
706 #[test]
707 fn flow_graph_meta_round_trip_preserves_all_but_feature_set() {
708 use std::time::Duration;
712 let meta = FlowGraphMeta {
713 version_hash: [0x42; 32],
714 compiled_at: SystemTime::UNIX_EPOCH + Duration::from_secs(1_000_000_000),
715 source_files: vec![PathBuf::from("/a.json"), PathBuf::from("/b.json")],
716 feature_set: &["h3", "wasm"],
717 short_circuit_response_entry: std::collections::BTreeMap::new(),
718 listener_tls: std::collections::BTreeMap::new(),
719 listener_kinds: std::collections::BTreeMap::new(),
720 listener_transports: std::collections::BTreeMap::new(),
721 annotations: Vec::new(),
722 };
723 let encoded = serde_json::to_string(&meta).expect("serialize meta");
724 assert!(
725 !encoded.contains("feature_set"),
726 "feature_set must be skipped in dry-run JSON, got: {encoded}",
727 );
728 let decoded: FlowGraphMeta = serde_json::from_str(&encoded).expect("deserialize meta");
729 assert_eq!(decoded.version_hash, meta.version_hash);
730 assert_eq!(decoded.compiled_at, meta.compiled_at);
731 assert_eq!(decoded.source_files, meta.source_files);
732 assert!(decoded.feature_set.is_empty(), "feature_set must default to empty on deserialize");
734 }
735}