1use std::collections::HashMap;
2use std::net::SocketAddr;
3use std::ops::Index;
4use std::path::PathBuf;
5use std::time::SystemTime;
6
7use crate::conn_context::Transport;
8use crate::fetch::{SymbolicFetchRef, Terminator};
9use crate::middleware::SymbolicMiddlewareRef;
10use crate::predicate::PredicateInst;
11
12macro_rules! id_newtype {
13 ($name:ident) => {
14 #[derive(
15 Copy, Clone, Eq, PartialEq, Hash, Debug, PartialOrd, Ord, serde::Serialize, serde::Deserialize,
16 )]
17 pub struct $name(u32);
18
19 impl $name {
20 #[must_use]
21 pub const fn new(raw: u32) -> Self {
22 Self(raw)
23 }
24
25 #[must_use]
26 pub const fn get(self) -> u32 {
27 self.0
28 }
29 }
30 };
31}
32
33id_newtype!(NodeId);
34id_newtype!(PredicateId);
35id_newtype!(MiddlewareId);
36id_newtype!(FetchId);
37id_newtype!(TerminatorId);
38
39#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug, serde::Serialize, serde::Deserialize)]
40pub enum BodySide {
41 Request,
42 Response,
43}
44
45#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug, serde::Serialize, serde::Deserialize)]
51pub enum ListenerKind {
52 Raw,
56 Http,
60 Auto,
64}
65
66#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
67pub enum Node {
68 Check {
69 predicate: PredicateId,
70 on_match: NodeId,
71 on_miss: NodeId,
72 collect_body_before: Option<BodySide>,
73 #[serde(default)]
74 body_limit: usize,
75 },
76 Middleware {
77 id: MiddlewareId,
78 next: NodeId,
79 on_error: Option<NodeId>,
80 collect_body_before: Option<BodySide>,
81 #[serde(default)]
82 body_limit: usize,
83 },
84 Fetch {
85 id: FetchId,
86 next_response: Option<NodeId>,
87 next_tunnel: Option<NodeId>,
88 collect_body_before: Option<BodySide>,
89 #[serde(default)]
90 body_limit: usize,
91 },
92 Upgrade {
93 next: NodeId,
94 },
95 Terminate(TerminatorId),
96}
97
98impl Node {
99 #[must_use]
100 pub const fn collect_body_before(&self) -> Option<BodySide> {
101 match self {
102 Self::Check { collect_body_before, .. }
103 | Self::Middleware { collect_body_before, .. }
104 | Self::Fetch { collect_body_before, .. } => *collect_body_before,
105 Self::Upgrade { .. } | Self::Terminate(_) => None,
106 }
107 }
108
109 #[must_use]
110 pub const fn body_limit(&self) -> usize {
111 match self {
112 Self::Check { body_limit, .. }
113 | Self::Middleware { body_limit, .. }
114 | Self::Fetch { body_limit, .. } => *body_limit,
115 _ => 0,
116 }
117 }
118}
119
120#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
121pub struct FlowGraphMeta {
122 pub version_hash: [u8; 32],
123 pub compiled_at: SystemTime,
124 pub source_files: Vec<PathBuf>,
125 #[serde(skip, default = "empty_feature_set")]
129 pub feature_set: &'static [&'static str],
130
131 #[serde(default)]
146 pub short_circuit_response_entry: std::collections::BTreeMap<NodeId, NodeId>,
147
148 #[serde(default)]
159 pub listener_tls: std::collections::BTreeMap<SocketAddr, crate::rule::ListenerTlsSpec>,
160
161 #[serde(default)]
169 pub listener_kinds: std::collections::BTreeMap<SocketAddr, ListenerKind>,
170
171 #[serde(default)]
181 pub listener_transports: std::collections::BTreeMap<SocketAddr, Transport>,
182
183 #[serde(default)]
192 pub annotations: Vec<DryRunAnnotation>,
193}
194
195#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
206pub struct DryRunAnnotation {
207 pub kind: String,
211 pub message: String,
214}
215
216const fn empty_feature_set() -> &'static [&'static str] {
217 &[]
218}
219
220#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
221pub struct SymbolicFlowGraph {
222 pub nodes: Vec<Node>,
223 pub predicates: Vec<PredicateInst>,
224 pub middlewares: Vec<SymbolicMiddlewareRef>,
225 pub fetches: Vec<SymbolicFetchRef>,
226 pub terminators: Vec<Terminator>,
227 pub entries: HashMap<SocketAddr, NodeId>,
228 pub meta: FlowGraphMeta,
229}
230
231impl Index<NodeId> for SymbolicFlowGraph {
232 type Output = Node;
233 fn index(&self, id: NodeId) -> &Node {
234 &self.nodes[id.get() as usize]
235 }
236}
237
238impl Index<PredicateId> for SymbolicFlowGraph {
239 type Output = PredicateInst;
240 fn index(&self, id: PredicateId) -> &PredicateInst {
241 &self.predicates[id.get() as usize]
242 }
243}
244
245impl Index<MiddlewareId> for SymbolicFlowGraph {
246 type Output = SymbolicMiddlewareRef;
247 fn index(&self, id: MiddlewareId) -> &SymbolicMiddlewareRef {
248 &self.middlewares[id.get() as usize]
249 }
250}
251
252impl Index<FetchId> for SymbolicFlowGraph {
253 type Output = SymbolicFetchRef;
254 fn index(&self, id: FetchId) -> &SymbolicFetchRef {
255 &self.fetches[id.get() as usize]
256 }
257}
258
259impl Index<TerminatorId> for SymbolicFlowGraph {
260 type Output = Terminator;
261 fn index(&self, id: TerminatorId) -> &Terminator {
262 &self.terminators[id.get() as usize]
263 }
264}
265
266#[cfg(test)]
267mod tests {
268 use std::collections::hash_map::DefaultHasher;
269 use std::hash::{Hash, Hasher};
270 use std::sync::Arc;
271
272 use serde_json::Value;
273
274 use super::*;
275 use crate::fetch::{FetchKind, SymbolicFetchRef, Terminator};
276 use crate::middleware::{MiddlewareKind, SymbolicMiddlewareRef};
277 use crate::predicate::{CompiledOperator, CompiledValue, FieldPath, PredicateInst};
278
279 #[test]
280 fn new_then_get_round_trips_raw_u32() {
281 for raw in [0_u32, 1, 42, u32::MAX] {
282 assert_eq!(NodeId::new(raw).get(), raw);
283 }
284 }
285
286 #[test]
287 fn node_id_equality_is_structural() {
288 assert_eq!(NodeId::new(7), NodeId::new(7));
289 assert_ne!(NodeId::new(7), NodeId::new(8));
290 }
291
292 #[test]
293 fn node_id_ordering_follows_raw_u32() {
294 assert!(NodeId::new(1) < NodeId::new(2));
295 assert!(NodeId::new(u32::MAX) > NodeId::new(0));
296 }
297
298 #[test]
299 fn node_id_serde_round_trip() {
300 let id = NodeId::new(0x0bad_f00d);
301 let encoded = serde_json::to_string(&id).expect("serialize");
302 let decoded: NodeId = serde_json::from_str(&encoded).expect("deserialize");
303 assert_eq!(decoded, id);
304 }
305
306 #[test]
307 fn body_side_serde_round_trip_per_variant() {
308 for s in [BodySide::Request, BodySide::Response] {
309 let encoded = serde_json::to_string(&s).expect("serialize");
310 let decoded: BodySide = serde_json::from_str(&encoded).expect("deserialize");
311 assert_eq!(decoded, s);
312 }
313 }
314
315 fn hash_of<T: Hash>(t: &T) -> u64 {
316 let mut h = DefaultHasher::new();
317 t.hash(&mut h);
318 h.finish()
319 }
320
321 #[test]
322 fn predicate_id_new_get_round_trip_and_hash_eq() {
323 for raw in [0_u32, 1, 42, u32::MAX] {
324 let a = PredicateId::new(raw);
325 let b = PredicateId::new(raw);
326 assert_eq!(a.get(), raw);
327 assert_eq!(a, b);
328 assert_eq!(hash_of(&a), hash_of(&b));
329 let encoded = serde_json::to_string(&a).expect("serialize");
330 let decoded: PredicateId = serde_json::from_str(&encoded).expect("deserialize");
331 assert_eq!(decoded, a);
332 }
333 }
334
335 #[test]
336 fn middleware_id_new_get_round_trip_and_hash_eq() {
337 for raw in [0_u32, 1, 42, u32::MAX] {
338 let a = MiddlewareId::new(raw);
339 let b = MiddlewareId::new(raw);
340 assert_eq!(a.get(), raw);
341 assert_eq!(a, b);
342 assert_eq!(hash_of(&a), hash_of(&b));
343 let encoded = serde_json::to_string(&a).expect("serialize");
344 let decoded: MiddlewareId = serde_json::from_str(&encoded).expect("deserialize");
345 assert_eq!(decoded, a);
346 }
347 }
348
349 #[test]
350 fn fetch_id_new_get_round_trip_and_hash_eq() {
351 for raw in [0_u32, 1, 42, u32::MAX] {
352 let a = FetchId::new(raw);
353 let b = FetchId::new(raw);
354 assert_eq!(a.get(), raw);
355 assert_eq!(a, b);
356 assert_eq!(hash_of(&a), hash_of(&b));
357 let encoded = serde_json::to_string(&a).expect("serialize");
358 let decoded: FetchId = serde_json::from_str(&encoded).expect("deserialize");
359 assert_eq!(decoded, a);
360 }
361 }
362
363 #[test]
364 fn terminator_id_new_get_round_trip_and_hash_eq() {
365 for raw in [0_u32, 1, 42, u32::MAX] {
366 let a = TerminatorId::new(raw);
367 let b = TerminatorId::new(raw);
368 assert_eq!(a.get(), raw);
369 assert_eq!(a, b);
370 assert_eq!(hash_of(&a), hash_of(&b));
371 let encoded = serde_json::to_string(&a).expect("serialize");
372 let decoded: TerminatorId = serde_json::from_str(&encoded).expect("deserialize");
373 assert_eq!(decoded, a);
374 }
375 }
376
377 fn _id_types_are_distinct(
382 _n: NodeId,
383 _p: PredicateId,
384 _m: MiddlewareId,
385 _f: FetchId,
386 _t: TerminatorId,
387 ) {
388 }
389
390 #[test]
391 fn node_check_collect_body_before_returns_stored_flag() {
392 let some = Node::Check {
393 predicate: PredicateId::new(0),
394 on_match: NodeId::new(0),
395 on_miss: NodeId::new(0),
396 collect_body_before: Some(BodySide::Request),
397 body_limit: 0,
398 };
399 assert_eq!(some.collect_body_before(), Some(BodySide::Request));
400
401 let none = Node::Check {
402 predicate: PredicateId::new(0),
403 on_match: NodeId::new(0),
404 on_miss: NodeId::new(0),
405 collect_body_before: None,
406 body_limit: 0,
407 };
408 assert_eq!(none.collect_body_before(), None);
409 }
410
411 #[test]
412 fn node_middleware_collect_body_before_returns_stored_flag() {
413 let some = Node::Middleware {
414 id: MiddlewareId::new(0),
415 next: NodeId::new(0),
416 on_error: None,
417 collect_body_before: Some(BodySide::Response),
418 body_limit: 0,
419 };
420 assert_eq!(some.collect_body_before(), Some(BodySide::Response));
421
422 let none = Node::Middleware {
423 id: MiddlewareId::new(0),
424 next: NodeId::new(0),
425 on_error: None,
426 collect_body_before: None,
427 body_limit: 0,
428 };
429 assert_eq!(none.collect_body_before(), None);
430 }
431
432 #[test]
433 fn node_fetch_collect_body_before_returns_stored_flag() {
434 let some = Node::Fetch {
435 id: FetchId::new(0),
436 next_response: None,
437 next_tunnel: None,
438 collect_body_before: Some(BodySide::Request),
439 body_limit: 0,
440 };
441 assert_eq!(some.collect_body_before(), Some(BodySide::Request));
442
443 let none = Node::Fetch {
444 id: FetchId::new(0),
445 next_response: None,
446 next_tunnel: None,
447 collect_body_before: None,
448 body_limit: 0,
449 };
450 assert_eq!(none.collect_body_before(), None);
451 }
452
453 #[test]
454 fn node_upgrade_collect_body_before_is_always_none() {
455 let n = Node::Upgrade { next: NodeId::new(0) };
456 assert_eq!(n.collect_body_before(), None);
457 }
458
459 #[test]
460 fn node_terminate_collect_body_before_is_always_none() {
461 let n = Node::Terminate(TerminatorId::new(0));
462 assert_eq!(n.collect_body_before(), None);
463 }
464
465 fn sample_predicate() -> PredicateInst {
466 PredicateInst {
467 path: FieldPath::TlsSni,
468 op: CompiledOperator::Equals(CompiledValue::Str(Arc::from("a"))),
469 }
470 }
471
472 fn sample_middleware() -> SymbolicMiddlewareRef {
473 SymbolicMiddlewareRef {
474 name: Arc::from("noop"),
475 args: Value::Null,
476 kind: MiddlewareKind::L7Request,
477 stateless: true,
478 needs_body: false,
479 on_error: None,
480 }
481 }
482
483 fn sample_fetch() -> SymbolicFetchRef {
484 SymbolicFetchRef {
485 kind: FetchKind::HttpProxy,
486 args: Value::Null,
487 retry_buffer_required: false,
488 allow_zero_rtt: None,
489 }
490 }
491
492 fn sample_meta() -> FlowGraphMeta {
493 FlowGraphMeta {
494 version_hash: [0; 32],
495 compiled_at: SystemTime::UNIX_EPOCH,
496 source_files: vec![],
497 feature_set: &[],
498 short_circuit_response_entry: std::collections::BTreeMap::new(),
499 listener_tls: std::collections::BTreeMap::new(),
500 listener_kinds: std::collections::BTreeMap::new(),
501 listener_transports: std::collections::BTreeMap::new(),
502 annotations: Vec::new(),
503 }
504 }
505
506 fn one_of_each_graph() -> SymbolicFlowGraph {
507 SymbolicFlowGraph {
508 nodes: vec![Node::Terminate(TerminatorId::new(0))],
509 predicates: vec![sample_predicate()],
510 middlewares: vec![sample_middleware()],
511 fetches: vec![sample_fetch()],
512 terminators: vec![Terminator::WriteHttpResponse],
513 entries: HashMap::new(),
514 meta: sample_meta(),
515 }
516 }
517
518 #[test]
519 fn index_by_node_id_returns_matching_node() {
520 let g = one_of_each_graph();
521 match &g[NodeId::new(0)] {
522 Node::Terminate(t) => assert_eq!(*t, TerminatorId::new(0)),
523 other => panic!("expected Terminate, got {other:?}"),
524 }
525 }
526
527 #[test]
528 fn index_by_predicate_id_returns_matching_predicate() {
529 let g = one_of_each_graph();
530 assert_eq!(g[PredicateId::new(0)], sample_predicate());
531 }
532
533 #[test]
534 fn index_by_middleware_id_returns_matching_middleware() {
535 let g = one_of_each_graph();
536 assert_eq!(g[MiddlewareId::new(0)], sample_middleware());
537 }
538
539 #[test]
540 fn index_by_fetch_id_returns_matching_fetch() {
541 let g = one_of_each_graph();
542 assert_eq!(g[FetchId::new(0)].kind, FetchKind::HttpProxy);
543 }
544
545 #[test]
546 fn index_by_terminator_id_returns_matching_terminator() {
547 let g = one_of_each_graph();
548 assert_eq!(g[TerminatorId::new(0)], Terminator::WriteHttpResponse);
549 }
550
551 fn node_round_trip(n: &Node) -> Node {
552 let encoded = serde_json::to_string(n).expect("serialize node");
553 serde_json::from_str(&encoded).expect("deserialize node")
554 }
555
556 #[test]
557 fn node_check_serde_round_trip_with_and_without_collect_flag() {
558 let with = Node::Check {
559 predicate: PredicateId::new(3),
560 on_match: NodeId::new(4),
561 on_miss: NodeId::new(5),
562 collect_body_before: Some(BodySide::Request),
563 body_limit: 0,
564 };
565 match node_round_trip(&with) {
566 Node::Check { predicate, on_match, on_miss, collect_body_before, .. } => {
567 assert_eq!(predicate, PredicateId::new(3));
568 assert_eq!(on_match, NodeId::new(4));
569 assert_eq!(on_miss, NodeId::new(5));
570 assert_eq!(collect_body_before, Some(BodySide::Request));
571 }
572 other => panic!("expected Check, got {other:?}"),
573 }
574
575 let without = Node::Check {
576 predicate: PredicateId::new(0),
577 on_match: NodeId::new(0),
578 on_miss: NodeId::new(0),
579 collect_body_before: None,
580 body_limit: 0,
581 };
582 match node_round_trip(&without) {
583 Node::Check { collect_body_before, .. } => assert_eq!(collect_body_before, None),
584 other => panic!("expected Check, got {other:?}"),
585 }
586 }
587
588 #[test]
589 fn node_middleware_serde_round_trip_with_and_without_collect_flag() {
590 let with = Node::Middleware {
591 id: MiddlewareId::new(1),
592 next: NodeId::new(2),
593 on_error: Some(NodeId::new(9)),
594 collect_body_before: Some(BodySide::Response),
595 body_limit: 0,
596 };
597 match node_round_trip(&with) {
598 Node::Middleware { id, next, on_error, collect_body_before, .. } => {
599 assert_eq!(id, MiddlewareId::new(1));
600 assert_eq!(next, NodeId::new(2));
601 assert_eq!(on_error, Some(NodeId::new(9)));
602 assert_eq!(collect_body_before, Some(BodySide::Response));
603 }
604 other => panic!("expected Middleware, got {other:?}"),
605 }
606
607 let without = Node::Middleware {
608 id: MiddlewareId::new(0),
609 next: NodeId::new(0),
610 on_error: None,
611 collect_body_before: None,
612 body_limit: 0,
613 };
614 match node_round_trip(&without) {
615 Node::Middleware { on_error, collect_body_before, .. } => {
616 assert_eq!(on_error, None);
617 assert_eq!(collect_body_before, None);
618 }
619 other => panic!("expected Middleware, got {other:?}"),
620 }
621 }
622
623 #[test]
624 fn node_fetch_serde_round_trip_with_and_without_collect_flag() {
625 let with = Node::Fetch {
626 id: FetchId::new(7),
627 next_response: Some(NodeId::new(8)),
628 next_tunnel: Some(NodeId::new(9)),
629 collect_body_before: Some(BodySide::Request),
630 body_limit: 0,
631 };
632 match node_round_trip(&with) {
633 Node::Fetch { id, next_response, next_tunnel, collect_body_before, .. } => {
634 assert_eq!(id, FetchId::new(7));
635 assert_eq!(next_response, Some(NodeId::new(8)));
636 assert_eq!(next_tunnel, Some(NodeId::new(9)));
637 assert_eq!(collect_body_before, Some(BodySide::Request));
638 }
639 other => panic!("expected Fetch, got {other:?}"),
640 }
641
642 let without = Node::Fetch {
643 id: FetchId::new(0),
644 next_response: None,
645 next_tunnel: None,
646 collect_body_before: None,
647 body_limit: 0,
648 };
649 match node_round_trip(&without) {
650 Node::Fetch { next_response, next_tunnel, collect_body_before, .. } => {
651 assert_eq!(next_response, None);
652 assert_eq!(next_tunnel, None);
653 assert_eq!(collect_body_before, None);
654 }
655 other => panic!("expected Fetch, got {other:?}"),
656 }
657 }
658
659 #[test]
660 fn node_upgrade_serde_round_trip() {
661 let n = Node::Upgrade { next: NodeId::new(11) };
662 match node_round_trip(&n) {
663 Node::Upgrade { next } => assert_eq!(next, NodeId::new(11)),
664 other => panic!("expected Upgrade, got {other:?}"),
665 }
666 }
667
668 #[test]
669 fn node_terminate_serde_round_trip() {
670 let n = Node::Terminate(TerminatorId::new(13));
671 match node_round_trip(&n) {
672 Node::Terminate(t) => assert_eq!(t, TerminatorId::new(13)),
673 other => panic!("expected Terminate, got {other:?}"),
674 }
675 }
676
677 #[test]
681 fn flow_graph_meta_serializes_and_emits_version_hash_field() {
682 let meta = sample_meta();
683 let encoded = serde_json::to_string(&meta).expect("serialize meta");
684 assert!(encoded.contains("version_hash"), "expected version_hash field in {encoded}");
685 }
686
687 #[test]
688 fn flow_graph_meta_round_trip_preserves_all_but_feature_set() {
689 use std::time::Duration;
693 let meta = FlowGraphMeta {
694 version_hash: [0x42; 32],
695 compiled_at: SystemTime::UNIX_EPOCH + Duration::from_secs(1_000_000_000),
696 source_files: vec![PathBuf::from("/a.json"), PathBuf::from("/b.json")],
697 feature_set: &["h3", "wasm"],
698 short_circuit_response_entry: std::collections::BTreeMap::new(),
699 listener_tls: std::collections::BTreeMap::new(),
700 listener_kinds: std::collections::BTreeMap::new(),
701 listener_transports: std::collections::BTreeMap::new(),
702 annotations: Vec::new(),
703 };
704 let encoded = serde_json::to_string(&meta).expect("serialize meta");
705 assert!(
706 !encoded.contains("feature_set"),
707 "feature_set must be skipped in dry-run JSON, got: {encoded}",
708 );
709 let decoded: FlowGraphMeta = serde_json::from_str(&encoded).expect("deserialize meta");
710 assert_eq!(decoded.version_hash, meta.version_hash);
711 assert_eq!(decoded.compiled_at, meta.compiled_at);
712 assert_eq!(decoded.source_files, meta.source_files);
713 assert!(decoded.feature_set.is_empty(), "feature_set must default to empty on deserialize");
715 }
716}