Skip to main content

vane_core/compile/
validate.rs

1use std::collections::HashSet;
2
3use crate::error::Error;
4use crate::ir::{Node, NodeId, SymbolicFlowGraph};
5use crate::phase::{Phase, PhaseNodeKind, Transition, transition};
6
7/// Run IR-level structural and phase validation on a freshly-lowered graph.
8///
9/// # Errors
10/// Returns [`Error::compile`] on missing-id references, Fetch edges that
11/// don't match the kind's output-mode contract, acyclicity violations, or
12/// phase-state-machine mismatches.
13pub fn validate(graph: &SymbolicFlowGraph) -> Result<(), Error> {
14	check_id_ranges(graph)?;
15	check_fetch_edges(graph)?;
16	check_acyclic(graph)?;
17	check_phases(graph)?;
18	Ok(())
19}
20
21fn check_id_ranges(graph: &SymbolicFlowGraph) -> Result<(), Error> {
22	let n_nodes = u32::try_from(graph.nodes.len()).unwrap_or(u32::MAX);
23	let n_preds = u32::try_from(graph.predicates.len()).unwrap_or(u32::MAX);
24	let n_mws = u32::try_from(graph.middlewares.len()).unwrap_or(u32::MAX);
25	let n_fetches = u32::try_from(graph.fetches.len()).unwrap_or(u32::MAX);
26	let n_terms = u32::try_from(graph.terminators.len()).unwrap_or(u32::MAX);
27
28	for (idx, node) in graph.nodes.iter().enumerate() {
29		match node {
30			Node::Check { predicate, on_match, on_miss, .. } => {
31				if predicate.get() >= n_preds {
32					return Err(Error::compile(format!(
33						"node {idx}: dangling PredicateId({})",
34						predicate.get()
35					)));
36				}
37				if on_match.get() >= n_nodes {
38					return Err(Error::compile(format!("node {idx}.on_match dangling")));
39				}
40				if on_miss.get() >= n_nodes {
41					return Err(Error::compile(format!("node {idx}.on_miss dangling")));
42				}
43			}
44			Node::Middleware { id, next, on_error, .. } => {
45				if id.get() >= n_mws {
46					return Err(Error::compile(format!("node {idx}: dangling MiddlewareId({})", id.get())));
47				}
48				if next.get() >= n_nodes {
49					return Err(Error::compile(format!("node {idx}.next dangling")));
50				}
51				if let Some(e) = on_error
52					&& e.get() >= n_nodes
53				{
54					return Err(Error::compile(format!("node {idx}.on_error dangling")));
55				}
56			}
57			Node::Fetch { id, next_response, next_tunnel, .. } => {
58				if id.get() >= n_fetches {
59					return Err(Error::compile(format!("node {idx}: dangling FetchId({})", id.get())));
60				}
61				if let Some(r) = next_response
62					&& r.get() >= n_nodes
63				{
64					return Err(Error::compile(format!("node {idx}.next_response dangling")));
65				}
66				if let Some(t) = next_tunnel
67					&& t.get() >= n_nodes
68				{
69					return Err(Error::compile(format!("node {idx}.next_tunnel dangling")));
70				}
71			}
72			Node::Upgrade { next } => {
73				if next.get() >= n_nodes {
74					return Err(Error::compile(format!("node {idx}.next dangling")));
75				}
76			}
77			Node::Terminate(t) => {
78				if t.get() >= n_terms {
79					return Err(Error::compile(format!("node {idx}: dangling TerminatorId({})", t.get())));
80				}
81			}
82		}
83	}
84	Ok(())
85}
86
87fn check_fetch_edges(graph: &SymbolicFlowGraph) -> Result<(), Error> {
88	use crate::fetch::FetchKind::{HttpProxy, HttpSynthesize, L4Forward, WebSocketUpgrade};
89	for (idx, node) in graph.nodes.iter().enumerate() {
90		let Node::Fetch { id, next_response, next_tunnel, .. } = node else {
91			continue;
92		};
93		let kind = graph[*id].kind;
94		match kind {
95			HttpProxy | HttpSynthesize => {
96				if next_response.is_none() {
97					return Err(Error::compile(format!("node {idx}: {kind:?} requires next_response")));
98				}
99				if next_tunnel.is_some() {
100					return Err(Error::compile(format!("node {idx}: {kind:?} must not have next_tunnel")));
101				}
102			}
103			L4Forward => {
104				if next_tunnel.is_none() {
105					return Err(Error::compile(format!("node {idx}: L4Forward requires next_tunnel")));
106				}
107				if next_response.is_some() {
108					return Err(Error::compile(format!("node {idx}: L4Forward must not have next_response")));
109				}
110			}
111			WebSocketUpgrade => {
112				if next_response.is_none() || next_tunnel.is_none() {
113					return Err(Error::compile(format!(
114						"node {idx}: WebSocketUpgrade requires both next_response and next_tunnel"
115					)));
116				}
117			}
118		}
119	}
120	Ok(())
121}
122
123fn check_acyclic(graph: &SymbolicFlowGraph) -> Result<(), Error> {
124	#[derive(Copy, Clone)]
125	enum Color {
126		White,
127		Gray,
128		Black,
129	}
130	let mut color: Vec<Color> = (0..graph.nodes.len()).map(|_| Color::White).collect();
131
132	for start in 0..graph.nodes.len() {
133		if !matches!(color[start], Color::White) {
134			continue;
135		}
136		let mut stack: Vec<(usize, usize)> = vec![(start, 0)];
137		color[start] = Color::Gray;
138		while let Some(&(node_idx, child_idx)) = stack.last() {
139			let succs = successors(&graph.nodes[node_idx]);
140			if child_idx < succs.len() {
141				let next = succs[child_idx].get() as usize;
142				stack.last_mut().expect("non-empty").1 += 1;
143				match color[next] {
144					Color::White => {
145						color[next] = Color::Gray;
146						stack.push((next, 0));
147					}
148					Color::Gray => {
149						return Err(Error::compile(format!("cycle in graph at node {next}")));
150					}
151					Color::Black => {}
152				}
153			} else {
154				color[node_idx] = Color::Black;
155				stack.pop();
156			}
157		}
158	}
159	Ok(())
160}
161
162fn successors(node: &Node) -> Vec<NodeId> {
163	match node {
164		Node::Check { on_match, on_miss, .. } => vec![*on_match, *on_miss],
165		Node::Middleware { next, on_error, .. } => {
166			let mut v = vec![*next];
167			if let Some(e) = on_error {
168				v.push(*e);
169			}
170			v
171		}
172		Node::Fetch { next_response, next_tunnel, .. } => {
173			let mut v = Vec::new();
174			if let Some(r) = next_response {
175				v.push(*r);
176			}
177			if let Some(t) = next_tunnel {
178				v.push(*t);
179			}
180			v
181		}
182		Node::Upgrade { next } => vec![*next],
183		Node::Terminate(_) => Vec::new(),
184	}
185}
186
187fn node_kind_for_phase(graph: &SymbolicFlowGraph, node: &Node) -> PhaseNodeKind {
188	match node {
189		Node::Check { .. } => PhaseNodeKind::Check,
190		Node::Middleware { id, .. } => PhaseNodeKind::Middleware(graph[*id].kind),
191		Node::Fetch { id, .. } => PhaseNodeKind::Fetch(graph[*id].kind),
192		Node::Upgrade { .. } => PhaseNodeKind::Upgrade,
193		Node::Terminate(t) => PhaseNodeKind::Terminate(graph[*t]),
194	}
195}
196
197/// Walk each listener entry through the phase transition table.
198///
199/// Not invoked from [`validate`] today because MVP graphs lack the
200/// `protocol_detect` middleware that advances `L4Raw → L4Peeked` — that
201/// middleware lands at S1-16. Callable directly for tests and for future
202/// validators that want phase coverage.
203///
204/// # Errors
205/// Returns [`Error::compile`] on phase mismatches per 02-flow.md § _Phase
206/// state machine_.
207pub fn check_phases(graph: &SymbolicFlowGraph) -> Result<(), Error> {
208	let mut seen: HashSet<(NodeId, Phase)> = HashSet::new();
209	for &entry in graph.entries.values() {
210		visit_phase(graph, entry, Phase::L4Raw, &mut seen)?;
211	}
212	// Walk every L7 listener's synthesised `Short(Response)` target as a
213	// second-class entry rooted at `Phase::L7Response`. The lower pass
214	// always emits these as `Terminate(WriteHttpResponse)` nodes —
215	// `WriteHttpResponse` accepts `L7Response` per the transition
216	// table, so a clean lower produces a clean walk. Bogus entries
217	// (a synth target whose terminator is not `WriteHttpResponse`) get
218	// caught here with the same "phase mismatch" error shape.
219	for &synth in graph.meta.short_circuit_response_entry.values() {
220		visit_phase(graph, synth, Phase::L7Response, &mut seen)?;
221	}
222	Ok(())
223}
224
225fn visit_phase(
226	graph: &SymbolicFlowGraph,
227	id: NodeId,
228	phase: Phase,
229	seen: &mut HashSet<(NodeId, Phase)>,
230) -> Result<(), Error> {
231	if !seen.insert((id, phase)) {
232		return Ok(());
233	}
234	let node = &graph[id];
235	let kind = node_kind_for_phase(graph, node);
236	let t = transition(kind, phase).map_err(|e| {
237		Error::compile(format!(
238			"phase mismatch at NodeId({}): expected one of {:?}, got {:?}",
239			id.get(),
240			e.expected,
241			e.got,
242		))
243	})?;
244	match (t, node) {
245		(Transition::Terminal, _) => Ok(()),
246		(Transition::PassThrough, _) => {
247			for succ in successors(node) {
248				visit_phase(graph, succ, phase, seen)?;
249			}
250			Ok(())
251		}
252		(Transition::Into(next_phase), _) => {
253			for succ in successors(node) {
254				visit_phase(graph, succ, next_phase, seen)?;
255			}
256			Ok(())
257		}
258		(
259			Transition::BiOutcome { response, tunnel },
260			Node::Fetch { next_response, next_tunnel, .. },
261		) => {
262			if let Some(r) = next_response {
263				visit_phase(graph, *r, response, seen)?;
264			}
265			if let Some(t) = next_tunnel {
266				visit_phase(graph, *t, tunnel, seen)?;
267			}
268			Ok(())
269		}
270		(Transition::BiOutcome { .. }, _) => {
271			Err(Error::compile("BiOutcome transition on non-Fetch node".to_string()))
272		}
273	}
274}
275
276#[cfg(test)]
277mod tests {
278	use std::collections::HashMap;
279	use std::path::PathBuf;
280	use std::time::SystemTime;
281
282	use super::*;
283	use crate::fetch::{FetchKind, SymbolicFetchRef, Terminator};
284	use crate::ir::{BodySide, FetchId, FlowGraphMeta, PredicateId, TerminatorId};
285
286	fn empty_meta() -> FlowGraphMeta {
287		FlowGraphMeta {
288			version_hash: [0; 32],
289			compiled_at: SystemTime::UNIX_EPOCH,
290			source_files: vec![PathBuf::new()],
291			feature_set: &[],
292			short_circuit_response_entry: std::collections::BTreeMap::new(),
293			listener_tls: std::collections::BTreeMap::new(),
294		}
295	}
296
297	#[test]
298	fn dangling_terminator_id_in_terminate_node_rejected() {
299		let graph = SymbolicFlowGraph {
300			nodes: vec![Node::Terminate(TerminatorId::new(0))],
301			predicates: vec![],
302			middlewares: vec![],
303			fetches: vec![],
304			terminators: vec![],
305			entries: HashMap::new(),
306			meta: empty_meta(),
307		};
308		let err = validate(&graph).expect_err("must error");
309		assert!(err.to_string().contains("dangling TerminatorId"));
310	}
311
312	#[test]
313	fn dangling_node_id_in_fetch_edge_rejected() {
314		let graph = SymbolicFlowGraph {
315			nodes: vec![Node::Fetch {
316				id: FetchId::new(0),
317				next_response: Some(NodeId::new(99)),
318				next_tunnel: None,
319				collect_body_before: None,
320				body_limit: 0,
321			}],
322			predicates: vec![],
323			middlewares: vec![],
324			fetches: vec![SymbolicFetchRef { kind: FetchKind::HttpProxy, args: serde_json::Value::Null }],
325			terminators: vec![],
326			entries: HashMap::new(),
327			meta: empty_meta(),
328		};
329		let err = validate(&graph).expect_err("must error");
330		assert!(err.to_string().contains("next_response dangling"));
331	}
332
333	#[test]
334	fn http_fetch_without_next_response_rejected() {
335		let term = Node::Terminate(TerminatorId::new(0));
336		let graph = SymbolicFlowGraph {
337			nodes: vec![
338				term,
339				Node::Fetch {
340					id: FetchId::new(0),
341					next_response: None,
342					next_tunnel: None,
343					collect_body_before: None,
344					body_limit: 0,
345				},
346			],
347			predicates: vec![],
348			middlewares: vec![],
349			fetches: vec![SymbolicFetchRef { kind: FetchKind::HttpProxy, args: serde_json::Value::Null }],
350			terminators: vec![Terminator::WriteHttpResponse],
351			entries: HashMap::new(),
352			meta: empty_meta(),
353		};
354		let err = validate(&graph).expect_err("must error");
355		assert!(err.to_string().contains("requires next_response"));
356	}
357
358	#[test]
359	fn l4_forward_with_next_response_rejected() {
360		let graph = SymbolicFlowGraph {
361			nodes: vec![
362				Node::Terminate(TerminatorId::new(0)),
363				Node::Fetch {
364					id: FetchId::new(0),
365					next_response: Some(NodeId::new(0)),
366					next_tunnel: Some(NodeId::new(0)),
367					collect_body_before: None,
368					body_limit: 0,
369				},
370			],
371			predicates: vec![],
372			middlewares: vec![],
373			fetches: vec![SymbolicFetchRef { kind: FetchKind::L4Forward, args: serde_json::Value::Null }],
374			terminators: vec![Terminator::ByteTunnel],
375			entries: HashMap::new(),
376			meta: empty_meta(),
377		};
378		let err = validate(&graph).expect_err("must error");
379		assert!(err.to_string().contains("L4Forward must not have next_response"));
380	}
381
382	#[test]
383	fn cyclic_graph_is_rejected() {
384		// Node 0 and Node 1 point at each other via Check on_match edges.
385		let graph = SymbolicFlowGraph {
386			nodes: vec![
387				Node::Check {
388					predicate: PredicateId::new(0),
389					on_match: NodeId::new(1),
390					on_miss: NodeId::new(1),
391					collect_body_before: None,
392					body_limit: 0,
393				},
394				Node::Check {
395					predicate: PredicateId::new(0),
396					on_match: NodeId::new(0),
397					on_miss: NodeId::new(0),
398					collect_body_before: None,
399					body_limit: 0,
400				},
401			],
402			predicates: vec![dummy_predicate()],
403			middlewares: vec![],
404			fetches: vec![],
405			terminators: vec![],
406			entries: HashMap::new(),
407			meta: empty_meta(),
408		};
409		let err = validate(&graph).expect_err("must error");
410		assert!(err.to_string().contains("cycle"));
411	}
412
413	#[test]
414	fn phase_check_rejects_write_http_response_reached_in_wrong_phase() {
415		// Upgrade out-phase is L7Request (spec C5.5 patch accepts L4Raw in);
416		// Terminate(WriteHttpResponse) requires L7Response — so walking
417		// Upgrade directly into it is a phase mismatch the validator must
418		// catch.
419		let tid = TerminatorId::new(0);
420		let graph = SymbolicFlowGraph {
421			nodes: vec![Node::Terminate(tid), Node::Upgrade { next: NodeId::new(0) }],
422			predicates: vec![],
423			middlewares: vec![],
424			fetches: vec![],
425			terminators: vec![Terminator::WriteHttpResponse],
426			entries: {
427				let mut m = HashMap::new();
428				m.insert("127.0.0.1:443".parse().expect("parse"), NodeId::new(1));
429				m
430			},
431			meta: empty_meta(),
432		};
433		let err = check_phases(&graph).expect_err("must error");
434		assert!(err.to_string().contains("phase mismatch"));
435	}
436
437	#[test]
438	fn phase_check_rejects_short_circuit_synth_with_wrong_terminator() {
439		// `meta.short_circuit_response_entry` values are walked at
440		// `Phase::L7Response`. A synth target whose terminator does not
441		// accept that phase must trip the same "phase mismatch" error
442		// the standard walker uses. `Terminator::Close` is phase-agnostic
443		// so it would never trip this check; `ByteTunnel` only accepts
444		// `Phase::Tunnel` and is the right negative-test fixture.
445		let bad_tid = TerminatorId::new(0);
446		let mut meta = empty_meta();
447		meta.short_circuit_response_entry.insert(NodeId::new(1), NodeId::new(0));
448		let graph = SymbolicFlowGraph {
449			nodes: vec![Node::Terminate(bad_tid), Node::Upgrade { next: NodeId::new(0) }],
450			predicates: vec![],
451			middlewares: vec![],
452			fetches: vec![],
453			terminators: vec![Terminator::ByteTunnel],
454			// No `entries` — exercise the synth walk in isolation.
455			entries: HashMap::new(),
456			meta,
457		};
458		let err = check_phases(&graph).expect_err("must error on bad synth phase");
459		assert!(err.to_string().contains("phase mismatch"), "{err}");
460	}
461
462	fn dummy_predicate() -> crate::predicate::PredicateInst {
463		use crate::predicate::{CompiledOperator, CompiledValue, FieldPath, PredicateInst};
464		PredicateInst {
465			path: FieldPath::TlsSni,
466			op: CompiledOperator::Equals(CompiledValue::Str(std::sync::Arc::from("x"))),
467		}
468	}
469
470	// `BodySide` import is kept here to keep test doc consistent with the
471	// `Node` field it accesses in the broader impl.
472	const _: BodySide = BodySide::Request;
473}