1use std::collections::HashSet;
2
3use crate::error::Error;
4use crate::ir::{Node, NodeId, SymbolicFlowGraph};
5use crate::phase::{Phase, PhaseNodeKind, Transition, transition};
6
7pub fn validate(graph: &SymbolicFlowGraph) -> Result<(), Error> {
14 check_id_ranges(graph)?;
15 check_fetch_edges(graph)?;
16 check_acyclic(graph)?;
17 check_phases(graph)?;
18 Ok(())
19}
20
21fn check_id_ranges(graph: &SymbolicFlowGraph) -> Result<(), Error> {
22 let n_nodes = u32::try_from(graph.nodes.len()).unwrap_or(u32::MAX);
23 let n_preds = u32::try_from(graph.predicates.len()).unwrap_or(u32::MAX);
24 let n_mws = u32::try_from(graph.middlewares.len()).unwrap_or(u32::MAX);
25 let n_fetches = u32::try_from(graph.fetches.len()).unwrap_or(u32::MAX);
26 let n_terms = u32::try_from(graph.terminators.len()).unwrap_or(u32::MAX);
27
28 for (idx, node) in graph.nodes.iter().enumerate() {
29 match node {
30 Node::Check { predicate, on_match, on_miss, .. } => {
31 if predicate.get() >= n_preds {
32 return Err(Error::compile(format!(
33 "node {idx}: dangling PredicateId({})",
34 predicate.get()
35 )));
36 }
37 if on_match.get() >= n_nodes {
38 return Err(Error::compile(format!("node {idx}.on_match dangling")));
39 }
40 if on_miss.get() >= n_nodes {
41 return Err(Error::compile(format!("node {idx}.on_miss dangling")));
42 }
43 }
44 Node::Middleware { id, next, on_error, .. } => {
45 if id.get() >= n_mws {
46 return Err(Error::compile(format!("node {idx}: dangling MiddlewareId({})", id.get())));
47 }
48 if next.get() >= n_nodes {
49 return Err(Error::compile(format!("node {idx}.next dangling")));
50 }
51 if let Some(e) = on_error
52 && e.get() >= n_nodes
53 {
54 return Err(Error::compile(format!("node {idx}.on_error dangling")));
55 }
56 }
57 Node::Fetch { id, next_response, next_tunnel, .. } => {
58 if id.get() >= n_fetches {
59 return Err(Error::compile(format!("node {idx}: dangling FetchId({})", id.get())));
60 }
61 if let Some(r) = next_response
62 && r.get() >= n_nodes
63 {
64 return Err(Error::compile(format!("node {idx}.next_response dangling")));
65 }
66 if let Some(t) = next_tunnel
67 && t.get() >= n_nodes
68 {
69 return Err(Error::compile(format!("node {idx}.next_tunnel dangling")));
70 }
71 }
72 Node::Upgrade { next } => {
73 if next.get() >= n_nodes {
74 return Err(Error::compile(format!("node {idx}.next dangling")));
75 }
76 }
77 Node::Terminate(t) => {
78 if t.get() >= n_terms {
79 return Err(Error::compile(format!("node {idx}: dangling TerminatorId({})", t.get())));
80 }
81 }
82 }
83 }
84 Ok(())
85}
86
87fn check_fetch_edges(graph: &SymbolicFlowGraph) -> Result<(), Error> {
88 use crate::fetch::FetchKind::{HttpProxy, HttpSynthesize, L4Forward, WebSocketUpgrade};
89 for (idx, node) in graph.nodes.iter().enumerate() {
90 let Node::Fetch { id, next_response, next_tunnel, .. } = node else {
91 continue;
92 };
93 let kind = graph[*id].kind;
94 match kind {
95 HttpProxy | HttpSynthesize => {
96 if next_response.is_none() {
97 return Err(Error::compile(format!("node {idx}: {kind:?} requires next_response")));
98 }
99 if next_tunnel.is_some() {
100 return Err(Error::compile(format!("node {idx}: {kind:?} must not have next_tunnel")));
101 }
102 }
103 L4Forward => {
104 if next_tunnel.is_none() {
105 return Err(Error::compile(format!("node {idx}: L4Forward requires next_tunnel")));
106 }
107 if next_response.is_some() {
108 return Err(Error::compile(format!("node {idx}: L4Forward must not have next_response")));
109 }
110 }
111 WebSocketUpgrade => {
112 if next_response.is_none() || next_tunnel.is_none() {
113 return Err(Error::compile(format!(
114 "node {idx}: WebSocketUpgrade requires both next_response and next_tunnel"
115 )));
116 }
117 }
118 }
119 }
120 Ok(())
121}
122
123fn check_acyclic(graph: &SymbolicFlowGraph) -> Result<(), Error> {
124 #[derive(Copy, Clone)]
125 enum Color {
126 White,
127 Gray,
128 Black,
129 }
130 let mut color: Vec<Color> = (0..graph.nodes.len()).map(|_| Color::White).collect();
131
132 for start in 0..graph.nodes.len() {
133 if !matches!(color[start], Color::White) {
134 continue;
135 }
136 let mut stack: Vec<(usize, usize)> = vec![(start, 0)];
137 color[start] = Color::Gray;
138 while let Some(&(node_idx, child_idx)) = stack.last() {
139 let succs = successors(&graph.nodes[node_idx]);
140 if child_idx < succs.len() {
141 let next = succs[child_idx].get() as usize;
142 stack.last_mut().expect("non-empty").1 += 1;
143 match color[next] {
144 Color::White => {
145 color[next] = Color::Gray;
146 stack.push((next, 0));
147 }
148 Color::Gray => {
149 return Err(Error::compile(format!("cycle in graph at node {next}")));
150 }
151 Color::Black => {}
152 }
153 } else {
154 color[node_idx] = Color::Black;
155 stack.pop();
156 }
157 }
158 }
159 Ok(())
160}
161
162fn successors(node: &Node) -> Vec<NodeId> {
163 match node {
164 Node::Check { on_match, on_miss, .. } => vec![*on_match, *on_miss],
165 Node::Middleware { next, on_error, .. } => {
166 let mut v = vec![*next];
167 if let Some(e) = on_error {
168 v.push(*e);
169 }
170 v
171 }
172 Node::Fetch { next_response, next_tunnel, .. } => {
173 let mut v = Vec::new();
174 if let Some(r) = next_response {
175 v.push(*r);
176 }
177 if let Some(t) = next_tunnel {
178 v.push(*t);
179 }
180 v
181 }
182 Node::Upgrade { next } => vec![*next],
183 Node::Terminate(_) => Vec::new(),
184 }
185}
186
187fn node_kind_for_phase(graph: &SymbolicFlowGraph, node: &Node) -> PhaseNodeKind {
188 match node {
189 Node::Check { .. } => PhaseNodeKind::Check,
190 Node::Middleware { id, .. } => PhaseNodeKind::Middleware(graph[*id].kind),
191 Node::Fetch { id, .. } => PhaseNodeKind::Fetch(graph[*id].kind),
192 Node::Upgrade { .. } => PhaseNodeKind::Upgrade,
193 Node::Terminate(t) => PhaseNodeKind::Terminate(graph[*t]),
194 }
195}
196
197pub fn check_phases(graph: &SymbolicFlowGraph) -> Result<(), Error> {
208 let mut seen: HashSet<(NodeId, Phase)> = HashSet::new();
209 for &entry in graph.entries.values() {
210 visit_phase(graph, entry, Phase::L4Raw, &mut seen)?;
211 }
212 for &synth in graph.meta.short_circuit_response_entry.values() {
220 visit_phase(graph, synth, Phase::L7Response, &mut seen)?;
221 }
222 Ok(())
223}
224
225fn visit_phase(
226 graph: &SymbolicFlowGraph,
227 id: NodeId,
228 phase: Phase,
229 seen: &mut HashSet<(NodeId, Phase)>,
230) -> Result<(), Error> {
231 if !seen.insert((id, phase)) {
232 return Ok(());
233 }
234 let node = &graph[id];
235 let kind = node_kind_for_phase(graph, node);
236 let t = transition(kind, phase).map_err(|e| {
237 Error::compile(format!(
238 "phase mismatch at NodeId({}): expected one of {:?}, got {:?}",
239 id.get(),
240 e.expected,
241 e.got,
242 ))
243 })?;
244 match (t, node) {
245 (Transition::Terminal, _) => Ok(()),
246 (Transition::PassThrough, _) => {
247 for succ in successors(node) {
248 visit_phase(graph, succ, phase, seen)?;
249 }
250 Ok(())
251 }
252 (Transition::Into(next_phase), _) => {
253 for succ in successors(node) {
254 visit_phase(graph, succ, next_phase, seen)?;
255 }
256 Ok(())
257 }
258 (
259 Transition::BiOutcome { response, tunnel },
260 Node::Fetch { next_response, next_tunnel, .. },
261 ) => {
262 if let Some(r) = next_response {
263 visit_phase(graph, *r, response, seen)?;
264 }
265 if let Some(t) = next_tunnel {
266 visit_phase(graph, *t, tunnel, seen)?;
267 }
268 Ok(())
269 }
270 (Transition::BiOutcome { .. }, _) => {
271 Err(Error::compile("BiOutcome transition on non-Fetch node".to_string()))
272 }
273 }
274}
275
276#[cfg(test)]
277mod tests {
278 use std::collections::HashMap;
279 use std::path::PathBuf;
280 use std::time::SystemTime;
281
282 use super::*;
283 use crate::fetch::{FetchKind, SymbolicFetchRef, Terminator};
284 use crate::ir::{BodySide, FetchId, FlowGraphMeta, PredicateId, TerminatorId};
285
286 fn empty_meta() -> FlowGraphMeta {
287 FlowGraphMeta {
288 version_hash: [0; 32],
289 compiled_at: SystemTime::UNIX_EPOCH,
290 source_files: vec![PathBuf::new()],
291 feature_set: &[],
292 short_circuit_response_entry: std::collections::BTreeMap::new(),
293 listener_tls: std::collections::BTreeMap::new(),
294 listener_kinds: std::collections::BTreeMap::new(),
295
296 listener_transports: std::collections::BTreeMap::new(),
297 }
298 }
299
300 #[test]
301 fn dangling_terminator_id_in_terminate_node_rejected() {
302 let graph = SymbolicFlowGraph {
303 nodes: vec![Node::Terminate(TerminatorId::new(0))],
304 predicates: vec![],
305 middlewares: vec![],
306 fetches: vec![],
307 terminators: vec![],
308 entries: HashMap::new(),
309 meta: empty_meta(),
310 };
311 let err = validate(&graph).expect_err("must error");
312 assert!(err.to_string().contains("dangling TerminatorId"));
313 }
314
315 #[test]
316 fn dangling_node_id_in_fetch_edge_rejected() {
317 let graph = SymbolicFlowGraph {
318 nodes: vec![Node::Fetch {
319 id: FetchId::new(0),
320 next_response: Some(NodeId::new(99)),
321 next_tunnel: None,
322 collect_body_before: None,
323 body_limit: 0,
324 }],
325 predicates: vec![],
326 middlewares: vec![],
327 fetches: vec![SymbolicFetchRef {
328 kind: FetchKind::HttpProxy,
329 args: serde_json::Value::Null,
330 retry_buffer_required: false,
331 }],
332 terminators: vec![],
333 entries: HashMap::new(),
334 meta: empty_meta(),
335 };
336 let err = validate(&graph).expect_err("must error");
337 assert!(err.to_string().contains("next_response dangling"));
338 }
339
340 #[test]
341 fn http_fetch_without_next_response_rejected() {
342 let term = Node::Terminate(TerminatorId::new(0));
343 let graph = SymbolicFlowGraph {
344 nodes: vec![
345 term,
346 Node::Fetch {
347 id: FetchId::new(0),
348 next_response: None,
349 next_tunnel: None,
350 collect_body_before: None,
351 body_limit: 0,
352 },
353 ],
354 predicates: vec![],
355 middlewares: vec![],
356 fetches: vec![SymbolicFetchRef {
357 kind: FetchKind::HttpProxy,
358 args: serde_json::Value::Null,
359 retry_buffer_required: false,
360 }],
361 terminators: vec![Terminator::WriteHttpResponse],
362 entries: HashMap::new(),
363 meta: empty_meta(),
364 };
365 let err = validate(&graph).expect_err("must error");
366 assert!(err.to_string().contains("requires next_response"));
367 }
368
369 #[test]
370 fn l4_forward_with_next_response_rejected() {
371 let graph = SymbolicFlowGraph {
372 nodes: vec![
373 Node::Terminate(TerminatorId::new(0)),
374 Node::Fetch {
375 id: FetchId::new(0),
376 next_response: Some(NodeId::new(0)),
377 next_tunnel: Some(NodeId::new(0)),
378 collect_body_before: None,
379 body_limit: 0,
380 },
381 ],
382 predicates: vec![],
383 middlewares: vec![],
384 fetches: vec![SymbolicFetchRef {
385 kind: FetchKind::L4Forward,
386 args: serde_json::Value::Null,
387 retry_buffer_required: false,
388 }],
389 terminators: vec![Terminator::ByteTunnel],
390 entries: HashMap::new(),
391 meta: empty_meta(),
392 };
393 let err = validate(&graph).expect_err("must error");
394 assert!(err.to_string().contains("L4Forward must not have next_response"));
395 }
396
397 #[test]
398 fn cyclic_graph_is_rejected() {
399 let graph = SymbolicFlowGraph {
401 nodes: vec![
402 Node::Check {
403 predicate: PredicateId::new(0),
404 on_match: NodeId::new(1),
405 on_miss: NodeId::new(1),
406 collect_body_before: None,
407 body_limit: 0,
408 },
409 Node::Check {
410 predicate: PredicateId::new(0),
411 on_match: NodeId::new(0),
412 on_miss: NodeId::new(0),
413 collect_body_before: None,
414 body_limit: 0,
415 },
416 ],
417 predicates: vec![dummy_predicate()],
418 middlewares: vec![],
419 fetches: vec![],
420 terminators: vec![],
421 entries: HashMap::new(),
422 meta: empty_meta(),
423 };
424 let err = validate(&graph).expect_err("must error");
425 assert!(err.to_string().contains("cycle"));
426 }
427
428 #[test]
429 fn phase_check_rejects_write_http_response_reached_in_wrong_phase() {
430 let tid = TerminatorId::new(0);
435 let graph = SymbolicFlowGraph {
436 nodes: vec![Node::Terminate(tid), Node::Upgrade { next: NodeId::new(0) }],
437 predicates: vec![],
438 middlewares: vec![],
439 fetches: vec![],
440 terminators: vec![Terminator::WriteHttpResponse],
441 entries: {
442 let mut m = HashMap::new();
443 m.insert("127.0.0.1:443".parse().expect("parse"), NodeId::new(1));
444 m
445 },
446 meta: empty_meta(),
447 };
448 let err = check_phases(&graph).expect_err("must error");
449 assert!(err.to_string().contains("phase mismatch"));
450 }
451
452 #[test]
453 fn phase_check_rejects_short_circuit_synth_with_wrong_terminator() {
454 let bad_tid = TerminatorId::new(0);
461 let mut meta = empty_meta();
462 meta.short_circuit_response_entry.insert(NodeId::new(1), NodeId::new(0));
463 let graph = SymbolicFlowGraph {
464 nodes: vec![Node::Terminate(bad_tid), Node::Upgrade { next: NodeId::new(0) }],
465 predicates: vec![],
466 middlewares: vec![],
467 fetches: vec![],
468 terminators: vec![Terminator::ByteTunnel],
469 entries: HashMap::new(),
471 meta,
472 };
473 let err = check_phases(&graph).expect_err("must error on bad synth phase");
474 assert!(err.to_string().contains("phase mismatch"), "{err}");
475 }
476
477 fn dummy_predicate() -> crate::predicate::PredicateInst {
478 use crate::predicate::{CompiledOperator, CompiledValue, FieldPath, PredicateInst};
479 PredicateInst {
480 path: FieldPath::TlsSni,
481 op: CompiledOperator::Equals(CompiledValue::Str(std::sync::Arc::from("x"))),
482 }
483 }
484
485 const _: BodySide = BodySide::Request;
488}