1use crate::compile::expand::RawRuleSet;
2use crate::error::Error;
3use crate::fetch::FetchKind;
4use crate::metadata::{FetchMetadataProvider, MiddlewareMetadataProvider};
5use crate::predicate::{FieldPath, Predicate};
6use crate::rule::RawRule;
7
8#[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug)]
9pub enum InspectionLevel {
10 L4Only,
11 L4Peek,
12 L7Header,
13 L7Body,
14}
15
16#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
17pub enum Posture {
18 L4,
19 L7,
20}
21
22#[derive(Debug, Clone)]
23pub struct AnalyzedRule {
24 pub raw: RawRule,
25 pub inspection_level: InspectionLevel,
26 pub specificity: usize,
27 pub posture: Posture,
28 pub needs_request_body: bool,
29 pub needs_response_body: bool,
30}
31
32#[derive(Debug, Clone)]
33pub struct AnalyzedRuleSet {
34 pub rules: Vec<AnalyzedRule>,
35 pub source_files: Vec<std::path::PathBuf>,
36}
37
38pub fn analyze(
46 set: RawRuleSet,
47 mw_meta: &dyn MiddlewareMetadataProvider,
48 fetch_meta: &dyn FetchMetadataProvider,
49) -> Result<AnalyzedRuleSet, Error> {
50 let mut analyzed = Vec::with_capacity(set.rules.len());
51 for raw in set.rules {
52 analyzed.push(analyze_rule(raw, mw_meta, fetch_meta)?);
53 }
54 Ok(AnalyzedRuleSet { rules: analyzed, source_files: set.source_files })
55}
56
57fn analyze_rule(
58 raw: RawRule,
59 mw_meta: &dyn MiddlewareMetadataProvider,
60 fetch_meta: &dyn FetchMetadataProvider,
61) -> Result<AnalyzedRule, Error> {
62 if let Some(tls) = raw.tls.as_ref() {
68 tls.validate().map_err(|e| Error::compile(format!("rule {:?}: {}", raw.name, e)))?;
69 }
70
71 let fetch_kind = Some(raw.terminate.kind);
72 let fetch_phase = fetch_phase_of(fetch_kind);
73
74 let mut max_level = InspectionLevel::L4Only;
75 let mut specificity = 0usize;
76 let mut reads_http_body = false;
77 if let Some(pred) = &raw.match_predicate {
78 walk_predicate(pred, &mut |p| match p {
79 Predicate::Check(c) => {
80 specificity += 1;
81 let lvl = field_path_inspection_level(&c.path);
82 if lvl > max_level {
83 max_level = lvl;
84 }
85 if matches!(c.path, FieldPath::HttpBody) {
86 reads_http_body = true;
87 }
88 }
89 Predicate::AnyOf(_) | Predicate::AllOf(_) | Predicate::Not(_) => {}
90 });
91 }
92
93 let mut needs_request_body = reads_http_body;
94 let mut needs_response_body = false;
95 for mw_ref in &raw.middleware_chain {
96 let meta = mw_meta
97 .get(&mw_ref.name)
98 .ok_or_else(|| Error::compile(format!("unknown middleware: {:?}", mw_ref.name)))?;
99 if meta.needs_body {
100 match meta.kind {
101 crate::middleware::MiddlewareKind::L7Request => needs_request_body = true,
102 crate::middleware::MiddlewareKind::L7Response => needs_response_body = true,
103 crate::middleware::MiddlewareKind::L4Peek | crate::middleware::MiddlewareKind::L4Bytes => {}
104 }
105 }
106 }
107
108 let _ = fetch_meta;
112
113 let posture = match fetch_phase {
114 FetchPhase::L4 if max_level <= InspectionLevel::L4Peek => Posture::L4,
115 FetchPhase::L4 => {
116 return Err(Error::compile(format!(
117 "rule {:?}: L7-level predicate on an L4 fetch is invalid",
118 raw.name
119 )));
120 }
121 FetchPhase::L7 => Posture::L7,
122 };
123
124 Ok(AnalyzedRule {
125 raw,
126 inspection_level: max_level,
127 specificity,
128 posture,
129 needs_request_body,
130 needs_response_body,
131 })
132}
133
134#[derive(Copy, Clone, Eq, PartialEq, Debug)]
135enum FetchPhase {
136 L4,
137 L7,
138}
139
140const fn fetch_phase_of(kind: Option<FetchKind>) -> FetchPhase {
141 match kind {
142 Some(FetchKind::L4Forward) => FetchPhase::L4,
143 _ => FetchPhase::L7,
144 }
145}
146
147fn walk_predicate(p: &Predicate, f: &mut impl FnMut(&Predicate)) {
148 f(p);
149 match p {
150 Predicate::AnyOf(a) => {
151 for child in &a.any_of {
152 walk_predicate(child, f);
153 }
154 }
155 Predicate::AllOf(a) => {
156 for child in &a.all_of {
157 walk_predicate(child, f);
158 }
159 }
160 Predicate::Not(n) => walk_predicate(&n.not, f),
161 Predicate::Check(_) => {}
162 }
163}
164
165const fn field_path_inspection_level(path: &FieldPath) -> InspectionLevel {
166 match path {
167 FieldPath::Transport
168 | FieldPath::RemoteIp
169 | FieldPath::RemotePort
170 | FieldPath::LocalIp
171 | FieldPath::LocalPort => InspectionLevel::L4Only,
172 FieldPath::Peek
173 | FieldPath::TlsSni
174 | FieldPath::TlsAlpn
175 | FieldPath::TlsVersion
176 | FieldPath::TlsPeerCertPresent
177 | FieldPath::TlsPeerCertSubjectCn
178 | FieldPath::TlsPeerCertSanDns
179 | FieldPath::TlsPeerCertFingerprintSha256
180 | FieldPath::TlsPeerCertSpkiSha256
181 | FieldPath::TlsPeerCertIssuerCn
182 | FieldPath::TlsPeerCertSerial => InspectionLevel::L4Peek,
183 FieldPath::HttpMethod
184 | FieldPath::HttpUriPath
185 | FieldPath::HttpUriQuery
186 | FieldPath::HttpHeader(_) => InspectionLevel::L7Header,
187 FieldPath::HttpBody => InspectionLevel::L7Body,
188 }
189}
190
191#[cfg(test)]
192mod tests {
193 use super::*;
194 use crate::compile::expand::RawRuleSet;
195 use crate::fetch::{FetchOutputModes, FetchPhase as FetchMetaPhase};
196 use crate::metadata::{FetchMetadata, MiddlewareMetadata};
197 use crate::middleware::MiddlewareKind;
198 use serde_json::Value;
199
200 struct Providers;
201
202 #[allow(clippy::unnecessary_wraps)]
203 fn validate_ok(_: &Value) -> Result<(), Error> {
204 Ok(())
205 }
206
207 impl MiddlewareMetadataProvider for Providers {
208 fn get(&self, name: &str) -> Option<MiddlewareMetadata> {
209 match name {
210 "req_plain" => Some(MiddlewareMetadata {
211 kind: MiddlewareKind::L7Request,
212 stateless: true,
213 needs_body: false,
214 validate_args: validate_ok,
215 }),
216 "req_body" => Some(MiddlewareMetadata {
217 kind: MiddlewareKind::L7Request,
218 stateless: true,
219 needs_body: true,
220 validate_args: validate_ok,
221 }),
222 "resp_body" => Some(MiddlewareMetadata {
223 kind: MiddlewareKind::L7Response,
224 stateless: true,
225 needs_body: true,
226 validate_args: validate_ok,
227 }),
228 _ => None,
229 }
230 }
231 }
232
233 impl FetchMetadataProvider for Providers {
234 fn get(&self, kind: FetchKind) -> Option<FetchMetadata> {
235 Some(FetchMetadata {
236 kind,
237 phase: match kind {
238 FetchKind::L4Forward => FetchMetaPhase::L4,
239 _ => FetchMetaPhase::L7,
240 },
241 output_modes: match kind {
242 FetchKind::L4Forward => FetchOutputModes { response: false, tunnel: true },
243 FetchKind::WebSocketUpgrade => FetchOutputModes { response: true, tunnel: true },
244 _ => FetchOutputModes { response: true, tunnel: false },
245 },
246 validate_args: validate_ok,
247 })
248 }
249 }
250
251 fn set(rules: Vec<RawRule>) -> RawRuleSet {
252 RawRuleSet { rules, source_files: vec![] }
253 }
254
255 fn parse_rule(j: serde_json::Value) -> RawRule {
256 serde_json::from_value(j).expect("parse rule")
257 }
258
259 #[test]
260 fn http_body_predicate_sets_request_body_flag_and_l7body_level() {
261 let rule = parse_rule(serde_json::json!({
262 "name": "r",
263 "listen": [":443"],
264 "match": { "http.body": { "contains": "admin" } },
265 "terminate": { "type": "http_proxy" },
266 }));
267 let out = analyze(set(vec![rule]), &Providers, &Providers).expect("analyze");
268 let a = &out.rules[0];
269 assert!(a.needs_request_body);
270 assert!(!a.needs_response_body);
271 assert_eq!(a.inspection_level, InspectionLevel::L7Body);
272 assert_eq!(a.posture, Posture::L7);
273 }
274
275 #[test]
276 fn l7_request_needs_body_middleware_flags_request_side() {
277 let rule = parse_rule(serde_json::json!({
278 "name": "r",
279 "listen": [":443"],
280 "middleware_chain": [{ "use": "req_body" }],
281 "terminate": { "type": "http_proxy" },
282 }));
283 let out = analyze(set(vec![rule]), &Providers, &Providers).expect("analyze");
284 assert!(out.rules[0].needs_request_body);
285 assert!(!out.rules[0].needs_response_body);
286 }
287
288 #[test]
289 fn l7_response_needs_body_middleware_flags_response_side() {
290 let rule = parse_rule(serde_json::json!({
291 "name": "r",
292 "listen": [":443"],
293 "middleware_chain": [{ "use": "resp_body" }],
294 "terminate": { "type": "http_proxy" },
295 }));
296 let out = analyze(set(vec![rule]), &Providers, &Providers).expect("analyze");
297 assert!(!out.rules[0].needs_request_body);
298 assert!(out.rules[0].needs_response_body);
299 }
300
301 #[test]
302 fn l4_fetch_with_l7_predicate_errors() {
303 let rule = parse_rule(serde_json::json!({
304 "name": "r",
305 "listen": [":22"],
306 "match": { "http.method": { "equals": "GET" } },
307 "terminate": { "type": "tcp_forward", "upstream": "10.0.0.1:22" },
308 }));
309 let err = analyze(set(vec![rule]), &Providers, &Providers).expect_err("must error");
310 assert!(err.to_string().contains("L7-level predicate"));
311 }
312
313 #[test]
314 fn unknown_middleware_name_errors() {
315 let rule = parse_rule(serde_json::json!({
316 "name": "r",
317 "listen": [":443"],
318 "middleware_chain": [{ "use": "does_not_exist" }],
319 "terminate": { "type": "http_proxy" },
320 }));
321 let err = analyze(set(vec![rule]), &Providers, &Providers).expect_err("must error");
322 assert!(err.to_string().contains("does_not_exist"));
323 }
324
325 #[test]
326 fn specificity_counts_check_predicates() {
327 let rule = parse_rule(serde_json::json!({
328 "name": "r",
329 "listen": [":443"],
330 "match": {
331 "any_of": [
332 { "tls.sni": { "equals": "a" } },
333 { "tls.sni": { "equals": "b" } },
334 ],
335 },
336 "terminate": { "type": "http_proxy" },
337 }));
338 let out = analyze(set(vec![rule]), &Providers, &Providers).expect("analyze");
339 assert_eq!(out.rules[0].specificity, 2);
340 }
341}