Skip to main content

vane_core/compile/
analyze.rs

1use crate::compile::expand::RawRuleSet;
2use crate::error::Error;
3use crate::fetch::FetchKind;
4use crate::metadata::{FetchMetadataProvider, MiddlewareMetadataProvider};
5use crate::predicate::{FieldPath, Predicate};
6use crate::rule::RawRule;
7
8#[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug)]
9pub enum InspectionLevel {
10	L4Only,
11	L4Peek,
12	L7Header,
13	L7Body,
14}
15
16#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
17pub enum Posture {
18	L4,
19	L7,
20}
21
22#[derive(Debug, Clone)]
23pub struct AnalyzedRule {
24	pub raw: RawRule,
25	pub inspection_level: InspectionLevel,
26	pub specificity: usize,
27	pub posture: Posture,
28	pub needs_request_body: bool,
29	pub needs_response_body: bool,
30}
31
32#[derive(Debug, Clone)]
33pub struct AnalyzedRuleSet {
34	pub rules: Vec<AnalyzedRule>,
35	pub source_files: Vec<std::path::PathBuf>,
36}
37
38/// Compute per-rule inspection level, specificity, posture (L4 vs L7), and
39/// `LazyBuffer` per-side buffer triggers.
40///
41/// # Errors
42/// Returns [`Error::compile`] when a referenced middleware name is missing
43/// from the provider registry (so compile-time analysis cannot decide what
44/// phase it sits in or whether it buffers the body).
45pub fn analyze(
46	set: RawRuleSet,
47	mw_meta: &dyn MiddlewareMetadataProvider,
48	fetch_meta: &dyn FetchMetadataProvider,
49) -> Result<AnalyzedRuleSet, Error> {
50	let mut analyzed = Vec::with_capacity(set.rules.len());
51	for raw in set.rules {
52		analyzed.push(analyze_rule(raw, mw_meta, fetch_meta)?);
53	}
54	Ok(AnalyzedRuleSet { rules: analyzed, source_files: set.source_files })
55}
56
57fn analyze_rule(
58	raw: RawRule,
59	mw_meta: &dyn MiddlewareMetadataProvider,
60	fetch_meta: &dyn FetchMetadataProvider,
61) -> Result<AnalyzedRule, Error> {
62	// Per-rule TLS validation runs at the analyze stage so the lower
63	// pass — which aggregates resolved specs into per-listener pools —
64	// can assume each `TlsConfig` is internally consistent. Surfacing
65	// the violation through the rule name keeps multi-file configs
66	// debuggable.
67	if let Some(tls) = raw.tls.as_ref() {
68		tls.validate().map_err(|e| Error::compile(format!("rule {:?}: {}", raw.name, e)))?;
69	}
70
71	let fetch_kind = Some(raw.terminate.kind);
72	let fetch_phase = fetch_phase_of(fetch_kind);
73
74	let mut max_level = InspectionLevel::L4Only;
75	let mut specificity = 0usize;
76	let mut reads_http_body = false;
77	if let Some(pred) = &raw.match_predicate {
78		walk_predicate(pred, &mut |p| match p {
79			Predicate::Check(c) => {
80				specificity += 1;
81				let lvl = field_path_inspection_level(&c.path);
82				if lvl > max_level {
83					max_level = lvl;
84				}
85				if matches!(c.path, FieldPath::HttpBody) {
86					reads_http_body = true;
87				}
88			}
89			Predicate::AnyOf(_) | Predicate::AllOf(_) | Predicate::Not(_) => {}
90		});
91	}
92
93	let mut needs_request_body = reads_http_body;
94	let mut needs_response_body = false;
95	for mw_ref in &raw.middleware_chain {
96		let meta = mw_meta
97			.get(&mw_ref.name)
98			.ok_or_else(|| Error::compile(format!("unknown middleware: {:?}", mw_ref.name)))?;
99		if meta.needs_body {
100			match meta.kind {
101				crate::middleware::MiddlewareKind::L7Request => needs_request_body = true,
102				crate::middleware::MiddlewareKind::L7Response => needs_response_body = true,
103				crate::middleware::MiddlewareKind::L4Peek | crate::middleware::MiddlewareKind::L4Bytes => {}
104			}
105		}
106	}
107
108	// fetch_meta is consulted so unknown kinds fail compile consistently with
109	// how link will fail later; the metadata itself is not currently consumed
110	// in analyze (phase comes from the fixed FetchKind table below).
111	let _ = fetch_meta;
112
113	let posture = match fetch_phase {
114		FetchPhase::L4 if max_level <= InspectionLevel::L4Peek => Posture::L4,
115		FetchPhase::L4 => {
116			return Err(Error::compile(format!(
117				"rule {:?}: L7-level predicate on an L4 fetch is invalid",
118				raw.name
119			)));
120		}
121		FetchPhase::L7 => Posture::L7,
122	};
123
124	Ok(AnalyzedRule {
125		raw,
126		inspection_level: max_level,
127		specificity,
128		posture,
129		needs_request_body,
130		needs_response_body,
131	})
132}
133
134#[derive(Copy, Clone, Eq, PartialEq, Debug)]
135enum FetchPhase {
136	L4,
137	L7,
138}
139
140const fn fetch_phase_of(kind: Option<FetchKind>) -> FetchPhase {
141	match kind {
142		Some(FetchKind::L4Forward) => FetchPhase::L4,
143		_ => FetchPhase::L7,
144	}
145}
146
147fn walk_predicate(p: &Predicate, f: &mut impl FnMut(&Predicate)) {
148	f(p);
149	match p {
150		Predicate::AnyOf(a) => {
151			for child in &a.any_of {
152				walk_predicate(child, f);
153			}
154		}
155		Predicate::AllOf(a) => {
156			for child in &a.all_of {
157				walk_predicate(child, f);
158			}
159		}
160		Predicate::Not(n) => walk_predicate(&n.not, f),
161		Predicate::Check(_) => {}
162	}
163}
164
165const fn field_path_inspection_level(path: &FieldPath) -> InspectionLevel {
166	match path {
167		FieldPath::Transport
168		| FieldPath::RemoteIp
169		| FieldPath::RemotePort
170		| FieldPath::LocalIp
171		| FieldPath::LocalPort => InspectionLevel::L4Only,
172		FieldPath::Peek
173		| FieldPath::TlsSni
174		| FieldPath::TlsAlpn
175		| FieldPath::TlsVersion
176		| FieldPath::TlsPeerCertPresent
177		| FieldPath::TlsPeerCertSubjectCn
178		| FieldPath::TlsPeerCertSanDns
179		| FieldPath::TlsPeerCertFingerprintSha256
180		| FieldPath::TlsPeerCertSpkiSha256
181		| FieldPath::TlsPeerCertIssuerCn
182		| FieldPath::TlsPeerCertSerial => InspectionLevel::L4Peek,
183		FieldPath::HttpMethod
184		| FieldPath::HttpUriPath
185		| FieldPath::HttpUriQuery
186		| FieldPath::HttpHeader(_) => InspectionLevel::L7Header,
187		FieldPath::HttpBody => InspectionLevel::L7Body,
188	}
189}
190
191#[cfg(test)]
192mod tests {
193	use super::*;
194	use crate::compile::expand::RawRuleSet;
195	use crate::fetch::{FetchOutputModes, FetchPhase as FetchMetaPhase};
196	use crate::metadata::{FetchMetadata, MiddlewareMetadata};
197	use crate::middleware::MiddlewareKind;
198	use serde_json::Value;
199
200	struct Providers;
201
202	#[allow(clippy::unnecessary_wraps)]
203	fn validate_ok(_: &Value) -> Result<(), Error> {
204		Ok(())
205	}
206
207	impl MiddlewareMetadataProvider for Providers {
208		fn get(&self, name: &str) -> Option<MiddlewareMetadata> {
209			match name {
210				"req_plain" => Some(MiddlewareMetadata {
211					kind: MiddlewareKind::L7Request,
212					stateless: true,
213					needs_body: false,
214					validate_args: validate_ok,
215				}),
216				"req_body" => Some(MiddlewareMetadata {
217					kind: MiddlewareKind::L7Request,
218					stateless: true,
219					needs_body: true,
220					validate_args: validate_ok,
221				}),
222				"resp_body" => Some(MiddlewareMetadata {
223					kind: MiddlewareKind::L7Response,
224					stateless: true,
225					needs_body: true,
226					validate_args: validate_ok,
227				}),
228				_ => None,
229			}
230		}
231	}
232
233	impl FetchMetadataProvider for Providers {
234		fn get(&self, kind: FetchKind) -> Option<FetchMetadata> {
235			Some(FetchMetadata {
236				kind,
237				phase: match kind {
238					FetchKind::L4Forward => FetchMetaPhase::L4,
239					_ => FetchMetaPhase::L7,
240				},
241				output_modes: match kind {
242					FetchKind::L4Forward => FetchOutputModes { response: false, tunnel: true },
243					FetchKind::WebSocketUpgrade => FetchOutputModes { response: true, tunnel: true },
244					_ => FetchOutputModes { response: true, tunnel: false },
245				},
246				validate_args: validate_ok,
247			})
248		}
249	}
250
251	fn set(rules: Vec<RawRule>) -> RawRuleSet {
252		RawRuleSet { rules, source_files: vec![] }
253	}
254
255	fn parse_rule(j: serde_json::Value) -> RawRule {
256		serde_json::from_value(j).expect("parse rule")
257	}
258
259	#[test]
260	fn http_body_predicate_sets_request_body_flag_and_l7body_level() {
261		let rule = parse_rule(serde_json::json!({
262			"name": "r",
263			"listen": [":443"],
264			"match": { "http.body": { "contains": "admin" } },
265			"terminate": { "type": "http_proxy" },
266		}));
267		let out = analyze(set(vec![rule]), &Providers, &Providers).expect("analyze");
268		let a = &out.rules[0];
269		assert!(a.needs_request_body);
270		assert!(!a.needs_response_body);
271		assert_eq!(a.inspection_level, InspectionLevel::L7Body);
272		assert_eq!(a.posture, Posture::L7);
273	}
274
275	#[test]
276	fn l7_request_needs_body_middleware_flags_request_side() {
277		let rule = parse_rule(serde_json::json!({
278			"name": "r",
279			"listen": [":443"],
280			"middleware_chain": [{ "use": "req_body" }],
281			"terminate": { "type": "http_proxy" },
282		}));
283		let out = analyze(set(vec![rule]), &Providers, &Providers).expect("analyze");
284		assert!(out.rules[0].needs_request_body);
285		assert!(!out.rules[0].needs_response_body);
286	}
287
288	#[test]
289	fn l7_response_needs_body_middleware_flags_response_side() {
290		let rule = parse_rule(serde_json::json!({
291			"name": "r",
292			"listen": [":443"],
293			"middleware_chain": [{ "use": "resp_body" }],
294			"terminate": { "type": "http_proxy" },
295		}));
296		let out = analyze(set(vec![rule]), &Providers, &Providers).expect("analyze");
297		assert!(!out.rules[0].needs_request_body);
298		assert!(out.rules[0].needs_response_body);
299	}
300
301	#[test]
302	fn l4_fetch_with_l7_predicate_errors() {
303		let rule = parse_rule(serde_json::json!({
304			"name": "r",
305			"listen": [":22"],
306			"match": { "http.method": { "equals": "GET" } },
307			"terminate": { "type": "tcp_forward", "upstream": "10.0.0.1:22" },
308		}));
309		let err = analyze(set(vec![rule]), &Providers, &Providers).expect_err("must error");
310		assert!(err.to_string().contains("L7-level predicate"));
311	}
312
313	#[test]
314	fn unknown_middleware_name_errors() {
315		let rule = parse_rule(serde_json::json!({
316			"name": "r",
317			"listen": [":443"],
318			"middleware_chain": [{ "use": "does_not_exist" }],
319			"terminate": { "type": "http_proxy" },
320		}));
321		let err = analyze(set(vec![rule]), &Providers, &Providers).expect_err("must error");
322		assert!(err.to_string().contains("does_not_exist"));
323	}
324
325	#[test]
326	fn specificity_counts_check_predicates() {
327		let rule = parse_rule(serde_json::json!({
328			"name": "r",
329			"listen": [":443"],
330			"match": {
331				"any_of": [
332					{ "tls.sni": { "equals": "a" } },
333					{ "tls.sni": { "equals": "b" } },
334				],
335			},
336			"terminate": { "type": "http_proxy" },
337		}));
338		let out = analyze(set(vec![rule]), &Providers, &Providers).expect("analyze");
339		assert_eq!(out.rules[0].specificity, 2);
340	}
341}