1use crate::compile::expand::RawRuleSet;
2use crate::error::{Diagnostics, Error};
3use crate::fetch::FetchKind;
4use crate::metadata::{FetchMetadataProvider, MiddlewareMetadataProvider};
5use crate::predicate::{FieldPath, Predicate};
6use crate::rule::RawRule;
7
8#[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug)]
9pub enum InspectionLevel {
10 L4Only,
11 L4Peek,
12 L7Header,
13 L7Body,
14}
15
16#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
17pub enum Posture {
18 L4,
19 L7,
20}
21
22#[derive(Debug, Clone)]
23pub struct AnalyzedRule {
24 pub raw: RawRule,
25 pub inspection_level: InspectionLevel,
26 pub specificity: usize,
27 pub posture: Posture,
28 pub needs_request_body: bool,
29 pub needs_response_body: bool,
30}
31
32#[derive(Debug, Clone)]
33pub struct AnalyzedRuleSet {
34 pub rules: Vec<AnalyzedRule>,
35 pub source_files: Vec<std::path::PathBuf>,
36}
37
38pub fn analyze(
46 set: RawRuleSet,
47 mw_meta: &dyn MiddlewareMetadataProvider,
48 fetch_meta: &dyn FetchMetadataProvider,
49) -> Result<AnalyzedRuleSet, Error> {
50 let (rules, d) = analyze_collecting(set, mw_meta, fetch_meta);
51 d.into_result(rules).map_err(Error::from)
52}
53
54pub fn analyze_collecting(
63 set: RawRuleSet,
64 mw_meta: &dyn MiddlewareMetadataProvider,
65 fetch_meta: &dyn FetchMetadataProvider,
66) -> (AnalyzedRuleSet, Diagnostics) {
67 let mut analyzed = Vec::with_capacity(set.rules.len());
68 let mut d = Diagnostics::new();
69 for raw in set.rules {
70 match analyze_rule(raw, mw_meta, fetch_meta) {
71 Ok(rule) => analyzed.push(rule),
72 Err(e) => d.push(e),
73 }
74 }
75 (AnalyzedRuleSet { rules: analyzed, source_files: set.source_files }, d)
76}
77
78fn analyze_rule(
79 raw: RawRule,
80 mw_meta: &dyn MiddlewareMetadataProvider,
81 fetch_meta: &dyn FetchMetadataProvider,
82) -> Result<AnalyzedRule, Error> {
83 if let Some(tls) = raw.tls.as_ref() {
89 tls.validate().map_err(|e| Error::compile(format!("rule {:?}: {}", raw.name, e)))?;
90 }
91
92 let fetch_kind = Some(raw.terminate.kind);
93 let fetch_phase = fetch_phase_of(fetch_kind);
94
95 let mut max_level = InspectionLevel::L4Only;
96 let mut specificity = 0usize;
97 let mut reads_http_body = false;
98 if let Some(pred) = &raw.match_predicate {
99 crate::predicate::check_max_depth(pred)
104 .map_err(|e| Error::compile(format!("rule {:?}: {}", raw.name, e)))?;
105 walk_predicate(pred, &mut |p| match p {
106 Predicate::Check(c) => {
107 specificity += 1;
108 let lvl = field_path_inspection_level(&c.path);
109 if lvl > max_level {
110 max_level = lvl;
111 }
112 if matches!(c.path, FieldPath::HttpBody) {
113 reads_http_body = true;
114 }
115 }
116 Predicate::AnyOf(_) | Predicate::AllOf(_) | Predicate::Not(_) => {}
117 });
118 }
119
120 let mut needs_request_body = reads_http_body;
121 let mut needs_response_body = false;
122 for mw_ref in &raw.middleware_chain {
123 let meta = mw_meta
124 .get(&mw_ref.name)
125 .ok_or_else(|| Error::compile(format!("unknown middleware: {:?}", mw_ref.name)))?;
126 (meta.validate_args)(&mw_ref.args).map_err(|e| {
130 Error::compile(format!("rule {:?}: middleware {:?} args invalid: {e}", raw.name, mw_ref.name))
131 })?;
132 if meta.needs_body {
133 match meta.kind {
134 crate::middleware::MiddlewareKind::L7Request => needs_request_body = true,
135 crate::middleware::MiddlewareKind::L7Response => needs_response_body = true,
136 crate::middleware::MiddlewareKind::L4Peek | crate::middleware::MiddlewareKind::L4Bytes => {}
137 }
138 }
139 }
140
141 if let Some(kind) = fetch_kind {
145 let meta = fetch_meta.get(kind).ok_or_else(|| {
146 Error::compile(format!("rule {:?}: unknown fetch kind {:?}", raw.name, kind))
147 })?;
148 (meta.validate_args)(&raw.terminate.args).map_err(|e| {
149 Error::compile(format!("rule {:?}: terminate.args for {:?} invalid: {e}", raw.name, kind))
150 })?;
151 }
152
153 let posture = match fetch_phase {
154 FetchPhase::L4 if max_level <= InspectionLevel::L4Peek => Posture::L4,
155 FetchPhase::L4 => {
156 return Err(Error::compile(format!(
157 "rule {:?}: L7-level predicate on an L4 fetch is invalid",
158 raw.name
159 )));
160 }
161 FetchPhase::L7 => Posture::L7,
162 };
163
164 Ok(AnalyzedRule {
165 raw,
166 inspection_level: max_level,
167 specificity,
168 posture,
169 needs_request_body,
170 needs_response_body,
171 })
172}
173
174#[derive(Copy, Clone, Eq, PartialEq, Debug)]
175enum FetchPhase {
176 L4,
177 L7,
178}
179
180const fn fetch_phase_of(kind: Option<FetchKind>) -> FetchPhase {
181 match kind {
182 Some(FetchKind::L4Forward) => FetchPhase::L4,
183 _ => FetchPhase::L7,
184 }
185}
186
187fn walk_predicate(root: &Predicate, f: &mut impl FnMut(&Predicate)) {
195 let mut stack: Vec<&Predicate> = vec![root];
196 while let Some(p) = stack.pop() {
197 f(p);
198 match p {
199 Predicate::AnyOf(a) => {
200 for child in a.any_of.iter().rev() {
201 stack.push(child);
202 }
203 }
204 Predicate::AllOf(a) => {
205 for child in a.all_of.iter().rev() {
206 stack.push(child);
207 }
208 }
209 Predicate::Not(n) => stack.push(n.not.as_ref()),
210 Predicate::Check(_) => {}
211 }
212 }
213}
214
215const fn field_path_inspection_level(path: &FieldPath) -> InspectionLevel {
216 match path {
217 FieldPath::Transport
218 | FieldPath::RemoteIp
219 | FieldPath::RemotePort
220 | FieldPath::LocalIp
221 | FieldPath::LocalPort => InspectionLevel::L4Only,
222 FieldPath::Peek
223 | FieldPath::TlsSni
224 | FieldPath::TlsAlpn
225 | FieldPath::TlsVersion
226 | FieldPath::TlsPeerCertPresent
227 | FieldPath::TlsPeerCertSubjectCn
228 | FieldPath::TlsPeerCertSanDns
229 | FieldPath::TlsPeerCertFingerprintSha256
230 | FieldPath::TlsPeerCertSpkiSha256
231 | FieldPath::TlsPeerCertIssuerCn
232 | FieldPath::TlsPeerCertSerial => InspectionLevel::L4Peek,
233 FieldPath::HttpMethod
234 | FieldPath::HttpUriPath
235 | FieldPath::HttpUriQuery
236 | FieldPath::HttpHeader(_) => InspectionLevel::L7Header,
237 FieldPath::HttpBody => InspectionLevel::L7Body,
238 }
239}
240
241#[cfg(test)]
242mod tests {
243 use super::*;
244 use crate::compile::expand::RawRuleSet;
245 use crate::fetch::{FetchOutputModes, FetchPhase as FetchMetaPhase};
246 use crate::metadata::{FetchMetadata, MiddlewareMetadata};
247 use crate::middleware::MiddlewareKind;
248 use serde_json::Value;
249
250 struct Providers;
251
252 fn validate_ok(_: &Value) -> Result<(), Error> {
253 Ok(())
254 }
255
256 impl MiddlewareMetadataProvider for Providers {
257 fn get(&self, name: &str) -> Option<MiddlewareMetadata> {
258 match name {
259 "req_plain" => Some(MiddlewareMetadata {
260 kind: MiddlewareKind::L7Request,
261 stateless: true,
262 needs_body: false,
263 validate_args: validate_ok,
264 }),
265 "req_body" => Some(MiddlewareMetadata {
266 kind: MiddlewareKind::L7Request,
267 stateless: true,
268 needs_body: true,
269 validate_args: validate_ok,
270 }),
271 "resp_body" => Some(MiddlewareMetadata {
272 kind: MiddlewareKind::L7Response,
273 stateless: true,
274 needs_body: true,
275 validate_args: validate_ok,
276 }),
277 _ => None,
278 }
279 }
280 }
281
282 impl FetchMetadataProvider for Providers {
283 fn get(&self, kind: FetchKind) -> Option<FetchMetadata> {
284 Some(FetchMetadata {
285 kind,
286 phase: match kind {
287 FetchKind::L4Forward => FetchMetaPhase::L4,
288 _ => FetchMetaPhase::L7,
289 },
290 output_modes: match kind {
291 FetchKind::L4Forward => FetchOutputModes { response: false, tunnel: true },
292 FetchKind::WebSocketUpgrade => FetchOutputModes { response: true, tunnel: true },
293 _ => FetchOutputModes { response: true, tunnel: false },
294 },
295 validate_args: validate_ok,
296 })
297 }
298 }
299
300 fn set(rules: Vec<RawRule>) -> RawRuleSet {
301 RawRuleSet { rules, source_files: vec![] }
302 }
303
304 fn parse_rule(j: serde_json::Value) -> RawRule {
305 serde_json::from_value(j).expect("parse rule")
306 }
307
308 #[test]
309 fn http_body_predicate_sets_request_body_flag_and_l7body_level() {
310 let rule = parse_rule(serde_json::json!({
311 "name": "r",
312 "listen": [":443"],
313 "match": { "http.body": { "contains": "admin" } },
314 "terminate": { "type": "http_proxy" },
315 }));
316 let out = analyze(set(vec![rule]), &Providers, &Providers).expect("analyze");
317 let a = &out.rules[0];
318 assert!(a.needs_request_body);
319 assert!(!a.needs_response_body);
320 assert_eq!(a.inspection_level, InspectionLevel::L7Body);
321 assert_eq!(a.posture, Posture::L7);
322 }
323
324 #[test]
325 fn l7_request_needs_body_middleware_flags_request_side() {
326 let rule = parse_rule(serde_json::json!({
327 "name": "r",
328 "listen": [":443"],
329 "middleware_chain": [{ "use": "req_body" }],
330 "terminate": { "type": "http_proxy" },
331 }));
332 let out = analyze(set(vec![rule]), &Providers, &Providers).expect("analyze");
333 assert!(out.rules[0].needs_request_body);
334 assert!(!out.rules[0].needs_response_body);
335 }
336
337 #[test]
338 fn l7_response_needs_body_middleware_flags_response_side() {
339 let rule = parse_rule(serde_json::json!({
340 "name": "r",
341 "listen": [":443"],
342 "middleware_chain": [{ "use": "resp_body" }],
343 "terminate": { "type": "http_proxy" },
344 }));
345 let out = analyze(set(vec![rule]), &Providers, &Providers).expect("analyze");
346 assert!(!out.rules[0].needs_request_body);
347 assert!(out.rules[0].needs_response_body);
348 }
349
350 #[test]
351 fn l4_fetch_with_l7_predicate_errors() {
352 let rule = parse_rule(serde_json::json!({
353 "name": "r",
354 "listen": [":22"],
355 "match": { "http.method": { "equals": "GET" } },
356 "terminate": { "type": "tcp_forward", "upstream": "10.0.0.1:22" },
357 }));
358 let err = analyze(set(vec![rule]), &Providers, &Providers).expect_err("must error");
359 assert!(err.to_string().contains("L7-level predicate"));
360 }
361
362 #[test]
363 fn unknown_middleware_name_errors() {
364 let rule = parse_rule(serde_json::json!({
365 "name": "r",
366 "listen": [":443"],
367 "middleware_chain": [{ "use": "does_not_exist" }],
368 "terminate": { "type": "http_proxy" },
369 }));
370 let err = analyze(set(vec![rule]), &Providers, &Providers).expect_err("must error");
371 assert!(err.to_string().contains("does_not_exist"));
372 }
373
374 #[test]
375 fn rejects_middleware_args_failing_validate() {
376 struct StrictProviders;
380 fn reject_null(v: &Value) -> Result<(), Error> {
381 if matches!(v, Value::Null) { Err(Error::compile("args must not be null")) } else { Ok(()) }
382 }
383 impl MiddlewareMetadataProvider for StrictProviders {
384 fn get(&self, name: &str) -> Option<MiddlewareMetadata> {
385 if name == "strict_args" {
386 Some(MiddlewareMetadata {
387 kind: MiddlewareKind::L7Request,
388 stateless: true,
389 needs_body: false,
390 validate_args: reject_null,
391 })
392 } else {
393 None
394 }
395 }
396 }
397 impl FetchMetadataProvider for StrictProviders {
398 fn get(&self, kind: FetchKind) -> Option<FetchMetadata> {
399 Some(FetchMetadata {
400 kind,
401 phase: FetchMetaPhase::L7,
402 output_modes: FetchOutputModes { response: true, tunnel: false },
403 validate_args: |_| Ok(()),
404 })
405 }
406 }
407 let rule = parse_rule(serde_json::json!({
408 "name": "r",
409 "listen": [":443"],
410 "middleware_chain": [{ "use": "strict_args" }],
411 "terminate": { "type": "http_proxy" },
412 }));
413 let err = analyze(set(vec![rule]), &StrictProviders, &StrictProviders)
414 .expect_err("must reject bad middleware args");
415 let msg = err.to_string();
416 assert!(msg.contains("strict_args"), "{msg}");
417 assert!(msg.contains("args invalid") || msg.contains("must not be null"), "{msg}");
418 }
419
420 #[test]
421 fn rejects_terminate_args_failing_validate() {
422 struct StrictProviders;
423 fn require_port(v: &Value) -> Result<(), Error> {
424 let ok = matches!(v, Value::Object(m) if m.get("port").is_some());
425 if ok { Ok(()) } else { Err(Error::compile("missing required `port` arg")) }
426 }
427 impl MiddlewareMetadataProvider for StrictProviders {
428 fn get(&self, _: &str) -> Option<MiddlewareMetadata> {
429 None
430 }
431 }
432 impl FetchMetadataProvider for StrictProviders {
433 fn get(&self, kind: FetchKind) -> Option<FetchMetadata> {
434 Some(FetchMetadata {
435 kind,
436 phase: FetchMetaPhase::L7,
437 output_modes: FetchOutputModes { response: true, tunnel: false },
438 validate_args: require_port,
439 })
440 }
441 }
442 let rule = parse_rule(serde_json::json!({
443 "name": "r",
444 "listen": [":443"],
445 "terminate": { "type": "http_proxy" },
446 }));
447 let err = analyze(set(vec![rule]), &StrictProviders, &StrictProviders)
448 .expect_err("must reject missing terminate args");
449 let msg = err.to_string();
450 assert!(msg.contains("terminate.args"), "{msg}");
451 assert!(msg.contains("missing required `port` arg"), "{msg}");
452 }
453
454 #[test]
455 fn rejects_predicate_nested_deeper_than_max_predicate_depth() {
456 let depth = crate::predicate::MAX_PREDICATE_DEPTH + 1;
459 let mut inner = serde_json::json!({ "tls.sni": { "equals": "a" } });
460 for _ in 0..depth {
461 inner = serde_json::json!({ "not": inner });
462 }
463 let raw = serde_json::json!({
464 "name": "r",
465 "listen": [":443"],
466 "match": inner,
467 "terminate": { "type": "http_proxy" },
468 });
469 let rule: crate::rule::RawRule = serde_json::from_value(raw).expect("parse");
470 let err =
471 analyze(set(vec![rule]), &Providers, &Providers).expect_err("deep predicate must reject");
472 assert!(err.to_string().contains("MAX_PREDICATE_DEPTH"), "{err}");
473 }
474
475 #[test]
476 fn accepts_predicate_at_max_predicate_depth() {
477 let depth = crate::predicate::MAX_PREDICATE_DEPTH - 1;
480 let mut inner = serde_json::json!({ "tls.sni": { "equals": "a" } });
481 for _ in 0..depth {
482 inner = serde_json::json!({ "not": inner });
483 }
484 let raw = serde_json::json!({
485 "name": "r",
486 "listen": [":443"],
487 "match": inner,
488 "terminate": { "type": "http_proxy" },
489 });
490 let rule: crate::rule::RawRule = serde_json::from_value(raw).expect("parse");
491 analyze(set(vec![rule]), &Providers, &Providers).expect("at-limit predicate compiles");
492 }
493
494 #[test]
495 fn specificity_counts_check_predicates() {
496 let rule = parse_rule(serde_json::json!({
497 "name": "r",
498 "listen": [":443"],
499 "match": {
500 "any_of": [
501 { "tls.sni": { "equals": "a" } },
502 { "tls.sni": { "equals": "b" } },
503 ],
504 },
505 "terminate": { "type": "http_proxy" },
506 }));
507 let out = analyze(set(vec![rule]), &Providers, &Providers).expect("analyze");
508 assert_eq!(out.rules[0].specificity, 2);
509 }
510}