1use std::collections::HashMap;
2use std::net::SocketAddr;
3use std::ops::Index;
4use std::path::PathBuf;
5use std::time::SystemTime;
6
7use crate::fetch::{SymbolicFetchRef, Terminator};
8use crate::middleware::SymbolicMiddlewareRef;
9use crate::predicate::PredicateInst;
10
11macro_rules! id_newtype {
12 ($name:ident) => {
13 #[derive(
14 Copy, Clone, Eq, PartialEq, Hash, Debug, PartialOrd, Ord, serde::Serialize, serde::Deserialize,
15 )]
16 pub struct $name(u32);
17
18 impl $name {
19 #[must_use]
20 pub const fn new(raw: u32) -> Self {
21 Self(raw)
22 }
23
24 #[must_use]
25 pub const fn get(self) -> u32 {
26 self.0
27 }
28 }
29 };
30}
31
32id_newtype!(NodeId);
33id_newtype!(PredicateId);
34id_newtype!(MiddlewareId);
35id_newtype!(FetchId);
36id_newtype!(TerminatorId);
37
38#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug, serde::Serialize, serde::Deserialize)]
39pub enum BodySide {
40 Request,
41 Response,
42}
43
44#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
45pub enum Node {
46 Check {
47 predicate: PredicateId,
48 on_match: NodeId,
49 on_miss: NodeId,
50 collect_body_before: Option<BodySide>,
51 #[serde(default)]
52 body_limit: usize,
53 },
54 Middleware {
55 id: MiddlewareId,
56 next: NodeId,
57 on_error: Option<NodeId>,
58 collect_body_before: Option<BodySide>,
59 #[serde(default)]
60 body_limit: usize,
61 },
62 Fetch {
63 id: FetchId,
64 next_response: Option<NodeId>,
65 next_tunnel: Option<NodeId>,
66 collect_body_before: Option<BodySide>,
67 #[serde(default)]
68 body_limit: usize,
69 },
70 Upgrade {
71 next: NodeId,
72 },
73 Terminate(TerminatorId),
74}
75
76impl Node {
77 #[must_use]
78 pub const fn collect_body_before(&self) -> Option<BodySide> {
79 match self {
80 Self::Check { collect_body_before, .. }
81 | Self::Middleware { collect_body_before, .. }
82 | Self::Fetch { collect_body_before, .. } => *collect_body_before,
83 Self::Upgrade { .. } | Self::Terminate(_) => None,
84 }
85 }
86
87 #[must_use]
88 pub const fn body_limit(&self) -> usize {
89 match self {
90 Self::Check { body_limit, .. }
91 | Self::Middleware { body_limit, .. }
92 | Self::Fetch { body_limit, .. } => *body_limit,
93 _ => 0,
94 }
95 }
96}
97
98#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
99pub struct FlowGraphMeta {
100 pub version_hash: [u8; 32],
101 pub compiled_at: SystemTime,
102 pub source_files: Vec<PathBuf>,
103 #[serde(skip, default = "empty_feature_set")]
107 pub feature_set: &'static [&'static str],
108
109 #[serde(default)]
124 pub short_circuit_response_entry: std::collections::BTreeMap<NodeId, NodeId>,
125
126 #[serde(default)]
138 pub listener_tls: std::collections::BTreeMap<SocketAddr, crate::rule::ListenerTlsSpec>,
139}
140
141const fn empty_feature_set() -> &'static [&'static str] {
142 &[]
143}
144
145#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
146pub struct SymbolicFlowGraph {
147 pub nodes: Vec<Node>,
148 pub predicates: Vec<PredicateInst>,
149 pub middlewares: Vec<SymbolicMiddlewareRef>,
150 pub fetches: Vec<SymbolicFetchRef>,
151 pub terminators: Vec<Terminator>,
152 pub entries: HashMap<SocketAddr, NodeId>,
153 pub meta: FlowGraphMeta,
154}
155
156impl Index<NodeId> for SymbolicFlowGraph {
157 type Output = Node;
158 fn index(&self, id: NodeId) -> &Node {
159 &self.nodes[id.get() as usize]
160 }
161}
162
163impl Index<PredicateId> for SymbolicFlowGraph {
164 type Output = PredicateInst;
165 fn index(&self, id: PredicateId) -> &PredicateInst {
166 &self.predicates[id.get() as usize]
167 }
168}
169
170impl Index<MiddlewareId> for SymbolicFlowGraph {
171 type Output = SymbolicMiddlewareRef;
172 fn index(&self, id: MiddlewareId) -> &SymbolicMiddlewareRef {
173 &self.middlewares[id.get() as usize]
174 }
175}
176
177impl Index<FetchId> for SymbolicFlowGraph {
178 type Output = SymbolicFetchRef;
179 fn index(&self, id: FetchId) -> &SymbolicFetchRef {
180 &self.fetches[id.get() as usize]
181 }
182}
183
184impl Index<TerminatorId> for SymbolicFlowGraph {
185 type Output = Terminator;
186 fn index(&self, id: TerminatorId) -> &Terminator {
187 &self.terminators[id.get() as usize]
188 }
189}
190
191#[cfg(test)]
192mod tests {
193 use std::collections::hash_map::DefaultHasher;
194 use std::hash::{Hash, Hasher};
195 use std::sync::Arc;
196
197 use serde_json::Value;
198
199 use super::*;
200 use crate::fetch::{FetchKind, SymbolicFetchRef, Terminator};
201 use crate::middleware::{MiddlewareKind, SymbolicMiddlewareRef};
202 use crate::predicate::{CompiledOperator, CompiledValue, FieldPath, PredicateInst};
203
204 #[test]
205 fn new_then_get_round_trips_raw_u32() {
206 for raw in [0_u32, 1, 42, u32::MAX] {
207 assert_eq!(NodeId::new(raw).get(), raw);
208 }
209 }
210
211 #[test]
212 fn node_id_equality_is_structural() {
213 assert_eq!(NodeId::new(7), NodeId::new(7));
214 assert_ne!(NodeId::new(7), NodeId::new(8));
215 }
216
217 #[test]
218 fn node_id_ordering_follows_raw_u32() {
219 assert!(NodeId::new(1) < NodeId::new(2));
220 assert!(NodeId::new(u32::MAX) > NodeId::new(0));
221 }
222
223 #[test]
224 fn node_id_serde_round_trip() {
225 let id = NodeId::new(0x0bad_f00d);
226 let encoded = serde_json::to_string(&id).expect("serialize");
227 let decoded: NodeId = serde_json::from_str(&encoded).expect("deserialize");
228 assert_eq!(decoded, id);
229 }
230
231 #[test]
232 fn body_side_serde_round_trip_per_variant() {
233 for s in [BodySide::Request, BodySide::Response] {
234 let encoded = serde_json::to_string(&s).expect("serialize");
235 let decoded: BodySide = serde_json::from_str(&encoded).expect("deserialize");
236 assert_eq!(decoded, s);
237 }
238 }
239
240 fn hash_of<T: Hash>(t: &T) -> u64 {
241 let mut h = DefaultHasher::new();
242 t.hash(&mut h);
243 h.finish()
244 }
245
246 #[test]
247 fn predicate_id_new_get_round_trip_and_hash_eq() {
248 for raw in [0_u32, 1, 42, u32::MAX] {
249 let a = PredicateId::new(raw);
250 let b = PredicateId::new(raw);
251 assert_eq!(a.get(), raw);
252 assert_eq!(a, b);
253 assert_eq!(hash_of(&a), hash_of(&b));
254 let encoded = serde_json::to_string(&a).expect("serialize");
255 let decoded: PredicateId = serde_json::from_str(&encoded).expect("deserialize");
256 assert_eq!(decoded, a);
257 }
258 }
259
260 #[test]
261 fn middleware_id_new_get_round_trip_and_hash_eq() {
262 for raw in [0_u32, 1, 42, u32::MAX] {
263 let a = MiddlewareId::new(raw);
264 let b = MiddlewareId::new(raw);
265 assert_eq!(a.get(), raw);
266 assert_eq!(a, b);
267 assert_eq!(hash_of(&a), hash_of(&b));
268 let encoded = serde_json::to_string(&a).expect("serialize");
269 let decoded: MiddlewareId = serde_json::from_str(&encoded).expect("deserialize");
270 assert_eq!(decoded, a);
271 }
272 }
273
274 #[test]
275 fn fetch_id_new_get_round_trip_and_hash_eq() {
276 for raw in [0_u32, 1, 42, u32::MAX] {
277 let a = FetchId::new(raw);
278 let b = FetchId::new(raw);
279 assert_eq!(a.get(), raw);
280 assert_eq!(a, b);
281 assert_eq!(hash_of(&a), hash_of(&b));
282 let encoded = serde_json::to_string(&a).expect("serialize");
283 let decoded: FetchId = serde_json::from_str(&encoded).expect("deserialize");
284 assert_eq!(decoded, a);
285 }
286 }
287
288 #[test]
289 fn terminator_id_new_get_round_trip_and_hash_eq() {
290 for raw in [0_u32, 1, 42, u32::MAX] {
291 let a = TerminatorId::new(raw);
292 let b = TerminatorId::new(raw);
293 assert_eq!(a.get(), raw);
294 assert_eq!(a, b);
295 assert_eq!(hash_of(&a), hash_of(&b));
296 let encoded = serde_json::to_string(&a).expect("serialize");
297 let decoded: TerminatorId = serde_json::from_str(&encoded).expect("deserialize");
298 assert_eq!(decoded, a);
299 }
300 }
301
302 fn _id_types_are_distinct(
307 _n: NodeId,
308 _p: PredicateId,
309 _m: MiddlewareId,
310 _f: FetchId,
311 _t: TerminatorId,
312 ) {
313 }
314
315 #[test]
316 fn node_check_collect_body_before_returns_stored_flag() {
317 let some = Node::Check {
318 predicate: PredicateId::new(0),
319 on_match: NodeId::new(0),
320 on_miss: NodeId::new(0),
321 collect_body_before: Some(BodySide::Request),
322 body_limit: 0,
323 };
324 assert_eq!(some.collect_body_before(), Some(BodySide::Request));
325
326 let none = Node::Check {
327 predicate: PredicateId::new(0),
328 on_match: NodeId::new(0),
329 on_miss: NodeId::new(0),
330 collect_body_before: None,
331 body_limit: 0,
332 };
333 assert_eq!(none.collect_body_before(), None);
334 }
335
336 #[test]
337 fn node_middleware_collect_body_before_returns_stored_flag() {
338 let some = Node::Middleware {
339 id: MiddlewareId::new(0),
340 next: NodeId::new(0),
341 on_error: None,
342 collect_body_before: Some(BodySide::Response),
343 body_limit: 0,
344 };
345 assert_eq!(some.collect_body_before(), Some(BodySide::Response));
346
347 let none = Node::Middleware {
348 id: MiddlewareId::new(0),
349 next: NodeId::new(0),
350 on_error: None,
351 collect_body_before: None,
352 body_limit: 0,
353 };
354 assert_eq!(none.collect_body_before(), None);
355 }
356
357 #[test]
358 fn node_fetch_collect_body_before_returns_stored_flag() {
359 let some = Node::Fetch {
360 id: FetchId::new(0),
361 next_response: None,
362 next_tunnel: None,
363 collect_body_before: Some(BodySide::Request),
364 body_limit: 0,
365 };
366 assert_eq!(some.collect_body_before(), Some(BodySide::Request));
367
368 let none = Node::Fetch {
369 id: FetchId::new(0),
370 next_response: None,
371 next_tunnel: None,
372 collect_body_before: None,
373 body_limit: 0,
374 };
375 assert_eq!(none.collect_body_before(), None);
376 }
377
378 #[test]
379 fn node_upgrade_collect_body_before_is_always_none() {
380 let n = Node::Upgrade { next: NodeId::new(0) };
381 assert_eq!(n.collect_body_before(), None);
382 }
383
384 #[test]
385 fn node_terminate_collect_body_before_is_always_none() {
386 let n = Node::Terminate(TerminatorId::new(0));
387 assert_eq!(n.collect_body_before(), None);
388 }
389
390 fn sample_predicate() -> PredicateInst {
391 PredicateInst {
392 path: FieldPath::TlsSni,
393 op: CompiledOperator::Equals(CompiledValue::Str(Arc::from("a"))),
394 }
395 }
396
397 fn sample_middleware() -> SymbolicMiddlewareRef {
398 SymbolicMiddlewareRef {
399 name: Arc::from("noop"),
400 args: Value::Null,
401 kind: MiddlewareKind::L7Request,
402 stateless: true,
403 needs_body: false,
404 on_error: None,
405 }
406 }
407
408 fn sample_fetch() -> SymbolicFetchRef {
409 SymbolicFetchRef { kind: FetchKind::HttpProxy, args: Value::Null }
410 }
411
412 fn sample_meta() -> FlowGraphMeta {
413 FlowGraphMeta {
414 version_hash: [0; 32],
415 compiled_at: SystemTime::UNIX_EPOCH,
416 source_files: vec![],
417 feature_set: &[],
418 short_circuit_response_entry: std::collections::BTreeMap::new(),
419 listener_tls: std::collections::BTreeMap::new(),
420 }
421 }
422
423 fn one_of_each_graph() -> SymbolicFlowGraph {
424 SymbolicFlowGraph {
425 nodes: vec![Node::Terminate(TerminatorId::new(0))],
426 predicates: vec![sample_predicate()],
427 middlewares: vec![sample_middleware()],
428 fetches: vec![sample_fetch()],
429 terminators: vec![Terminator::WriteHttpResponse],
430 entries: HashMap::new(),
431 meta: sample_meta(),
432 }
433 }
434
435 #[test]
436 fn index_by_node_id_returns_matching_node() {
437 let g = one_of_each_graph();
438 match &g[NodeId::new(0)] {
439 Node::Terminate(t) => assert_eq!(*t, TerminatorId::new(0)),
440 other => panic!("expected Terminate, got {other:?}"),
441 }
442 }
443
444 #[test]
445 fn index_by_predicate_id_returns_matching_predicate() {
446 let g = one_of_each_graph();
447 assert_eq!(g[PredicateId::new(0)], sample_predicate());
448 }
449
450 #[test]
451 fn index_by_middleware_id_returns_matching_middleware() {
452 let g = one_of_each_graph();
453 assert_eq!(g[MiddlewareId::new(0)], sample_middleware());
454 }
455
456 #[test]
457 fn index_by_fetch_id_returns_matching_fetch() {
458 let g = one_of_each_graph();
459 assert_eq!(g[FetchId::new(0)].kind, FetchKind::HttpProxy);
460 }
461
462 #[test]
463 fn index_by_terminator_id_returns_matching_terminator() {
464 let g = one_of_each_graph();
465 assert_eq!(g[TerminatorId::new(0)], Terminator::WriteHttpResponse);
466 }
467
468 fn node_round_trip(n: &Node) -> Node {
469 let encoded = serde_json::to_string(n).expect("serialize node");
470 serde_json::from_str(&encoded).expect("deserialize node")
471 }
472
473 #[test]
474 fn node_check_serde_round_trip_with_and_without_collect_flag() {
475 let with = Node::Check {
476 predicate: PredicateId::new(3),
477 on_match: NodeId::new(4),
478 on_miss: NodeId::new(5),
479 collect_body_before: Some(BodySide::Request),
480 body_limit: 0,
481 };
482 match node_round_trip(&with) {
483 Node::Check { predicate, on_match, on_miss, collect_body_before, .. } => {
484 assert_eq!(predicate, PredicateId::new(3));
485 assert_eq!(on_match, NodeId::new(4));
486 assert_eq!(on_miss, NodeId::new(5));
487 assert_eq!(collect_body_before, Some(BodySide::Request));
488 }
489 other => panic!("expected Check, got {other:?}"),
490 }
491
492 let without = Node::Check {
493 predicate: PredicateId::new(0),
494 on_match: NodeId::new(0),
495 on_miss: NodeId::new(0),
496 collect_body_before: None,
497 body_limit: 0,
498 };
499 match node_round_trip(&without) {
500 Node::Check { collect_body_before, .. } => assert_eq!(collect_body_before, None),
501 other => panic!("expected Check, got {other:?}"),
502 }
503 }
504
505 #[test]
506 fn node_middleware_serde_round_trip_with_and_without_collect_flag() {
507 let with = Node::Middleware {
508 id: MiddlewareId::new(1),
509 next: NodeId::new(2),
510 on_error: Some(NodeId::new(9)),
511 collect_body_before: Some(BodySide::Response),
512 body_limit: 0,
513 };
514 match node_round_trip(&with) {
515 Node::Middleware { id, next, on_error, collect_body_before, .. } => {
516 assert_eq!(id, MiddlewareId::new(1));
517 assert_eq!(next, NodeId::new(2));
518 assert_eq!(on_error, Some(NodeId::new(9)));
519 assert_eq!(collect_body_before, Some(BodySide::Response));
520 }
521 other => panic!("expected Middleware, got {other:?}"),
522 }
523
524 let without = Node::Middleware {
525 id: MiddlewareId::new(0),
526 next: NodeId::new(0),
527 on_error: None,
528 collect_body_before: None,
529 body_limit: 0,
530 };
531 match node_round_trip(&without) {
532 Node::Middleware { on_error, collect_body_before, .. } => {
533 assert_eq!(on_error, None);
534 assert_eq!(collect_body_before, None);
535 }
536 other => panic!("expected Middleware, got {other:?}"),
537 }
538 }
539
540 #[test]
541 fn node_fetch_serde_round_trip_with_and_without_collect_flag() {
542 let with = Node::Fetch {
543 id: FetchId::new(7),
544 next_response: Some(NodeId::new(8)),
545 next_tunnel: Some(NodeId::new(9)),
546 collect_body_before: Some(BodySide::Request),
547 body_limit: 0,
548 };
549 match node_round_trip(&with) {
550 Node::Fetch { id, next_response, next_tunnel, collect_body_before, .. } => {
551 assert_eq!(id, FetchId::new(7));
552 assert_eq!(next_response, Some(NodeId::new(8)));
553 assert_eq!(next_tunnel, Some(NodeId::new(9)));
554 assert_eq!(collect_body_before, Some(BodySide::Request));
555 }
556 other => panic!("expected Fetch, got {other:?}"),
557 }
558
559 let without = Node::Fetch {
560 id: FetchId::new(0),
561 next_response: None,
562 next_tunnel: None,
563 collect_body_before: None,
564 body_limit: 0,
565 };
566 match node_round_trip(&without) {
567 Node::Fetch { next_response, next_tunnel, collect_body_before, .. } => {
568 assert_eq!(next_response, None);
569 assert_eq!(next_tunnel, None);
570 assert_eq!(collect_body_before, None);
571 }
572 other => panic!("expected Fetch, got {other:?}"),
573 }
574 }
575
576 #[test]
577 fn node_upgrade_serde_round_trip() {
578 let n = Node::Upgrade { next: NodeId::new(11) };
579 match node_round_trip(&n) {
580 Node::Upgrade { next } => assert_eq!(next, NodeId::new(11)),
581 other => panic!("expected Upgrade, got {other:?}"),
582 }
583 }
584
585 #[test]
586 fn node_terminate_serde_round_trip() {
587 let n = Node::Terminate(TerminatorId::new(13));
588 match node_round_trip(&n) {
589 Node::Terminate(t) => assert_eq!(t, TerminatorId::new(13)),
590 other => panic!("expected Terminate, got {other:?}"),
591 }
592 }
593
594 #[test]
598 fn flow_graph_meta_serializes_and_emits_version_hash_field() {
599 let meta = sample_meta();
600 let encoded = serde_json::to_string(&meta).expect("serialize meta");
601 assert!(encoded.contains("version_hash"), "expected version_hash field in {encoded}");
602 }
603
604 #[test]
605 fn flow_graph_meta_round_trip_preserves_all_but_feature_set() {
606 use std::time::Duration;
610 let meta = FlowGraphMeta {
611 version_hash: [0x42; 32],
612 compiled_at: SystemTime::UNIX_EPOCH + Duration::from_secs(1_000_000_000),
613 source_files: vec![PathBuf::from("/a.json"), PathBuf::from("/b.json")],
614 feature_set: &["h3", "wasm"],
615 short_circuit_response_entry: std::collections::BTreeMap::new(),
616 listener_tls: std::collections::BTreeMap::new(),
617 };
618 let encoded = serde_json::to_string(&meta).expect("serialize meta");
619 assert!(
620 !encoded.contains("feature_set"),
621 "feature_set must be skipped in dry-run JSON, got: {encoded}",
622 );
623 let decoded: FlowGraphMeta = serde_json::from_str(&encoded).expect("deserialize meta");
624 assert_eq!(decoded.version_hash, meta.version_hash);
625 assert_eq!(decoded.compiled_at, meta.compiled_at);
626 assert_eq!(decoded.source_files, meta.source_files);
627 assert!(decoded.feature_set.is_empty(), "feature_set must default to empty on deserialize");
629 }
630}