Skip to main content

vane_core/compile/
analyze.rs

1use crate::compile::expand::RawRuleSet;
2use crate::error::{Diagnostics, Error};
3use crate::fetch::FetchKind;
4use crate::metadata::{FetchMetadataProvider, MiddlewareMetadataProvider};
5use crate::predicate::{FieldPath, Predicate};
6use crate::rule::RawRule;
7
8#[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug)]
9pub enum InspectionLevel {
10	L4Only,
11	L4Peek,
12	L7Header,
13	L7Body,
14}
15
16#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
17pub enum Posture {
18	L4,
19	L7,
20}
21
22#[derive(Debug, Clone)]
23pub struct AnalyzedRule {
24	pub raw: RawRule,
25	pub inspection_level: InspectionLevel,
26	pub specificity: usize,
27	pub posture: Posture,
28	pub needs_request_body: bool,
29	pub needs_response_body: bool,
30}
31
32#[derive(Debug, Clone)]
33pub struct AnalyzedRuleSet {
34	pub rules: Vec<AnalyzedRule>,
35	pub source_files: Vec<std::path::PathBuf>,
36}
37
38/// Compute per-rule inspection level, specificity, posture (L4 vs L7), and
39/// `LazyBuffer` per-side buffer triggers.
40///
41/// # Errors
42/// Returns [`Error::compile`] when a referenced middleware name is missing
43/// from the provider registry (so compile-time analysis cannot decide what
44/// phase it sits in or whether it buffers the body).
45pub fn analyze(
46	set: RawRuleSet,
47	mw_meta: &dyn MiddlewareMetadataProvider,
48	fetch_meta: &dyn FetchMetadataProvider,
49) -> Result<AnalyzedRuleSet, Error> {
50	let (rules, d) = analyze_collecting(set, mw_meta, fetch_meta);
51	d.into_result(rules).map_err(Error::from)
52}
53
54/// Push+continue form of [`analyze`]: every rule is analyzed
55/// independently; per-rule errors are collected and the offending
56/// rule is dropped from the returned [`AnalyzedRuleSet`]. The caller
57/// uses [`Diagnostics::has_fatal`] at the stage boundary to decide
58/// whether to bail or feed the (partial) set into the next stage —
59/// today the compile pipeline always bails because every downstream
60/// stage assumes a complete rule set, but the partial set is still
61/// useful for the dry-run dump endpoint.
62pub fn analyze_collecting(
63	set: RawRuleSet,
64	mw_meta: &dyn MiddlewareMetadataProvider,
65	fetch_meta: &dyn FetchMetadataProvider,
66) -> (AnalyzedRuleSet, Diagnostics) {
67	let mut analyzed = Vec::with_capacity(set.rules.len());
68	let mut d = Diagnostics::new();
69	for raw in set.rules {
70		match analyze_rule(raw, mw_meta, fetch_meta) {
71			Ok(rule) => analyzed.push(rule),
72			Err(e) => d.push(e),
73		}
74	}
75	(AnalyzedRuleSet { rules: analyzed, source_files: set.source_files }, d)
76}
77
78fn analyze_rule(
79	raw: RawRule,
80	mw_meta: &dyn MiddlewareMetadataProvider,
81	fetch_meta: &dyn FetchMetadataProvider,
82) -> Result<AnalyzedRule, Error> {
83	// Per-rule TLS validation runs at the analyze stage so the lower
84	// pass — which aggregates resolved specs into per-listener pools —
85	// can assume each `TlsConfig` is internally consistent. Surfacing
86	// the violation through the rule name keeps multi-file configs
87	// debuggable.
88	if let Some(tls) = raw.tls.as_ref() {
89		tls.validate().map_err(|e| Error::compile(format!("rule {:?}: {}", raw.name, e)))?;
90	}
91
92	let fetch_kind = Some(raw.terminate.kind);
93	let fetch_phase = fetch_phase_of(fetch_kind);
94
95	let mut max_level = InspectionLevel::L4Only;
96	let mut specificity = 0usize;
97	let mut reads_http_body = false;
98	if let Some(pred) = &raw.match_predicate {
99		// Bound predicate nesting depth before any recursive walker
100		// (here, in lower, or in collect_levels) touches the tree — a
101		// pathologically nested operator-authored rule should fail
102		// loud at compile, not crash the recursive walks at runtime.
103		crate::predicate::check_max_depth(pred)
104			.map_err(|e| Error::compile(format!("rule {:?}: {}", raw.name, e)))?;
105		walk_predicate(pred, &mut |p| match p {
106			Predicate::Check(c) => {
107				specificity += 1;
108				let lvl = field_path_inspection_level(&c.path);
109				if lvl > max_level {
110					max_level = lvl;
111				}
112				if matches!(c.path, FieldPath::HttpBody) {
113					reads_http_body = true;
114				}
115			}
116			Predicate::AnyOf(_) | Predicate::AllOf(_) | Predicate::Not(_) => {}
117		});
118	}
119
120	let mut needs_request_body = reads_http_body;
121	let mut needs_response_body = false;
122	for mw_ref in &raw.middleware_chain {
123		let meta = mw_meta
124			.get(&mw_ref.name)
125			.ok_or_else(|| Error::compile(format!("unknown middleware: {:?}", mw_ref.name)))?;
126		// Schema-check each middleware's args at the analyze stage so a
127		// rule with a misspelled key fails compile loudly, instead of
128		// surfacing at runtime when the middleware instantiates.
129		(meta.validate_args)(&mw_ref.args).map_err(|e| {
130			Error::compile(format!("rule {:?}: middleware {:?} args invalid: {e}", raw.name, mw_ref.name))
131		})?;
132		if meta.needs_body {
133			match meta.kind {
134				crate::middleware::MiddlewareKind::L7Request => needs_request_body = true,
135				crate::middleware::MiddlewareKind::L7Response => needs_response_body = true,
136				crate::middleware::MiddlewareKind::L4Peek | crate::middleware::MiddlewareKind::L4Bytes => {}
137			}
138		}
139	}
140
141	// Schema-check the terminator's args via `fetch_meta`. Unknown
142	// fetch kinds are reported here too, matching the way the link
143	// pass would surface them — but now with the rule name attached.
144	if let Some(kind) = fetch_kind {
145		let meta = fetch_meta.get(kind).ok_or_else(|| {
146			Error::compile(format!("rule {:?}: unknown fetch kind {:?}", raw.name, kind))
147		})?;
148		(meta.validate_args)(&raw.terminate.args).map_err(|e| {
149			Error::compile(format!("rule {:?}: terminate.args for {:?} invalid: {e}", raw.name, kind))
150		})?;
151	}
152
153	let posture = match fetch_phase {
154		FetchPhase::L4 if max_level <= InspectionLevel::L4Peek => Posture::L4,
155		FetchPhase::L4 => {
156			return Err(Error::compile(format!(
157				"rule {:?}: L7-level predicate on an L4 fetch is invalid",
158				raw.name
159			)));
160		}
161		FetchPhase::L7 => Posture::L7,
162	};
163
164	Ok(AnalyzedRule {
165		raw,
166		inspection_level: max_level,
167		specificity,
168		posture,
169		needs_request_body,
170		needs_response_body,
171	})
172}
173
174#[derive(Copy, Clone, Eq, PartialEq, Debug)]
175enum FetchPhase {
176	L4,
177	L7,
178}
179
180const fn fetch_phase_of(kind: Option<FetchKind>) -> FetchPhase {
181	match kind {
182		Some(FetchKind::L4Forward) => FetchPhase::L4,
183		_ => FetchPhase::L7,
184	}
185}
186
187/// Pre-order walk over a predicate tree using an explicit stack.
188///
189/// Depth is bounded by [`crate::predicate::MAX_PREDICATE_DEPTH`]
190/// thanks to the upstream `check_max_depth` guard in `analyze_rule`,
191/// but the iterative form keeps the walker independent of the system
192/// stack and matches the spec recommendation to mirror
193/// `check_acyclic`'s explicit-stack shape.
194fn walk_predicate(root: &Predicate, f: &mut impl FnMut(&Predicate)) {
195	let mut stack: Vec<&Predicate> = vec![root];
196	while let Some(p) = stack.pop() {
197		f(p);
198		match p {
199			Predicate::AnyOf(a) => {
200				for child in a.any_of.iter().rev() {
201					stack.push(child);
202				}
203			}
204			Predicate::AllOf(a) => {
205				for child in a.all_of.iter().rev() {
206					stack.push(child);
207				}
208			}
209			Predicate::Not(n) => stack.push(n.not.as_ref()),
210			Predicate::Check(_) => {}
211		}
212	}
213}
214
215const fn field_path_inspection_level(path: &FieldPath) -> InspectionLevel {
216	match path {
217		FieldPath::Transport
218		| FieldPath::RemoteIp
219		| FieldPath::RemotePort
220		| FieldPath::LocalIp
221		| FieldPath::LocalPort => InspectionLevel::L4Only,
222		FieldPath::Peek
223		| FieldPath::TlsSni
224		| FieldPath::TlsAlpn
225		| FieldPath::TlsVersion
226		| FieldPath::TlsPeerCertPresent
227		| FieldPath::TlsPeerCertSubjectCn
228		| FieldPath::TlsPeerCertSanDns
229		| FieldPath::TlsPeerCertFingerprintSha256
230		| FieldPath::TlsPeerCertSpkiSha256
231		| FieldPath::TlsPeerCertIssuerCn
232		| FieldPath::TlsPeerCertSerial => InspectionLevel::L4Peek,
233		FieldPath::HttpMethod
234		| FieldPath::HttpUriPath
235		| FieldPath::HttpUriQuery
236		| FieldPath::HttpHeader(_) => InspectionLevel::L7Header,
237		FieldPath::HttpBody => InspectionLevel::L7Body,
238	}
239}
240
241#[cfg(test)]
242mod tests {
243	use super::*;
244	use crate::compile::expand::RawRuleSet;
245	use crate::fetch::{FetchOutputModes, FetchPhase as FetchMetaPhase};
246	use crate::metadata::{FetchMetadata, MiddlewareMetadata};
247	use crate::middleware::MiddlewareKind;
248	use serde_json::Value;
249
250	struct Providers;
251
252	fn validate_ok(_: &Value) -> Result<(), Error> {
253		Ok(())
254	}
255
256	impl MiddlewareMetadataProvider for Providers {
257		fn get(&self, name: &str) -> Option<MiddlewareMetadata> {
258			match name {
259				"req_plain" => Some(MiddlewareMetadata {
260					kind: MiddlewareKind::L7Request,
261					stateless: true,
262					needs_body: false,
263					validate_args: validate_ok,
264				}),
265				"req_body" => Some(MiddlewareMetadata {
266					kind: MiddlewareKind::L7Request,
267					stateless: true,
268					needs_body: true,
269					validate_args: validate_ok,
270				}),
271				"resp_body" => Some(MiddlewareMetadata {
272					kind: MiddlewareKind::L7Response,
273					stateless: true,
274					needs_body: true,
275					validate_args: validate_ok,
276				}),
277				_ => None,
278			}
279		}
280	}
281
282	impl FetchMetadataProvider for Providers {
283		fn get(&self, kind: FetchKind) -> Option<FetchMetadata> {
284			Some(FetchMetadata {
285				kind,
286				phase: match kind {
287					FetchKind::L4Forward => FetchMetaPhase::L4,
288					_ => FetchMetaPhase::L7,
289				},
290				output_modes: match kind {
291					FetchKind::L4Forward => FetchOutputModes { response: false, tunnel: true },
292					FetchKind::WebSocketUpgrade => FetchOutputModes { response: true, tunnel: true },
293					_ => FetchOutputModes { response: true, tunnel: false },
294				},
295				validate_args: validate_ok,
296			})
297		}
298	}
299
300	fn set(rules: Vec<RawRule>) -> RawRuleSet {
301		RawRuleSet { rules, source_files: vec![] }
302	}
303
304	fn parse_rule(j: serde_json::Value) -> RawRule {
305		serde_json::from_value(j).expect("parse rule")
306	}
307
308	#[test]
309	fn http_body_predicate_sets_request_body_flag_and_l7body_level() {
310		let rule = parse_rule(serde_json::json!({
311			"name": "r",
312			"listen": [":443"],
313			"match": { "http.body": { "contains": "admin" } },
314			"terminate": { "type": "http_proxy" },
315		}));
316		let out = analyze(set(vec![rule]), &Providers, &Providers).expect("analyze");
317		let a = &out.rules[0];
318		assert!(a.needs_request_body);
319		assert!(!a.needs_response_body);
320		assert_eq!(a.inspection_level, InspectionLevel::L7Body);
321		assert_eq!(a.posture, Posture::L7);
322	}
323
324	#[test]
325	fn l7_request_needs_body_middleware_flags_request_side() {
326		let rule = parse_rule(serde_json::json!({
327			"name": "r",
328			"listen": [":443"],
329			"middleware_chain": [{ "use": "req_body" }],
330			"terminate": { "type": "http_proxy" },
331		}));
332		let out = analyze(set(vec![rule]), &Providers, &Providers).expect("analyze");
333		assert!(out.rules[0].needs_request_body);
334		assert!(!out.rules[0].needs_response_body);
335	}
336
337	#[test]
338	fn l7_response_needs_body_middleware_flags_response_side() {
339		let rule = parse_rule(serde_json::json!({
340			"name": "r",
341			"listen": [":443"],
342			"middleware_chain": [{ "use": "resp_body" }],
343			"terminate": { "type": "http_proxy" },
344		}));
345		let out = analyze(set(vec![rule]), &Providers, &Providers).expect("analyze");
346		assert!(!out.rules[0].needs_request_body);
347		assert!(out.rules[0].needs_response_body);
348	}
349
350	#[test]
351	fn l4_fetch_with_l7_predicate_errors() {
352		let rule = parse_rule(serde_json::json!({
353			"name": "r",
354			"listen": [":22"],
355			"match": { "http.method": { "equals": "GET" } },
356			"terminate": { "type": "tcp_forward", "upstream": "10.0.0.1:22" },
357		}));
358		let err = analyze(set(vec![rule]), &Providers, &Providers).expect_err("must error");
359		assert!(err.to_string().contains("L7-level predicate"));
360	}
361
362	#[test]
363	fn unknown_middleware_name_errors() {
364		let rule = parse_rule(serde_json::json!({
365			"name": "r",
366			"listen": [":443"],
367			"middleware_chain": [{ "use": "does_not_exist" }],
368			"terminate": { "type": "http_proxy" },
369		}));
370		let err = analyze(set(vec![rule]), &Providers, &Providers).expect_err("must error");
371		assert!(err.to_string().contains("does_not_exist"));
372	}
373
374	#[test]
375	fn rejects_middleware_args_failing_validate() {
376		// Providers' `req_plain` accepts any args, but `validate_ok`
377		// override below specifically rejects null args for the dummy
378		// middleware named "strict_args".
379		struct StrictProviders;
380		fn reject_null(v: &Value) -> Result<(), Error> {
381			if matches!(v, Value::Null) { Err(Error::compile("args must not be null")) } else { Ok(()) }
382		}
383		impl MiddlewareMetadataProvider for StrictProviders {
384			fn get(&self, name: &str) -> Option<MiddlewareMetadata> {
385				if name == "strict_args" {
386					Some(MiddlewareMetadata {
387						kind: MiddlewareKind::L7Request,
388						stateless: true,
389						needs_body: false,
390						validate_args: reject_null,
391					})
392				} else {
393					None
394				}
395			}
396		}
397		impl FetchMetadataProvider for StrictProviders {
398			fn get(&self, kind: FetchKind) -> Option<FetchMetadata> {
399				Some(FetchMetadata {
400					kind,
401					phase: FetchMetaPhase::L7,
402					output_modes: FetchOutputModes { response: true, tunnel: false },
403					validate_args: |_| Ok(()),
404				})
405			}
406		}
407		let rule = parse_rule(serde_json::json!({
408			"name": "r",
409			"listen": [":443"],
410			"middleware_chain": [{ "use": "strict_args" }],
411			"terminate": { "type": "http_proxy" },
412		}));
413		let err = analyze(set(vec![rule]), &StrictProviders, &StrictProviders)
414			.expect_err("must reject bad middleware args");
415		let msg = err.to_string();
416		assert!(msg.contains("strict_args"), "{msg}");
417		assert!(msg.contains("args invalid") || msg.contains("must not be null"), "{msg}");
418	}
419
420	#[test]
421	fn rejects_terminate_args_failing_validate() {
422		struct StrictProviders;
423		fn require_port(v: &Value) -> Result<(), Error> {
424			let ok = matches!(v, Value::Object(m) if m.get("port").is_some());
425			if ok { Ok(()) } else { Err(Error::compile("missing required `port` arg")) }
426		}
427		impl MiddlewareMetadataProvider for StrictProviders {
428			fn get(&self, _: &str) -> Option<MiddlewareMetadata> {
429				None
430			}
431		}
432		impl FetchMetadataProvider for StrictProviders {
433			fn get(&self, kind: FetchKind) -> Option<FetchMetadata> {
434				Some(FetchMetadata {
435					kind,
436					phase: FetchMetaPhase::L7,
437					output_modes: FetchOutputModes { response: true, tunnel: false },
438					validate_args: require_port,
439				})
440			}
441		}
442		let rule = parse_rule(serde_json::json!({
443			"name": "r",
444			"listen": [":443"],
445			"terminate": { "type": "http_proxy" },
446		}));
447		let err = analyze(set(vec![rule]), &StrictProviders, &StrictProviders)
448			.expect_err("must reject missing terminate args");
449		let msg = err.to_string();
450		assert!(msg.contains("terminate.args"), "{msg}");
451		assert!(msg.contains("missing required `port` arg"), "{msg}");
452	}
453
454	#[test]
455	fn rejects_predicate_nested_deeper_than_max_predicate_depth() {
456		// Build `not(not(not(... check ...)))` over `MAX_PREDICATE_DEPTH+1`
457		// levels — straight chains are the easiest pathological shape.
458		let depth = crate::predicate::MAX_PREDICATE_DEPTH + 1;
459		let mut inner = serde_json::json!({ "tls.sni": { "equals": "a" } });
460		for _ in 0..depth {
461			inner = serde_json::json!({ "not": inner });
462		}
463		let raw = serde_json::json!({
464			"name": "r",
465			"listen": [":443"],
466			"match": inner,
467			"terminate": { "type": "http_proxy" },
468		});
469		let rule: crate::rule::RawRule = serde_json::from_value(raw).expect("parse");
470		let err =
471			analyze(set(vec![rule]), &Providers, &Providers).expect_err("deep predicate must reject");
472		assert!(err.to_string().contains("MAX_PREDICATE_DEPTH"), "{err}");
473	}
474
475	#[test]
476	fn accepts_predicate_at_max_predicate_depth() {
477		// Exactly MAX_PREDICATE_DEPTH levels of `not` wrapping a leaf
478		// Check must still compile.
479		let depth = crate::predicate::MAX_PREDICATE_DEPTH - 1;
480		let mut inner = serde_json::json!({ "tls.sni": { "equals": "a" } });
481		for _ in 0..depth {
482			inner = serde_json::json!({ "not": inner });
483		}
484		let raw = serde_json::json!({
485			"name": "r",
486			"listen": [":443"],
487			"match": inner,
488			"terminate": { "type": "http_proxy" },
489		});
490		let rule: crate::rule::RawRule = serde_json::from_value(raw).expect("parse");
491		analyze(set(vec![rule]), &Providers, &Providers).expect("at-limit predicate compiles");
492	}
493
494	#[test]
495	fn specificity_counts_check_predicates() {
496		let rule = parse_rule(serde_json::json!({
497			"name": "r",
498			"listen": [":443"],
499			"match": {
500				"any_of": [
501					{ "tls.sni": { "equals": "a" } },
502					{ "tls.sni": { "equals": "b" } },
503				],
504			},
505			"terminate": { "type": "http_proxy" },
506		}));
507		let out = analyze(set(vec![rule]), &Providers, &Providers).expect("analyze");
508		assert_eq!(out.rules[0].specificity, 2);
509	}
510}