Skip to main content

vane_core/compile/
validate.rs

1use std::collections::HashSet;
2
3use crate::error::Error;
4use crate::ir::{Node, NodeId, SymbolicFlowGraph};
5use crate::phase::{Phase, PhaseNodeKind, Transition, transition};
6
7/// Run IR-level structural and phase validation on a freshly-lowered graph.
8///
9/// # Errors
10/// Returns [`Error::compile`] on missing-id references, Fetch edges that
11/// don't match the kind's output-mode contract, acyclicity violations, or
12/// phase-state-machine mismatches.
13pub fn validate(graph: &SymbolicFlowGraph) -> Result<(), Error> {
14	check_id_ranges(graph)?;
15	check_fetch_edges(graph)?;
16	check_acyclic(graph)?;
17	check_phases(graph)?;
18	Ok(())
19}
20
21fn check_id_ranges(graph: &SymbolicFlowGraph) -> Result<(), Error> {
22	let n_nodes = u32::try_from(graph.nodes.len()).unwrap_or(u32::MAX);
23	let n_preds = u32::try_from(graph.predicates.len()).unwrap_or(u32::MAX);
24	let n_mws = u32::try_from(graph.middlewares.len()).unwrap_or(u32::MAX);
25	let n_fetches = u32::try_from(graph.fetches.len()).unwrap_or(u32::MAX);
26	let n_terms = u32::try_from(graph.terminators.len()).unwrap_or(u32::MAX);
27
28	for (idx, node) in graph.nodes.iter().enumerate() {
29		match node {
30			Node::Check { predicate, on_match, on_miss, .. } => {
31				if predicate.get() >= n_preds {
32					return Err(Error::compile(format!(
33						"node {idx}: dangling PredicateId({})",
34						predicate.get()
35					)));
36				}
37				if on_match.get() >= n_nodes {
38					return Err(Error::compile(format!("node {idx}.on_match dangling")));
39				}
40				if on_miss.get() >= n_nodes {
41					return Err(Error::compile(format!("node {idx}.on_miss dangling")));
42				}
43			}
44			Node::Middleware { id, next, on_error, .. } => {
45				if id.get() >= n_mws {
46					return Err(Error::compile(format!("node {idx}: dangling MiddlewareId({})", id.get())));
47				}
48				if next.get() >= n_nodes {
49					return Err(Error::compile(format!("node {idx}.next dangling")));
50				}
51				if let Some(e) = on_error
52					&& e.get() >= n_nodes
53				{
54					return Err(Error::compile(format!("node {idx}.on_error dangling")));
55				}
56			}
57			Node::Fetch { id, next_response, next_tunnel, .. } => {
58				if id.get() >= n_fetches {
59					return Err(Error::compile(format!("node {idx}: dangling FetchId({})", id.get())));
60				}
61				if let Some(r) = next_response
62					&& r.get() >= n_nodes
63				{
64					return Err(Error::compile(format!("node {idx}.next_response dangling")));
65				}
66				if let Some(t) = next_tunnel
67					&& t.get() >= n_nodes
68				{
69					return Err(Error::compile(format!("node {idx}.next_tunnel dangling")));
70				}
71			}
72			Node::Upgrade { next } => {
73				if next.get() >= n_nodes {
74					return Err(Error::compile(format!("node {idx}.next dangling")));
75				}
76			}
77			Node::Terminate(t) => {
78				if t.get() >= n_terms {
79					return Err(Error::compile(format!("node {idx}: dangling TerminatorId({})", t.get())));
80				}
81			}
82		}
83	}
84	Ok(())
85}
86
87fn check_fetch_edges(graph: &SymbolicFlowGraph) -> Result<(), Error> {
88	use crate::fetch::FetchKind::{HttpProxy, HttpSynthesize, L4Forward, WebSocketUpgrade};
89	for (idx, node) in graph.nodes.iter().enumerate() {
90		let Node::Fetch { id, next_response, next_tunnel, .. } = node else {
91			continue;
92		};
93		let kind = graph[*id].kind;
94		match kind {
95			HttpProxy | HttpSynthesize => {
96				if next_response.is_none() {
97					return Err(Error::compile(format!("node {idx}: {kind:?} requires next_response")));
98				}
99				if next_tunnel.is_some() {
100					return Err(Error::compile(format!("node {idx}: {kind:?} must not have next_tunnel")));
101				}
102			}
103			L4Forward => {
104				if next_tunnel.is_none() {
105					return Err(Error::compile(format!("node {idx}: L4Forward requires next_tunnel")));
106				}
107				if next_response.is_some() {
108					return Err(Error::compile(format!("node {idx}: L4Forward must not have next_response")));
109				}
110			}
111			WebSocketUpgrade => {
112				if next_response.is_none() || next_tunnel.is_none() {
113					return Err(Error::compile(format!(
114						"node {idx}: WebSocketUpgrade requires both next_response and next_tunnel"
115					)));
116				}
117			}
118		}
119	}
120	Ok(())
121}
122
123fn check_acyclic(graph: &SymbolicFlowGraph) -> Result<(), Error> {
124	#[derive(Copy, Clone)]
125	enum Color {
126		White,
127		Gray,
128		Black,
129	}
130	let mut color: Vec<Color> = (0..graph.nodes.len()).map(|_| Color::White).collect();
131
132	for start in 0..graph.nodes.len() {
133		if !matches!(color[start], Color::White) {
134			continue;
135		}
136		let mut stack: Vec<(usize, usize)> = vec![(start, 0)];
137		color[start] = Color::Gray;
138		while let Some(&(node_idx, child_idx)) = stack.last() {
139			let succs = successors(&graph.nodes[node_idx]);
140			if child_idx < succs.len() {
141				let next = succs[child_idx].get() as usize;
142				stack.last_mut().expect("non-empty").1 += 1;
143				match color[next] {
144					Color::White => {
145						color[next] = Color::Gray;
146						stack.push((next, 0));
147					}
148					Color::Gray => {
149						return Err(Error::compile(format!("cycle in graph at node {next}")));
150					}
151					Color::Black => {}
152				}
153			} else {
154				color[node_idx] = Color::Black;
155				stack.pop();
156			}
157		}
158	}
159	Ok(())
160}
161
162fn successors(node: &Node) -> Vec<NodeId> {
163	match node {
164		Node::Check { on_match, on_miss, .. } => vec![*on_match, *on_miss],
165		Node::Middleware { next, on_error, .. } => {
166			let mut v = vec![*next];
167			if let Some(e) = on_error {
168				v.push(*e);
169			}
170			v
171		}
172		Node::Fetch { next_response, next_tunnel, .. } => {
173			let mut v = Vec::new();
174			if let Some(r) = next_response {
175				v.push(*r);
176			}
177			if let Some(t) = next_tunnel {
178				v.push(*t);
179			}
180			v
181		}
182		Node::Upgrade { next } => vec![*next],
183		Node::Terminate(_) => Vec::new(),
184	}
185}
186
187fn node_kind_for_phase(graph: &SymbolicFlowGraph, node: &Node) -> PhaseNodeKind {
188	match node {
189		Node::Check { .. } => PhaseNodeKind::Check,
190		Node::Middleware { id, .. } => PhaseNodeKind::Middleware(graph[*id].kind),
191		Node::Fetch { id, .. } => PhaseNodeKind::Fetch(graph[*id].kind),
192		Node::Upgrade { .. } => PhaseNodeKind::Upgrade,
193		Node::Terminate(t) => PhaseNodeKind::Terminate(graph[*t]),
194	}
195}
196
197/// Walk each listener entry through the phase transition table.
198///
199/// Not invoked from [`validate`] today because MVP graphs lack the
200/// `protocol_detect` middleware that advances `L4Raw → L4Peeked` — that
201/// middleware lands at S1-16. Callable directly for tests and for future
202/// validators that want phase coverage.
203///
204/// # Errors
205/// Returns [`Error::compile`] on phase mismatches per 02-flow.md § _Phase
206/// state machine_.
207pub fn check_phases(graph: &SymbolicFlowGraph) -> Result<(), Error> {
208	let mut seen: HashSet<(NodeId, Phase)> = HashSet::new();
209	for &entry in graph.entries.values() {
210		visit_phase(graph, entry, Phase::L4Raw, &mut seen)?;
211	}
212	// Walk every L7 listener's synthesised `Short(Response)` target as a
213	// second-class entry rooted at `Phase::L7Response`. The lower pass
214	// always emits these as `Terminate(WriteHttpResponse)` nodes —
215	// `WriteHttpResponse` accepts `L7Response` per the transition
216	// table, so a clean lower produces a clean walk. Bogus entries
217	// (a synth target whose terminator is not `WriteHttpResponse`) get
218	// caught here with the same "phase mismatch" error shape.
219	for &synth in graph.meta.short_circuit_response_entry.values() {
220		visit_phase(graph, synth, Phase::L7Response, &mut seen)?;
221	}
222	Ok(())
223}
224
225fn visit_phase(
226	graph: &SymbolicFlowGraph,
227	id: NodeId,
228	phase: Phase,
229	seen: &mut HashSet<(NodeId, Phase)>,
230) -> Result<(), Error> {
231	if !seen.insert((id, phase)) {
232		return Ok(());
233	}
234	let node = &graph[id];
235	let kind = node_kind_for_phase(graph, node);
236	let t = transition(kind, phase).map_err(|e| {
237		Error::compile(format!(
238			"phase mismatch at NodeId({}): expected one of {:?}, got {:?}",
239			id.get(),
240			e.expected,
241			e.got,
242		))
243	})?;
244	match (t, node) {
245		(Transition::Terminal, _) => Ok(()),
246		(Transition::PassThrough, _) => {
247			for succ in successors(node) {
248				visit_phase(graph, succ, phase, seen)?;
249			}
250			Ok(())
251		}
252		(Transition::Into(next_phase), _) => {
253			for succ in successors(node) {
254				visit_phase(graph, succ, next_phase, seen)?;
255			}
256			Ok(())
257		}
258		(
259			Transition::BiOutcome { response, tunnel },
260			Node::Fetch { next_response, next_tunnel, .. },
261		) => {
262			if let Some(r) = next_response {
263				visit_phase(graph, *r, response, seen)?;
264			}
265			if let Some(t) = next_tunnel {
266				visit_phase(graph, *t, tunnel, seen)?;
267			}
268			Ok(())
269		}
270		(Transition::BiOutcome { .. }, _) => {
271			Err(Error::compile("BiOutcome transition on non-Fetch node".to_string()))
272		}
273	}
274}
275
276#[cfg(test)]
277mod tests {
278	use std::collections::HashMap;
279	use std::path::PathBuf;
280	use std::time::SystemTime;
281
282	use super::*;
283	use crate::fetch::{FetchKind, SymbolicFetchRef, Terminator};
284	use crate::ir::{BodySide, FetchId, FlowGraphMeta, PredicateId, TerminatorId};
285
286	fn empty_meta() -> FlowGraphMeta {
287		FlowGraphMeta {
288			version_hash: [0; 32],
289			compiled_at: SystemTime::UNIX_EPOCH,
290			source_files: vec![PathBuf::new()],
291			feature_set: &[],
292			short_circuit_response_entry: std::collections::BTreeMap::new(),
293			listener_tls: std::collections::BTreeMap::new(),
294			listener_kinds: std::collections::BTreeMap::new(),
295
296			listener_transports: std::collections::BTreeMap::new(),
297		}
298	}
299
300	#[test]
301	fn dangling_terminator_id_in_terminate_node_rejected() {
302		let graph = SymbolicFlowGraph {
303			nodes: vec![Node::Terminate(TerminatorId::new(0))],
304			predicates: vec![],
305			middlewares: vec![],
306			fetches: vec![],
307			terminators: vec![],
308			entries: HashMap::new(),
309			meta: empty_meta(),
310		};
311		let err = validate(&graph).expect_err("must error");
312		assert!(err.to_string().contains("dangling TerminatorId"));
313	}
314
315	#[test]
316	fn dangling_node_id_in_fetch_edge_rejected() {
317		let graph = SymbolicFlowGraph {
318			nodes: vec![Node::Fetch {
319				id: FetchId::new(0),
320				next_response: Some(NodeId::new(99)),
321				next_tunnel: None,
322				collect_body_before: None,
323				body_limit: 0,
324			}],
325			predicates: vec![],
326			middlewares: vec![],
327			fetches: vec![SymbolicFetchRef {
328				kind: FetchKind::HttpProxy,
329				args: serde_json::Value::Null,
330				retry_buffer_required: false,
331			}],
332			terminators: vec![],
333			entries: HashMap::new(),
334			meta: empty_meta(),
335		};
336		let err = validate(&graph).expect_err("must error");
337		assert!(err.to_string().contains("next_response dangling"));
338	}
339
340	#[test]
341	fn http_fetch_without_next_response_rejected() {
342		let term = Node::Terminate(TerminatorId::new(0));
343		let graph = SymbolicFlowGraph {
344			nodes: vec![
345				term,
346				Node::Fetch {
347					id: FetchId::new(0),
348					next_response: None,
349					next_tunnel: None,
350					collect_body_before: None,
351					body_limit: 0,
352				},
353			],
354			predicates: vec![],
355			middlewares: vec![],
356			fetches: vec![SymbolicFetchRef {
357				kind: FetchKind::HttpProxy,
358				args: serde_json::Value::Null,
359				retry_buffer_required: false,
360			}],
361			terminators: vec![Terminator::WriteHttpResponse],
362			entries: HashMap::new(),
363			meta: empty_meta(),
364		};
365		let err = validate(&graph).expect_err("must error");
366		assert!(err.to_string().contains("requires next_response"));
367	}
368
369	#[test]
370	fn l4_forward_with_next_response_rejected() {
371		let graph = SymbolicFlowGraph {
372			nodes: vec![
373				Node::Terminate(TerminatorId::new(0)),
374				Node::Fetch {
375					id: FetchId::new(0),
376					next_response: Some(NodeId::new(0)),
377					next_tunnel: Some(NodeId::new(0)),
378					collect_body_before: None,
379					body_limit: 0,
380				},
381			],
382			predicates: vec![],
383			middlewares: vec![],
384			fetches: vec![SymbolicFetchRef {
385				kind: FetchKind::L4Forward,
386				args: serde_json::Value::Null,
387				retry_buffer_required: false,
388			}],
389			terminators: vec![Terminator::ByteTunnel],
390			entries: HashMap::new(),
391			meta: empty_meta(),
392		};
393		let err = validate(&graph).expect_err("must error");
394		assert!(err.to_string().contains("L4Forward must not have next_response"));
395	}
396
397	#[test]
398	fn cyclic_graph_is_rejected() {
399		// Node 0 and Node 1 point at each other via Check on_match edges.
400		let graph = SymbolicFlowGraph {
401			nodes: vec![
402				Node::Check {
403					predicate: PredicateId::new(0),
404					on_match: NodeId::new(1),
405					on_miss: NodeId::new(1),
406					collect_body_before: None,
407					body_limit: 0,
408				},
409				Node::Check {
410					predicate: PredicateId::new(0),
411					on_match: NodeId::new(0),
412					on_miss: NodeId::new(0),
413					collect_body_before: None,
414					body_limit: 0,
415				},
416			],
417			predicates: vec![dummy_predicate()],
418			middlewares: vec![],
419			fetches: vec![],
420			terminators: vec![],
421			entries: HashMap::new(),
422			meta: empty_meta(),
423		};
424		let err = validate(&graph).expect_err("must error");
425		assert!(err.to_string().contains("cycle"));
426	}
427
428	#[test]
429	fn phase_check_rejects_write_http_response_reached_in_wrong_phase() {
430		// Upgrade out-phase is L7Request (spec C5.5 patch accepts L4Raw in);
431		// Terminate(WriteHttpResponse) requires L7Response — so walking
432		// Upgrade directly into it is a phase mismatch the validator must
433		// catch.
434		let tid = TerminatorId::new(0);
435		let graph = SymbolicFlowGraph {
436			nodes: vec![Node::Terminate(tid), Node::Upgrade { next: NodeId::new(0) }],
437			predicates: vec![],
438			middlewares: vec![],
439			fetches: vec![],
440			terminators: vec![Terminator::WriteHttpResponse],
441			entries: {
442				let mut m = HashMap::new();
443				m.insert("127.0.0.1:443".parse().expect("parse"), NodeId::new(1));
444				m
445			},
446			meta: empty_meta(),
447		};
448		let err = check_phases(&graph).expect_err("must error");
449		assert!(err.to_string().contains("phase mismatch"));
450	}
451
452	#[test]
453	fn phase_check_rejects_short_circuit_synth_with_wrong_terminator() {
454		// `meta.short_circuit_response_entry` values are walked at
455		// `Phase::L7Response`. A synth target whose terminator does not
456		// accept that phase must trip the same "phase mismatch" error
457		// the standard walker uses. `Terminator::Close` is phase-agnostic
458		// so it would never trip this check; `ByteTunnel` only accepts
459		// `Phase::Tunnel` and is the right negative-test fixture.
460		let bad_tid = TerminatorId::new(0);
461		let mut meta = empty_meta();
462		meta.short_circuit_response_entry.insert(NodeId::new(1), NodeId::new(0));
463		let graph = SymbolicFlowGraph {
464			nodes: vec![Node::Terminate(bad_tid), Node::Upgrade { next: NodeId::new(0) }],
465			predicates: vec![],
466			middlewares: vec![],
467			fetches: vec![],
468			terminators: vec![Terminator::ByteTunnel],
469			// No `entries` — exercise the synth walk in isolation.
470			entries: HashMap::new(),
471			meta,
472		};
473		let err = check_phases(&graph).expect_err("must error on bad synth phase");
474		assert!(err.to_string().contains("phase mismatch"), "{err}");
475	}
476
477	fn dummy_predicate() -> crate::predicate::PredicateInst {
478		use crate::predicate::{CompiledOperator, CompiledValue, FieldPath, PredicateInst};
479		PredicateInst {
480			path: FieldPath::TlsSni,
481			op: CompiledOperator::Equals(CompiledValue::Str(std::sync::Arc::from("x"))),
482		}
483	}
484
485	// `BodySide` import is kept here to keep test doc consistent with the
486	// `Node` field it accesses in the broader impl.
487	const _: BodySide = BodySide::Request;
488}