Skip to main content

vane_core/compile/
validate.rs

1use std::collections::HashSet;
2
3use crate::error::Error;
4use crate::ir::{Node, NodeId, SymbolicFlowGraph};
5use crate::phase::{Phase, PhaseNodeKind, Transition, transition};
6
7/// Run IR-level structural and phase validation on a freshly-lowered graph.
8///
9/// # Errors
10/// Returns [`Error::compile`] on missing-id references, Fetch edges that
11/// don't match the kind's output-mode contract, acyclicity violations, or
12/// phase-state-machine mismatches.
13pub fn validate(graph: &SymbolicFlowGraph) -> Result<(), Error> {
14	check_id_ranges(graph)?;
15	check_fetch_edges(graph)?;
16	check_acyclic(graph)?;
17	check_phases(graph)?;
18	Ok(())
19}
20
21fn check_id_ranges(graph: &SymbolicFlowGraph) -> Result<(), Error> {
22	let n_nodes = u32::try_from(graph.nodes.len()).unwrap_or(u32::MAX);
23	let n_preds = u32::try_from(graph.predicates.len()).unwrap_or(u32::MAX);
24	let n_mws = u32::try_from(graph.middlewares.len()).unwrap_or(u32::MAX);
25	let n_fetches = u32::try_from(graph.fetches.len()).unwrap_or(u32::MAX);
26	let n_terms = u32::try_from(graph.terminators.len()).unwrap_or(u32::MAX);
27
28	for (idx, node) in graph.nodes.iter().enumerate() {
29		match node {
30			Node::Check { predicate, on_match, on_miss, .. } => {
31				if predicate.get() >= n_preds {
32					return Err(Error::compile(format!(
33						"node {idx}: dangling PredicateId({})",
34						predicate.get()
35					)));
36				}
37				if on_match.get() >= n_nodes {
38					return Err(Error::compile(format!("node {idx}.on_match dangling")));
39				}
40				if on_miss.get() >= n_nodes {
41					return Err(Error::compile(format!("node {idx}.on_miss dangling")));
42				}
43			}
44			Node::Middleware { id, next, on_error, .. } => {
45				if id.get() >= n_mws {
46					return Err(Error::compile(format!("node {idx}: dangling MiddlewareId({})", id.get())));
47				}
48				if next.get() >= n_nodes {
49					return Err(Error::compile(format!("node {idx}.next dangling")));
50				}
51				if let Some(e) = on_error
52					&& e.get() >= n_nodes
53				{
54					return Err(Error::compile(format!("node {idx}.on_error dangling")));
55				}
56			}
57			Node::Fetch { id, next_response, next_tunnel, .. } => {
58				if id.get() >= n_fetches {
59					return Err(Error::compile(format!("node {idx}: dangling FetchId({})", id.get())));
60				}
61				if let Some(r) = next_response
62					&& r.get() >= n_nodes
63				{
64					return Err(Error::compile(format!("node {idx}.next_response dangling")));
65				}
66				if let Some(t) = next_tunnel
67					&& t.get() >= n_nodes
68				{
69					return Err(Error::compile(format!("node {idx}.next_tunnel dangling")));
70				}
71			}
72			Node::Upgrade { next } => {
73				if next.get() >= n_nodes {
74					return Err(Error::compile(format!("node {idx}.next dangling")));
75				}
76			}
77			Node::Terminate(t) => {
78				if t.get() >= n_terms {
79					return Err(Error::compile(format!("node {idx}: dangling TerminatorId({})", t.get())));
80				}
81			}
82		}
83	}
84	Ok(())
85}
86
87fn check_fetch_edges(graph: &SymbolicFlowGraph) -> Result<(), Error> {
88	use crate::fetch::FetchKind::{HttpProxy, HttpSynthesize, L4Forward, WebSocketUpgrade};
89	for (idx, node) in graph.nodes.iter().enumerate() {
90		let Node::Fetch { id, next_response, next_tunnel, .. } = node else {
91			continue;
92		};
93		let kind = graph[*id].kind;
94		match kind {
95			HttpProxy | HttpSynthesize => {
96				if next_response.is_none() {
97					return Err(Error::compile(format!("node {idx}: {kind:?} requires next_response")));
98				}
99				if next_tunnel.is_some() {
100					return Err(Error::compile(format!("node {idx}: {kind:?} must not have next_tunnel")));
101				}
102			}
103			L4Forward => {
104				if next_tunnel.is_none() {
105					return Err(Error::compile(format!("node {idx}: L4Forward requires next_tunnel")));
106				}
107				if next_response.is_some() {
108					return Err(Error::compile(format!("node {idx}: L4Forward must not have next_response")));
109				}
110			}
111			WebSocketUpgrade => {
112				if next_response.is_none() || next_tunnel.is_none() {
113					return Err(Error::compile(format!(
114						"node {idx}: WebSocketUpgrade requires both next_response and next_tunnel"
115					)));
116				}
117			}
118		}
119	}
120	Ok(())
121}
122
123fn check_acyclic(graph: &SymbolicFlowGraph) -> Result<(), Error> {
124	#[derive(Copy, Clone)]
125	enum Color {
126		White,
127		Gray,
128		Black,
129	}
130	let mut color: Vec<Color> = (0..graph.nodes.len()).map(|_| Color::White).collect();
131
132	for start in 0..graph.nodes.len() {
133		if !matches!(color[start], Color::White) {
134			continue;
135		}
136		let mut stack: Vec<(usize, usize)> = vec![(start, 0)];
137		color[start] = Color::Gray;
138		while let Some(&(node_idx, child_idx)) = stack.last() {
139			let succs = successors(&graph.nodes[node_idx]);
140			if child_idx < succs.len() {
141				let next = succs[child_idx].get() as usize;
142				stack.last_mut().expect("non-empty").1 += 1;
143				match color[next] {
144					Color::White => {
145						color[next] = Color::Gray;
146						stack.push((next, 0));
147					}
148					Color::Gray => {
149						return Err(Error::compile(format!("cycle in graph at node {next}")));
150					}
151					Color::Black => {}
152				}
153			} else {
154				color[node_idx] = Color::Black;
155				stack.pop();
156			}
157		}
158	}
159	Ok(())
160}
161
162fn successors(node: &Node) -> Vec<NodeId> {
163	match node {
164		Node::Check { on_match, on_miss, .. } => vec![*on_match, *on_miss],
165		Node::Middleware { next, on_error, .. } => {
166			let mut v = vec![*next];
167			if let Some(e) = on_error {
168				v.push(*e);
169			}
170			v
171		}
172		Node::Fetch { next_response, next_tunnel, .. } => {
173			let mut v = Vec::new();
174			if let Some(r) = next_response {
175				v.push(*r);
176			}
177			if let Some(t) = next_tunnel {
178				v.push(*t);
179			}
180			v
181		}
182		Node::Upgrade { next } => vec![*next],
183		Node::Terminate(_) => Vec::new(),
184	}
185}
186
187fn node_kind_for_phase(graph: &SymbolicFlowGraph, node: &Node) -> PhaseNodeKind {
188	match node {
189		Node::Check { .. } => PhaseNodeKind::Check,
190		Node::Middleware { id, .. } => PhaseNodeKind::Middleware(graph[*id].kind),
191		Node::Fetch { id, .. } => PhaseNodeKind::Fetch(graph[*id].kind),
192		Node::Upgrade { .. } => PhaseNodeKind::Upgrade,
193		Node::Terminate(t) => PhaseNodeKind::Terminate(graph[*t]),
194	}
195}
196
197/// Walk each listener entry through the phase transition table.
198///
199/// Not invoked from [`validate`] today because MVP graphs lack the
200/// `protocol_detect` middleware that advances `L4Raw โ†’ L4Peeked` โ€” that
201/// middleware lands at S1-16. Callable directly for tests and for future
202/// validators that want phase coverage.
203///
204/// # Errors
205/// Returns [`Error::compile`] on phase mismatches per 02-flow.md ยง _Phase
206/// state machine_.
207pub fn check_phases(graph: &SymbolicFlowGraph) -> Result<(), Error> {
208	let mut seen: HashSet<(NodeId, Phase)> = HashSet::new();
209	for &entry in graph.entries.values() {
210		visit_phase(graph, entry, Phase::L4Raw, &mut seen)?;
211	}
212	Ok(())
213}
214
215fn visit_phase(
216	graph: &SymbolicFlowGraph,
217	id: NodeId,
218	phase: Phase,
219	seen: &mut HashSet<(NodeId, Phase)>,
220) -> Result<(), Error> {
221	if !seen.insert((id, phase)) {
222		return Ok(());
223	}
224	let node = &graph[id];
225	let kind = node_kind_for_phase(graph, node);
226	let t = transition(kind, phase).map_err(|e| {
227		Error::compile(format!(
228			"phase mismatch at NodeId({}): expected one of {:?}, got {:?}",
229			id.get(),
230			e.expected,
231			e.got,
232		))
233	})?;
234	match (t, node) {
235		(Transition::Terminal, _) => Ok(()),
236		(Transition::PassThrough, _) => {
237			for succ in successors(node) {
238				visit_phase(graph, succ, phase, seen)?;
239			}
240			Ok(())
241		}
242		(Transition::Into(next_phase), _) => {
243			for succ in successors(node) {
244				visit_phase(graph, succ, next_phase, seen)?;
245			}
246			Ok(())
247		}
248		(
249			Transition::BiOutcome { response, tunnel },
250			Node::Fetch { next_response, next_tunnel, .. },
251		) => {
252			if let Some(r) = next_response {
253				visit_phase(graph, *r, response, seen)?;
254			}
255			if let Some(t) = next_tunnel {
256				visit_phase(graph, *t, tunnel, seen)?;
257			}
258			Ok(())
259		}
260		(Transition::BiOutcome { .. }, _) => {
261			Err(Error::compile("BiOutcome transition on non-Fetch node".to_string()))
262		}
263	}
264}
265
266#[cfg(test)]
267mod tests {
268	use std::collections::HashMap;
269	use std::path::PathBuf;
270	use std::time::SystemTime;
271
272	use super::*;
273	use crate::fetch::{FetchKind, SymbolicFetchRef, Terminator};
274	use crate::ir::{BodySide, FetchId, FlowGraphMeta, PredicateId, TerminatorId};
275
276	fn empty_meta() -> FlowGraphMeta {
277		FlowGraphMeta {
278			version_hash: [0; 32],
279			compiled_at: SystemTime::UNIX_EPOCH,
280			source_files: vec![PathBuf::new()],
281			feature_set: &[],
282		}
283	}
284
285	#[test]
286	fn dangling_terminator_id_in_terminate_node_rejected() {
287		let graph = SymbolicFlowGraph {
288			nodes: vec![Node::Terminate(TerminatorId::new(0))],
289			predicates: vec![],
290			middlewares: vec![],
291			fetches: vec![],
292			terminators: vec![],
293			entries: HashMap::new(),
294			meta: empty_meta(),
295		};
296		let err = validate(&graph).expect_err("must error");
297		assert!(err.to_string().contains("dangling TerminatorId"));
298	}
299
300	#[test]
301	fn dangling_node_id_in_fetch_edge_rejected() {
302		let graph = SymbolicFlowGraph {
303			nodes: vec![Node::Fetch {
304				id: FetchId::new(0),
305				next_response: Some(NodeId::new(99)),
306				next_tunnel: None,
307				collect_body_before: None,
308			}],
309			predicates: vec![],
310			middlewares: vec![],
311			fetches: vec![SymbolicFetchRef { kind: FetchKind::HttpProxy, args: serde_json::Value::Null }],
312			terminators: vec![],
313			entries: HashMap::new(),
314			meta: empty_meta(),
315		};
316		let err = validate(&graph).expect_err("must error");
317		assert!(err.to_string().contains("next_response dangling"));
318	}
319
320	#[test]
321	fn http_fetch_without_next_response_rejected() {
322		let term = Node::Terminate(TerminatorId::new(0));
323		let graph = SymbolicFlowGraph {
324			nodes: vec![
325				term,
326				Node::Fetch {
327					id: FetchId::new(0),
328					next_response: None,
329					next_tunnel: None,
330					collect_body_before: None,
331				},
332			],
333			predicates: vec![],
334			middlewares: vec![],
335			fetches: vec![SymbolicFetchRef { kind: FetchKind::HttpProxy, args: serde_json::Value::Null }],
336			terminators: vec![Terminator::WriteHttpResponse],
337			entries: HashMap::new(),
338			meta: empty_meta(),
339		};
340		let err = validate(&graph).expect_err("must error");
341		assert!(err.to_string().contains("requires next_response"));
342	}
343
344	#[test]
345	fn l4_forward_with_next_response_rejected() {
346		let graph = SymbolicFlowGraph {
347			nodes: vec![
348				Node::Terminate(TerminatorId::new(0)),
349				Node::Fetch {
350					id: FetchId::new(0),
351					next_response: Some(NodeId::new(0)),
352					next_tunnel: Some(NodeId::new(0)),
353					collect_body_before: None,
354				},
355			],
356			predicates: vec![],
357			middlewares: vec![],
358			fetches: vec![SymbolicFetchRef { kind: FetchKind::L4Forward, args: serde_json::Value::Null }],
359			terminators: vec![Terminator::ByteTunnel],
360			entries: HashMap::new(),
361			meta: empty_meta(),
362		};
363		let err = validate(&graph).expect_err("must error");
364		assert!(err.to_string().contains("L4Forward must not have next_response"));
365	}
366
367	#[test]
368	fn cyclic_graph_is_rejected() {
369		// Node 0 and Node 1 point at each other via Check on_match edges.
370		let graph = SymbolicFlowGraph {
371			nodes: vec![
372				Node::Check {
373					predicate: PredicateId::new(0),
374					on_match: NodeId::new(1),
375					on_miss: NodeId::new(1),
376					collect_body_before: None,
377				},
378				Node::Check {
379					predicate: PredicateId::new(0),
380					on_match: NodeId::new(0),
381					on_miss: NodeId::new(0),
382					collect_body_before: None,
383				},
384			],
385			predicates: vec![dummy_predicate()],
386			middlewares: vec![],
387			fetches: vec![],
388			terminators: vec![],
389			entries: HashMap::new(),
390			meta: empty_meta(),
391		};
392		let err = validate(&graph).expect_err("must error");
393		assert!(err.to_string().contains("cycle"));
394	}
395
396	#[test]
397	fn phase_check_rejects_write_http_response_reached_in_wrong_phase() {
398		// Upgrade out-phase is L7Request (spec C5.5 patch accepts L4Raw in);
399		// Terminate(WriteHttpResponse) requires L7Response โ€” so walking
400		// Upgrade directly into it is a phase mismatch the validator must
401		// catch.
402		let tid = TerminatorId::new(0);
403		let graph = SymbolicFlowGraph {
404			nodes: vec![Node::Terminate(tid), Node::Upgrade { next: NodeId::new(0) }],
405			predicates: vec![],
406			middlewares: vec![],
407			fetches: vec![],
408			terminators: vec![Terminator::WriteHttpResponse],
409			entries: {
410				let mut m = HashMap::new();
411				m.insert("127.0.0.1:443".parse().expect("parse"), NodeId::new(1));
412				m
413			},
414			meta: empty_meta(),
415		};
416		let err = check_phases(&graph).expect_err("must error");
417		assert!(err.to_string().contains("phase mismatch"));
418	}
419
420	fn dummy_predicate() -> crate::predicate::PredicateInst {
421		use crate::predicate::{CompiledOperator, CompiledValue, FieldPath, PredicateInst};
422		PredicateInst {
423			path: FieldPath::TlsSni,
424			op: CompiledOperator::Equals(CompiledValue::Str(std::sync::Arc::from("x"))),
425		}
426	}
427
428	// `BodySide` import is kept here to keep test doc consistent with the
429	// `Node` field it accesses in the broader impl.
430	const _: BodySide = BodySide::Request;
431}