1use std::collections::HashSet;
2
3use crate::error::Error;
4use crate::ir::{Node, NodeId, SymbolicFlowGraph};
5use crate::phase::{Phase, PhaseNodeKind, Transition, transition};
6
7pub fn validate(graph: &SymbolicFlowGraph) -> Result<(), Error> {
14 check_id_ranges(graph)?;
15 check_fetch_edges(graph)?;
16 check_acyclic(graph)?;
17 check_phases(graph)?;
18 Ok(())
19}
20
21fn check_id_ranges(graph: &SymbolicFlowGraph) -> Result<(), Error> {
22 let n_nodes = u32::try_from(graph.nodes.len()).unwrap_or(u32::MAX);
23 let n_preds = u32::try_from(graph.predicates.len()).unwrap_or(u32::MAX);
24 let n_mws = u32::try_from(graph.middlewares.len()).unwrap_or(u32::MAX);
25 let n_fetches = u32::try_from(graph.fetches.len()).unwrap_or(u32::MAX);
26 let n_terms = u32::try_from(graph.terminators.len()).unwrap_or(u32::MAX);
27
28 for (idx, node) in graph.nodes.iter().enumerate() {
29 match node {
30 Node::Check { predicate, on_match, on_miss, .. } => {
31 if predicate.get() >= n_preds {
32 return Err(Error::compile(format!(
33 "node {idx}: dangling PredicateId({})",
34 predicate.get()
35 )));
36 }
37 if on_match.get() >= n_nodes {
38 return Err(Error::compile(format!("node {idx}.on_match dangling")));
39 }
40 if on_miss.get() >= n_nodes {
41 return Err(Error::compile(format!("node {idx}.on_miss dangling")));
42 }
43 }
44 Node::Middleware { id, next, on_error, .. } => {
45 if id.get() >= n_mws {
46 return Err(Error::compile(format!("node {idx}: dangling MiddlewareId({})", id.get())));
47 }
48 if next.get() >= n_nodes {
49 return Err(Error::compile(format!("node {idx}.next dangling")));
50 }
51 if let Some(e) = on_error
52 && e.get() >= n_nodes
53 {
54 return Err(Error::compile(format!("node {idx}.on_error dangling")));
55 }
56 }
57 Node::Fetch { id, next_response, next_tunnel, .. } => {
58 if id.get() >= n_fetches {
59 return Err(Error::compile(format!("node {idx}: dangling FetchId({})", id.get())));
60 }
61 if let Some(r) = next_response
62 && r.get() >= n_nodes
63 {
64 return Err(Error::compile(format!("node {idx}.next_response dangling")));
65 }
66 if let Some(t) = next_tunnel
67 && t.get() >= n_nodes
68 {
69 return Err(Error::compile(format!("node {idx}.next_tunnel dangling")));
70 }
71 }
72 Node::Upgrade { next } => {
73 if next.get() >= n_nodes {
74 return Err(Error::compile(format!("node {idx}.next dangling")));
75 }
76 }
77 Node::Terminate(t) => {
78 if t.get() >= n_terms {
79 return Err(Error::compile(format!("node {idx}: dangling TerminatorId({})", t.get())));
80 }
81 }
82 }
83 }
84 Ok(())
85}
86
87fn check_fetch_edges(graph: &SymbolicFlowGraph) -> Result<(), Error> {
88 use crate::fetch::FetchKind::{
89 AcmeChallenge, HttpProxy, HttpSynthesize, L4Forward, WebSocketUpgrade,
90 };
91 for (idx, node) in graph.nodes.iter().enumerate() {
92 let Node::Fetch { id, next_response, next_tunnel, .. } = node else {
93 continue;
94 };
95 let kind = graph[*id].kind;
96 match kind {
97 HttpProxy | HttpSynthesize | AcmeChallenge => {
98 if next_response.is_none() {
99 return Err(Error::compile(format!("node {idx}: {kind:?} requires next_response")));
100 }
101 if next_tunnel.is_some() {
102 return Err(Error::compile(format!("node {idx}: {kind:?} must not have next_tunnel")));
103 }
104 }
105 L4Forward => {
106 if next_tunnel.is_none() {
107 return Err(Error::compile(format!("node {idx}: L4Forward requires next_tunnel")));
108 }
109 if next_response.is_some() {
110 return Err(Error::compile(format!("node {idx}: L4Forward must not have next_response")));
111 }
112 }
113 WebSocketUpgrade => {
114 if next_response.is_none() || next_tunnel.is_none() {
115 return Err(Error::compile(format!(
116 "node {idx}: WebSocketUpgrade requires both next_response and next_tunnel"
117 )));
118 }
119 }
120 }
121 }
122 Ok(())
123}
124
125fn check_acyclic(graph: &SymbolicFlowGraph) -> Result<(), Error> {
126 #[derive(Copy, Clone)]
127 enum Color {
128 White,
129 Gray,
130 Black,
131 }
132 let mut color: Vec<Color> = (0..graph.nodes.len()).map(|_| Color::White).collect();
133
134 for start in 0..graph.nodes.len() {
135 if !matches!(color[start], Color::White) {
136 continue;
137 }
138 let mut stack: Vec<(usize, usize)> = vec![(start, 0)];
139 color[start] = Color::Gray;
140 while let Some(&(node_idx, child_idx)) = stack.last() {
141 let succs = successors(&graph.nodes[node_idx]);
142 if child_idx < succs.len() {
143 let next = succs[child_idx].get() as usize;
144 stack.last_mut().expect("non-empty").1 += 1;
145 match color[next] {
146 Color::White => {
147 color[next] = Color::Gray;
148 stack.push((next, 0));
149 }
150 Color::Gray => {
151 return Err(Error::compile(format!("cycle in graph at node {next}")));
152 }
153 Color::Black => {}
154 }
155 } else {
156 color[node_idx] = Color::Black;
157 stack.pop();
158 }
159 }
160 }
161 Ok(())
162}
163
164fn successors(node: &Node) -> Vec<NodeId> {
165 match node {
166 Node::Check { on_match, on_miss, .. } => vec![*on_match, *on_miss],
167 Node::Middleware { next, on_error, .. } => {
168 let mut v = vec![*next];
169 if let Some(e) = on_error {
170 v.push(*e);
171 }
172 v
173 }
174 Node::Fetch { next_response, next_tunnel, .. } => {
175 let mut v = Vec::new();
176 if let Some(r) = next_response {
177 v.push(*r);
178 }
179 if let Some(t) = next_tunnel {
180 v.push(*t);
181 }
182 v
183 }
184 Node::Upgrade { next } => vec![*next],
185 Node::Terminate(_) => Vec::new(),
186 }
187}
188
189fn node_kind_for_phase(graph: &SymbolicFlowGraph, node: &Node) -> PhaseNodeKind {
190 match node {
191 Node::Check { .. } => PhaseNodeKind::Check,
192 Node::Middleware { id, .. } => PhaseNodeKind::Middleware(graph[*id].kind),
193 Node::Fetch { id, .. } => PhaseNodeKind::Fetch(graph[*id].kind),
194 Node::Upgrade { .. } => PhaseNodeKind::Upgrade,
195 Node::Terminate(t) => PhaseNodeKind::Terminate(graph[*t]),
196 }
197}
198
199pub fn check_phases(graph: &SymbolicFlowGraph) -> Result<(), Error> {
211 let mut seen: HashSet<(NodeId, Phase)> = HashSet::new();
212 for &entry in graph.entries.values() {
213 visit_phase(graph, entry, Phase::L4Raw, &mut seen)?;
214 }
215 for &synth in graph.meta.short_circuit_response_entry.values() {
223 visit_phase(graph, synth, Phase::L7Response, &mut seen)?;
224 }
225 Ok(())
226}
227
228fn visit_phase(
229 graph: &SymbolicFlowGraph,
230 id: NodeId,
231 phase: Phase,
232 seen: &mut HashSet<(NodeId, Phase)>,
233) -> Result<(), Error> {
234 if !seen.insert((id, phase)) {
235 return Ok(());
236 }
237 let node = &graph[id];
238 let kind = node_kind_for_phase(graph, node);
239 let t = transition(kind, phase).map_err(|e| {
240 Error::compile(format!(
241 "phase mismatch at NodeId({}): expected one of {:?}, got {:?}",
242 id.get(),
243 e.expected,
244 e.got,
245 ))
246 })?;
247 match (t, node) {
248 (Transition::Terminal, _) => Ok(()),
249 (Transition::PassThrough, _) => {
250 for succ in successors(node) {
251 visit_phase(graph, succ, phase, seen)?;
252 }
253 Ok(())
254 }
255 (Transition::Into(next_phase), _) => {
256 for succ in successors(node) {
257 visit_phase(graph, succ, next_phase, seen)?;
258 }
259 Ok(())
260 }
261 (
262 Transition::BiOutcome { response, tunnel },
263 Node::Fetch { next_response, next_tunnel, .. },
264 ) => {
265 if let Some(r) = next_response {
266 visit_phase(graph, *r, response, seen)?;
267 }
268 if let Some(t) = next_tunnel {
269 visit_phase(graph, *t, tunnel, seen)?;
270 }
271 Ok(())
272 }
273 (Transition::BiOutcome { .. }, _) => {
274 Err(Error::compile("BiOutcome transition on non-Fetch node".to_string()))
275 }
276 }
277}
278
279#[cfg(test)]
280mod tests {
281 use std::collections::HashMap;
282 use std::path::PathBuf;
283 use std::time::SystemTime;
284
285 use super::*;
286 use crate::fetch::{FetchKind, SymbolicFetchRef, Terminator};
287 use crate::ir::{BodySide, FetchId, FlowGraphMeta, PredicateId, TerminatorId};
288
289 fn empty_meta() -> FlowGraphMeta {
290 FlowGraphMeta {
291 version_hash: [0; 32],
292 compiled_at: SystemTime::UNIX_EPOCH,
293 source_files: vec![PathBuf::new()],
294 feature_set: &[],
295 short_circuit_response_entry: std::collections::BTreeMap::new(),
296 listener_tls: std::collections::BTreeMap::new(),
297 listener_kinds: std::collections::BTreeMap::new(),
298 listener_transports: std::collections::BTreeMap::new(),
299 annotations: Vec::new(),
300 }
301 }
302
303 #[test]
304 fn dangling_terminator_id_in_terminate_node_rejected() {
305 let graph = SymbolicFlowGraph {
306 nodes: vec![Node::Terminate(TerminatorId::new(0))],
307 predicates: vec![],
308 middlewares: vec![],
309 fetches: vec![],
310 terminators: vec![],
311 entries: HashMap::new(),
312 meta: empty_meta(),
313 };
314 let err = validate(&graph).expect_err("must error");
315 assert!(err.to_string().contains("dangling TerminatorId"));
316 }
317
318 #[test]
319 fn dangling_node_id_in_fetch_edge_rejected() {
320 let graph = SymbolicFlowGraph {
321 nodes: vec![Node::Fetch {
322 id: FetchId::new(0),
323 next_response: Some(NodeId::new(99)),
324 next_tunnel: None,
325 collect_body_before: None,
326 body_limit: 0,
327 }],
328 predicates: vec![],
329 middlewares: vec![],
330 fetches: vec![SymbolicFetchRef {
331 kind: FetchKind::HttpProxy,
332 args: serde_json::Value::Null,
333 retry_buffer_required: false,
334 allow_zero_rtt: None,
335 }],
336 terminators: vec![],
337 entries: HashMap::new(),
338 meta: empty_meta(),
339 };
340 let err = validate(&graph).expect_err("must error");
341 assert!(err.to_string().contains("next_response dangling"));
342 }
343
344 #[test]
345 fn http_fetch_without_next_response_rejected() {
346 let term = Node::Terminate(TerminatorId::new(0));
347 let graph = SymbolicFlowGraph {
348 nodes: vec![
349 term,
350 Node::Fetch {
351 id: FetchId::new(0),
352 next_response: None,
353 next_tunnel: None,
354 collect_body_before: None,
355 body_limit: 0,
356 },
357 ],
358 predicates: vec![],
359 middlewares: vec![],
360 fetches: vec![SymbolicFetchRef {
361 kind: FetchKind::HttpProxy,
362 args: serde_json::Value::Null,
363 retry_buffer_required: false,
364 allow_zero_rtt: None,
365 }],
366 terminators: vec![Terminator::WriteHttpResponse],
367 entries: HashMap::new(),
368 meta: empty_meta(),
369 };
370 let err = validate(&graph).expect_err("must error");
371 assert!(err.to_string().contains("requires next_response"));
372 }
373
374 #[test]
375 fn l4_forward_with_next_response_rejected() {
376 let graph = SymbolicFlowGraph {
377 nodes: vec![
378 Node::Terminate(TerminatorId::new(0)),
379 Node::Fetch {
380 id: FetchId::new(0),
381 next_response: Some(NodeId::new(0)),
382 next_tunnel: Some(NodeId::new(0)),
383 collect_body_before: None,
384 body_limit: 0,
385 },
386 ],
387 predicates: vec![],
388 middlewares: vec![],
389 fetches: vec![SymbolicFetchRef {
390 kind: FetchKind::L4Forward,
391 args: serde_json::Value::Null,
392 retry_buffer_required: false,
393 allow_zero_rtt: None,
394 }],
395 terminators: vec![Terminator::ByteTunnel],
396 entries: HashMap::new(),
397 meta: empty_meta(),
398 };
399 let err = validate(&graph).expect_err("must error");
400 assert!(err.to_string().contains("L4Forward must not have next_response"));
401 }
402
403 #[test]
404 fn cyclic_graph_is_rejected() {
405 let graph = SymbolicFlowGraph {
407 nodes: vec![
408 Node::Check {
409 predicate: PredicateId::new(0),
410 on_match: NodeId::new(1),
411 on_miss: NodeId::new(1),
412 collect_body_before: None,
413 body_limit: 0,
414 },
415 Node::Check {
416 predicate: PredicateId::new(0),
417 on_match: NodeId::new(0),
418 on_miss: NodeId::new(0),
419 collect_body_before: None,
420 body_limit: 0,
421 },
422 ],
423 predicates: vec![dummy_predicate()],
424 middlewares: vec![],
425 fetches: vec![],
426 terminators: vec![],
427 entries: HashMap::new(),
428 meta: empty_meta(),
429 };
430 let err = validate(&graph).expect_err("must error");
431 assert!(err.to_string().contains("cycle"));
432 }
433
434 #[test]
435 fn phase_check_rejects_write_http_response_reached_in_wrong_phase() {
436 let tid = TerminatorId::new(0);
440 let graph = SymbolicFlowGraph {
441 nodes: vec![Node::Terminate(tid), Node::Upgrade { next: NodeId::new(0) }],
442 predicates: vec![],
443 middlewares: vec![],
444 fetches: vec![],
445 terminators: vec![Terminator::WriteHttpResponse],
446 entries: {
447 let mut m = HashMap::new();
448 m.insert("127.0.0.1:443".parse().expect("parse"), NodeId::new(1));
449 m
450 },
451 meta: empty_meta(),
452 };
453 let err = check_phases(&graph).expect_err("must error");
454 assert!(err.to_string().contains("phase mismatch"));
455 }
456
457 #[test]
458 fn phase_check_rejects_short_circuit_synth_with_wrong_terminator() {
459 let bad_tid = TerminatorId::new(0);
466 let mut meta = empty_meta();
467 meta.short_circuit_response_entry.insert(NodeId::new(1), NodeId::new(0));
468 let graph = SymbolicFlowGraph {
469 nodes: vec![Node::Terminate(bad_tid), Node::Upgrade { next: NodeId::new(0) }],
470 predicates: vec![],
471 middlewares: vec![],
472 fetches: vec![],
473 terminators: vec![Terminator::ByteTunnel],
474 entries: HashMap::new(),
476 meta,
477 };
478 let err = check_phases(&graph).expect_err("must error on bad synth phase");
479 assert!(err.to_string().contains("phase mismatch"), "{err}");
480 }
481
482 fn dummy_predicate() -> crate::predicate::PredicateInst {
483 use crate::predicate::{CompiledOperator, CompiledValue, FieldPath, PredicateInst};
484 PredicateInst {
485 path: FieldPath::TlsSni,
486 op: CompiledOperator::Equals(CompiledValue::Str(std::sync::Arc::from("x"))),
487 }
488 }
489
490 const _: BodySide = BodySide::Request;
493}