Skip to main content

vane_core/compile/
validate.rs

1use std::collections::HashSet;
2
3use crate::error::Error;
4use crate::ir::{Node, NodeId, SymbolicFlowGraph};
5use crate::phase::{Phase, PhaseNodeKind, Transition, transition};
6
7/// Run IR-level structural and phase validation on a freshly-lowered graph.
8///
9/// # Errors
10/// Returns [`Error::compile`] on missing-id references, Fetch edges that
11/// don't match the kind's output-mode contract, acyclicity violations, or
12/// phase-state-machine mismatches.
13pub fn validate(graph: &SymbolicFlowGraph) -> Result<(), Error> {
14	check_id_ranges(graph)?;
15	check_fetch_edges(graph)?;
16	check_acyclic(graph)?;
17	check_phases(graph)?;
18	Ok(())
19}
20
21fn check_id_ranges(graph: &SymbolicFlowGraph) -> Result<(), Error> {
22	let n_nodes = u32::try_from(graph.nodes.len()).unwrap_or(u32::MAX);
23	let n_preds = u32::try_from(graph.predicates.len()).unwrap_or(u32::MAX);
24	let n_mws = u32::try_from(graph.middlewares.len()).unwrap_or(u32::MAX);
25	let n_fetches = u32::try_from(graph.fetches.len()).unwrap_or(u32::MAX);
26	let n_terms = u32::try_from(graph.terminators.len()).unwrap_or(u32::MAX);
27
28	for (idx, node) in graph.nodes.iter().enumerate() {
29		match node {
30			Node::Check { predicate, on_match, on_miss, .. } => {
31				if predicate.get() >= n_preds {
32					return Err(Error::compile(format!(
33						"node {idx}: dangling PredicateId({})",
34						predicate.get()
35					)));
36				}
37				if on_match.get() >= n_nodes {
38					return Err(Error::compile(format!("node {idx}.on_match dangling")));
39				}
40				if on_miss.get() >= n_nodes {
41					return Err(Error::compile(format!("node {idx}.on_miss dangling")));
42				}
43			}
44			Node::Middleware { id, next, on_error, .. } => {
45				if id.get() >= n_mws {
46					return Err(Error::compile(format!("node {idx}: dangling MiddlewareId({})", id.get())));
47				}
48				if next.get() >= n_nodes {
49					return Err(Error::compile(format!("node {idx}.next dangling")));
50				}
51				if let Some(e) = on_error
52					&& e.get() >= n_nodes
53				{
54					return Err(Error::compile(format!("node {idx}.on_error dangling")));
55				}
56			}
57			Node::Fetch { id, next_response, next_tunnel, .. } => {
58				if id.get() >= n_fetches {
59					return Err(Error::compile(format!("node {idx}: dangling FetchId({})", id.get())));
60				}
61				if let Some(r) = next_response
62					&& r.get() >= n_nodes
63				{
64					return Err(Error::compile(format!("node {idx}.next_response dangling")));
65				}
66				if let Some(t) = next_tunnel
67					&& t.get() >= n_nodes
68				{
69					return Err(Error::compile(format!("node {idx}.next_tunnel dangling")));
70				}
71			}
72			Node::Upgrade { next } => {
73				if next.get() >= n_nodes {
74					return Err(Error::compile(format!("node {idx}.next dangling")));
75				}
76			}
77			Node::Terminate(t) => {
78				if t.get() >= n_terms {
79					return Err(Error::compile(format!("node {idx}: dangling TerminatorId({})", t.get())));
80				}
81			}
82		}
83	}
84	Ok(())
85}
86
87fn check_fetch_edges(graph: &SymbolicFlowGraph) -> Result<(), Error> {
88	use crate::fetch::FetchKind::{
89		AcmeChallenge, HttpProxy, HttpSynthesize, L4Forward, WebSocketUpgrade,
90	};
91	for (idx, node) in graph.nodes.iter().enumerate() {
92		let Node::Fetch { id, next_response, next_tunnel, .. } = node else {
93			continue;
94		};
95		let kind = graph[*id].kind;
96		match kind {
97			HttpProxy | HttpSynthesize | AcmeChallenge => {
98				if next_response.is_none() {
99					return Err(Error::compile(format!("node {idx}: {kind:?} requires next_response")));
100				}
101				if next_tunnel.is_some() {
102					return Err(Error::compile(format!("node {idx}: {kind:?} must not have next_tunnel")));
103				}
104			}
105			L4Forward => {
106				if next_tunnel.is_none() {
107					return Err(Error::compile(format!("node {idx}: L4Forward requires next_tunnel")));
108				}
109				if next_response.is_some() {
110					return Err(Error::compile(format!("node {idx}: L4Forward must not have next_response")));
111				}
112			}
113			WebSocketUpgrade => {
114				if next_response.is_none() || next_tunnel.is_none() {
115					return Err(Error::compile(format!(
116						"node {idx}: WebSocketUpgrade requires both next_response and next_tunnel"
117					)));
118				}
119			}
120		}
121	}
122	Ok(())
123}
124
125fn check_acyclic(graph: &SymbolicFlowGraph) -> Result<(), Error> {
126	#[derive(Copy, Clone)]
127	enum Color {
128		White,
129		Gray,
130		Black,
131	}
132	let mut color: Vec<Color> = (0..graph.nodes.len()).map(|_| Color::White).collect();
133
134	for start in 0..graph.nodes.len() {
135		if !matches!(color[start], Color::White) {
136			continue;
137		}
138		let mut stack: Vec<(usize, usize)> = vec![(start, 0)];
139		color[start] = Color::Gray;
140		while let Some(&(node_idx, child_idx)) = stack.last() {
141			let succs = successors(&graph.nodes[node_idx]);
142			if child_idx < succs.len() {
143				let next = succs[child_idx].get() as usize;
144				stack.last_mut().expect("non-empty").1 += 1;
145				match color[next] {
146					Color::White => {
147						color[next] = Color::Gray;
148						stack.push((next, 0));
149					}
150					Color::Gray => {
151						return Err(Error::compile(format!("cycle in graph at node {next}")));
152					}
153					Color::Black => {}
154				}
155			} else {
156				color[node_idx] = Color::Black;
157				stack.pop();
158			}
159		}
160	}
161	Ok(())
162}
163
164fn successors(node: &Node) -> Vec<NodeId> {
165	match node {
166		Node::Check { on_match, on_miss, .. } => vec![*on_match, *on_miss],
167		Node::Middleware { next, on_error, .. } => {
168			let mut v = vec![*next];
169			if let Some(e) = on_error {
170				v.push(*e);
171			}
172			v
173		}
174		Node::Fetch { next_response, next_tunnel, .. } => {
175			let mut v = Vec::new();
176			if let Some(r) = next_response {
177				v.push(*r);
178			}
179			if let Some(t) = next_tunnel {
180				v.push(*t);
181			}
182			v
183		}
184		Node::Upgrade { next } => vec![*next],
185		Node::Terminate(_) => Vec::new(),
186	}
187}
188
189fn node_kind_for_phase(graph: &SymbolicFlowGraph, node: &Node) -> PhaseNodeKind {
190	match node {
191		Node::Check { .. } => PhaseNodeKind::Check,
192		Node::Middleware { id, .. } => PhaseNodeKind::Middleware(graph[*id].kind),
193		Node::Fetch { id, .. } => PhaseNodeKind::Fetch(graph[*id].kind),
194		Node::Upgrade { .. } => PhaseNodeKind::Upgrade,
195		Node::Terminate(t) => PhaseNodeKind::Terminate(graph[*t]),
196	}
197}
198
199/// Walk each listener entry through the phase transition table.
200///
201/// Callable directly for tests and for validators that want phase
202/// coverage; the regular [`validate`] entry point does not call this
203/// today because production graphs reach `L4Peeked` through the
204/// `protocol_detect` middleware that ships in `vane-engine`, not
205/// through any IR-only construction.
206///
207/// # Errors
208/// Returns [`Error::compile`] on phase mismatches per
209/// [`spec/flow-model.md` § _Phase state machine_](../../../../spec/flow-model.md#phase-state-machine).
210pub fn check_phases(graph: &SymbolicFlowGraph) -> Result<(), Error> {
211	let mut seen: HashSet<(NodeId, Phase)> = HashSet::new();
212	for &entry in graph.entries.values() {
213		visit_phase(graph, entry, Phase::L4Raw, &mut seen)?;
214	}
215	// Walk every L7 listener's synthesised `Short(Response)` target as a
216	// second-class entry rooted at `Phase::L7Response`. The lower pass
217	// always emits these as `Terminate(WriteHttpResponse)` nodes —
218	// `WriteHttpResponse` accepts `L7Response` per the transition
219	// table, so a clean lower produces a clean walk. Bogus entries
220	// (a synth target whose terminator is not `WriteHttpResponse`) get
221	// caught here with the same "phase mismatch" error shape.
222	for &synth in graph.meta.short_circuit_response_entry.values() {
223		visit_phase(graph, synth, Phase::L7Response, &mut seen)?;
224	}
225	Ok(())
226}
227
228fn visit_phase(
229	graph: &SymbolicFlowGraph,
230	id: NodeId,
231	phase: Phase,
232	seen: &mut HashSet<(NodeId, Phase)>,
233) -> Result<(), Error> {
234	if !seen.insert((id, phase)) {
235		return Ok(());
236	}
237	let node = &graph[id];
238	let kind = node_kind_for_phase(graph, node);
239	let t = transition(kind, phase).map_err(|e| {
240		Error::compile(format!(
241			"phase mismatch at NodeId({}): expected one of {:?}, got {:?}",
242			id.get(),
243			e.expected,
244			e.got,
245		))
246	})?;
247	match (t, node) {
248		(Transition::Terminal, _) => Ok(()),
249		(Transition::PassThrough, _) => {
250			for succ in successors(node) {
251				visit_phase(graph, succ, phase, seen)?;
252			}
253			Ok(())
254		}
255		(Transition::Into(next_phase), _) => {
256			for succ in successors(node) {
257				visit_phase(graph, succ, next_phase, seen)?;
258			}
259			Ok(())
260		}
261		(
262			Transition::BiOutcome { response, tunnel },
263			Node::Fetch { next_response, next_tunnel, .. },
264		) => {
265			if let Some(r) = next_response {
266				visit_phase(graph, *r, response, seen)?;
267			}
268			if let Some(t) = next_tunnel {
269				visit_phase(graph, *t, tunnel, seen)?;
270			}
271			Ok(())
272		}
273		(Transition::BiOutcome { .. }, _) => {
274			Err(Error::compile("BiOutcome transition on non-Fetch node".to_string()))
275		}
276	}
277}
278
279#[cfg(test)]
280mod tests {
281	use std::collections::HashMap;
282	use std::path::PathBuf;
283	use std::time::SystemTime;
284
285	use super::*;
286	use crate::fetch::{FetchKind, SymbolicFetchRef, Terminator};
287	use crate::ir::{BodySide, FetchId, FlowGraphMeta, PredicateId, TerminatorId};
288
289	fn empty_meta() -> FlowGraphMeta {
290		FlowGraphMeta {
291			version_hash: [0; 32],
292			compiled_at: SystemTime::UNIX_EPOCH,
293			source_files: vec![PathBuf::new()],
294			feature_set: &[],
295			short_circuit_response_entry: std::collections::BTreeMap::new(),
296			listener_tls: std::collections::BTreeMap::new(),
297			listener_kinds: std::collections::BTreeMap::new(),
298			listener_transports: std::collections::BTreeMap::new(),
299			annotations: Vec::new(),
300		}
301	}
302
303	#[test]
304	fn dangling_terminator_id_in_terminate_node_rejected() {
305		let graph = SymbolicFlowGraph {
306			nodes: vec![Node::Terminate(TerminatorId::new(0))],
307			predicates: vec![],
308			middlewares: vec![],
309			fetches: vec![],
310			terminators: vec![],
311			entries: HashMap::new(),
312			meta: empty_meta(),
313		};
314		let err = validate(&graph).expect_err("must error");
315		assert!(err.to_string().contains("dangling TerminatorId"));
316	}
317
318	#[test]
319	fn dangling_node_id_in_fetch_edge_rejected() {
320		let graph = SymbolicFlowGraph {
321			nodes: vec![Node::Fetch {
322				id: FetchId::new(0),
323				next_response: Some(NodeId::new(99)),
324				next_tunnel: None,
325				collect_body_before: None,
326				body_limit: 0,
327			}],
328			predicates: vec![],
329			middlewares: vec![],
330			fetches: vec![SymbolicFetchRef {
331				kind: FetchKind::HttpProxy,
332				args: serde_json::Value::Null,
333				retry_buffer_required: false,
334				allow_zero_rtt: None,
335			}],
336			terminators: vec![],
337			entries: HashMap::new(),
338			meta: empty_meta(),
339		};
340		let err = validate(&graph).expect_err("must error");
341		assert!(err.to_string().contains("next_response dangling"));
342	}
343
344	#[test]
345	fn http_fetch_without_next_response_rejected() {
346		let term = Node::Terminate(TerminatorId::new(0));
347		let graph = SymbolicFlowGraph {
348			nodes: vec![
349				term,
350				Node::Fetch {
351					id: FetchId::new(0),
352					next_response: None,
353					next_tunnel: None,
354					collect_body_before: None,
355					body_limit: 0,
356				},
357			],
358			predicates: vec![],
359			middlewares: vec![],
360			fetches: vec![SymbolicFetchRef {
361				kind: FetchKind::HttpProxy,
362				args: serde_json::Value::Null,
363				retry_buffer_required: false,
364				allow_zero_rtt: None,
365			}],
366			terminators: vec![Terminator::WriteHttpResponse],
367			entries: HashMap::new(),
368			meta: empty_meta(),
369		};
370		let err = validate(&graph).expect_err("must error");
371		assert!(err.to_string().contains("requires next_response"));
372	}
373
374	#[test]
375	fn l4_forward_with_next_response_rejected() {
376		let graph = SymbolicFlowGraph {
377			nodes: vec![
378				Node::Terminate(TerminatorId::new(0)),
379				Node::Fetch {
380					id: FetchId::new(0),
381					next_response: Some(NodeId::new(0)),
382					next_tunnel: Some(NodeId::new(0)),
383					collect_body_before: None,
384					body_limit: 0,
385				},
386			],
387			predicates: vec![],
388			middlewares: vec![],
389			fetches: vec![SymbolicFetchRef {
390				kind: FetchKind::L4Forward,
391				args: serde_json::Value::Null,
392				retry_buffer_required: false,
393				allow_zero_rtt: None,
394			}],
395			terminators: vec![Terminator::ByteTunnel],
396			entries: HashMap::new(),
397			meta: empty_meta(),
398		};
399		let err = validate(&graph).expect_err("must error");
400		assert!(err.to_string().contains("L4Forward must not have next_response"));
401	}
402
403	#[test]
404	fn cyclic_graph_is_rejected() {
405		// Node 0 and Node 1 point at each other via Check on_match edges.
406		let graph = SymbolicFlowGraph {
407			nodes: vec![
408				Node::Check {
409					predicate: PredicateId::new(0),
410					on_match: NodeId::new(1),
411					on_miss: NodeId::new(1),
412					collect_body_before: None,
413					body_limit: 0,
414				},
415				Node::Check {
416					predicate: PredicateId::new(0),
417					on_match: NodeId::new(0),
418					on_miss: NodeId::new(0),
419					collect_body_before: None,
420					body_limit: 0,
421				},
422			],
423			predicates: vec![dummy_predicate()],
424			middlewares: vec![],
425			fetches: vec![],
426			terminators: vec![],
427			entries: HashMap::new(),
428			meta: empty_meta(),
429		};
430		let err = validate(&graph).expect_err("must error");
431		assert!(err.to_string().contains("cycle"));
432	}
433
434	#[test]
435	fn phase_check_rejects_write_http_response_reached_in_wrong_phase() {
436		// Upgrade out-phase is `L7Request`; `Terminate(WriteHttpResponse)`
437		// requires `L7Response`. Walking Upgrade directly into it is a
438		// phase mismatch the validator must catch.
439		let tid = TerminatorId::new(0);
440		let graph = SymbolicFlowGraph {
441			nodes: vec![Node::Terminate(tid), Node::Upgrade { next: NodeId::new(0) }],
442			predicates: vec![],
443			middlewares: vec![],
444			fetches: vec![],
445			terminators: vec![Terminator::WriteHttpResponse],
446			entries: {
447				let mut m = HashMap::new();
448				m.insert("127.0.0.1:443".parse().expect("parse"), NodeId::new(1));
449				m
450			},
451			meta: empty_meta(),
452		};
453		let err = check_phases(&graph).expect_err("must error");
454		assert!(err.to_string().contains("phase mismatch"));
455	}
456
457	#[test]
458	fn phase_check_rejects_short_circuit_synth_with_wrong_terminator() {
459		// `meta.short_circuit_response_entry` values are walked at
460		// `Phase::L7Response`. A synth target whose terminator does not
461		// accept that phase must trip the same "phase mismatch" error
462		// the standard walker uses. `Terminator::Close` is phase-agnostic
463		// so it would never trip this check; `ByteTunnel` only accepts
464		// `Phase::Tunnel` and is the right negative-test fixture.
465		let bad_tid = TerminatorId::new(0);
466		let mut meta = empty_meta();
467		meta.short_circuit_response_entry.insert(NodeId::new(1), NodeId::new(0));
468		let graph = SymbolicFlowGraph {
469			nodes: vec![Node::Terminate(bad_tid), Node::Upgrade { next: NodeId::new(0) }],
470			predicates: vec![],
471			middlewares: vec![],
472			fetches: vec![],
473			terminators: vec![Terminator::ByteTunnel],
474			// No `entries` — exercise the synth walk in isolation.
475			entries: HashMap::new(),
476			meta,
477		};
478		let err = check_phases(&graph).expect_err("must error on bad synth phase");
479		assert!(err.to_string().contains("phase mismatch"), "{err}");
480	}
481
482	fn dummy_predicate() -> crate::predicate::PredicateInst {
483		use crate::predicate::{CompiledOperator, CompiledValue, FieldPath, PredicateInst};
484		PredicateInst {
485			path: FieldPath::TlsSni,
486			op: CompiledOperator::Equals(CompiledValue::Str(std::sync::Arc::from("x"))),
487		}
488	}
489
490	// `BodySide` import is kept here to keep test doc consistent with the
491	// `Node` field it accesses in the broader impl.
492	const _: BodySide = BodySide::Request;
493}