1use std::collections::HashSet;
2
3use crate::error::Error;
4use crate::ir::{Node, NodeId, SymbolicFlowGraph};
5use crate::phase::{Phase, PhaseNodeKind, Transition, transition};
6
7pub fn validate(graph: &SymbolicFlowGraph) -> Result<(), Error> {
14 check_id_ranges(graph)?;
15 check_fetch_edges(graph)?;
16 check_acyclic(graph)?;
17 check_phases(graph)?;
18 Ok(())
19}
20
21fn check_id_ranges(graph: &SymbolicFlowGraph) -> Result<(), Error> {
22 let n_nodes = u32::try_from(graph.nodes.len()).unwrap_or(u32::MAX);
23 let n_preds = u32::try_from(graph.predicates.len()).unwrap_or(u32::MAX);
24 let n_mws = u32::try_from(graph.middlewares.len()).unwrap_or(u32::MAX);
25 let n_fetches = u32::try_from(graph.fetches.len()).unwrap_or(u32::MAX);
26 let n_terms = u32::try_from(graph.terminators.len()).unwrap_or(u32::MAX);
27
28 for (idx, node) in graph.nodes.iter().enumerate() {
29 match node {
30 Node::Check { predicate, on_match, on_miss, .. } => {
31 if predicate.get() >= n_preds {
32 return Err(Error::compile(format!(
33 "node {idx}: dangling PredicateId({})",
34 predicate.get()
35 )));
36 }
37 if on_match.get() >= n_nodes {
38 return Err(Error::compile(format!("node {idx}.on_match dangling")));
39 }
40 if on_miss.get() >= n_nodes {
41 return Err(Error::compile(format!("node {idx}.on_miss dangling")));
42 }
43 }
44 Node::Middleware { id, next, on_error, .. } => {
45 if id.get() >= n_mws {
46 return Err(Error::compile(format!("node {idx}: dangling MiddlewareId({})", id.get())));
47 }
48 if next.get() >= n_nodes {
49 return Err(Error::compile(format!("node {idx}.next dangling")));
50 }
51 if let Some(e) = on_error
52 && e.get() >= n_nodes
53 {
54 return Err(Error::compile(format!("node {idx}.on_error dangling")));
55 }
56 }
57 Node::Fetch { id, next_response, next_tunnel, .. } => {
58 if id.get() >= n_fetches {
59 return Err(Error::compile(format!("node {idx}: dangling FetchId({})", id.get())));
60 }
61 if let Some(r) = next_response
62 && r.get() >= n_nodes
63 {
64 return Err(Error::compile(format!("node {idx}.next_response dangling")));
65 }
66 if let Some(t) = next_tunnel
67 && t.get() >= n_nodes
68 {
69 return Err(Error::compile(format!("node {idx}.next_tunnel dangling")));
70 }
71 }
72 Node::Upgrade { next } => {
73 if next.get() >= n_nodes {
74 return Err(Error::compile(format!("node {idx}.next dangling")));
75 }
76 }
77 Node::Terminate(t) => {
78 if t.get() >= n_terms {
79 return Err(Error::compile(format!("node {idx}: dangling TerminatorId({})", t.get())));
80 }
81 }
82 }
83 }
84 Ok(())
85}
86
87fn check_fetch_edges(graph: &SymbolicFlowGraph) -> Result<(), Error> {
88 use crate::fetch::FetchKind::{HttpProxy, HttpSynthesize, L4Forward, WebSocketUpgrade};
89 for (idx, node) in graph.nodes.iter().enumerate() {
90 let Node::Fetch { id, next_response, next_tunnel, .. } = node else {
91 continue;
92 };
93 let kind = graph[*id].kind;
94 match kind {
95 HttpProxy | HttpSynthesize => {
96 if next_response.is_none() {
97 return Err(Error::compile(format!("node {idx}: {kind:?} requires next_response")));
98 }
99 if next_tunnel.is_some() {
100 return Err(Error::compile(format!("node {idx}: {kind:?} must not have next_tunnel")));
101 }
102 }
103 L4Forward => {
104 if next_tunnel.is_none() {
105 return Err(Error::compile(format!("node {idx}: L4Forward requires next_tunnel")));
106 }
107 if next_response.is_some() {
108 return Err(Error::compile(format!("node {idx}: L4Forward must not have next_response")));
109 }
110 }
111 WebSocketUpgrade => {
112 if next_response.is_none() || next_tunnel.is_none() {
113 return Err(Error::compile(format!(
114 "node {idx}: WebSocketUpgrade requires both next_response and next_tunnel"
115 )));
116 }
117 }
118 }
119 }
120 Ok(())
121}
122
123fn check_acyclic(graph: &SymbolicFlowGraph) -> Result<(), Error> {
124 #[derive(Copy, Clone)]
125 enum Color {
126 White,
127 Gray,
128 Black,
129 }
130 let mut color: Vec<Color> = (0..graph.nodes.len()).map(|_| Color::White).collect();
131
132 for start in 0..graph.nodes.len() {
133 if !matches!(color[start], Color::White) {
134 continue;
135 }
136 let mut stack: Vec<(usize, usize)> = vec![(start, 0)];
137 color[start] = Color::Gray;
138 while let Some(&(node_idx, child_idx)) = stack.last() {
139 let succs = successors(&graph.nodes[node_idx]);
140 if child_idx < succs.len() {
141 let next = succs[child_idx].get() as usize;
142 stack.last_mut().expect("non-empty").1 += 1;
143 match color[next] {
144 Color::White => {
145 color[next] = Color::Gray;
146 stack.push((next, 0));
147 }
148 Color::Gray => {
149 return Err(Error::compile(format!("cycle in graph at node {next}")));
150 }
151 Color::Black => {}
152 }
153 } else {
154 color[node_idx] = Color::Black;
155 stack.pop();
156 }
157 }
158 }
159 Ok(())
160}
161
162fn successors(node: &Node) -> Vec<NodeId> {
163 match node {
164 Node::Check { on_match, on_miss, .. } => vec![*on_match, *on_miss],
165 Node::Middleware { next, on_error, .. } => {
166 let mut v = vec![*next];
167 if let Some(e) = on_error {
168 v.push(*e);
169 }
170 v
171 }
172 Node::Fetch { next_response, next_tunnel, .. } => {
173 let mut v = Vec::new();
174 if let Some(r) = next_response {
175 v.push(*r);
176 }
177 if let Some(t) = next_tunnel {
178 v.push(*t);
179 }
180 v
181 }
182 Node::Upgrade { next } => vec![*next],
183 Node::Terminate(_) => Vec::new(),
184 }
185}
186
187fn node_kind_for_phase(graph: &SymbolicFlowGraph, node: &Node) -> PhaseNodeKind {
188 match node {
189 Node::Check { .. } => PhaseNodeKind::Check,
190 Node::Middleware { id, .. } => PhaseNodeKind::Middleware(graph[*id].kind),
191 Node::Fetch { id, .. } => PhaseNodeKind::Fetch(graph[*id].kind),
192 Node::Upgrade { .. } => PhaseNodeKind::Upgrade,
193 Node::Terminate(t) => PhaseNodeKind::Terminate(graph[*t]),
194 }
195}
196
197pub fn check_phases(graph: &SymbolicFlowGraph) -> Result<(), Error> {
208 let mut seen: HashSet<(NodeId, Phase)> = HashSet::new();
209 for &entry in graph.entries.values() {
210 visit_phase(graph, entry, Phase::L4Raw, &mut seen)?;
211 }
212 Ok(())
213}
214
215fn visit_phase(
216 graph: &SymbolicFlowGraph,
217 id: NodeId,
218 phase: Phase,
219 seen: &mut HashSet<(NodeId, Phase)>,
220) -> Result<(), Error> {
221 if !seen.insert((id, phase)) {
222 return Ok(());
223 }
224 let node = &graph[id];
225 let kind = node_kind_for_phase(graph, node);
226 let t = transition(kind, phase).map_err(|e| {
227 Error::compile(format!(
228 "phase mismatch at NodeId({}): expected one of {:?}, got {:?}",
229 id.get(),
230 e.expected,
231 e.got,
232 ))
233 })?;
234 match (t, node) {
235 (Transition::Terminal, _) => Ok(()),
236 (Transition::PassThrough, _) => {
237 for succ in successors(node) {
238 visit_phase(graph, succ, phase, seen)?;
239 }
240 Ok(())
241 }
242 (Transition::Into(next_phase), _) => {
243 for succ in successors(node) {
244 visit_phase(graph, succ, next_phase, seen)?;
245 }
246 Ok(())
247 }
248 (
249 Transition::BiOutcome { response, tunnel },
250 Node::Fetch { next_response, next_tunnel, .. },
251 ) => {
252 if let Some(r) = next_response {
253 visit_phase(graph, *r, response, seen)?;
254 }
255 if let Some(t) = next_tunnel {
256 visit_phase(graph, *t, tunnel, seen)?;
257 }
258 Ok(())
259 }
260 (Transition::BiOutcome { .. }, _) => {
261 Err(Error::compile("BiOutcome transition on non-Fetch node".to_string()))
262 }
263 }
264}
265
266#[cfg(test)]
267mod tests {
268 use std::collections::HashMap;
269 use std::path::PathBuf;
270 use std::time::SystemTime;
271
272 use super::*;
273 use crate::fetch::{FetchKind, SymbolicFetchRef, Terminator};
274 use crate::ir::{BodySide, FetchId, FlowGraphMeta, PredicateId, TerminatorId};
275
276 fn empty_meta() -> FlowGraphMeta {
277 FlowGraphMeta {
278 version_hash: [0; 32],
279 compiled_at: SystemTime::UNIX_EPOCH,
280 source_files: vec![PathBuf::new()],
281 feature_set: &[],
282 }
283 }
284
285 #[test]
286 fn dangling_terminator_id_in_terminate_node_rejected() {
287 let graph = SymbolicFlowGraph {
288 nodes: vec![Node::Terminate(TerminatorId::new(0))],
289 predicates: vec![],
290 middlewares: vec![],
291 fetches: vec![],
292 terminators: vec![],
293 entries: HashMap::new(),
294 meta: empty_meta(),
295 };
296 let err = validate(&graph).expect_err("must error");
297 assert!(err.to_string().contains("dangling TerminatorId"));
298 }
299
300 #[test]
301 fn dangling_node_id_in_fetch_edge_rejected() {
302 let graph = SymbolicFlowGraph {
303 nodes: vec![Node::Fetch {
304 id: FetchId::new(0),
305 next_response: Some(NodeId::new(99)),
306 next_tunnel: None,
307 collect_body_before: None,
308 }],
309 predicates: vec![],
310 middlewares: vec![],
311 fetches: vec![SymbolicFetchRef { kind: FetchKind::HttpProxy, args: serde_json::Value::Null }],
312 terminators: vec![],
313 entries: HashMap::new(),
314 meta: empty_meta(),
315 };
316 let err = validate(&graph).expect_err("must error");
317 assert!(err.to_string().contains("next_response dangling"));
318 }
319
320 #[test]
321 fn http_fetch_without_next_response_rejected() {
322 let term = Node::Terminate(TerminatorId::new(0));
323 let graph = SymbolicFlowGraph {
324 nodes: vec![
325 term,
326 Node::Fetch {
327 id: FetchId::new(0),
328 next_response: None,
329 next_tunnel: None,
330 collect_body_before: None,
331 },
332 ],
333 predicates: vec![],
334 middlewares: vec![],
335 fetches: vec![SymbolicFetchRef { kind: FetchKind::HttpProxy, args: serde_json::Value::Null }],
336 terminators: vec![Terminator::WriteHttpResponse],
337 entries: HashMap::new(),
338 meta: empty_meta(),
339 };
340 let err = validate(&graph).expect_err("must error");
341 assert!(err.to_string().contains("requires next_response"));
342 }
343
344 #[test]
345 fn l4_forward_with_next_response_rejected() {
346 let graph = SymbolicFlowGraph {
347 nodes: vec![
348 Node::Terminate(TerminatorId::new(0)),
349 Node::Fetch {
350 id: FetchId::new(0),
351 next_response: Some(NodeId::new(0)),
352 next_tunnel: Some(NodeId::new(0)),
353 collect_body_before: None,
354 },
355 ],
356 predicates: vec![],
357 middlewares: vec![],
358 fetches: vec![SymbolicFetchRef { kind: FetchKind::L4Forward, args: serde_json::Value::Null }],
359 terminators: vec![Terminator::ByteTunnel],
360 entries: HashMap::new(),
361 meta: empty_meta(),
362 };
363 let err = validate(&graph).expect_err("must error");
364 assert!(err.to_string().contains("L4Forward must not have next_response"));
365 }
366
367 #[test]
368 fn cyclic_graph_is_rejected() {
369 let graph = SymbolicFlowGraph {
371 nodes: vec![
372 Node::Check {
373 predicate: PredicateId::new(0),
374 on_match: NodeId::new(1),
375 on_miss: NodeId::new(1),
376 collect_body_before: None,
377 },
378 Node::Check {
379 predicate: PredicateId::new(0),
380 on_match: NodeId::new(0),
381 on_miss: NodeId::new(0),
382 collect_body_before: None,
383 },
384 ],
385 predicates: vec![dummy_predicate()],
386 middlewares: vec![],
387 fetches: vec![],
388 terminators: vec![],
389 entries: HashMap::new(),
390 meta: empty_meta(),
391 };
392 let err = validate(&graph).expect_err("must error");
393 assert!(err.to_string().contains("cycle"));
394 }
395
396 #[test]
397 fn phase_check_rejects_write_http_response_reached_in_wrong_phase() {
398 let tid = TerminatorId::new(0);
403 let graph = SymbolicFlowGraph {
404 nodes: vec![Node::Terminate(tid), Node::Upgrade { next: NodeId::new(0) }],
405 predicates: vec![],
406 middlewares: vec![],
407 fetches: vec![],
408 terminators: vec![Terminator::WriteHttpResponse],
409 entries: {
410 let mut m = HashMap::new();
411 m.insert("127.0.0.1:443".parse().expect("parse"), NodeId::new(1));
412 m
413 },
414 meta: empty_meta(),
415 };
416 let err = check_phases(&graph).expect_err("must error");
417 assert!(err.to_string().contains("phase mismatch"));
418 }
419
420 fn dummy_predicate() -> crate::predicate::PredicateInst {
421 use crate::predicate::{CompiledOperator, CompiledValue, FieldPath, PredicateInst};
422 PredicateInst {
423 path: FieldPath::TlsSni,
424 op: CompiledOperator::Equals(CompiledValue::Str(std::sync::Arc::from("x"))),
425 }
426 }
427
428 const _: BodySide = BodySide::Request;
431}