1use std::collections::HashSet;
2
3use crate::error::Error;
4use crate::ir::{Node, NodeId, SymbolicFlowGraph};
5use crate::phase::{Phase, PhaseNodeKind, Transition, transition};
6
7pub fn validate(graph: &SymbolicFlowGraph) -> Result<(), Error> {
14 check_id_ranges(graph)?;
15 check_fetch_edges(graph)?;
16 check_acyclic(graph)?;
17 check_phases(graph)?;
18 Ok(())
19}
20
21fn check_id_ranges(graph: &SymbolicFlowGraph) -> Result<(), Error> {
22 let n_nodes = u32::try_from(graph.nodes.len()).unwrap_or(u32::MAX);
23 let n_preds = u32::try_from(graph.predicates.len()).unwrap_or(u32::MAX);
24 let n_mws = u32::try_from(graph.middlewares.len()).unwrap_or(u32::MAX);
25 let n_fetches = u32::try_from(graph.fetches.len()).unwrap_or(u32::MAX);
26 let n_terms = u32::try_from(graph.terminators.len()).unwrap_or(u32::MAX);
27
28 for (idx, node) in graph.nodes.iter().enumerate() {
29 match node {
30 Node::Check { predicate, on_match, on_miss, .. } => {
31 if predicate.get() >= n_preds {
32 return Err(Error::compile(format!(
33 "node {idx}: dangling PredicateId({})",
34 predicate.get()
35 )));
36 }
37 if on_match.get() >= n_nodes {
38 return Err(Error::compile(format!("node {idx}.on_match dangling")));
39 }
40 if on_miss.get() >= n_nodes {
41 return Err(Error::compile(format!("node {idx}.on_miss dangling")));
42 }
43 }
44 Node::Middleware { id, next, on_error, .. } => {
45 if id.get() >= n_mws {
46 return Err(Error::compile(format!("node {idx}: dangling MiddlewareId({})", id.get())));
47 }
48 if next.get() >= n_nodes {
49 return Err(Error::compile(format!("node {idx}.next dangling")));
50 }
51 if let Some(e) = on_error
52 && e.get() >= n_nodes
53 {
54 return Err(Error::compile(format!("node {idx}.on_error dangling")));
55 }
56 }
57 Node::Fetch { id, next_response, next_tunnel, .. } => {
58 if id.get() >= n_fetches {
59 return Err(Error::compile(format!("node {idx}: dangling FetchId({})", id.get())));
60 }
61 if let Some(r) = next_response
62 && r.get() >= n_nodes
63 {
64 return Err(Error::compile(format!("node {idx}.next_response dangling")));
65 }
66 if let Some(t) = next_tunnel
67 && t.get() >= n_nodes
68 {
69 return Err(Error::compile(format!("node {idx}.next_tunnel dangling")));
70 }
71 }
72 Node::Upgrade { next } => {
73 if next.get() >= n_nodes {
74 return Err(Error::compile(format!("node {idx}.next dangling")));
75 }
76 }
77 Node::Terminate(t) => {
78 if t.get() >= n_terms {
79 return Err(Error::compile(format!("node {idx}: dangling TerminatorId({})", t.get())));
80 }
81 }
82 }
83 }
84 Ok(())
85}
86
87fn check_fetch_edges(graph: &SymbolicFlowGraph) -> Result<(), Error> {
88 use crate::fetch::FetchKind::{HttpProxy, HttpSynthesize, L4Forward, WebSocketUpgrade};
89 for (idx, node) in graph.nodes.iter().enumerate() {
90 let Node::Fetch { id, next_response, next_tunnel, .. } = node else {
91 continue;
92 };
93 let kind = graph[*id].kind;
94 match kind {
95 HttpProxy | HttpSynthesize => {
96 if next_response.is_none() {
97 return Err(Error::compile(format!("node {idx}: {kind:?} requires next_response")));
98 }
99 if next_tunnel.is_some() {
100 return Err(Error::compile(format!("node {idx}: {kind:?} must not have next_tunnel")));
101 }
102 }
103 L4Forward => {
104 if next_tunnel.is_none() {
105 return Err(Error::compile(format!("node {idx}: L4Forward requires next_tunnel")));
106 }
107 if next_response.is_some() {
108 return Err(Error::compile(format!("node {idx}: L4Forward must not have next_response")));
109 }
110 }
111 WebSocketUpgrade => {
112 if next_response.is_none() || next_tunnel.is_none() {
113 return Err(Error::compile(format!(
114 "node {idx}: WebSocketUpgrade requires both next_response and next_tunnel"
115 )));
116 }
117 }
118 }
119 }
120 Ok(())
121}
122
123fn check_acyclic(graph: &SymbolicFlowGraph) -> Result<(), Error> {
124 #[derive(Copy, Clone)]
125 enum Color {
126 White,
127 Gray,
128 Black,
129 }
130 let mut color: Vec<Color> = (0..graph.nodes.len()).map(|_| Color::White).collect();
131
132 for start in 0..graph.nodes.len() {
133 if !matches!(color[start], Color::White) {
134 continue;
135 }
136 let mut stack: Vec<(usize, usize)> = vec![(start, 0)];
137 color[start] = Color::Gray;
138 while let Some(&(node_idx, child_idx)) = stack.last() {
139 let succs = successors(&graph.nodes[node_idx]);
140 if child_idx < succs.len() {
141 let next = succs[child_idx].get() as usize;
142 stack.last_mut().expect("non-empty").1 += 1;
143 match color[next] {
144 Color::White => {
145 color[next] = Color::Gray;
146 stack.push((next, 0));
147 }
148 Color::Gray => {
149 return Err(Error::compile(format!("cycle in graph at node {next}")));
150 }
151 Color::Black => {}
152 }
153 } else {
154 color[node_idx] = Color::Black;
155 stack.pop();
156 }
157 }
158 }
159 Ok(())
160}
161
162fn successors(node: &Node) -> Vec<NodeId> {
163 match node {
164 Node::Check { on_match, on_miss, .. } => vec![*on_match, *on_miss],
165 Node::Middleware { next, on_error, .. } => {
166 let mut v = vec![*next];
167 if let Some(e) = on_error {
168 v.push(*e);
169 }
170 v
171 }
172 Node::Fetch { next_response, next_tunnel, .. } => {
173 let mut v = Vec::new();
174 if let Some(r) = next_response {
175 v.push(*r);
176 }
177 if let Some(t) = next_tunnel {
178 v.push(*t);
179 }
180 v
181 }
182 Node::Upgrade { next } => vec![*next],
183 Node::Terminate(_) => Vec::new(),
184 }
185}
186
187fn node_kind_for_phase(graph: &SymbolicFlowGraph, node: &Node) -> PhaseNodeKind {
188 match node {
189 Node::Check { .. } => PhaseNodeKind::Check,
190 Node::Middleware { id, .. } => PhaseNodeKind::Middleware(graph[*id].kind),
191 Node::Fetch { id, .. } => PhaseNodeKind::Fetch(graph[*id].kind),
192 Node::Upgrade { .. } => PhaseNodeKind::Upgrade,
193 Node::Terminate(t) => PhaseNodeKind::Terminate(graph[*t]),
194 }
195}
196
197pub fn check_phases(graph: &SymbolicFlowGraph) -> Result<(), Error> {
208 let mut seen: HashSet<(NodeId, Phase)> = HashSet::new();
209 for &entry in graph.entries.values() {
210 visit_phase(graph, entry, Phase::L4Raw, &mut seen)?;
211 }
212 for &synth in graph.meta.short_circuit_response_entry.values() {
220 visit_phase(graph, synth, Phase::L7Response, &mut seen)?;
221 }
222 Ok(())
223}
224
225fn visit_phase(
226 graph: &SymbolicFlowGraph,
227 id: NodeId,
228 phase: Phase,
229 seen: &mut HashSet<(NodeId, Phase)>,
230) -> Result<(), Error> {
231 if !seen.insert((id, phase)) {
232 return Ok(());
233 }
234 let node = &graph[id];
235 let kind = node_kind_for_phase(graph, node);
236 let t = transition(kind, phase).map_err(|e| {
237 Error::compile(format!(
238 "phase mismatch at NodeId({}): expected one of {:?}, got {:?}",
239 id.get(),
240 e.expected,
241 e.got,
242 ))
243 })?;
244 match (t, node) {
245 (Transition::Terminal, _) => Ok(()),
246 (Transition::PassThrough, _) => {
247 for succ in successors(node) {
248 visit_phase(graph, succ, phase, seen)?;
249 }
250 Ok(())
251 }
252 (Transition::Into(next_phase), _) => {
253 for succ in successors(node) {
254 visit_phase(graph, succ, next_phase, seen)?;
255 }
256 Ok(())
257 }
258 (
259 Transition::BiOutcome { response, tunnel },
260 Node::Fetch { next_response, next_tunnel, .. },
261 ) => {
262 if let Some(r) = next_response {
263 visit_phase(graph, *r, response, seen)?;
264 }
265 if let Some(t) = next_tunnel {
266 visit_phase(graph, *t, tunnel, seen)?;
267 }
268 Ok(())
269 }
270 (Transition::BiOutcome { .. }, _) => {
271 Err(Error::compile("BiOutcome transition on non-Fetch node".to_string()))
272 }
273 }
274}
275
276#[cfg(test)]
277mod tests {
278 use std::collections::HashMap;
279 use std::path::PathBuf;
280 use std::time::SystemTime;
281
282 use super::*;
283 use crate::fetch::{FetchKind, SymbolicFetchRef, Terminator};
284 use crate::ir::{BodySide, FetchId, FlowGraphMeta, PredicateId, TerminatorId};
285
286 fn empty_meta() -> FlowGraphMeta {
287 FlowGraphMeta {
288 version_hash: [0; 32],
289 compiled_at: SystemTime::UNIX_EPOCH,
290 source_files: vec![PathBuf::new()],
291 feature_set: &[],
292 short_circuit_response_entry: std::collections::BTreeMap::new(),
293 listener_tls: std::collections::BTreeMap::new(),
294 }
295 }
296
297 #[test]
298 fn dangling_terminator_id_in_terminate_node_rejected() {
299 let graph = SymbolicFlowGraph {
300 nodes: vec![Node::Terminate(TerminatorId::new(0))],
301 predicates: vec![],
302 middlewares: vec![],
303 fetches: vec![],
304 terminators: vec![],
305 entries: HashMap::new(),
306 meta: empty_meta(),
307 };
308 let err = validate(&graph).expect_err("must error");
309 assert!(err.to_string().contains("dangling TerminatorId"));
310 }
311
312 #[test]
313 fn dangling_node_id_in_fetch_edge_rejected() {
314 let graph = SymbolicFlowGraph {
315 nodes: vec![Node::Fetch {
316 id: FetchId::new(0),
317 next_response: Some(NodeId::new(99)),
318 next_tunnel: None,
319 collect_body_before: None,
320 body_limit: 0,
321 }],
322 predicates: vec![],
323 middlewares: vec![],
324 fetches: vec![SymbolicFetchRef { kind: FetchKind::HttpProxy, args: serde_json::Value::Null }],
325 terminators: vec![],
326 entries: HashMap::new(),
327 meta: empty_meta(),
328 };
329 let err = validate(&graph).expect_err("must error");
330 assert!(err.to_string().contains("next_response dangling"));
331 }
332
333 #[test]
334 fn http_fetch_without_next_response_rejected() {
335 let term = Node::Terminate(TerminatorId::new(0));
336 let graph = SymbolicFlowGraph {
337 nodes: vec![
338 term,
339 Node::Fetch {
340 id: FetchId::new(0),
341 next_response: None,
342 next_tunnel: None,
343 collect_body_before: None,
344 body_limit: 0,
345 },
346 ],
347 predicates: vec![],
348 middlewares: vec![],
349 fetches: vec![SymbolicFetchRef { kind: FetchKind::HttpProxy, args: serde_json::Value::Null }],
350 terminators: vec![Terminator::WriteHttpResponse],
351 entries: HashMap::new(),
352 meta: empty_meta(),
353 };
354 let err = validate(&graph).expect_err("must error");
355 assert!(err.to_string().contains("requires next_response"));
356 }
357
358 #[test]
359 fn l4_forward_with_next_response_rejected() {
360 let graph = SymbolicFlowGraph {
361 nodes: vec![
362 Node::Terminate(TerminatorId::new(0)),
363 Node::Fetch {
364 id: FetchId::new(0),
365 next_response: Some(NodeId::new(0)),
366 next_tunnel: Some(NodeId::new(0)),
367 collect_body_before: None,
368 body_limit: 0,
369 },
370 ],
371 predicates: vec![],
372 middlewares: vec![],
373 fetches: vec![SymbolicFetchRef { kind: FetchKind::L4Forward, args: serde_json::Value::Null }],
374 terminators: vec![Terminator::ByteTunnel],
375 entries: HashMap::new(),
376 meta: empty_meta(),
377 };
378 let err = validate(&graph).expect_err("must error");
379 assert!(err.to_string().contains("L4Forward must not have next_response"));
380 }
381
382 #[test]
383 fn cyclic_graph_is_rejected() {
384 let graph = SymbolicFlowGraph {
386 nodes: vec![
387 Node::Check {
388 predicate: PredicateId::new(0),
389 on_match: NodeId::new(1),
390 on_miss: NodeId::new(1),
391 collect_body_before: None,
392 body_limit: 0,
393 },
394 Node::Check {
395 predicate: PredicateId::new(0),
396 on_match: NodeId::new(0),
397 on_miss: NodeId::new(0),
398 collect_body_before: None,
399 body_limit: 0,
400 },
401 ],
402 predicates: vec![dummy_predicate()],
403 middlewares: vec![],
404 fetches: vec![],
405 terminators: vec![],
406 entries: HashMap::new(),
407 meta: empty_meta(),
408 };
409 let err = validate(&graph).expect_err("must error");
410 assert!(err.to_string().contains("cycle"));
411 }
412
413 #[test]
414 fn phase_check_rejects_write_http_response_reached_in_wrong_phase() {
415 let tid = TerminatorId::new(0);
420 let graph = SymbolicFlowGraph {
421 nodes: vec![Node::Terminate(tid), Node::Upgrade { next: NodeId::new(0) }],
422 predicates: vec![],
423 middlewares: vec![],
424 fetches: vec![],
425 terminators: vec![Terminator::WriteHttpResponse],
426 entries: {
427 let mut m = HashMap::new();
428 m.insert("127.0.0.1:443".parse().expect("parse"), NodeId::new(1));
429 m
430 },
431 meta: empty_meta(),
432 };
433 let err = check_phases(&graph).expect_err("must error");
434 assert!(err.to_string().contains("phase mismatch"));
435 }
436
437 #[test]
438 fn phase_check_rejects_short_circuit_synth_with_wrong_terminator() {
439 let bad_tid = TerminatorId::new(0);
446 let mut meta = empty_meta();
447 meta.short_circuit_response_entry.insert(NodeId::new(1), NodeId::new(0));
448 let graph = SymbolicFlowGraph {
449 nodes: vec![Node::Terminate(bad_tid), Node::Upgrade { next: NodeId::new(0) }],
450 predicates: vec![],
451 middlewares: vec![],
452 fetches: vec![],
453 terminators: vec![Terminator::ByteTunnel],
454 entries: HashMap::new(),
456 meta,
457 };
458 let err = check_phases(&graph).expect_err("must error on bad synth phase");
459 assert!(err.to_string().contains("phase mismatch"), "{err}");
460 }
461
462 fn dummy_predicate() -> crate::predicate::PredicateInst {
463 use crate::predicate::{CompiledOperator, CompiledValue, FieldPath, PredicateInst};
464 PredicateInst {
465 path: FieldPath::TlsSni,
466 op: CompiledOperator::Equals(CompiledValue::Str(std::sync::Arc::from("x"))),
467 }
468 }
469
470 const _: BodySide = BodySide::Request;
473}