Skip to main content

vane_core/compile/
analyze.rs

1use crate::compile::expand::RawRuleSet;
2use crate::error::Error;
3use crate::fetch::FetchKind;
4use crate::metadata::{FetchMetadataProvider, MiddlewareMetadataProvider};
5use crate::predicate::{FieldPath, Predicate};
6use crate::rule::RawRule;
7
8#[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug)]
9pub enum InspectionLevel {
10	L4Only,
11	L4Peek,
12	L7Header,
13	L7Body,
14}
15
16#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
17pub enum Posture {
18	L4,
19	L7,
20}
21
22#[derive(Debug, Clone)]
23pub struct AnalyzedRule {
24	pub raw: RawRule,
25	pub inspection_level: InspectionLevel,
26	pub specificity: usize,
27	pub posture: Posture,
28	pub needs_request_body: bool,
29	pub needs_response_body: bool,
30}
31
32#[derive(Debug, Clone)]
33pub struct AnalyzedRuleSet {
34	pub rules: Vec<AnalyzedRule>,
35	pub source_files: Vec<std::path::PathBuf>,
36}
37
38/// Compute per-rule inspection level, specificity, posture (L4 vs L7), and
39/// `LazyBuffer` per-side buffer triggers.
40///
41/// # Errors
42/// Returns [`Error::compile`] when a referenced middleware name is missing
43/// from the provider registry (so compile-time analysis cannot decide what
44/// phase it sits in or whether it buffers the body).
45pub fn analyze(
46	set: RawRuleSet,
47	mw_meta: &dyn MiddlewareMetadataProvider,
48	fetch_meta: &dyn FetchMetadataProvider,
49) -> Result<AnalyzedRuleSet, Error> {
50	let mut analyzed = Vec::with_capacity(set.rules.len());
51	for raw in set.rules {
52		analyzed.push(analyze_rule(raw, mw_meta, fetch_meta)?);
53	}
54	Ok(AnalyzedRuleSet { rules: analyzed, source_files: set.source_files })
55}
56
57fn analyze_rule(
58	raw: RawRule,
59	mw_meta: &dyn MiddlewareMetadataProvider,
60	fetch_meta: &dyn FetchMetadataProvider,
61) -> Result<AnalyzedRule, Error> {
62	let fetch_kind = Some(raw.terminate.kind);
63	let fetch_phase = fetch_phase_of(fetch_kind);
64
65	let mut max_level = InspectionLevel::L4Only;
66	let mut specificity = 0usize;
67	let mut reads_http_body = false;
68	if let Some(pred) = &raw.match_predicate {
69		walk_predicate(pred, &mut |p| match p {
70			Predicate::Check(c) => {
71				specificity += 1;
72				let lvl = field_path_inspection_level(&c.path);
73				if lvl > max_level {
74					max_level = lvl;
75				}
76				if matches!(c.path, FieldPath::HttpBody) {
77					reads_http_body = true;
78				}
79			}
80			Predicate::AnyOf(_) | Predicate::AllOf(_) | Predicate::Not(_) => {}
81		});
82	}
83
84	let mut needs_request_body = reads_http_body;
85	let mut needs_response_body = false;
86	for mw_ref in &raw.middleware_chain {
87		let meta = mw_meta
88			.get(&mw_ref.name)
89			.ok_or_else(|| Error::compile(format!("unknown middleware: {:?}", mw_ref.name)))?;
90		if meta.needs_body {
91			match meta.kind {
92				crate::middleware::MiddlewareKind::L7Request => needs_request_body = true,
93				crate::middleware::MiddlewareKind::L7Response => needs_response_body = true,
94				crate::middleware::MiddlewareKind::L4Peek | crate::middleware::MiddlewareKind::L4Bytes => {}
95			}
96		}
97	}
98
99	// fetch_meta is consulted so unknown kinds fail compile consistently with
100	// how link will fail later; the metadata itself is not currently consumed
101	// in analyze (phase comes from the fixed FetchKind table below).
102	let _ = fetch_meta;
103
104	let posture = match fetch_phase {
105		FetchPhase::L4 if max_level <= InspectionLevel::L4Peek => Posture::L4,
106		FetchPhase::L4 => {
107			return Err(Error::compile(format!(
108				"rule {:?}: L7-level predicate on an L4 fetch is invalid",
109				raw.name
110			)));
111		}
112		FetchPhase::L7 => Posture::L7,
113	};
114
115	Ok(AnalyzedRule {
116		raw,
117		inspection_level: max_level,
118		specificity,
119		posture,
120		needs_request_body,
121		needs_response_body,
122	})
123}
124
125#[derive(Copy, Clone, Eq, PartialEq, Debug)]
126enum FetchPhase {
127	L4,
128	L7,
129}
130
131const fn fetch_phase_of(kind: Option<FetchKind>) -> FetchPhase {
132	match kind {
133		Some(FetchKind::L4Forward) => FetchPhase::L4,
134		_ => FetchPhase::L7,
135	}
136}
137
138fn walk_predicate(p: &Predicate, f: &mut impl FnMut(&Predicate)) {
139	f(p);
140	match p {
141		Predicate::AnyOf(a) => {
142			for child in &a.any_of {
143				walk_predicate(child, f);
144			}
145		}
146		Predicate::AllOf(a) => {
147			for child in &a.all_of {
148				walk_predicate(child, f);
149			}
150		}
151		Predicate::Not(n) => walk_predicate(&n.not, f),
152		Predicate::Check(_) => {}
153	}
154}
155
156const fn field_path_inspection_level(path: &FieldPath) -> InspectionLevel {
157	match path {
158		FieldPath::Transport
159		| FieldPath::RemoteIp
160		| FieldPath::RemotePort
161		| FieldPath::LocalIp
162		| FieldPath::LocalPort => InspectionLevel::L4Only,
163		FieldPath::Peek
164		| FieldPath::TlsSni
165		| FieldPath::TlsAlpn
166		| FieldPath::TlsVersion
167		| FieldPath::TlsPeerCertPresent
168		| FieldPath::TlsPeerCertSubjectCn
169		| FieldPath::TlsPeerCertSanDns
170		| FieldPath::TlsPeerCertFingerprintSha256
171		| FieldPath::TlsPeerCertSpkiSha256
172		| FieldPath::TlsPeerCertIssuerCn
173		| FieldPath::TlsPeerCertSerial => InspectionLevel::L4Peek,
174		FieldPath::HttpMethod
175		| FieldPath::HttpUriPath
176		| FieldPath::HttpUriQuery
177		| FieldPath::HttpHeader(_) => InspectionLevel::L7Header,
178		FieldPath::HttpBody => InspectionLevel::L7Body,
179	}
180}
181
182#[cfg(test)]
183mod tests {
184	use super::*;
185	use crate::compile::expand::RawRuleSet;
186	use crate::fetch::{FetchOutputModes, FetchPhase as FetchMetaPhase};
187	use crate::metadata::{FetchMetadata, MiddlewareMetadata};
188	use crate::middleware::MiddlewareKind;
189	use serde_json::Value;
190
191	struct Providers;
192
193	#[allow(clippy::unnecessary_wraps)]
194	fn validate_ok(_: &Value) -> Result<(), Error> {
195		Ok(())
196	}
197
198	impl MiddlewareMetadataProvider for Providers {
199		fn get(&self, name: &str) -> Option<MiddlewareMetadata> {
200			match name {
201				"req_plain" => Some(MiddlewareMetadata {
202					kind: MiddlewareKind::L7Request,
203					stateless: true,
204					needs_body: false,
205					validate_args: validate_ok,
206				}),
207				"req_body" => Some(MiddlewareMetadata {
208					kind: MiddlewareKind::L7Request,
209					stateless: true,
210					needs_body: true,
211					validate_args: validate_ok,
212				}),
213				"resp_body" => Some(MiddlewareMetadata {
214					kind: MiddlewareKind::L7Response,
215					stateless: true,
216					needs_body: true,
217					validate_args: validate_ok,
218				}),
219				_ => None,
220			}
221		}
222	}
223
224	impl FetchMetadataProvider for Providers {
225		fn get(&self, kind: FetchKind) -> Option<FetchMetadata> {
226			Some(FetchMetadata {
227				kind,
228				phase: match kind {
229					FetchKind::L4Forward => FetchMetaPhase::L4,
230					_ => FetchMetaPhase::L7,
231				},
232				output_modes: match kind {
233					FetchKind::L4Forward => FetchOutputModes { response: false, tunnel: true },
234					FetchKind::WebSocketUpgrade => FetchOutputModes { response: true, tunnel: true },
235					_ => FetchOutputModes { response: true, tunnel: false },
236				},
237				validate_args: validate_ok,
238			})
239		}
240	}
241
242	fn set(rules: Vec<RawRule>) -> RawRuleSet {
243		RawRuleSet { rules, source_files: vec![] }
244	}
245
246	fn parse_rule(j: serde_json::Value) -> RawRule {
247		serde_json::from_value(j).expect("parse rule")
248	}
249
250	#[test]
251	fn http_body_predicate_sets_request_body_flag_and_l7body_level() {
252		let rule = parse_rule(serde_json::json!({
253			"name": "r",
254			"listen": [":443"],
255			"match": { "http.body": { "contains": "admin" } },
256			"terminate": { "type": "http_proxy" },
257		}));
258		let out = analyze(set(vec![rule]), &Providers, &Providers).expect("analyze");
259		let a = &out.rules[0];
260		assert!(a.needs_request_body);
261		assert!(!a.needs_response_body);
262		assert_eq!(a.inspection_level, InspectionLevel::L7Body);
263		assert_eq!(a.posture, Posture::L7);
264	}
265
266	#[test]
267	fn l7_request_needs_body_middleware_flags_request_side() {
268		let rule = parse_rule(serde_json::json!({
269			"name": "r",
270			"listen": [":443"],
271			"middleware_chain": [{ "use": "req_body" }],
272			"terminate": { "type": "http_proxy" },
273		}));
274		let out = analyze(set(vec![rule]), &Providers, &Providers).expect("analyze");
275		assert!(out.rules[0].needs_request_body);
276		assert!(!out.rules[0].needs_response_body);
277	}
278
279	#[test]
280	fn l7_response_needs_body_middleware_flags_response_side() {
281		let rule = parse_rule(serde_json::json!({
282			"name": "r",
283			"listen": [":443"],
284			"middleware_chain": [{ "use": "resp_body" }],
285			"terminate": { "type": "http_proxy" },
286		}));
287		let out = analyze(set(vec![rule]), &Providers, &Providers).expect("analyze");
288		assert!(!out.rules[0].needs_request_body);
289		assert!(out.rules[0].needs_response_body);
290	}
291
292	#[test]
293	fn l4_fetch_with_l7_predicate_errors() {
294		let rule = parse_rule(serde_json::json!({
295			"name": "r",
296			"listen": [":22"],
297			"match": { "http.method": { "equals": "GET" } },
298			"terminate": { "type": "tcp_forward", "upstream": "10.0.0.1:22" },
299		}));
300		let err = analyze(set(vec![rule]), &Providers, &Providers).expect_err("must error");
301		assert!(err.to_string().contains("L7-level predicate"));
302	}
303
304	#[test]
305	fn unknown_middleware_name_errors() {
306		let rule = parse_rule(serde_json::json!({
307			"name": "r",
308			"listen": [":443"],
309			"middleware_chain": [{ "use": "does_not_exist" }],
310			"terminate": { "type": "http_proxy" },
311		}));
312		let err = analyze(set(vec![rule]), &Providers, &Providers).expect_err("must error");
313		assert!(err.to_string().contains("does_not_exist"));
314	}
315
316	#[test]
317	fn specificity_counts_check_predicates() {
318		let rule = parse_rule(serde_json::json!({
319			"name": "r",
320			"listen": [":443"],
321			"match": {
322				"any_of": [
323					{ "tls.sni": { "equals": "a" } },
324					{ "tls.sni": { "equals": "b" } },
325				],
326			},
327			"terminate": { "type": "http_proxy" },
328		}));
329		let out = analyze(set(vec![rule]), &Providers, &Providers).expect("analyze");
330		assert_eq!(out.rules[0].specificity, 2);
331	}
332}