1use std::hash::Hash;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5
6use crate::body::{Request, Response};
7use crate::conn_context::ConnContext;
8use crate::error::Error;
9use crate::flow_ctx::FlowCtx;
10use crate::ir::NodeId;
11use crate::l4::L4Conn;
12
13#[async_trait]
14pub trait L4PeekMiddleware: Send + Sync {
15 async fn run(
16 &self,
17 peek: &[u8],
18 conn: &Arc<ConnContext>,
19 ctx: &mut FlowCtx,
20 ) -> Result<Decision, Error>;
21}
22
23#[async_trait]
24pub trait L4BytesMiddleware: Send + Sync {
25 async fn run(
26 &self,
27 l4: &mut L4Conn,
28 conn: &Arc<ConnContext>,
29 ctx: &mut FlowCtx,
30 ) -> Result<Decision, Error>;
31}
32
33#[async_trait]
34pub trait L7RequestMiddleware: Send + Sync {
35 async fn run(
36 &self,
37 req: &mut Request,
38 conn: &Arc<ConnContext>,
39 ctx: &mut FlowCtx,
40 ) -> Result<Decision, Error>;
41
42 fn needs_body(&self) -> bool {
43 false
44 }
45}
46
47#[async_trait]
48pub trait L7ResponseMiddleware: Send + Sync {
49 async fn run(
50 &self,
51 resp: &mut Response,
52 conn: &Arc<ConnContext>,
53 ctx: &mut FlowCtx,
54 ) -> Result<Decision, Error>;
55
56 fn needs_body(&self) -> bool {
57 false
58 }
59}
60
61pub enum Decision {
62 Continue,
63 Short(ShortCircuit),
64}
65
66pub enum ShortCircuit {
67 Response(Response),
68 Close(CloseReason),
69}
70
71#[derive(Clone, Debug)]
72pub enum CloseReason {
73 Graceful,
74 PolicyDenied(std::borrow::Cow<'static, str>),
75 ProtocolError(std::borrow::Cow<'static, str>),
76 Cancelled,
82}
83
84#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug, serde::Serialize, serde::Deserialize)]
85pub enum MiddlewareKind {
86 L4Peek,
87 L4Bytes,
88 L7Request,
89 L7Response,
90}
91
92#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
93pub struct SymbolicMiddlewareRef {
94 pub name: Arc<str>,
95 pub args: serde_json::Value,
96 pub kind: MiddlewareKind,
97 pub stateless: bool,
98 pub needs_body: bool,
99 pub on_error: Option<NodeId>,
100}
101
102impl PartialEq for SymbolicMiddlewareRef {
103 fn eq(&self, other: &Self) -> bool {
104 self.name == other.name
105 && self.kind == other.kind
106 && self.stateless == other.stateless
107 && self.needs_body == other.needs_body
108 && self.on_error == other.on_error
109 && canonical_json_eq(&self.args, &other.args)
110 }
111}
112
113impl Eq for SymbolicMiddlewareRef {}
114
115impl std::hash::Hash for SymbolicMiddlewareRef {
116 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
117 self.name.hash(state);
118 self.kind.hash(state);
119 self.stateless.hash(state);
120 self.needs_body.hash(state);
121 self.on_error.hash(state);
122 hash_canonical_json(&self.args, state);
123 }
124}
125
126fn canonical_json_eq(a: &serde_json::Value, b: &serde_json::Value) -> bool {
127 use serde_json::Value;
128 match (a, b) {
129 (Value::Null, Value::Null) => true,
130 (Value::Bool(x), Value::Bool(y)) => x == y,
131 (Value::Number(x), Value::Number(y)) => x == y,
132 (Value::String(x), Value::String(y)) => x == y,
133 (Value::Array(xs), Value::Array(ys)) => {
134 xs.len() == ys.len() && xs.iter().zip(ys).all(|(x, y)| canonical_json_eq(x, y))
135 }
136 (Value::Object(xs), Value::Object(ys)) if xs.len() == ys.len() => {
137 xs.iter().all(|(k, v)| ys.get(k).is_some_and(|w| canonical_json_eq(v, w)))
138 }
139 _ => false,
140 }
141}
142
143fn hash_canonical_json<H: std::hash::Hasher>(v: &serde_json::Value, state: &mut H) {
144 use serde_json::Value;
145 match v {
146 Value::Null => 0u8.hash(state),
147 Value::Bool(b) => {
148 1u8.hash(state);
149 b.hash(state);
150 }
151 Value::Number(n) => {
152 2u8.hash(state);
153 n.to_string().hash(state);
154 }
155 Value::String(s) => {
156 3u8.hash(state);
157 s.hash(state);
158 }
159 Value::Array(xs) => {
160 4u8.hash(state);
161 xs.len().hash(state);
162 for x in xs {
163 hash_canonical_json(x, state);
164 }
165 }
166 Value::Object(xs) => {
167 5u8.hash(state);
168 let mut keys: Vec<&String> = xs.keys().collect();
169 keys.sort();
170 keys.len().hash(state);
171 for k in keys {
172 k.hash(state);
173 hash_canonical_json(&xs[k], state);
174 }
175 }
176 }
177}
178
179#[cfg(test)]
180mod tests {
181 use std::collections::hash_map::DefaultHasher;
182 use std::future::Future;
183 use std::hash::{Hash, Hasher};
184 use std::net::SocketAddr;
185 use std::pin::Pin;
186 use std::time::Instant;
187
188 use parking_lot::Mutex;
189 use serde_json::json;
190 use tokio_util::sync::CancellationToken;
191
192 use super::*;
193 use crate::conn_context::{ConnId, Transport};
194 use crate::flow_log::{FlowLogEvent, FlowLogSink};
195
196 struct PassPeek;
197 #[async_trait]
198 impl L4PeekMiddleware for PassPeek {
199 async fn run(
200 &self,
201 _peek: &[u8],
202 _conn: &Arc<ConnContext>,
203 _ctx: &mut FlowCtx,
204 ) -> Result<Decision, Error> {
205 Ok(Decision::Continue)
206 }
207 }
208
209 struct PassBytes;
210 #[async_trait]
211 impl L4BytesMiddleware for PassBytes {
212 async fn run(
213 &self,
214 _l4: &mut L4Conn,
215 _conn: &Arc<ConnContext>,
216 _ctx: &mut FlowCtx,
217 ) -> Result<Decision, Error> {
218 Ok(Decision::Continue)
219 }
220 }
221
222 struct PassReq;
223 #[async_trait]
224 impl L7RequestMiddleware for PassReq {
225 async fn run(
226 &self,
227 _req: &mut Request,
228 _conn: &Arc<ConnContext>,
229 _ctx: &mut FlowCtx,
230 ) -> Result<Decision, Error> {
231 Ok(Decision::Continue)
232 }
233 }
234
235 struct PassResp;
236 #[async_trait]
237 impl L7ResponseMiddleware for PassResp {
238 async fn run(
239 &self,
240 _resp: &mut Response,
241 _conn: &Arc<ConnContext>,
242 _ctx: &mut FlowCtx,
243 ) -> Result<Decision, Error> {
244 Ok(Decision::Continue)
245 }
246 }
247
248 fn assert_send<F: Send>(_: &F) {}
253
254 struct NullSink;
255 impl FlowLogSink for NullSink {
256 fn emit(&self, _event: FlowLogEvent) {}
257 }
258
259 fn make_conn_context() -> Arc<ConnContext> {
260 let addr: SocketAddr = "127.0.0.1:0".parse().expect("parse addr");
261 Arc::new(ConnContext {
262 id: ConnId(0),
263 remote: addr,
264 local: addr,
265 transport: Transport::Tcp,
266 entered_at: Instant::now(),
267 tls: Mutex::new(None),
268 http_version: std::sync::OnceLock::new(),
269 user: Mutex::new(http::Extensions::new()),
270 })
271 }
272
273 #[test]
279 fn l4_peek_is_constructible_as_arc_dyn_send_sync() {
280 let m: Arc<dyn L4PeekMiddleware + Send + Sync> = Arc::new(PassPeek);
281 let _: Arc<dyn L4PeekMiddleware> = m;
284 }
285
286 #[test]
287 fn l4_bytes_is_constructible_as_arc_dyn_send_sync() {
288 let m: Arc<dyn L4BytesMiddleware + Send + Sync> = Arc::new(PassBytes);
289 let _: Arc<dyn L4BytesMiddleware> = m;
290 }
291
292 #[test]
293 fn l7_request_is_constructible_as_arc_dyn_send_sync() {
294 let m: Arc<dyn L7RequestMiddleware + Send + Sync> = Arc::new(PassReq);
295 let _: Arc<dyn L7RequestMiddleware> = m;
296 }
297
298 #[test]
299 fn l7_response_is_constructible_as_arc_dyn_send_sync() {
300 let m: Arc<dyn L7ResponseMiddleware + Send + Sync> = Arc::new(PassResp);
301 let _: Arc<dyn L7ResponseMiddleware> = m;
302 }
303
304 fn make_flow_ctx(conn_id: ConnId) -> FlowCtx {
305 FlowCtx {
306 span: tracing::Span::none(),
307 log: Arc::new(NullSink),
308 cancel: CancellationToken::new(),
309 verbosity: crate::flow_log::FlowLogVerbosity::Trajectory,
310 trajectory: crate::flow_log::TrajectoryBuilder::new(conn_id, crate::ir::NodeId::new(0), 0),
311 }
312 }
313
314 #[test]
315 fn l4_peek_run_returns_send_future() {
316 let m: Arc<dyn L4PeekMiddleware> = Arc::new(PassPeek);
317 let conn = make_conn_context();
318 let mut ctx = make_flow_ctx(conn.id);
319 let peek: &[u8] = &[];
320 let fut: Pin<Box<dyn Future<Output = Result<Decision, Error>> + Send + '_>> =
323 m.run(peek, &conn, &mut ctx);
324 assert_send(&fut);
325 drop(fut);
326 }
327
328 #[test]
329 fn l7_request_run_returns_send_future() {
330 let m: Arc<dyn L7RequestMiddleware> = Arc::new(PassReq);
331 let conn = make_conn_context();
332 let mut ctx = make_flow_ctx(conn.id);
333 let mut req: Request =
334 http::Request::builder().uri("/").body(crate::body::Body::Empty).expect("build req");
335 let fut: Pin<Box<dyn Future<Output = Result<Decision, Error>> + Send + '_>> =
336 m.run(&mut req, &conn, &mut ctx);
337 assert_send(&fut);
338 drop(fut);
339 }
340
341 #[test]
342 fn l7_response_run_returns_send_future() {
343 let m: Arc<dyn L7ResponseMiddleware> = Arc::new(PassResp);
344 let conn = make_conn_context();
345 let mut ctx = make_flow_ctx(conn.id);
346 let mut resp: Response =
347 http::Response::builder().status(200).body(crate::body::Body::Empty).expect("build resp");
348 let fut: Pin<Box<dyn Future<Output = Result<Decision, Error>> + Send + '_>> =
349 m.run(&mut resp, &conn, &mut ctx);
350 assert_send(&fut);
351 drop(fut);
352 }
353
354 #[test]
355 fn l7_request_needs_body_defaults_to_false() {
356 assert!(!L7RequestMiddleware::needs_body(&PassReq));
357 }
358
359 #[test]
360 fn l7_response_needs_body_defaults_to_false() {
361 assert!(!L7ResponseMiddleware::needs_body(&PassResp));
362 }
363
364 #[test]
365 fn middleware_kind_serde_round_trip_per_variant() {
366 for k in [
367 MiddlewareKind::L4Peek,
368 MiddlewareKind::L4Bytes,
369 MiddlewareKind::L7Request,
370 MiddlewareKind::L7Response,
371 ] {
372 let encoded = serde_json::to_string(&k).expect("serialize");
373 let decoded: MiddlewareKind = serde_json::from_str(&encoded).expect("deserialize");
374 assert_eq!(decoded, k);
375 }
376 }
377
378 #[test]
379 fn decision_and_shortcircuit_construct_per_variant() {
380 let _ = Decision::Continue;
381 let _ = Decision::Short(ShortCircuit::Close(CloseReason::Graceful));
382 let _ = ShortCircuit::Close(CloseReason::PolicyDenied("over quota".into()));
383 let _ = ShortCircuit::Close(CloseReason::ProtocolError("bad frame".into()));
384 }
385
386 #[test]
387 fn close_reason_construct_per_variant() {
388 let _ = CloseReason::Graceful;
389 let _ = CloseReason::PolicyDenied(std::borrow::Cow::Borrowed("over quota"));
390 let _ = CloseReason::ProtocolError(std::borrow::Cow::Owned(String::from("bad frame")));
391 let _ = CloseReason::Cancelled;
392 }
393
394 fn hash_of<T: Hash>(v: &T) -> u64 {
395 let mut h = DefaultHasher::new();
396 v.hash(&mut h);
397 h.finish()
398 }
399
400 fn sym_ref(args: serde_json::Value) -> SymbolicMiddlewareRef {
401 SymbolicMiddlewareRef {
402 name: Arc::from("rate_limit"),
403 args,
404 kind: MiddlewareKind::L7Request,
405 stateless: true,
406 needs_body: false,
407 on_error: None,
408 }
409 }
410
411 #[test]
412 fn symbolic_ref_args_hash_is_object_key_order_insensitive() {
413 let mut a = serde_json::Map::new();
416 a.insert("a".to_string(), json!(1));
417 a.insert("b".to_string(), json!(2));
418 let mut b = serde_json::Map::new();
419 b.insert("b".to_string(), json!(2));
420 b.insert("a".to_string(), json!(1));
421
422 let lhs = sym_ref(serde_json::Value::Object(a));
423 let rhs = sym_ref(serde_json::Value::Object(b));
424
425 assert_eq!(lhs, rhs);
426 assert_eq!(hash_of(&lhs), hash_of(&rhs));
427 }
428
429 #[test]
430 fn symbolic_ref_nested_object_key_order_is_ignored() {
431 let lhs = sym_ref(json!({ "outer": { "x": 1, "y": 2 } }));
432 let mut inner = serde_json::Map::new();
434 inner.insert("y".to_string(), json!(2));
435 inner.insert("x".to_string(), json!(1));
436 let mut outer = serde_json::Map::new();
437 outer.insert("outer".to_string(), serde_json::Value::Object(inner));
438 let rhs = sym_ref(serde_json::Value::Object(outer));
439
440 assert_eq!(lhs, rhs);
441 assert_eq!(hash_of(&lhs), hash_of(&rhs));
442 }
443
444 #[test]
445 fn symbolic_ref_arrays_are_order_sensitive() {
446 let lhs = sym_ref(json!({ "xs": [1, 2] }));
447 let rhs = sym_ref(json!({ "xs": [2, 1] }));
448 assert_ne!(lhs, rhs);
449 }
450
451 #[test]
452 fn symbolic_ref_differs_on_name() {
453 let a = sym_ref(json!({}));
454 let mut b = sym_ref(json!({}));
455 b.name = Arc::from("other");
456 assert_ne!(a, b);
457 }
458
459 #[test]
460 fn symbolic_ref_differs_on_kind() {
461 let a = sym_ref(json!({}));
462 let mut b = sym_ref(json!({}));
463 b.kind = MiddlewareKind::L4Peek;
464 assert_ne!(a, b);
465 }
466
467 #[test]
468 fn symbolic_ref_differs_on_stateless() {
469 let a = sym_ref(json!({}));
470 let mut b = sym_ref(json!({}));
471 b.stateless = false;
472 assert_ne!(a, b);
473 }
474
475 #[test]
476 fn symbolic_ref_differs_on_needs_body() {
477 let a = sym_ref(json!({}));
478 let mut b = sym_ref(json!({}));
479 b.needs_body = true;
480 assert_ne!(a, b);
481 }
482
483 #[test]
484 fn symbolic_ref_differs_on_on_error() {
485 let a = sym_ref(json!({}));
486 let mut b = sym_ref(json!({}));
487 b.on_error = Some(NodeId::new(3));
488 assert_ne!(a, b);
489 }
490
491 #[test]
492 fn symbolic_ref_same_name_but_distinct_args_are_unequal() {
493 let a = sym_ref(json!({ "limit": 100 }));
494 let b = sym_ref(json!({ "limit": 200 }));
495 assert_ne!(a, b);
496 }
497
498 #[test]
505 fn symbolic_middleware_ref_round_trip_preserves_all_fields() {
506 let m = SymbolicMiddlewareRef {
507 name: Arc::from("rate_limit"),
508 args: json!({ "rate": 100 }),
509 kind: MiddlewareKind::L7Request,
510 stateless: false,
511 needs_body: false,
512 on_error: Some(NodeId::new(5)),
513 };
514 let encoded = serde_json::to_string(&m).expect("serialize");
515 let decoded: SymbolicMiddlewareRef = serde_json::from_str(&encoded).expect("deserialize");
516 assert_eq!(decoded.name, m.name);
517 assert_eq!(decoded.kind, m.kind);
518 assert_eq!(decoded.stateless, m.stateless);
519 assert_eq!(decoded.needs_body, m.needs_body);
520 assert_eq!(decoded.on_error, m.on_error);
521 assert_eq!(decoded, m);
522 }
523
524 #[test]
525 fn symbolic_middleware_ref_round_trip_args_are_canonical_key_order_insensitive() {
526 let mut obj = serde_json::Map::new();
528 obj.insert("b".to_string(), json!(1));
529 obj.insert("a".to_string(), json!(2));
530 let m = SymbolicMiddlewareRef {
531 name: Arc::from("mw"),
532 args: serde_json::Value::Object(obj),
533 kind: MiddlewareKind::L7Request,
534 stateless: true,
535 needs_body: false,
536 on_error: None,
537 };
538 let encoded = serde_json::to_string(&m).expect("serialize");
539 let decoded: SymbolicMiddlewareRef = serde_json::from_str(&encoded).expect("deserialize");
540 assert_eq!(decoded, m);
543 }
544}