1use std::collections::HashMap;
2use std::net::SocketAddr;
3use std::ops::Index;
4use std::path::PathBuf;
5use std::time::SystemTime;
6
7use crate::fetch::{SymbolicFetchRef, Terminator};
8use crate::middleware::SymbolicMiddlewareRef;
9use crate::predicate::PredicateInst;
10
11macro_rules! id_newtype {
12 ($name:ident) => {
13 #[derive(
14 Copy, Clone, Eq, PartialEq, Hash, Debug, PartialOrd, Ord, serde::Serialize, serde::Deserialize,
15 )]
16 pub struct $name(u32);
17
18 impl $name {
19 #[must_use]
20 pub const fn new(raw: u32) -> Self {
21 Self(raw)
22 }
23
24 #[must_use]
25 pub const fn get(self) -> u32 {
26 self.0
27 }
28 }
29 };
30}
31
32id_newtype!(NodeId);
33id_newtype!(PredicateId);
34id_newtype!(MiddlewareId);
35id_newtype!(FetchId);
36id_newtype!(TerminatorId);
37
38#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug, serde::Serialize, serde::Deserialize)]
39pub enum BodySide {
40 Request,
41 Response,
42}
43
44#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
45pub enum Node {
46 Check {
47 predicate: PredicateId,
48 on_match: NodeId,
49 on_miss: NodeId,
50 collect_body_before: Option<BodySide>,
51 },
52 Middleware {
53 id: MiddlewareId,
54 next: NodeId,
55 on_error: Option<NodeId>,
56 collect_body_before: Option<BodySide>,
57 },
58 Fetch {
59 id: FetchId,
60 next_response: Option<NodeId>,
61 next_tunnel: Option<NodeId>,
62 collect_body_before: Option<BodySide>,
63 },
64 Upgrade {
65 next: NodeId,
66 },
67 Terminate(TerminatorId),
68}
69
70impl Node {
71 #[must_use]
72 pub const fn collect_body_before(&self) -> Option<BodySide> {
73 match self {
74 Self::Check { collect_body_before, .. }
75 | Self::Middleware { collect_body_before, .. }
76 | Self::Fetch { collect_body_before, .. } => *collect_body_before,
77 Self::Upgrade { .. } | Self::Terminate(_) => None,
78 }
79 }
80}
81
82#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
83pub struct FlowGraphMeta {
84 pub version_hash: [u8; 32],
85 pub compiled_at: SystemTime,
86 pub source_files: Vec<PathBuf>,
87 #[serde(skip, default = "empty_feature_set")]
91 pub feature_set: &'static [&'static str],
92}
93
94const fn empty_feature_set() -> &'static [&'static str] {
95 &[]
96}
97
98#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
99pub struct SymbolicFlowGraph {
100 pub nodes: Vec<Node>,
101 pub predicates: Vec<PredicateInst>,
102 pub middlewares: Vec<SymbolicMiddlewareRef>,
103 pub fetches: Vec<SymbolicFetchRef>,
104 pub terminators: Vec<Terminator>,
105 pub entries: HashMap<SocketAddr, NodeId>,
106 pub meta: FlowGraphMeta,
107}
108
109impl Index<NodeId> for SymbolicFlowGraph {
110 type Output = Node;
111 fn index(&self, id: NodeId) -> &Node {
112 &self.nodes[id.get() as usize]
113 }
114}
115
116impl Index<PredicateId> for SymbolicFlowGraph {
117 type Output = PredicateInst;
118 fn index(&self, id: PredicateId) -> &PredicateInst {
119 &self.predicates[id.get() as usize]
120 }
121}
122
123impl Index<MiddlewareId> for SymbolicFlowGraph {
124 type Output = SymbolicMiddlewareRef;
125 fn index(&self, id: MiddlewareId) -> &SymbolicMiddlewareRef {
126 &self.middlewares[id.get() as usize]
127 }
128}
129
130impl Index<FetchId> for SymbolicFlowGraph {
131 type Output = SymbolicFetchRef;
132 fn index(&self, id: FetchId) -> &SymbolicFetchRef {
133 &self.fetches[id.get() as usize]
134 }
135}
136
137impl Index<TerminatorId> for SymbolicFlowGraph {
138 type Output = Terminator;
139 fn index(&self, id: TerminatorId) -> &Terminator {
140 &self.terminators[id.get() as usize]
141 }
142}
143
144#[cfg(test)]
145mod tests {
146 use std::collections::hash_map::DefaultHasher;
147 use std::hash::{Hash, Hasher};
148 use std::sync::Arc;
149
150 use serde_json::Value;
151
152 use super::*;
153 use crate::fetch::{FetchKind, SymbolicFetchRef, Terminator};
154 use crate::middleware::{MiddlewareKind, SymbolicMiddlewareRef};
155 use crate::predicate::{CompiledOperator, CompiledValue, FieldPath, PredicateInst};
156
157 #[test]
158 fn new_then_get_round_trips_raw_u32() {
159 for raw in [0_u32, 1, 42, u32::MAX] {
160 assert_eq!(NodeId::new(raw).get(), raw);
161 }
162 }
163
164 #[test]
165 fn node_id_equality_is_structural() {
166 assert_eq!(NodeId::new(7), NodeId::new(7));
167 assert_ne!(NodeId::new(7), NodeId::new(8));
168 }
169
170 #[test]
171 fn node_id_ordering_follows_raw_u32() {
172 assert!(NodeId::new(1) < NodeId::new(2));
173 assert!(NodeId::new(u32::MAX) > NodeId::new(0));
174 }
175
176 #[test]
177 fn node_id_serde_round_trip() {
178 let id = NodeId::new(0x0bad_f00d);
179 let encoded = serde_json::to_string(&id).expect("serialize");
180 let decoded: NodeId = serde_json::from_str(&encoded).expect("deserialize");
181 assert_eq!(decoded, id);
182 }
183
184 #[test]
185 fn body_side_serde_round_trip_per_variant() {
186 for s in [BodySide::Request, BodySide::Response] {
187 let encoded = serde_json::to_string(&s).expect("serialize");
188 let decoded: BodySide = serde_json::from_str(&encoded).expect("deserialize");
189 assert_eq!(decoded, s);
190 }
191 }
192
193 fn hash_of<T: Hash>(t: &T) -> u64 {
194 let mut h = DefaultHasher::new();
195 t.hash(&mut h);
196 h.finish()
197 }
198
199 #[test]
200 fn predicate_id_new_get_round_trip_and_hash_eq() {
201 for raw in [0_u32, 1, 42, u32::MAX] {
202 let a = PredicateId::new(raw);
203 let b = PredicateId::new(raw);
204 assert_eq!(a.get(), raw);
205 assert_eq!(a, b);
206 assert_eq!(hash_of(&a), hash_of(&b));
207 let encoded = serde_json::to_string(&a).expect("serialize");
208 let decoded: PredicateId = serde_json::from_str(&encoded).expect("deserialize");
209 assert_eq!(decoded, a);
210 }
211 }
212
213 #[test]
214 fn middleware_id_new_get_round_trip_and_hash_eq() {
215 for raw in [0_u32, 1, 42, u32::MAX] {
216 let a = MiddlewareId::new(raw);
217 let b = MiddlewareId::new(raw);
218 assert_eq!(a.get(), raw);
219 assert_eq!(a, b);
220 assert_eq!(hash_of(&a), hash_of(&b));
221 let encoded = serde_json::to_string(&a).expect("serialize");
222 let decoded: MiddlewareId = serde_json::from_str(&encoded).expect("deserialize");
223 assert_eq!(decoded, a);
224 }
225 }
226
227 #[test]
228 fn fetch_id_new_get_round_trip_and_hash_eq() {
229 for raw in [0_u32, 1, 42, u32::MAX] {
230 let a = FetchId::new(raw);
231 let b = FetchId::new(raw);
232 assert_eq!(a.get(), raw);
233 assert_eq!(a, b);
234 assert_eq!(hash_of(&a), hash_of(&b));
235 let encoded = serde_json::to_string(&a).expect("serialize");
236 let decoded: FetchId = serde_json::from_str(&encoded).expect("deserialize");
237 assert_eq!(decoded, a);
238 }
239 }
240
241 #[test]
242 fn terminator_id_new_get_round_trip_and_hash_eq() {
243 for raw in [0_u32, 1, 42, u32::MAX] {
244 let a = TerminatorId::new(raw);
245 let b = TerminatorId::new(raw);
246 assert_eq!(a.get(), raw);
247 assert_eq!(a, b);
248 assert_eq!(hash_of(&a), hash_of(&b));
249 let encoded = serde_json::to_string(&a).expect("serialize");
250 let decoded: TerminatorId = serde_json::from_str(&encoded).expect("deserialize");
251 assert_eq!(decoded, a);
252 }
253 }
254
255 fn _id_types_are_distinct(
260 _n: NodeId,
261 _p: PredicateId,
262 _m: MiddlewareId,
263 _f: FetchId,
264 _t: TerminatorId,
265 ) {
266 }
267
268 #[test]
269 fn node_check_collect_body_before_returns_stored_flag() {
270 let some = Node::Check {
271 predicate: PredicateId::new(0),
272 on_match: NodeId::new(0),
273 on_miss: NodeId::new(0),
274 collect_body_before: Some(BodySide::Request),
275 };
276 assert_eq!(some.collect_body_before(), Some(BodySide::Request));
277
278 let none = Node::Check {
279 predicate: PredicateId::new(0),
280 on_match: NodeId::new(0),
281 on_miss: NodeId::new(0),
282 collect_body_before: None,
283 };
284 assert_eq!(none.collect_body_before(), None);
285 }
286
287 #[test]
288 fn node_middleware_collect_body_before_returns_stored_flag() {
289 let some = Node::Middleware {
290 id: MiddlewareId::new(0),
291 next: NodeId::new(0),
292 on_error: None,
293 collect_body_before: Some(BodySide::Response),
294 };
295 assert_eq!(some.collect_body_before(), Some(BodySide::Response));
296
297 let none = Node::Middleware {
298 id: MiddlewareId::new(0),
299 next: NodeId::new(0),
300 on_error: None,
301 collect_body_before: None,
302 };
303 assert_eq!(none.collect_body_before(), None);
304 }
305
306 #[test]
307 fn node_fetch_collect_body_before_returns_stored_flag() {
308 let some = Node::Fetch {
309 id: FetchId::new(0),
310 next_response: None,
311 next_tunnel: None,
312 collect_body_before: Some(BodySide::Request),
313 };
314 assert_eq!(some.collect_body_before(), Some(BodySide::Request));
315
316 let none = Node::Fetch {
317 id: FetchId::new(0),
318 next_response: None,
319 next_tunnel: None,
320 collect_body_before: None,
321 };
322 assert_eq!(none.collect_body_before(), None);
323 }
324
325 #[test]
326 fn node_upgrade_collect_body_before_is_always_none() {
327 let n = Node::Upgrade { next: NodeId::new(0) };
328 assert_eq!(n.collect_body_before(), None);
329 }
330
331 #[test]
332 fn node_terminate_collect_body_before_is_always_none() {
333 let n = Node::Terminate(TerminatorId::new(0));
334 assert_eq!(n.collect_body_before(), None);
335 }
336
337 fn sample_predicate() -> PredicateInst {
338 PredicateInst {
339 path: FieldPath::TlsSni,
340 op: CompiledOperator::Equals(CompiledValue::Str(Arc::from("a"))),
341 }
342 }
343
344 fn sample_middleware() -> SymbolicMiddlewareRef {
345 SymbolicMiddlewareRef {
346 name: Arc::from("noop"),
347 args: Value::Null,
348 kind: MiddlewareKind::L7Request,
349 stateless: true,
350 needs_body: false,
351 on_error: None,
352 }
353 }
354
355 fn sample_fetch() -> SymbolicFetchRef {
356 SymbolicFetchRef { kind: FetchKind::HttpProxy, args: Value::Null }
357 }
358
359 fn sample_meta() -> FlowGraphMeta {
360 FlowGraphMeta {
361 version_hash: [0; 32],
362 compiled_at: SystemTime::UNIX_EPOCH,
363 source_files: vec![],
364 feature_set: &[],
365 }
366 }
367
368 fn one_of_each_graph() -> SymbolicFlowGraph {
369 SymbolicFlowGraph {
370 nodes: vec![Node::Terminate(TerminatorId::new(0))],
371 predicates: vec![sample_predicate()],
372 middlewares: vec![sample_middleware()],
373 fetches: vec![sample_fetch()],
374 terminators: vec![Terminator::WriteHttpResponse],
375 entries: HashMap::new(),
376 meta: sample_meta(),
377 }
378 }
379
380 #[test]
381 fn index_by_node_id_returns_matching_node() {
382 let g = one_of_each_graph();
383 match &g[NodeId::new(0)] {
384 Node::Terminate(t) => assert_eq!(*t, TerminatorId::new(0)),
385 other => panic!("expected Terminate, got {other:?}"),
386 }
387 }
388
389 #[test]
390 fn index_by_predicate_id_returns_matching_predicate() {
391 let g = one_of_each_graph();
392 assert_eq!(g[PredicateId::new(0)], sample_predicate());
393 }
394
395 #[test]
396 fn index_by_middleware_id_returns_matching_middleware() {
397 let g = one_of_each_graph();
398 assert_eq!(g[MiddlewareId::new(0)], sample_middleware());
399 }
400
401 #[test]
402 fn index_by_fetch_id_returns_matching_fetch() {
403 let g = one_of_each_graph();
404 assert_eq!(g[FetchId::new(0)].kind, FetchKind::HttpProxy);
405 }
406
407 #[test]
408 fn index_by_terminator_id_returns_matching_terminator() {
409 let g = one_of_each_graph();
410 assert_eq!(g[TerminatorId::new(0)], Terminator::WriteHttpResponse);
411 }
412
413 fn node_round_trip(n: &Node) -> Node {
414 let encoded = serde_json::to_string(n).expect("serialize node");
415 serde_json::from_str(&encoded).expect("deserialize node")
416 }
417
418 #[test]
419 fn node_check_serde_round_trip_with_and_without_collect_flag() {
420 let with = Node::Check {
421 predicate: PredicateId::new(3),
422 on_match: NodeId::new(4),
423 on_miss: NodeId::new(5),
424 collect_body_before: Some(BodySide::Request),
425 };
426 match node_round_trip(&with) {
427 Node::Check { predicate, on_match, on_miss, collect_body_before } => {
428 assert_eq!(predicate, PredicateId::new(3));
429 assert_eq!(on_match, NodeId::new(4));
430 assert_eq!(on_miss, NodeId::new(5));
431 assert_eq!(collect_body_before, Some(BodySide::Request));
432 }
433 other => panic!("expected Check, got {other:?}"),
434 }
435
436 let without = Node::Check {
437 predicate: PredicateId::new(0),
438 on_match: NodeId::new(0),
439 on_miss: NodeId::new(0),
440 collect_body_before: None,
441 };
442 match node_round_trip(&without) {
443 Node::Check { collect_body_before, .. } => assert_eq!(collect_body_before, None),
444 other => panic!("expected Check, got {other:?}"),
445 }
446 }
447
448 #[test]
449 fn node_middleware_serde_round_trip_with_and_without_collect_flag() {
450 let with = Node::Middleware {
451 id: MiddlewareId::new(1),
452 next: NodeId::new(2),
453 on_error: Some(NodeId::new(9)),
454 collect_body_before: Some(BodySide::Response),
455 };
456 match node_round_trip(&with) {
457 Node::Middleware { id, next, on_error, collect_body_before } => {
458 assert_eq!(id, MiddlewareId::new(1));
459 assert_eq!(next, NodeId::new(2));
460 assert_eq!(on_error, Some(NodeId::new(9)));
461 assert_eq!(collect_body_before, Some(BodySide::Response));
462 }
463 other => panic!("expected Middleware, got {other:?}"),
464 }
465
466 let without = Node::Middleware {
467 id: MiddlewareId::new(0),
468 next: NodeId::new(0),
469 on_error: None,
470 collect_body_before: None,
471 };
472 match node_round_trip(&without) {
473 Node::Middleware { on_error, collect_body_before, .. } => {
474 assert_eq!(on_error, None);
475 assert_eq!(collect_body_before, None);
476 }
477 other => panic!("expected Middleware, got {other:?}"),
478 }
479 }
480
481 #[test]
482 fn node_fetch_serde_round_trip_with_and_without_collect_flag() {
483 let with = Node::Fetch {
484 id: FetchId::new(7),
485 next_response: Some(NodeId::new(8)),
486 next_tunnel: Some(NodeId::new(9)),
487 collect_body_before: Some(BodySide::Request),
488 };
489 match node_round_trip(&with) {
490 Node::Fetch { id, next_response, next_tunnel, collect_body_before } => {
491 assert_eq!(id, FetchId::new(7));
492 assert_eq!(next_response, Some(NodeId::new(8)));
493 assert_eq!(next_tunnel, Some(NodeId::new(9)));
494 assert_eq!(collect_body_before, Some(BodySide::Request));
495 }
496 other => panic!("expected Fetch, got {other:?}"),
497 }
498
499 let without = Node::Fetch {
500 id: FetchId::new(0),
501 next_response: None,
502 next_tunnel: None,
503 collect_body_before: None,
504 };
505 match node_round_trip(&without) {
506 Node::Fetch { next_response, next_tunnel, collect_body_before, .. } => {
507 assert_eq!(next_response, None);
508 assert_eq!(next_tunnel, None);
509 assert_eq!(collect_body_before, None);
510 }
511 other => panic!("expected Fetch, got {other:?}"),
512 }
513 }
514
515 #[test]
516 fn node_upgrade_serde_round_trip() {
517 let n = Node::Upgrade { next: NodeId::new(11) };
518 match node_round_trip(&n) {
519 Node::Upgrade { next } => assert_eq!(next, NodeId::new(11)),
520 other => panic!("expected Upgrade, got {other:?}"),
521 }
522 }
523
524 #[test]
525 fn node_terminate_serde_round_trip() {
526 let n = Node::Terminate(TerminatorId::new(13));
527 match node_round_trip(&n) {
528 Node::Terminate(t) => assert_eq!(t, TerminatorId::new(13)),
529 other => panic!("expected Terminate, got {other:?}"),
530 }
531 }
532
533 #[test]
537 fn flow_graph_meta_serializes_and_emits_version_hash_field() {
538 let meta = sample_meta();
539 let encoded = serde_json::to_string(&meta).expect("serialize meta");
540 assert!(encoded.contains("version_hash"), "expected version_hash field in {encoded}");
541 }
542
543 #[test]
544 fn flow_graph_meta_round_trip_preserves_all_but_feature_set() {
545 use std::time::Duration;
549 let meta = FlowGraphMeta {
550 version_hash: [0x42; 32],
551 compiled_at: SystemTime::UNIX_EPOCH + Duration::from_secs(1_000_000_000),
552 source_files: vec![PathBuf::from("/a.json"), PathBuf::from("/b.json")],
553 feature_set: &["h3", "wasm"],
554 };
555 let encoded = serde_json::to_string(&meta).expect("serialize meta");
556 assert!(
557 !encoded.contains("feature_set"),
558 "feature_set must be skipped in dry-run JSON, got: {encoded}",
559 );
560 let decoded: FlowGraphMeta = serde_json::from_str(&encoded).expect("deserialize meta");
561 assert_eq!(decoded.version_hash, meta.version_hash);
562 assert_eq!(decoded.compiled_at, meta.compiled_at);
563 assert_eq!(decoded.source_files, meta.source_files);
564 assert!(decoded.feature_set.is_empty(), "feature_set must default to empty on deserialize");
566 }
567}