Skip to main content

vane_core/
ir.rs

1use std::collections::HashMap;
2use std::net::SocketAddr;
3use std::ops::Index;
4use std::path::PathBuf;
5use std::time::SystemTime;
6
7use crate::fetch::{SymbolicFetchRef, Terminator};
8use crate::middleware::SymbolicMiddlewareRef;
9use crate::predicate::PredicateInst;
10
11macro_rules! id_newtype {
12	($name:ident) => {
13		#[derive(
14			Copy, Clone, Eq, PartialEq, Hash, Debug, PartialOrd, Ord, serde::Serialize, serde::Deserialize,
15		)]
16		pub struct $name(u32);
17
18		impl $name {
19			#[must_use]
20			pub const fn new(raw: u32) -> Self {
21				Self(raw)
22			}
23
24			#[must_use]
25			pub const fn get(self) -> u32 {
26				self.0
27			}
28		}
29	};
30}
31
32id_newtype!(NodeId);
33id_newtype!(PredicateId);
34id_newtype!(MiddlewareId);
35id_newtype!(FetchId);
36id_newtype!(TerminatorId);
37
38#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug, serde::Serialize, serde::Deserialize)]
39pub enum BodySide {
40	Request,
41	Response,
42}
43
44#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
45pub enum Node {
46	Check {
47		predicate: PredicateId,
48		on_match: NodeId,
49		on_miss: NodeId,
50		collect_body_before: Option<BodySide>,
51		#[serde(default)]
52		body_limit: usize,
53	},
54	Middleware {
55		id: MiddlewareId,
56		next: NodeId,
57		on_error: Option<NodeId>,
58		collect_body_before: Option<BodySide>,
59		#[serde(default)]
60		body_limit: usize,
61	},
62	Fetch {
63		id: FetchId,
64		next_response: Option<NodeId>,
65		next_tunnel: Option<NodeId>,
66		collect_body_before: Option<BodySide>,
67		#[serde(default)]
68		body_limit: usize,
69	},
70	Upgrade {
71		next: NodeId,
72	},
73	Terminate(TerminatorId),
74}
75
76impl Node {
77	#[must_use]
78	pub const fn collect_body_before(&self) -> Option<BodySide> {
79		match self {
80			Self::Check { collect_body_before, .. }
81			| Self::Middleware { collect_body_before, .. }
82			| Self::Fetch { collect_body_before, .. } => *collect_body_before,
83			Self::Upgrade { .. } | Self::Terminate(_) => None,
84		}
85	}
86
87	#[must_use]
88	pub const fn body_limit(&self) -> usize {
89		match self {
90			Self::Check { body_limit, .. }
91			| Self::Middleware { body_limit, .. }
92			| Self::Fetch { body_limit, .. } => *body_limit,
93			_ => 0,
94		}
95	}
96}
97
98#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
99pub struct FlowGraphMeta {
100	pub version_hash: [u8; 32],
101	pub compiled_at: SystemTime,
102	pub source_files: Vec<PathBuf>,
103	// `feature_set` is a compile-time slice the daemon fills in at link, not
104	// a user-authored value; dry-run JSON omits it and deserialization
105	// restores the empty slice. Engine's link step installs the real value.
106	#[serde(skip, default = "empty_feature_set")]
107	pub feature_set: &'static [&'static str],
108
109	/// Map of L7-listener entry `NodeId` → synthesised
110	/// `Terminate(WriteHttpResponse)` `NodeId`. The executor jumps here
111	/// when an L7 request middleware returns
112	/// `Decision::Short(ShortCircuit::Response(_))`: it sets the response
113	/// slot and walks to the synth target so the response runs through
114	/// the standard `WriteHttpResponse` write path. Empty for L4-only
115	/// graphs and for any L7 entry whose listener is not bound to a
116	/// post-`Upgrade` chain (which the lower pass guarantees never
117	/// happens for legal L7 listeners). See spec/architecture/02-flow.md
118	/// § _`FlowGraph` metadata_.
119	///
120	/// `#[serde(default)]` keeps older dry-run JSON snapshots
121	/// deserializable: missing field decodes as an empty map, which
122	/// matches the legacy "no L7 listeners" graph shape.
123	#[serde(default)]
124	pub short_circuit_response_entry: std::collections::BTreeMap<NodeId, NodeId>,
125
126	/// Per-listener cert pool. Symbolic — each entry is the aggregated
127	/// `(default, sni_certs)` view across every rule on the bind
128	/// address that carried a `tls` block; the engine's `link` stage
129	/// reads PEM files referenced here and builds a `rustls::ServerConfig`
130	/// with an SNI resolver that falls back to `default` for unmatched
131	/// SNI. Listeners absent from this map are cleartext. See
132	/// `spec/architecture/08-tls.md` § _TLS termination (L4 → L7
133	/// upgrade)_ and § _SNI normalization_.
134	///
135	/// `#[serde(default)]` for the same wire-compat reason as the map
136	/// above.
137	#[serde(default)]
138	pub listener_tls: std::collections::BTreeMap<SocketAddr, crate::rule::ListenerTlsSpec>,
139}
140
141const fn empty_feature_set() -> &'static [&'static str] {
142	&[]
143}
144
145#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
146pub struct SymbolicFlowGraph {
147	pub nodes: Vec<Node>,
148	pub predicates: Vec<PredicateInst>,
149	pub middlewares: Vec<SymbolicMiddlewareRef>,
150	pub fetches: Vec<SymbolicFetchRef>,
151	pub terminators: Vec<Terminator>,
152	pub entries: HashMap<SocketAddr, NodeId>,
153	pub meta: FlowGraphMeta,
154}
155
156impl Index<NodeId> for SymbolicFlowGraph {
157	type Output = Node;
158	fn index(&self, id: NodeId) -> &Node {
159		&self.nodes[id.get() as usize]
160	}
161}
162
163impl Index<PredicateId> for SymbolicFlowGraph {
164	type Output = PredicateInst;
165	fn index(&self, id: PredicateId) -> &PredicateInst {
166		&self.predicates[id.get() as usize]
167	}
168}
169
170impl Index<MiddlewareId> for SymbolicFlowGraph {
171	type Output = SymbolicMiddlewareRef;
172	fn index(&self, id: MiddlewareId) -> &SymbolicMiddlewareRef {
173		&self.middlewares[id.get() as usize]
174	}
175}
176
177impl Index<FetchId> for SymbolicFlowGraph {
178	type Output = SymbolicFetchRef;
179	fn index(&self, id: FetchId) -> &SymbolicFetchRef {
180		&self.fetches[id.get() as usize]
181	}
182}
183
184impl Index<TerminatorId> for SymbolicFlowGraph {
185	type Output = Terminator;
186	fn index(&self, id: TerminatorId) -> &Terminator {
187		&self.terminators[id.get() as usize]
188	}
189}
190
191#[cfg(test)]
192mod tests {
193	use std::collections::hash_map::DefaultHasher;
194	use std::hash::{Hash, Hasher};
195	use std::sync::Arc;
196
197	use serde_json::Value;
198
199	use super::*;
200	use crate::fetch::{FetchKind, SymbolicFetchRef, Terminator};
201	use crate::middleware::{MiddlewareKind, SymbolicMiddlewareRef};
202	use crate::predicate::{CompiledOperator, CompiledValue, FieldPath, PredicateInst};
203
204	#[test]
205	fn new_then_get_round_trips_raw_u32() {
206		for raw in [0_u32, 1, 42, u32::MAX] {
207			assert_eq!(NodeId::new(raw).get(), raw);
208		}
209	}
210
211	#[test]
212	fn node_id_equality_is_structural() {
213		assert_eq!(NodeId::new(7), NodeId::new(7));
214		assert_ne!(NodeId::new(7), NodeId::new(8));
215	}
216
217	#[test]
218	fn node_id_ordering_follows_raw_u32() {
219		assert!(NodeId::new(1) < NodeId::new(2));
220		assert!(NodeId::new(u32::MAX) > NodeId::new(0));
221	}
222
223	#[test]
224	fn node_id_serde_round_trip() {
225		let id = NodeId::new(0x0bad_f00d);
226		let encoded = serde_json::to_string(&id).expect("serialize");
227		let decoded: NodeId = serde_json::from_str(&encoded).expect("deserialize");
228		assert_eq!(decoded, id);
229	}
230
231	#[test]
232	fn body_side_serde_round_trip_per_variant() {
233		for s in [BodySide::Request, BodySide::Response] {
234			let encoded = serde_json::to_string(&s).expect("serialize");
235			let decoded: BodySide = serde_json::from_str(&encoded).expect("deserialize");
236			assert_eq!(decoded, s);
237		}
238	}
239
240	fn hash_of<T: Hash>(t: &T) -> u64 {
241		let mut h = DefaultHasher::new();
242		t.hash(&mut h);
243		h.finish()
244	}
245
246	#[test]
247	fn predicate_id_new_get_round_trip_and_hash_eq() {
248		for raw in [0_u32, 1, 42, u32::MAX] {
249			let a = PredicateId::new(raw);
250			let b = PredicateId::new(raw);
251			assert_eq!(a.get(), raw);
252			assert_eq!(a, b);
253			assert_eq!(hash_of(&a), hash_of(&b));
254			let encoded = serde_json::to_string(&a).expect("serialize");
255			let decoded: PredicateId = serde_json::from_str(&encoded).expect("deserialize");
256			assert_eq!(decoded, a);
257		}
258	}
259
260	#[test]
261	fn middleware_id_new_get_round_trip_and_hash_eq() {
262		for raw in [0_u32, 1, 42, u32::MAX] {
263			let a = MiddlewareId::new(raw);
264			let b = MiddlewareId::new(raw);
265			assert_eq!(a.get(), raw);
266			assert_eq!(a, b);
267			assert_eq!(hash_of(&a), hash_of(&b));
268			let encoded = serde_json::to_string(&a).expect("serialize");
269			let decoded: MiddlewareId = serde_json::from_str(&encoded).expect("deserialize");
270			assert_eq!(decoded, a);
271		}
272	}
273
274	#[test]
275	fn fetch_id_new_get_round_trip_and_hash_eq() {
276		for raw in [0_u32, 1, 42, u32::MAX] {
277			let a = FetchId::new(raw);
278			let b = FetchId::new(raw);
279			assert_eq!(a.get(), raw);
280			assert_eq!(a, b);
281			assert_eq!(hash_of(&a), hash_of(&b));
282			let encoded = serde_json::to_string(&a).expect("serialize");
283			let decoded: FetchId = serde_json::from_str(&encoded).expect("deserialize");
284			assert_eq!(decoded, a);
285		}
286	}
287
288	#[test]
289	fn terminator_id_new_get_round_trip_and_hash_eq() {
290		for raw in [0_u32, 1, 42, u32::MAX] {
291			let a = TerminatorId::new(raw);
292			let b = TerminatorId::new(raw);
293			assert_eq!(a.get(), raw);
294			assert_eq!(a, b);
295			assert_eq!(hash_of(&a), hash_of(&b));
296			let encoded = serde_json::to_string(&a).expect("serialize");
297			let decoded: TerminatorId = serde_json::from_str(&encoded).expect("deserialize");
298			assert_eq!(decoded, a);
299		}
300	}
301
302	// The newtype wrappers are distinct types — a function accepting `NodeId`
303	// refuses a `PredicateId` at compile time. `_id_types_are_distinct` is a
304	// compile-only witness that the signatures pin the right types; any mix-up
305	// at a call site would fail to type-check.
306	fn _id_types_are_distinct(
307		_n: NodeId,
308		_p: PredicateId,
309		_m: MiddlewareId,
310		_f: FetchId,
311		_t: TerminatorId,
312	) {
313	}
314
315	#[test]
316	fn node_check_collect_body_before_returns_stored_flag() {
317		let some = Node::Check {
318			predicate: PredicateId::new(0),
319			on_match: NodeId::new(0),
320			on_miss: NodeId::new(0),
321			collect_body_before: Some(BodySide::Request),
322			body_limit: 0,
323		};
324		assert_eq!(some.collect_body_before(), Some(BodySide::Request));
325
326		let none = Node::Check {
327			predicate: PredicateId::new(0),
328			on_match: NodeId::new(0),
329			on_miss: NodeId::new(0),
330			collect_body_before: None,
331			body_limit: 0,
332		};
333		assert_eq!(none.collect_body_before(), None);
334	}
335
336	#[test]
337	fn node_middleware_collect_body_before_returns_stored_flag() {
338		let some = Node::Middleware {
339			id: MiddlewareId::new(0),
340			next: NodeId::new(0),
341			on_error: None,
342			collect_body_before: Some(BodySide::Response),
343			body_limit: 0,
344		};
345		assert_eq!(some.collect_body_before(), Some(BodySide::Response));
346
347		let none = Node::Middleware {
348			id: MiddlewareId::new(0),
349			next: NodeId::new(0),
350			on_error: None,
351			collect_body_before: None,
352			body_limit: 0,
353		};
354		assert_eq!(none.collect_body_before(), None);
355	}
356
357	#[test]
358	fn node_fetch_collect_body_before_returns_stored_flag() {
359		let some = Node::Fetch {
360			id: FetchId::new(0),
361			next_response: None,
362			next_tunnel: None,
363			collect_body_before: Some(BodySide::Request),
364			body_limit: 0,
365		};
366		assert_eq!(some.collect_body_before(), Some(BodySide::Request));
367
368		let none = Node::Fetch {
369			id: FetchId::new(0),
370			next_response: None,
371			next_tunnel: None,
372			collect_body_before: None,
373			body_limit: 0,
374		};
375		assert_eq!(none.collect_body_before(), None);
376	}
377
378	#[test]
379	fn node_upgrade_collect_body_before_is_always_none() {
380		let n = Node::Upgrade { next: NodeId::new(0) };
381		assert_eq!(n.collect_body_before(), None);
382	}
383
384	#[test]
385	fn node_terminate_collect_body_before_is_always_none() {
386		let n = Node::Terminate(TerminatorId::new(0));
387		assert_eq!(n.collect_body_before(), None);
388	}
389
390	fn sample_predicate() -> PredicateInst {
391		PredicateInst {
392			path: FieldPath::TlsSni,
393			op: CompiledOperator::Equals(CompiledValue::Str(Arc::from("a"))),
394		}
395	}
396
397	fn sample_middleware() -> SymbolicMiddlewareRef {
398		SymbolicMiddlewareRef {
399			name: Arc::from("noop"),
400			args: Value::Null,
401			kind: MiddlewareKind::L7Request,
402			stateless: true,
403			needs_body: false,
404			on_error: None,
405		}
406	}
407
408	fn sample_fetch() -> SymbolicFetchRef {
409		SymbolicFetchRef { kind: FetchKind::HttpProxy, args: Value::Null }
410	}
411
412	fn sample_meta() -> FlowGraphMeta {
413		FlowGraphMeta {
414			version_hash: [0; 32],
415			compiled_at: SystemTime::UNIX_EPOCH,
416			source_files: vec![],
417			feature_set: &[],
418			short_circuit_response_entry: std::collections::BTreeMap::new(),
419			listener_tls: std::collections::BTreeMap::new(),
420		}
421	}
422
423	fn one_of_each_graph() -> SymbolicFlowGraph {
424		SymbolicFlowGraph {
425			nodes: vec![Node::Terminate(TerminatorId::new(0))],
426			predicates: vec![sample_predicate()],
427			middlewares: vec![sample_middleware()],
428			fetches: vec![sample_fetch()],
429			terminators: vec![Terminator::WriteHttpResponse],
430			entries: HashMap::new(),
431			meta: sample_meta(),
432		}
433	}
434
435	#[test]
436	fn index_by_node_id_returns_matching_node() {
437		let g = one_of_each_graph();
438		match &g[NodeId::new(0)] {
439			Node::Terminate(t) => assert_eq!(*t, TerminatorId::new(0)),
440			other => panic!("expected Terminate, got {other:?}"),
441		}
442	}
443
444	#[test]
445	fn index_by_predicate_id_returns_matching_predicate() {
446		let g = one_of_each_graph();
447		assert_eq!(g[PredicateId::new(0)], sample_predicate());
448	}
449
450	#[test]
451	fn index_by_middleware_id_returns_matching_middleware() {
452		let g = one_of_each_graph();
453		assert_eq!(g[MiddlewareId::new(0)], sample_middleware());
454	}
455
456	#[test]
457	fn index_by_fetch_id_returns_matching_fetch() {
458		let g = one_of_each_graph();
459		assert_eq!(g[FetchId::new(0)].kind, FetchKind::HttpProxy);
460	}
461
462	#[test]
463	fn index_by_terminator_id_returns_matching_terminator() {
464		let g = one_of_each_graph();
465		assert_eq!(g[TerminatorId::new(0)], Terminator::WriteHttpResponse);
466	}
467
468	fn node_round_trip(n: &Node) -> Node {
469		let encoded = serde_json::to_string(n).expect("serialize node");
470		serde_json::from_str(&encoded).expect("deserialize node")
471	}
472
473	#[test]
474	fn node_check_serde_round_trip_with_and_without_collect_flag() {
475		let with = Node::Check {
476			predicate: PredicateId::new(3),
477			on_match: NodeId::new(4),
478			on_miss: NodeId::new(5),
479			collect_body_before: Some(BodySide::Request),
480			body_limit: 0,
481		};
482		match node_round_trip(&with) {
483			Node::Check { predicate, on_match, on_miss, collect_body_before, .. } => {
484				assert_eq!(predicate, PredicateId::new(3));
485				assert_eq!(on_match, NodeId::new(4));
486				assert_eq!(on_miss, NodeId::new(5));
487				assert_eq!(collect_body_before, Some(BodySide::Request));
488			}
489			other => panic!("expected Check, got {other:?}"),
490		}
491
492		let without = Node::Check {
493			predicate: PredicateId::new(0),
494			on_match: NodeId::new(0),
495			on_miss: NodeId::new(0),
496			collect_body_before: None,
497			body_limit: 0,
498		};
499		match node_round_trip(&without) {
500			Node::Check { collect_body_before, .. } => assert_eq!(collect_body_before, None),
501			other => panic!("expected Check, got {other:?}"),
502		}
503	}
504
505	#[test]
506	fn node_middleware_serde_round_trip_with_and_without_collect_flag() {
507		let with = Node::Middleware {
508			id: MiddlewareId::new(1),
509			next: NodeId::new(2),
510			on_error: Some(NodeId::new(9)),
511			collect_body_before: Some(BodySide::Response),
512			body_limit: 0,
513		};
514		match node_round_trip(&with) {
515			Node::Middleware { id, next, on_error, collect_body_before, .. } => {
516				assert_eq!(id, MiddlewareId::new(1));
517				assert_eq!(next, NodeId::new(2));
518				assert_eq!(on_error, Some(NodeId::new(9)));
519				assert_eq!(collect_body_before, Some(BodySide::Response));
520			}
521			other => panic!("expected Middleware, got {other:?}"),
522		}
523
524		let without = Node::Middleware {
525			id: MiddlewareId::new(0),
526			next: NodeId::new(0),
527			on_error: None,
528			collect_body_before: None,
529			body_limit: 0,
530		};
531		match node_round_trip(&without) {
532			Node::Middleware { on_error, collect_body_before, .. } => {
533				assert_eq!(on_error, None);
534				assert_eq!(collect_body_before, None);
535			}
536			other => panic!("expected Middleware, got {other:?}"),
537		}
538	}
539
540	#[test]
541	fn node_fetch_serde_round_trip_with_and_without_collect_flag() {
542		let with = Node::Fetch {
543			id: FetchId::new(7),
544			next_response: Some(NodeId::new(8)),
545			next_tunnel: Some(NodeId::new(9)),
546			collect_body_before: Some(BodySide::Request),
547			body_limit: 0,
548		};
549		match node_round_trip(&with) {
550			Node::Fetch { id, next_response, next_tunnel, collect_body_before, .. } => {
551				assert_eq!(id, FetchId::new(7));
552				assert_eq!(next_response, Some(NodeId::new(8)));
553				assert_eq!(next_tunnel, Some(NodeId::new(9)));
554				assert_eq!(collect_body_before, Some(BodySide::Request));
555			}
556			other => panic!("expected Fetch, got {other:?}"),
557		}
558
559		let without = Node::Fetch {
560			id: FetchId::new(0),
561			next_response: None,
562			next_tunnel: None,
563			collect_body_before: None,
564			body_limit: 0,
565		};
566		match node_round_trip(&without) {
567			Node::Fetch { next_response, next_tunnel, collect_body_before, .. } => {
568				assert_eq!(next_response, None);
569				assert_eq!(next_tunnel, None);
570				assert_eq!(collect_body_before, None);
571			}
572			other => panic!("expected Fetch, got {other:?}"),
573		}
574	}
575
576	#[test]
577	fn node_upgrade_serde_round_trip() {
578		let n = Node::Upgrade { next: NodeId::new(11) };
579		match node_round_trip(&n) {
580			Node::Upgrade { next } => assert_eq!(next, NodeId::new(11)),
581			other => panic!("expected Upgrade, got {other:?}"),
582		}
583	}
584
585	#[test]
586	fn node_terminate_serde_round_trip() {
587		let n = Node::Terminate(TerminatorId::new(13));
588		match node_round_trip(&n) {
589			Node::Terminate(t) => assert_eq!(t, TerminatorId::new(13)),
590			other => panic!("expected Terminate, got {other:?}"),
591		}
592	}
593
594	// `FlowGraphMeta` derives `Serialize` but not `Deserialize` (the spec
595	// comment in this module notes `Deserialize` lands with S1-32). Assert the
596	// forward direction only.
597	#[test]
598	fn flow_graph_meta_serializes_and_emits_version_hash_field() {
599		let meta = sample_meta();
600		let encoded = serde_json::to_string(&meta).expect("serialize meta");
601		assert!(encoded.contains("version_hash"), "expected version_hash field in {encoded}");
602	}
603
604	#[test]
605	fn flow_graph_meta_round_trip_preserves_all_but_feature_set() {
606		// 02-flow.md § _FlowGraph metadata_: feature_set is a compile-time
607		// slice the daemon fills in at link and is NOT emitted to dry-run JSON.
608		// version_hash / compiled_at / source_files must round-trip.
609		use std::time::Duration;
610		let meta = FlowGraphMeta {
611			version_hash: [0x42; 32],
612			compiled_at: SystemTime::UNIX_EPOCH + Duration::from_secs(1_000_000_000),
613			source_files: vec![PathBuf::from("/a.json"), PathBuf::from("/b.json")],
614			feature_set: &["h3", "wasm"],
615			short_circuit_response_entry: std::collections::BTreeMap::new(),
616			listener_tls: std::collections::BTreeMap::new(),
617		};
618		let encoded = serde_json::to_string(&meta).expect("serialize meta");
619		assert!(
620			!encoded.contains("feature_set"),
621			"feature_set must be skipped in dry-run JSON, got: {encoded}",
622		);
623		let decoded: FlowGraphMeta = serde_json::from_str(&encoded).expect("deserialize meta");
624		assert_eq!(decoded.version_hash, meta.version_hash);
625		assert_eq!(decoded.compiled_at, meta.compiled_at);
626		assert_eq!(decoded.source_files, meta.source_files);
627		// feature_set is restored to the empty slice by #[serde(skip, default=...)].
628		assert!(decoded.feature_set.is_empty(), "feature_set must default to empty on deserialize");
629	}
630}