Skip to main content

vane_core/
ir.rs

1use std::collections::HashMap;
2use std::net::SocketAddr;
3use std::ops::Index;
4use std::path::PathBuf;
5use std::time::SystemTime;
6
7use crate::fetch::{SymbolicFetchRef, Terminator};
8use crate::middleware::SymbolicMiddlewareRef;
9use crate::predicate::PredicateInst;
10
11macro_rules! id_newtype {
12	($name:ident) => {
13		#[derive(
14			Copy, Clone, Eq, PartialEq, Hash, Debug, PartialOrd, Ord, serde::Serialize, serde::Deserialize,
15		)]
16		pub struct $name(u32);
17
18		impl $name {
19			#[must_use]
20			pub const fn new(raw: u32) -> Self {
21				Self(raw)
22			}
23
24			#[must_use]
25			pub const fn get(self) -> u32 {
26				self.0
27			}
28		}
29	};
30}
31
32id_newtype!(NodeId);
33id_newtype!(PredicateId);
34id_newtype!(MiddlewareId);
35id_newtype!(FetchId);
36id_newtype!(TerminatorId);
37
38#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug, serde::Serialize, serde::Deserialize)]
39pub enum BodySide {
40	Request,
41	Response,
42}
43
44#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
45pub enum Node {
46	Check {
47		predicate: PredicateId,
48		on_match: NodeId,
49		on_miss: NodeId,
50		collect_body_before: Option<BodySide>,
51	},
52	Middleware {
53		id: MiddlewareId,
54		next: NodeId,
55		on_error: Option<NodeId>,
56		collect_body_before: Option<BodySide>,
57	},
58	Fetch {
59		id: FetchId,
60		next_response: Option<NodeId>,
61		next_tunnel: Option<NodeId>,
62		collect_body_before: Option<BodySide>,
63	},
64	Upgrade {
65		next: NodeId,
66	},
67	Terminate(TerminatorId),
68}
69
70impl Node {
71	#[must_use]
72	pub const fn collect_body_before(&self) -> Option<BodySide> {
73		match self {
74			Self::Check { collect_body_before, .. }
75			| Self::Middleware { collect_body_before, .. }
76			| Self::Fetch { collect_body_before, .. } => *collect_body_before,
77			Self::Upgrade { .. } | Self::Terminate(_) => None,
78		}
79	}
80}
81
82#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
83pub struct FlowGraphMeta {
84	pub version_hash: [u8; 32],
85	pub compiled_at: SystemTime,
86	pub source_files: Vec<PathBuf>,
87	// `feature_set` is a compile-time slice the daemon fills in at link, not
88	// a user-authored value; dry-run JSON omits it and deserialization
89	// restores the empty slice. Engine's link step installs the real value.
90	#[serde(skip, default = "empty_feature_set")]
91	pub feature_set: &'static [&'static str],
92}
93
94const fn empty_feature_set() -> &'static [&'static str] {
95	&[]
96}
97
98#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
99pub struct SymbolicFlowGraph {
100	pub nodes: Vec<Node>,
101	pub predicates: Vec<PredicateInst>,
102	pub middlewares: Vec<SymbolicMiddlewareRef>,
103	pub fetches: Vec<SymbolicFetchRef>,
104	pub terminators: Vec<Terminator>,
105	pub entries: HashMap<SocketAddr, NodeId>,
106	pub meta: FlowGraphMeta,
107}
108
109impl Index<NodeId> for SymbolicFlowGraph {
110	type Output = Node;
111	fn index(&self, id: NodeId) -> &Node {
112		&self.nodes[id.get() as usize]
113	}
114}
115
116impl Index<PredicateId> for SymbolicFlowGraph {
117	type Output = PredicateInst;
118	fn index(&self, id: PredicateId) -> &PredicateInst {
119		&self.predicates[id.get() as usize]
120	}
121}
122
123impl Index<MiddlewareId> for SymbolicFlowGraph {
124	type Output = SymbolicMiddlewareRef;
125	fn index(&self, id: MiddlewareId) -> &SymbolicMiddlewareRef {
126		&self.middlewares[id.get() as usize]
127	}
128}
129
130impl Index<FetchId> for SymbolicFlowGraph {
131	type Output = SymbolicFetchRef;
132	fn index(&self, id: FetchId) -> &SymbolicFetchRef {
133		&self.fetches[id.get() as usize]
134	}
135}
136
137impl Index<TerminatorId> for SymbolicFlowGraph {
138	type Output = Terminator;
139	fn index(&self, id: TerminatorId) -> &Terminator {
140		&self.terminators[id.get() as usize]
141	}
142}
143
144#[cfg(test)]
145mod tests {
146	use std::collections::hash_map::DefaultHasher;
147	use std::hash::{Hash, Hasher};
148	use std::sync::Arc;
149
150	use serde_json::Value;
151
152	use super::*;
153	use crate::fetch::{FetchKind, SymbolicFetchRef, Terminator};
154	use crate::middleware::{MiddlewareKind, SymbolicMiddlewareRef};
155	use crate::predicate::{CompiledOperator, CompiledValue, FieldPath, PredicateInst};
156
157	#[test]
158	fn new_then_get_round_trips_raw_u32() {
159		for raw in [0_u32, 1, 42, u32::MAX] {
160			assert_eq!(NodeId::new(raw).get(), raw);
161		}
162	}
163
164	#[test]
165	fn node_id_equality_is_structural() {
166		assert_eq!(NodeId::new(7), NodeId::new(7));
167		assert_ne!(NodeId::new(7), NodeId::new(8));
168	}
169
170	#[test]
171	fn node_id_ordering_follows_raw_u32() {
172		assert!(NodeId::new(1) < NodeId::new(2));
173		assert!(NodeId::new(u32::MAX) > NodeId::new(0));
174	}
175
176	#[test]
177	fn node_id_serde_round_trip() {
178		let id = NodeId::new(0x0bad_f00d);
179		let encoded = serde_json::to_string(&id).expect("serialize");
180		let decoded: NodeId = serde_json::from_str(&encoded).expect("deserialize");
181		assert_eq!(decoded, id);
182	}
183
184	#[test]
185	fn body_side_serde_round_trip_per_variant() {
186		for s in [BodySide::Request, BodySide::Response] {
187			let encoded = serde_json::to_string(&s).expect("serialize");
188			let decoded: BodySide = serde_json::from_str(&encoded).expect("deserialize");
189			assert_eq!(decoded, s);
190		}
191	}
192
193	fn hash_of<T: Hash>(t: &T) -> u64 {
194		let mut h = DefaultHasher::new();
195		t.hash(&mut h);
196		h.finish()
197	}
198
199	#[test]
200	fn predicate_id_new_get_round_trip_and_hash_eq() {
201		for raw in [0_u32, 1, 42, u32::MAX] {
202			let a = PredicateId::new(raw);
203			let b = PredicateId::new(raw);
204			assert_eq!(a.get(), raw);
205			assert_eq!(a, b);
206			assert_eq!(hash_of(&a), hash_of(&b));
207			let encoded = serde_json::to_string(&a).expect("serialize");
208			let decoded: PredicateId = serde_json::from_str(&encoded).expect("deserialize");
209			assert_eq!(decoded, a);
210		}
211	}
212
213	#[test]
214	fn middleware_id_new_get_round_trip_and_hash_eq() {
215		for raw in [0_u32, 1, 42, u32::MAX] {
216			let a = MiddlewareId::new(raw);
217			let b = MiddlewareId::new(raw);
218			assert_eq!(a.get(), raw);
219			assert_eq!(a, b);
220			assert_eq!(hash_of(&a), hash_of(&b));
221			let encoded = serde_json::to_string(&a).expect("serialize");
222			let decoded: MiddlewareId = serde_json::from_str(&encoded).expect("deserialize");
223			assert_eq!(decoded, a);
224		}
225	}
226
227	#[test]
228	fn fetch_id_new_get_round_trip_and_hash_eq() {
229		for raw in [0_u32, 1, 42, u32::MAX] {
230			let a = FetchId::new(raw);
231			let b = FetchId::new(raw);
232			assert_eq!(a.get(), raw);
233			assert_eq!(a, b);
234			assert_eq!(hash_of(&a), hash_of(&b));
235			let encoded = serde_json::to_string(&a).expect("serialize");
236			let decoded: FetchId = serde_json::from_str(&encoded).expect("deserialize");
237			assert_eq!(decoded, a);
238		}
239	}
240
241	#[test]
242	fn terminator_id_new_get_round_trip_and_hash_eq() {
243		for raw in [0_u32, 1, 42, u32::MAX] {
244			let a = TerminatorId::new(raw);
245			let b = TerminatorId::new(raw);
246			assert_eq!(a.get(), raw);
247			assert_eq!(a, b);
248			assert_eq!(hash_of(&a), hash_of(&b));
249			let encoded = serde_json::to_string(&a).expect("serialize");
250			let decoded: TerminatorId = serde_json::from_str(&encoded).expect("deserialize");
251			assert_eq!(decoded, a);
252		}
253	}
254
255	// The newtype wrappers are distinct types — a function accepting `NodeId`
256	// refuses a `PredicateId` at compile time. `_id_types_are_distinct` is a
257	// compile-only witness that the signatures pin the right types; any mix-up
258	// at a call site would fail to type-check.
259	fn _id_types_are_distinct(
260		_n: NodeId,
261		_p: PredicateId,
262		_m: MiddlewareId,
263		_f: FetchId,
264		_t: TerminatorId,
265	) {
266	}
267
268	#[test]
269	fn node_check_collect_body_before_returns_stored_flag() {
270		let some = Node::Check {
271			predicate: PredicateId::new(0),
272			on_match: NodeId::new(0),
273			on_miss: NodeId::new(0),
274			collect_body_before: Some(BodySide::Request),
275		};
276		assert_eq!(some.collect_body_before(), Some(BodySide::Request));
277
278		let none = Node::Check {
279			predicate: PredicateId::new(0),
280			on_match: NodeId::new(0),
281			on_miss: NodeId::new(0),
282			collect_body_before: None,
283		};
284		assert_eq!(none.collect_body_before(), None);
285	}
286
287	#[test]
288	fn node_middleware_collect_body_before_returns_stored_flag() {
289		let some = Node::Middleware {
290			id: MiddlewareId::new(0),
291			next: NodeId::new(0),
292			on_error: None,
293			collect_body_before: Some(BodySide::Response),
294		};
295		assert_eq!(some.collect_body_before(), Some(BodySide::Response));
296
297		let none = Node::Middleware {
298			id: MiddlewareId::new(0),
299			next: NodeId::new(0),
300			on_error: None,
301			collect_body_before: None,
302		};
303		assert_eq!(none.collect_body_before(), None);
304	}
305
306	#[test]
307	fn node_fetch_collect_body_before_returns_stored_flag() {
308		let some = Node::Fetch {
309			id: FetchId::new(0),
310			next_response: None,
311			next_tunnel: None,
312			collect_body_before: Some(BodySide::Request),
313		};
314		assert_eq!(some.collect_body_before(), Some(BodySide::Request));
315
316		let none = Node::Fetch {
317			id: FetchId::new(0),
318			next_response: None,
319			next_tunnel: None,
320			collect_body_before: None,
321		};
322		assert_eq!(none.collect_body_before(), None);
323	}
324
325	#[test]
326	fn node_upgrade_collect_body_before_is_always_none() {
327		let n = Node::Upgrade { next: NodeId::new(0) };
328		assert_eq!(n.collect_body_before(), None);
329	}
330
331	#[test]
332	fn node_terminate_collect_body_before_is_always_none() {
333		let n = Node::Terminate(TerminatorId::new(0));
334		assert_eq!(n.collect_body_before(), None);
335	}
336
337	fn sample_predicate() -> PredicateInst {
338		PredicateInst {
339			path: FieldPath::TlsSni,
340			op: CompiledOperator::Equals(CompiledValue::Str(Arc::from("a"))),
341		}
342	}
343
344	fn sample_middleware() -> SymbolicMiddlewareRef {
345		SymbolicMiddlewareRef {
346			name: Arc::from("noop"),
347			args: Value::Null,
348			kind: MiddlewareKind::L7Request,
349			stateless: true,
350			needs_body: false,
351			on_error: None,
352		}
353	}
354
355	fn sample_fetch() -> SymbolicFetchRef {
356		SymbolicFetchRef { kind: FetchKind::HttpProxy, args: Value::Null }
357	}
358
359	fn sample_meta() -> FlowGraphMeta {
360		FlowGraphMeta {
361			version_hash: [0; 32],
362			compiled_at: SystemTime::UNIX_EPOCH,
363			source_files: vec![],
364			feature_set: &[],
365		}
366	}
367
368	fn one_of_each_graph() -> SymbolicFlowGraph {
369		SymbolicFlowGraph {
370			nodes: vec![Node::Terminate(TerminatorId::new(0))],
371			predicates: vec![sample_predicate()],
372			middlewares: vec![sample_middleware()],
373			fetches: vec![sample_fetch()],
374			terminators: vec![Terminator::WriteHttpResponse],
375			entries: HashMap::new(),
376			meta: sample_meta(),
377		}
378	}
379
380	#[test]
381	fn index_by_node_id_returns_matching_node() {
382		let g = one_of_each_graph();
383		match &g[NodeId::new(0)] {
384			Node::Terminate(t) => assert_eq!(*t, TerminatorId::new(0)),
385			other => panic!("expected Terminate, got {other:?}"),
386		}
387	}
388
389	#[test]
390	fn index_by_predicate_id_returns_matching_predicate() {
391		let g = one_of_each_graph();
392		assert_eq!(g[PredicateId::new(0)], sample_predicate());
393	}
394
395	#[test]
396	fn index_by_middleware_id_returns_matching_middleware() {
397		let g = one_of_each_graph();
398		assert_eq!(g[MiddlewareId::new(0)], sample_middleware());
399	}
400
401	#[test]
402	fn index_by_fetch_id_returns_matching_fetch() {
403		let g = one_of_each_graph();
404		assert_eq!(g[FetchId::new(0)].kind, FetchKind::HttpProxy);
405	}
406
407	#[test]
408	fn index_by_terminator_id_returns_matching_terminator() {
409		let g = one_of_each_graph();
410		assert_eq!(g[TerminatorId::new(0)], Terminator::WriteHttpResponse);
411	}
412
413	fn node_round_trip(n: &Node) -> Node {
414		let encoded = serde_json::to_string(n).expect("serialize node");
415		serde_json::from_str(&encoded).expect("deserialize node")
416	}
417
418	#[test]
419	fn node_check_serde_round_trip_with_and_without_collect_flag() {
420		let with = Node::Check {
421			predicate: PredicateId::new(3),
422			on_match: NodeId::new(4),
423			on_miss: NodeId::new(5),
424			collect_body_before: Some(BodySide::Request),
425		};
426		match node_round_trip(&with) {
427			Node::Check { predicate, on_match, on_miss, collect_body_before } => {
428				assert_eq!(predicate, PredicateId::new(3));
429				assert_eq!(on_match, NodeId::new(4));
430				assert_eq!(on_miss, NodeId::new(5));
431				assert_eq!(collect_body_before, Some(BodySide::Request));
432			}
433			other => panic!("expected Check, got {other:?}"),
434		}
435
436		let without = Node::Check {
437			predicate: PredicateId::new(0),
438			on_match: NodeId::new(0),
439			on_miss: NodeId::new(0),
440			collect_body_before: None,
441		};
442		match node_round_trip(&without) {
443			Node::Check { collect_body_before, .. } => assert_eq!(collect_body_before, None),
444			other => panic!("expected Check, got {other:?}"),
445		}
446	}
447
448	#[test]
449	fn node_middleware_serde_round_trip_with_and_without_collect_flag() {
450		let with = Node::Middleware {
451			id: MiddlewareId::new(1),
452			next: NodeId::new(2),
453			on_error: Some(NodeId::new(9)),
454			collect_body_before: Some(BodySide::Response),
455		};
456		match node_round_trip(&with) {
457			Node::Middleware { id, next, on_error, collect_body_before } => {
458				assert_eq!(id, MiddlewareId::new(1));
459				assert_eq!(next, NodeId::new(2));
460				assert_eq!(on_error, Some(NodeId::new(9)));
461				assert_eq!(collect_body_before, Some(BodySide::Response));
462			}
463			other => panic!("expected Middleware, got {other:?}"),
464		}
465
466		let without = Node::Middleware {
467			id: MiddlewareId::new(0),
468			next: NodeId::new(0),
469			on_error: None,
470			collect_body_before: None,
471		};
472		match node_round_trip(&without) {
473			Node::Middleware { on_error, collect_body_before, .. } => {
474				assert_eq!(on_error, None);
475				assert_eq!(collect_body_before, None);
476			}
477			other => panic!("expected Middleware, got {other:?}"),
478		}
479	}
480
481	#[test]
482	fn node_fetch_serde_round_trip_with_and_without_collect_flag() {
483		let with = Node::Fetch {
484			id: FetchId::new(7),
485			next_response: Some(NodeId::new(8)),
486			next_tunnel: Some(NodeId::new(9)),
487			collect_body_before: Some(BodySide::Request),
488		};
489		match node_round_trip(&with) {
490			Node::Fetch { id, next_response, next_tunnel, collect_body_before } => {
491				assert_eq!(id, FetchId::new(7));
492				assert_eq!(next_response, Some(NodeId::new(8)));
493				assert_eq!(next_tunnel, Some(NodeId::new(9)));
494				assert_eq!(collect_body_before, Some(BodySide::Request));
495			}
496			other => panic!("expected Fetch, got {other:?}"),
497		}
498
499		let without = Node::Fetch {
500			id: FetchId::new(0),
501			next_response: None,
502			next_tunnel: None,
503			collect_body_before: None,
504		};
505		match node_round_trip(&without) {
506			Node::Fetch { next_response, next_tunnel, collect_body_before, .. } => {
507				assert_eq!(next_response, None);
508				assert_eq!(next_tunnel, None);
509				assert_eq!(collect_body_before, None);
510			}
511			other => panic!("expected Fetch, got {other:?}"),
512		}
513	}
514
515	#[test]
516	fn node_upgrade_serde_round_trip() {
517		let n = Node::Upgrade { next: NodeId::new(11) };
518		match node_round_trip(&n) {
519			Node::Upgrade { next } => assert_eq!(next, NodeId::new(11)),
520			other => panic!("expected Upgrade, got {other:?}"),
521		}
522	}
523
524	#[test]
525	fn node_terminate_serde_round_trip() {
526		let n = Node::Terminate(TerminatorId::new(13));
527		match node_round_trip(&n) {
528			Node::Terminate(t) => assert_eq!(t, TerminatorId::new(13)),
529			other => panic!("expected Terminate, got {other:?}"),
530		}
531	}
532
533	// `FlowGraphMeta` derives `Serialize` but not `Deserialize` (the spec
534	// comment in this module notes `Deserialize` lands with S1-32). Assert the
535	// forward direction only.
536	#[test]
537	fn flow_graph_meta_serializes_and_emits_version_hash_field() {
538		let meta = sample_meta();
539		let encoded = serde_json::to_string(&meta).expect("serialize meta");
540		assert!(encoded.contains("version_hash"), "expected version_hash field in {encoded}");
541	}
542
543	#[test]
544	fn flow_graph_meta_round_trip_preserves_all_but_feature_set() {
545		// 02-flow.md § _FlowGraph metadata_: feature_set is a compile-time
546		// slice the daemon fills in at link and is NOT emitted to dry-run JSON.
547		// version_hash / compiled_at / source_files must round-trip.
548		use std::time::Duration;
549		let meta = FlowGraphMeta {
550			version_hash: [0x42; 32],
551			compiled_at: SystemTime::UNIX_EPOCH + Duration::from_secs(1_000_000_000),
552			source_files: vec![PathBuf::from("/a.json"), PathBuf::from("/b.json")],
553			feature_set: &["h3", "wasm"],
554		};
555		let encoded = serde_json::to_string(&meta).expect("serialize meta");
556		assert!(
557			!encoded.contains("feature_set"),
558			"feature_set must be skipped in dry-run JSON, got: {encoded}",
559		);
560		let decoded: FlowGraphMeta = serde_json::from_str(&encoded).expect("deserialize meta");
561		assert_eq!(decoded.version_hash, meta.version_hash);
562		assert_eq!(decoded.compiled_at, meta.compiled_at);
563		assert_eq!(decoded.source_files, meta.source_files);
564		// feature_set is restored to the empty slice by #[serde(skip, default=...)].
565		assert!(decoded.feature_set.is_empty(), "feature_set must default to empty on deserialize");
566	}
567}