1use crate::compile::expand::RawRuleSet;
2use crate::error::Error;
3use crate::fetch::FetchKind;
4use crate::metadata::{FetchMetadataProvider, MiddlewareMetadataProvider};
5use crate::predicate::{FieldPath, Predicate};
6use crate::rule::RawRule;
7
8#[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug)]
9pub enum InspectionLevel {
10 L4Only,
11 L4Peek,
12 L7Header,
13 L7Body,
14}
15
16#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
17pub enum Posture {
18 L4,
19 L7,
20}
21
22#[derive(Debug, Clone)]
23pub struct AnalyzedRule {
24 pub raw: RawRule,
25 pub inspection_level: InspectionLevel,
26 pub specificity: usize,
27 pub posture: Posture,
28 pub needs_request_body: bool,
29 pub needs_response_body: bool,
30}
31
32#[derive(Debug, Clone)]
33pub struct AnalyzedRuleSet {
34 pub rules: Vec<AnalyzedRule>,
35 pub source_files: Vec<std::path::PathBuf>,
36}
37
38pub fn analyze(
46 set: RawRuleSet,
47 mw_meta: &dyn MiddlewareMetadataProvider,
48 fetch_meta: &dyn FetchMetadataProvider,
49) -> Result<AnalyzedRuleSet, Error> {
50 let mut analyzed = Vec::with_capacity(set.rules.len());
51 for raw in set.rules {
52 analyzed.push(analyze_rule(raw, mw_meta, fetch_meta)?);
53 }
54 Ok(AnalyzedRuleSet { rules: analyzed, source_files: set.source_files })
55}
56
57fn analyze_rule(
58 raw: RawRule,
59 mw_meta: &dyn MiddlewareMetadataProvider,
60 fetch_meta: &dyn FetchMetadataProvider,
61) -> Result<AnalyzedRule, Error> {
62 let fetch_kind = Some(raw.terminate.kind);
63 let fetch_phase = fetch_phase_of(fetch_kind);
64
65 let mut max_level = InspectionLevel::L4Only;
66 let mut specificity = 0usize;
67 let mut reads_http_body = false;
68 if let Some(pred) = &raw.match_predicate {
69 walk_predicate(pred, &mut |p| match p {
70 Predicate::Check(c) => {
71 specificity += 1;
72 let lvl = field_path_inspection_level(&c.path);
73 if lvl > max_level {
74 max_level = lvl;
75 }
76 if matches!(c.path, FieldPath::HttpBody) {
77 reads_http_body = true;
78 }
79 }
80 Predicate::AnyOf(_) | Predicate::AllOf(_) | Predicate::Not(_) => {}
81 });
82 }
83
84 let mut needs_request_body = reads_http_body;
85 let mut needs_response_body = false;
86 for mw_ref in &raw.middleware_chain {
87 let meta = mw_meta
88 .get(&mw_ref.name)
89 .ok_or_else(|| Error::compile(format!("unknown middleware: {:?}", mw_ref.name)))?;
90 if meta.needs_body {
91 match meta.kind {
92 crate::middleware::MiddlewareKind::L7Request => needs_request_body = true,
93 crate::middleware::MiddlewareKind::L7Response => needs_response_body = true,
94 crate::middleware::MiddlewareKind::L4Peek | crate::middleware::MiddlewareKind::L4Bytes => {}
95 }
96 }
97 }
98
99 let _ = fetch_meta;
103
104 let posture = match fetch_phase {
105 FetchPhase::L4 if max_level <= InspectionLevel::L4Peek => Posture::L4,
106 FetchPhase::L4 => {
107 return Err(Error::compile(format!(
108 "rule {:?}: L7-level predicate on an L4 fetch is invalid",
109 raw.name
110 )));
111 }
112 FetchPhase::L7 => Posture::L7,
113 };
114
115 Ok(AnalyzedRule {
116 raw,
117 inspection_level: max_level,
118 specificity,
119 posture,
120 needs_request_body,
121 needs_response_body,
122 })
123}
124
125#[derive(Copy, Clone, Eq, PartialEq, Debug)]
126enum FetchPhase {
127 L4,
128 L7,
129}
130
131const fn fetch_phase_of(kind: Option<FetchKind>) -> FetchPhase {
132 match kind {
133 Some(FetchKind::L4Forward) => FetchPhase::L4,
134 _ => FetchPhase::L7,
135 }
136}
137
138fn walk_predicate(p: &Predicate, f: &mut impl FnMut(&Predicate)) {
139 f(p);
140 match p {
141 Predicate::AnyOf(a) => {
142 for child in &a.any_of {
143 walk_predicate(child, f);
144 }
145 }
146 Predicate::AllOf(a) => {
147 for child in &a.all_of {
148 walk_predicate(child, f);
149 }
150 }
151 Predicate::Not(n) => walk_predicate(&n.not, f),
152 Predicate::Check(_) => {}
153 }
154}
155
156const fn field_path_inspection_level(path: &FieldPath) -> InspectionLevel {
157 match path {
158 FieldPath::Transport
159 | FieldPath::RemoteIp
160 | FieldPath::RemotePort
161 | FieldPath::LocalIp
162 | FieldPath::LocalPort => InspectionLevel::L4Only,
163 FieldPath::Peek
164 | FieldPath::TlsSni
165 | FieldPath::TlsAlpn
166 | FieldPath::TlsVersion
167 | FieldPath::TlsPeerCertPresent
168 | FieldPath::TlsPeerCertSubjectCn
169 | FieldPath::TlsPeerCertSanDns
170 | FieldPath::TlsPeerCertFingerprintSha256
171 | FieldPath::TlsPeerCertSpkiSha256
172 | FieldPath::TlsPeerCertIssuerCn
173 | FieldPath::TlsPeerCertSerial => InspectionLevel::L4Peek,
174 FieldPath::HttpMethod
175 | FieldPath::HttpUriPath
176 | FieldPath::HttpUriQuery
177 | FieldPath::HttpHeader(_) => InspectionLevel::L7Header,
178 FieldPath::HttpBody => InspectionLevel::L7Body,
179 }
180}
181
182#[cfg(test)]
183mod tests {
184 use super::*;
185 use crate::compile::expand::RawRuleSet;
186 use crate::fetch::{FetchOutputModes, FetchPhase as FetchMetaPhase};
187 use crate::metadata::{FetchMetadata, MiddlewareMetadata};
188 use crate::middleware::MiddlewareKind;
189 use serde_json::Value;
190
191 struct Providers;
192
193 #[allow(clippy::unnecessary_wraps)]
194 fn validate_ok(_: &Value) -> Result<(), Error> {
195 Ok(())
196 }
197
198 impl MiddlewareMetadataProvider for Providers {
199 fn get(&self, name: &str) -> Option<MiddlewareMetadata> {
200 match name {
201 "req_plain" => Some(MiddlewareMetadata {
202 kind: MiddlewareKind::L7Request,
203 stateless: true,
204 needs_body: false,
205 validate_args: validate_ok,
206 }),
207 "req_body" => Some(MiddlewareMetadata {
208 kind: MiddlewareKind::L7Request,
209 stateless: true,
210 needs_body: true,
211 validate_args: validate_ok,
212 }),
213 "resp_body" => Some(MiddlewareMetadata {
214 kind: MiddlewareKind::L7Response,
215 stateless: true,
216 needs_body: true,
217 validate_args: validate_ok,
218 }),
219 _ => None,
220 }
221 }
222 }
223
224 impl FetchMetadataProvider for Providers {
225 fn get(&self, kind: FetchKind) -> Option<FetchMetadata> {
226 Some(FetchMetadata {
227 kind,
228 phase: match kind {
229 FetchKind::L4Forward => FetchMetaPhase::L4,
230 _ => FetchMetaPhase::L7,
231 },
232 output_modes: match kind {
233 FetchKind::L4Forward => FetchOutputModes { response: false, tunnel: true },
234 FetchKind::WebSocketUpgrade => FetchOutputModes { response: true, tunnel: true },
235 _ => FetchOutputModes { response: true, tunnel: false },
236 },
237 validate_args: validate_ok,
238 })
239 }
240 }
241
242 fn set(rules: Vec<RawRule>) -> RawRuleSet {
243 RawRuleSet { rules, source_files: vec![] }
244 }
245
246 fn parse_rule(j: serde_json::Value) -> RawRule {
247 serde_json::from_value(j).expect("parse rule")
248 }
249
250 #[test]
251 fn http_body_predicate_sets_request_body_flag_and_l7body_level() {
252 let rule = parse_rule(serde_json::json!({
253 "name": "r",
254 "listen": [":443"],
255 "match": { "http.body": { "contains": "admin" } },
256 "terminate": { "type": "http_proxy" },
257 }));
258 let out = analyze(set(vec![rule]), &Providers, &Providers).expect("analyze");
259 let a = &out.rules[0];
260 assert!(a.needs_request_body);
261 assert!(!a.needs_response_body);
262 assert_eq!(a.inspection_level, InspectionLevel::L7Body);
263 assert_eq!(a.posture, Posture::L7);
264 }
265
266 #[test]
267 fn l7_request_needs_body_middleware_flags_request_side() {
268 let rule = parse_rule(serde_json::json!({
269 "name": "r",
270 "listen": [":443"],
271 "middleware_chain": [{ "use": "req_body" }],
272 "terminate": { "type": "http_proxy" },
273 }));
274 let out = analyze(set(vec![rule]), &Providers, &Providers).expect("analyze");
275 assert!(out.rules[0].needs_request_body);
276 assert!(!out.rules[0].needs_response_body);
277 }
278
279 #[test]
280 fn l7_response_needs_body_middleware_flags_response_side() {
281 let rule = parse_rule(serde_json::json!({
282 "name": "r",
283 "listen": [":443"],
284 "middleware_chain": [{ "use": "resp_body" }],
285 "terminate": { "type": "http_proxy" },
286 }));
287 let out = analyze(set(vec![rule]), &Providers, &Providers).expect("analyze");
288 assert!(!out.rules[0].needs_request_body);
289 assert!(out.rules[0].needs_response_body);
290 }
291
292 #[test]
293 fn l4_fetch_with_l7_predicate_errors() {
294 let rule = parse_rule(serde_json::json!({
295 "name": "r",
296 "listen": [":22"],
297 "match": { "http.method": { "equals": "GET" } },
298 "terminate": { "type": "tcp_forward", "upstream": "10.0.0.1:22" },
299 }));
300 let err = analyze(set(vec![rule]), &Providers, &Providers).expect_err("must error");
301 assert!(err.to_string().contains("L7-level predicate"));
302 }
303
304 #[test]
305 fn unknown_middleware_name_errors() {
306 let rule = parse_rule(serde_json::json!({
307 "name": "r",
308 "listen": [":443"],
309 "middleware_chain": [{ "use": "does_not_exist" }],
310 "terminate": { "type": "http_proxy" },
311 }));
312 let err = analyze(set(vec![rule]), &Providers, &Providers).expect_err("must error");
313 assert!(err.to_string().contains("does_not_exist"));
314 }
315
316 #[test]
317 fn specificity_counts_check_predicates() {
318 let rule = parse_rule(serde_json::json!({
319 "name": "r",
320 "listen": [":443"],
321 "match": {
322 "any_of": [
323 { "tls.sni": { "equals": "a" } },
324 { "tls.sni": { "equals": "b" } },
325 ],
326 },
327 "terminate": { "type": "http_proxy" },
328 }));
329 let out = analyze(set(vec![rule]), &Providers, &Providers).expect("analyze");
330 assert_eq!(out.rules[0].specificity, 2);
331 }
332}