1use std::collections::HashMap;
2use std::net::SocketAddr;
3use std::ops::Index;
4use std::path::PathBuf;
5use std::time::SystemTime;
6
7use crate::conn_context::Transport;
8use crate::fetch::{SymbolicFetchRef, Terminator};
9use crate::middleware::SymbolicMiddlewareRef;
10use crate::predicate::PredicateInst;
11
12macro_rules! id_newtype {
13 ($name:ident) => {
14 #[derive(
15 Copy, Clone, Eq, PartialEq, Hash, Debug, PartialOrd, Ord, serde::Serialize, serde::Deserialize,
16 )]
17 pub struct $name(u32);
18
19 impl $name {
20 #[must_use]
21 pub const fn new(raw: u32) -> Self {
22 Self(raw)
23 }
24
25 #[must_use]
26 pub const fn get(self) -> u32 {
27 self.0
28 }
29 }
30 };
31}
32
33id_newtype!(NodeId);
34id_newtype!(PredicateId);
35id_newtype!(MiddlewareId);
36id_newtype!(FetchId);
37id_newtype!(TerminatorId);
38
39#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug, serde::Serialize, serde::Deserialize)]
40pub enum BodySide {
41 Request,
42 Response,
43}
44
45#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug, serde::Serialize, serde::Deserialize)]
51pub enum ListenerKind {
52 Raw,
56 Http,
60 Auto,
64}
65
66#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
67pub enum Node {
68 Check {
69 predicate: PredicateId,
70 on_match: NodeId,
71 on_miss: NodeId,
72 collect_body_before: Option<BodySide>,
73 #[serde(default)]
74 body_limit: usize,
75 },
76 Middleware {
77 id: MiddlewareId,
78 next: NodeId,
79 on_error: Option<NodeId>,
80 collect_body_before: Option<BodySide>,
81 #[serde(default)]
82 body_limit: usize,
83 },
84 Fetch {
85 id: FetchId,
86 next_response: Option<NodeId>,
87 next_tunnel: Option<NodeId>,
88 collect_body_before: Option<BodySide>,
89 #[serde(default)]
90 body_limit: usize,
91 },
92 Upgrade {
93 next: NodeId,
94 },
95 Terminate(TerminatorId),
96}
97
98impl Node {
99 #[must_use]
100 pub const fn collect_body_before(&self) -> Option<BodySide> {
101 match self {
102 Self::Check { collect_body_before, .. }
103 | Self::Middleware { collect_body_before, .. }
104 | Self::Fetch { collect_body_before, .. } => *collect_body_before,
105 Self::Upgrade { .. } | Self::Terminate(_) => None,
106 }
107 }
108
109 #[must_use]
110 pub const fn body_limit(&self) -> usize {
111 match self {
112 Self::Check { body_limit, .. }
113 | Self::Middleware { body_limit, .. }
114 | Self::Fetch { body_limit, .. } => *body_limit,
115 _ => 0,
116 }
117 }
118}
119
120#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
121pub struct FlowGraphMeta {
122 pub version_hash: [u8; 32],
123 pub compiled_at: SystemTime,
124 pub source_files: Vec<PathBuf>,
125 #[serde(skip, default = "empty_feature_set")]
129 pub feature_set: &'static [&'static str],
130
131 #[serde(default)]
146 pub short_circuit_response_entry: std::collections::BTreeMap<NodeId, NodeId>,
147
148 #[serde(default)]
160 pub listener_tls: std::collections::BTreeMap<SocketAddr, crate::rule::ListenerTlsSpec>,
161
162 #[serde(default)]
170 pub listener_kinds: std::collections::BTreeMap<SocketAddr, ListenerKind>,
171
172 #[serde(default)]
182 pub listener_transports: std::collections::BTreeMap<SocketAddr, Transport>,
183}
184
185const fn empty_feature_set() -> &'static [&'static str] {
186 &[]
187}
188
189#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
190pub struct SymbolicFlowGraph {
191 pub nodes: Vec<Node>,
192 pub predicates: Vec<PredicateInst>,
193 pub middlewares: Vec<SymbolicMiddlewareRef>,
194 pub fetches: Vec<SymbolicFetchRef>,
195 pub terminators: Vec<Terminator>,
196 pub entries: HashMap<SocketAddr, NodeId>,
197 pub meta: FlowGraphMeta,
198}
199
200impl Index<NodeId> for SymbolicFlowGraph {
201 type Output = Node;
202 fn index(&self, id: NodeId) -> &Node {
203 &self.nodes[id.get() as usize]
204 }
205}
206
207impl Index<PredicateId> for SymbolicFlowGraph {
208 type Output = PredicateInst;
209 fn index(&self, id: PredicateId) -> &PredicateInst {
210 &self.predicates[id.get() as usize]
211 }
212}
213
214impl Index<MiddlewareId> for SymbolicFlowGraph {
215 type Output = SymbolicMiddlewareRef;
216 fn index(&self, id: MiddlewareId) -> &SymbolicMiddlewareRef {
217 &self.middlewares[id.get() as usize]
218 }
219}
220
221impl Index<FetchId> for SymbolicFlowGraph {
222 type Output = SymbolicFetchRef;
223 fn index(&self, id: FetchId) -> &SymbolicFetchRef {
224 &self.fetches[id.get() as usize]
225 }
226}
227
228impl Index<TerminatorId> for SymbolicFlowGraph {
229 type Output = Terminator;
230 fn index(&self, id: TerminatorId) -> &Terminator {
231 &self.terminators[id.get() as usize]
232 }
233}
234
235#[cfg(test)]
236mod tests {
237 use std::collections::hash_map::DefaultHasher;
238 use std::hash::{Hash, Hasher};
239 use std::sync::Arc;
240
241 use serde_json::Value;
242
243 use super::*;
244 use crate::fetch::{FetchKind, SymbolicFetchRef, Terminator};
245 use crate::middleware::{MiddlewareKind, SymbolicMiddlewareRef};
246 use crate::predicate::{CompiledOperator, CompiledValue, FieldPath, PredicateInst};
247
248 #[test]
249 fn new_then_get_round_trips_raw_u32() {
250 for raw in [0_u32, 1, 42, u32::MAX] {
251 assert_eq!(NodeId::new(raw).get(), raw);
252 }
253 }
254
255 #[test]
256 fn node_id_equality_is_structural() {
257 assert_eq!(NodeId::new(7), NodeId::new(7));
258 assert_ne!(NodeId::new(7), NodeId::new(8));
259 }
260
261 #[test]
262 fn node_id_ordering_follows_raw_u32() {
263 assert!(NodeId::new(1) < NodeId::new(2));
264 assert!(NodeId::new(u32::MAX) > NodeId::new(0));
265 }
266
267 #[test]
268 fn node_id_serde_round_trip() {
269 let id = NodeId::new(0x0bad_f00d);
270 let encoded = serde_json::to_string(&id).expect("serialize");
271 let decoded: NodeId = serde_json::from_str(&encoded).expect("deserialize");
272 assert_eq!(decoded, id);
273 }
274
275 #[test]
276 fn body_side_serde_round_trip_per_variant() {
277 for s in [BodySide::Request, BodySide::Response] {
278 let encoded = serde_json::to_string(&s).expect("serialize");
279 let decoded: BodySide = serde_json::from_str(&encoded).expect("deserialize");
280 assert_eq!(decoded, s);
281 }
282 }
283
284 fn hash_of<T: Hash>(t: &T) -> u64 {
285 let mut h = DefaultHasher::new();
286 t.hash(&mut h);
287 h.finish()
288 }
289
290 #[test]
291 fn predicate_id_new_get_round_trip_and_hash_eq() {
292 for raw in [0_u32, 1, 42, u32::MAX] {
293 let a = PredicateId::new(raw);
294 let b = PredicateId::new(raw);
295 assert_eq!(a.get(), raw);
296 assert_eq!(a, b);
297 assert_eq!(hash_of(&a), hash_of(&b));
298 let encoded = serde_json::to_string(&a).expect("serialize");
299 let decoded: PredicateId = serde_json::from_str(&encoded).expect("deserialize");
300 assert_eq!(decoded, a);
301 }
302 }
303
304 #[test]
305 fn middleware_id_new_get_round_trip_and_hash_eq() {
306 for raw in [0_u32, 1, 42, u32::MAX] {
307 let a = MiddlewareId::new(raw);
308 let b = MiddlewareId::new(raw);
309 assert_eq!(a.get(), raw);
310 assert_eq!(a, b);
311 assert_eq!(hash_of(&a), hash_of(&b));
312 let encoded = serde_json::to_string(&a).expect("serialize");
313 let decoded: MiddlewareId = serde_json::from_str(&encoded).expect("deserialize");
314 assert_eq!(decoded, a);
315 }
316 }
317
318 #[test]
319 fn fetch_id_new_get_round_trip_and_hash_eq() {
320 for raw in [0_u32, 1, 42, u32::MAX] {
321 let a = FetchId::new(raw);
322 let b = FetchId::new(raw);
323 assert_eq!(a.get(), raw);
324 assert_eq!(a, b);
325 assert_eq!(hash_of(&a), hash_of(&b));
326 let encoded = serde_json::to_string(&a).expect("serialize");
327 let decoded: FetchId = serde_json::from_str(&encoded).expect("deserialize");
328 assert_eq!(decoded, a);
329 }
330 }
331
332 #[test]
333 fn terminator_id_new_get_round_trip_and_hash_eq() {
334 for raw in [0_u32, 1, 42, u32::MAX] {
335 let a = TerminatorId::new(raw);
336 let b = TerminatorId::new(raw);
337 assert_eq!(a.get(), raw);
338 assert_eq!(a, b);
339 assert_eq!(hash_of(&a), hash_of(&b));
340 let encoded = serde_json::to_string(&a).expect("serialize");
341 let decoded: TerminatorId = serde_json::from_str(&encoded).expect("deserialize");
342 assert_eq!(decoded, a);
343 }
344 }
345
346 fn _id_types_are_distinct(
351 _n: NodeId,
352 _p: PredicateId,
353 _m: MiddlewareId,
354 _f: FetchId,
355 _t: TerminatorId,
356 ) {
357 }
358
359 #[test]
360 fn node_check_collect_body_before_returns_stored_flag() {
361 let some = Node::Check {
362 predicate: PredicateId::new(0),
363 on_match: NodeId::new(0),
364 on_miss: NodeId::new(0),
365 collect_body_before: Some(BodySide::Request),
366 body_limit: 0,
367 };
368 assert_eq!(some.collect_body_before(), Some(BodySide::Request));
369
370 let none = Node::Check {
371 predicate: PredicateId::new(0),
372 on_match: NodeId::new(0),
373 on_miss: NodeId::new(0),
374 collect_body_before: None,
375 body_limit: 0,
376 };
377 assert_eq!(none.collect_body_before(), None);
378 }
379
380 #[test]
381 fn node_middleware_collect_body_before_returns_stored_flag() {
382 let some = Node::Middleware {
383 id: MiddlewareId::new(0),
384 next: NodeId::new(0),
385 on_error: None,
386 collect_body_before: Some(BodySide::Response),
387 body_limit: 0,
388 };
389 assert_eq!(some.collect_body_before(), Some(BodySide::Response));
390
391 let none = Node::Middleware {
392 id: MiddlewareId::new(0),
393 next: NodeId::new(0),
394 on_error: None,
395 collect_body_before: None,
396 body_limit: 0,
397 };
398 assert_eq!(none.collect_body_before(), None);
399 }
400
401 #[test]
402 fn node_fetch_collect_body_before_returns_stored_flag() {
403 let some = Node::Fetch {
404 id: FetchId::new(0),
405 next_response: None,
406 next_tunnel: None,
407 collect_body_before: Some(BodySide::Request),
408 body_limit: 0,
409 };
410 assert_eq!(some.collect_body_before(), Some(BodySide::Request));
411
412 let none = Node::Fetch {
413 id: FetchId::new(0),
414 next_response: None,
415 next_tunnel: None,
416 collect_body_before: None,
417 body_limit: 0,
418 };
419 assert_eq!(none.collect_body_before(), None);
420 }
421
422 #[test]
423 fn node_upgrade_collect_body_before_is_always_none() {
424 let n = Node::Upgrade { next: NodeId::new(0) };
425 assert_eq!(n.collect_body_before(), None);
426 }
427
428 #[test]
429 fn node_terminate_collect_body_before_is_always_none() {
430 let n = Node::Terminate(TerminatorId::new(0));
431 assert_eq!(n.collect_body_before(), None);
432 }
433
434 fn sample_predicate() -> PredicateInst {
435 PredicateInst {
436 path: FieldPath::TlsSni,
437 op: CompiledOperator::Equals(CompiledValue::Str(Arc::from("a"))),
438 }
439 }
440
441 fn sample_middleware() -> SymbolicMiddlewareRef {
442 SymbolicMiddlewareRef {
443 name: Arc::from("noop"),
444 args: Value::Null,
445 kind: MiddlewareKind::L7Request,
446 stateless: true,
447 needs_body: false,
448 on_error: None,
449 }
450 }
451
452 fn sample_fetch() -> SymbolicFetchRef {
453 SymbolicFetchRef { kind: FetchKind::HttpProxy, args: Value::Null, retry_buffer_required: false }
454 }
455
456 fn sample_meta() -> FlowGraphMeta {
457 FlowGraphMeta {
458 version_hash: [0; 32],
459 compiled_at: SystemTime::UNIX_EPOCH,
460 source_files: vec![],
461 feature_set: &[],
462 short_circuit_response_entry: std::collections::BTreeMap::new(),
463 listener_tls: std::collections::BTreeMap::new(),
464 listener_kinds: std::collections::BTreeMap::new(),
465
466 listener_transports: std::collections::BTreeMap::new(),
467 }
468 }
469
470 fn one_of_each_graph() -> SymbolicFlowGraph {
471 SymbolicFlowGraph {
472 nodes: vec![Node::Terminate(TerminatorId::new(0))],
473 predicates: vec![sample_predicate()],
474 middlewares: vec![sample_middleware()],
475 fetches: vec![sample_fetch()],
476 terminators: vec![Terminator::WriteHttpResponse],
477 entries: HashMap::new(),
478 meta: sample_meta(),
479 }
480 }
481
482 #[test]
483 fn index_by_node_id_returns_matching_node() {
484 let g = one_of_each_graph();
485 match &g[NodeId::new(0)] {
486 Node::Terminate(t) => assert_eq!(*t, TerminatorId::new(0)),
487 other => panic!("expected Terminate, got {other:?}"),
488 }
489 }
490
491 #[test]
492 fn index_by_predicate_id_returns_matching_predicate() {
493 let g = one_of_each_graph();
494 assert_eq!(g[PredicateId::new(0)], sample_predicate());
495 }
496
497 #[test]
498 fn index_by_middleware_id_returns_matching_middleware() {
499 let g = one_of_each_graph();
500 assert_eq!(g[MiddlewareId::new(0)], sample_middleware());
501 }
502
503 #[test]
504 fn index_by_fetch_id_returns_matching_fetch() {
505 let g = one_of_each_graph();
506 assert_eq!(g[FetchId::new(0)].kind, FetchKind::HttpProxy);
507 }
508
509 #[test]
510 fn index_by_terminator_id_returns_matching_terminator() {
511 let g = one_of_each_graph();
512 assert_eq!(g[TerminatorId::new(0)], Terminator::WriteHttpResponse);
513 }
514
515 fn node_round_trip(n: &Node) -> Node {
516 let encoded = serde_json::to_string(n).expect("serialize node");
517 serde_json::from_str(&encoded).expect("deserialize node")
518 }
519
520 #[test]
521 fn node_check_serde_round_trip_with_and_without_collect_flag() {
522 let with = Node::Check {
523 predicate: PredicateId::new(3),
524 on_match: NodeId::new(4),
525 on_miss: NodeId::new(5),
526 collect_body_before: Some(BodySide::Request),
527 body_limit: 0,
528 };
529 match node_round_trip(&with) {
530 Node::Check { predicate, on_match, on_miss, collect_body_before, .. } => {
531 assert_eq!(predicate, PredicateId::new(3));
532 assert_eq!(on_match, NodeId::new(4));
533 assert_eq!(on_miss, NodeId::new(5));
534 assert_eq!(collect_body_before, Some(BodySide::Request));
535 }
536 other => panic!("expected Check, got {other:?}"),
537 }
538
539 let without = Node::Check {
540 predicate: PredicateId::new(0),
541 on_match: NodeId::new(0),
542 on_miss: NodeId::new(0),
543 collect_body_before: None,
544 body_limit: 0,
545 };
546 match node_round_trip(&without) {
547 Node::Check { collect_body_before, .. } => assert_eq!(collect_body_before, None),
548 other => panic!("expected Check, got {other:?}"),
549 }
550 }
551
552 #[test]
553 fn node_middleware_serde_round_trip_with_and_without_collect_flag() {
554 let with = Node::Middleware {
555 id: MiddlewareId::new(1),
556 next: NodeId::new(2),
557 on_error: Some(NodeId::new(9)),
558 collect_body_before: Some(BodySide::Response),
559 body_limit: 0,
560 };
561 match node_round_trip(&with) {
562 Node::Middleware { id, next, on_error, collect_body_before, .. } => {
563 assert_eq!(id, MiddlewareId::new(1));
564 assert_eq!(next, NodeId::new(2));
565 assert_eq!(on_error, Some(NodeId::new(9)));
566 assert_eq!(collect_body_before, Some(BodySide::Response));
567 }
568 other => panic!("expected Middleware, got {other:?}"),
569 }
570
571 let without = Node::Middleware {
572 id: MiddlewareId::new(0),
573 next: NodeId::new(0),
574 on_error: None,
575 collect_body_before: None,
576 body_limit: 0,
577 };
578 match node_round_trip(&without) {
579 Node::Middleware { on_error, collect_body_before, .. } => {
580 assert_eq!(on_error, None);
581 assert_eq!(collect_body_before, None);
582 }
583 other => panic!("expected Middleware, got {other:?}"),
584 }
585 }
586
587 #[test]
588 fn node_fetch_serde_round_trip_with_and_without_collect_flag() {
589 let with = Node::Fetch {
590 id: FetchId::new(7),
591 next_response: Some(NodeId::new(8)),
592 next_tunnel: Some(NodeId::new(9)),
593 collect_body_before: Some(BodySide::Request),
594 body_limit: 0,
595 };
596 match node_round_trip(&with) {
597 Node::Fetch { id, next_response, next_tunnel, collect_body_before, .. } => {
598 assert_eq!(id, FetchId::new(7));
599 assert_eq!(next_response, Some(NodeId::new(8)));
600 assert_eq!(next_tunnel, Some(NodeId::new(9)));
601 assert_eq!(collect_body_before, Some(BodySide::Request));
602 }
603 other => panic!("expected Fetch, got {other:?}"),
604 }
605
606 let without = Node::Fetch {
607 id: FetchId::new(0),
608 next_response: None,
609 next_tunnel: None,
610 collect_body_before: None,
611 body_limit: 0,
612 };
613 match node_round_trip(&without) {
614 Node::Fetch { next_response, next_tunnel, collect_body_before, .. } => {
615 assert_eq!(next_response, None);
616 assert_eq!(next_tunnel, None);
617 assert_eq!(collect_body_before, None);
618 }
619 other => panic!("expected Fetch, got {other:?}"),
620 }
621 }
622
623 #[test]
624 fn node_upgrade_serde_round_trip() {
625 let n = Node::Upgrade { next: NodeId::new(11) };
626 match node_round_trip(&n) {
627 Node::Upgrade { next } => assert_eq!(next, NodeId::new(11)),
628 other => panic!("expected Upgrade, got {other:?}"),
629 }
630 }
631
632 #[test]
633 fn node_terminate_serde_round_trip() {
634 let n = Node::Terminate(TerminatorId::new(13));
635 match node_round_trip(&n) {
636 Node::Terminate(t) => assert_eq!(t, TerminatorId::new(13)),
637 other => panic!("expected Terminate, got {other:?}"),
638 }
639 }
640
641 #[test]
645 fn flow_graph_meta_serializes_and_emits_version_hash_field() {
646 let meta = sample_meta();
647 let encoded = serde_json::to_string(&meta).expect("serialize meta");
648 assert!(encoded.contains("version_hash"), "expected version_hash field in {encoded}");
649 }
650
651 #[test]
652 fn flow_graph_meta_round_trip_preserves_all_but_feature_set() {
653 use std::time::Duration;
657 let meta = FlowGraphMeta {
658 version_hash: [0x42; 32],
659 compiled_at: SystemTime::UNIX_EPOCH + Duration::from_secs(1_000_000_000),
660 source_files: vec![PathBuf::from("/a.json"), PathBuf::from("/b.json")],
661 feature_set: &["h3", "wasm"],
662 short_circuit_response_entry: std::collections::BTreeMap::new(),
663 listener_tls: std::collections::BTreeMap::new(),
664 listener_kinds: std::collections::BTreeMap::new(),
665
666 listener_transports: std::collections::BTreeMap::new(),
667 };
668 let encoded = serde_json::to_string(&meta).expect("serialize meta");
669 assert!(
670 !encoded.contains("feature_set"),
671 "feature_set must be skipped in dry-run JSON, got: {encoded}",
672 );
673 let decoded: FlowGraphMeta = serde_json::from_str(&encoded).expect("deserialize meta");
674 assert_eq!(decoded.version_hash, meta.version_hash);
675 assert_eq!(decoded.compiled_at, meta.compiled_at);
676 assert_eq!(decoded.source_files, meta.source_files);
677 assert!(decoded.feature_set.is_empty(), "feature_set must default to empty on deserialize");
679 }
680}