1use std::hash::Hash;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5
6use crate::body::{Request, Response};
7use crate::conn_context::ConnContext;
8use crate::error::Error;
9use crate::flow_ctx::FlowCtx;
10use crate::ir::NodeId;
11use crate::l4::L4Conn;
12
13#[async_trait]
14pub trait L4PeekMiddleware: Send + Sync {
15 async fn run(
16 &self,
17 peek: &[u8],
18 conn: &Arc<ConnContext>,
19 ctx: &mut FlowCtx,
20 ) -> Result<Decision, Error>;
21}
22
23#[async_trait]
24pub trait L4BytesMiddleware: Send + Sync {
25 async fn run(
26 &self,
27 l4: &mut L4Conn,
28 conn: &Arc<ConnContext>,
29 ctx: &mut FlowCtx,
30 ) -> Result<Decision, Error>;
31}
32
33#[async_trait]
34pub trait L7RequestMiddleware: Send + Sync {
35 async fn run(
36 &self,
37 req: &mut Request,
38 conn: &Arc<ConnContext>,
39 ctx: &mut FlowCtx,
40 ) -> Result<Decision, Error>;
41}
42
43#[async_trait]
44pub trait L7ResponseMiddleware: Send + Sync {
45 async fn run(
46 &self,
47 resp: &mut Response,
48 conn: &Arc<ConnContext>,
49 ctx: &mut FlowCtx,
50 ) -> Result<Decision, Error>;
51}
52
53#[non_exhaustive]
57pub enum Decision {
58 Continue,
59 Short(ShortCircuit),
60}
61
62#[non_exhaustive]
67pub enum ShortCircuit {
68 Response(Response),
69 Close(CloseReason),
70}
71
72#[derive(Clone, Debug)]
77#[non_exhaustive]
78pub enum CloseReason {
79 Graceful,
80 PolicyDenied(std::borrow::Cow<'static, str>),
81 ProtocolError(std::borrow::Cow<'static, str>),
82 Cancelled,
88}
89
90#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug, serde::Serialize, serde::Deserialize)]
91pub enum MiddlewareKind {
92 L4Peek,
93 L4Bytes,
94 L7Request,
95 L7Response,
96}
97
98#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
99pub struct SymbolicMiddlewareRef {
100 pub name: Arc<str>,
101 pub args: serde_json::Value,
102 pub kind: MiddlewareKind,
103 pub stateless: bool,
104 pub needs_body: bool,
105 pub on_error: Option<NodeId>,
106}
107
108impl PartialEq for SymbolicMiddlewareRef {
109 fn eq(&self, other: &Self) -> bool {
110 self.name == other.name
111 && self.kind == other.kind
112 && self.stateless == other.stateless
113 && self.needs_body == other.needs_body
114 && self.on_error == other.on_error
115 && canonical_json_eq(&self.args, &other.args)
116 }
117}
118
119impl Eq for SymbolicMiddlewareRef {}
120
121impl std::hash::Hash for SymbolicMiddlewareRef {
122 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
123 self.name.hash(state);
124 self.kind.hash(state);
125 self.stateless.hash(state);
126 self.needs_body.hash(state);
127 self.on_error.hash(state);
128 hash_canonical_json(&self.args, state);
129 }
130}
131
132fn canonical_json_eq(a: &serde_json::Value, b: &serde_json::Value) -> bool {
133 use serde_json::Value;
134 match (a, b) {
135 (Value::Null, Value::Null) => true,
136 (Value::Bool(x), Value::Bool(y)) => x == y,
137 (Value::Number(x), Value::Number(y)) => x == y,
138 (Value::String(x), Value::String(y)) => x == y,
139 (Value::Array(xs), Value::Array(ys)) => {
140 xs.len() == ys.len() && xs.iter().zip(ys).all(|(x, y)| canonical_json_eq(x, y))
141 }
142 (Value::Object(xs), Value::Object(ys)) if xs.len() == ys.len() => {
143 xs.iter().all(|(k, v)| ys.get(k).is_some_and(|w| canonical_json_eq(v, w)))
144 }
145 _ => false,
146 }
147}
148
149fn hash_canonical_json<H: std::hash::Hasher>(v: &serde_json::Value, state: &mut H) {
154 let mut buf = String::new();
155 crate::canonical::write_into_lossy(&mut buf, v);
156 buf.hash(state);
157}
158
159#[cfg(test)]
160mod tests {
161 use std::collections::hash_map::DefaultHasher;
162 use std::future::Future;
163 use std::hash::{Hash, Hasher};
164 use std::net::SocketAddr;
165 use std::pin::Pin;
166 use std::time::Instant;
167
168 use parking_lot::Mutex;
169 use serde_json::json;
170 use tokio_util::sync::CancellationToken;
171
172 use super::*;
173 use crate::conn_context::{ConnId, Transport};
174 use crate::flow_log::{FlowLogEvent, FlowLogSink};
175
176 struct PassPeek;
177 #[async_trait]
178 impl L4PeekMiddleware for PassPeek {
179 async fn run(
180 &self,
181 _peek: &[u8],
182 _conn: &Arc<ConnContext>,
183 _ctx: &mut FlowCtx,
184 ) -> Result<Decision, Error> {
185 Ok(Decision::Continue)
186 }
187 }
188
189 struct PassBytes;
190 #[async_trait]
191 impl L4BytesMiddleware for PassBytes {
192 async fn run(
193 &self,
194 _l4: &mut L4Conn,
195 _conn: &Arc<ConnContext>,
196 _ctx: &mut FlowCtx,
197 ) -> Result<Decision, Error> {
198 Ok(Decision::Continue)
199 }
200 }
201
202 struct PassReq;
203 #[async_trait]
204 impl L7RequestMiddleware for PassReq {
205 async fn run(
206 &self,
207 _req: &mut Request,
208 _conn: &Arc<ConnContext>,
209 _ctx: &mut FlowCtx,
210 ) -> Result<Decision, Error> {
211 Ok(Decision::Continue)
212 }
213 }
214
215 struct PassResp;
216 #[async_trait]
217 impl L7ResponseMiddleware for PassResp {
218 async fn run(
219 &self,
220 _resp: &mut Response,
221 _conn: &Arc<ConnContext>,
222 _ctx: &mut FlowCtx,
223 ) -> Result<Decision, Error> {
224 Ok(Decision::Continue)
225 }
226 }
227
228 fn assert_send<F: Send>(_: &F) {}
233
234 struct NullSink;
235 impl FlowLogSink for NullSink {
236 fn emit(&self, _event: FlowLogEvent) {}
237 }
238
239 fn make_conn_context() -> Arc<ConnContext> {
240 let addr: SocketAddr = "127.0.0.1:0".parse().expect("parse addr");
241 Arc::new(ConnContext {
242 id: ConnId(0),
243 remote: addr,
244 local: addr,
245 transport: Transport::Tcp,
246 entered_at: Instant::now(),
247 tls: Mutex::new(None),
248 http_version: std::sync::OnceLock::new(),
249 user: Mutex::new(http::Extensions::new()),
250 })
251 }
252
253 #[test]
259 fn l4_peek_is_constructible_as_arc_dyn_send_sync() {
260 let m: Arc<dyn L4PeekMiddleware + Send + Sync> = Arc::new(PassPeek);
261 let _: Arc<dyn L4PeekMiddleware> = m;
264 }
265
266 #[test]
267 fn l4_bytes_is_constructible_as_arc_dyn_send_sync() {
268 let m: Arc<dyn L4BytesMiddleware + Send + Sync> = Arc::new(PassBytes);
269 let _: Arc<dyn L4BytesMiddleware> = m;
270 }
271
272 #[test]
273 fn l7_request_is_constructible_as_arc_dyn_send_sync() {
274 let m: Arc<dyn L7RequestMiddleware + Send + Sync> = Arc::new(PassReq);
275 let _: Arc<dyn L7RequestMiddleware> = m;
276 }
277
278 #[test]
279 fn l7_response_is_constructible_as_arc_dyn_send_sync() {
280 let m: Arc<dyn L7ResponseMiddleware + Send + Sync> = Arc::new(PassResp);
281 let _: Arc<dyn L7ResponseMiddleware> = m;
282 }
283
284 fn make_flow_ctx(conn_id: ConnId) -> FlowCtx {
285 FlowCtx {
286 span: tracing::Span::none(),
287 log: Arc::new(NullSink),
288 cancel: CancellationToken::new(),
289 accept_cancel: CancellationToken::new(),
290 verbosity: crate::flow_log::FlowLogVerbosity::Trajectory,
291 trajectory: crate::flow_log::TrajectoryBuilder::new(conn_id, crate::ir::NodeId::new(0), 0),
292 }
293 }
294
295 #[test]
296 fn l4_peek_run_returns_send_future() {
297 let m: Arc<dyn L4PeekMiddleware> = Arc::new(PassPeek);
298 let conn = make_conn_context();
299 let mut ctx = make_flow_ctx(conn.id);
300 let peek: &[u8] = &[];
301 let fut: Pin<Box<dyn Future<Output = Result<Decision, Error>> + Send + '_>> =
304 m.run(peek, &conn, &mut ctx);
305 assert_send(&fut);
306 drop(fut);
307 }
308
309 #[test]
310 fn l7_request_run_returns_send_future() {
311 let m: Arc<dyn L7RequestMiddleware> = Arc::new(PassReq);
312 let conn = make_conn_context();
313 let mut ctx = make_flow_ctx(conn.id);
314 let mut req: Request =
315 http::Request::builder().uri("/").body(crate::body::Body::Empty).expect("build req");
316 let fut: Pin<Box<dyn Future<Output = Result<Decision, Error>> + Send + '_>> =
317 m.run(&mut req, &conn, &mut ctx);
318 assert_send(&fut);
319 drop(fut);
320 }
321
322 #[test]
323 fn l7_response_run_returns_send_future() {
324 let m: Arc<dyn L7ResponseMiddleware> = Arc::new(PassResp);
325 let conn = make_conn_context();
326 let mut ctx = make_flow_ctx(conn.id);
327 let mut resp: Response =
328 http::Response::builder().status(200).body(crate::body::Body::Empty).expect("build resp");
329 let fut: Pin<Box<dyn Future<Output = Result<Decision, Error>> + Send + '_>> =
330 m.run(&mut resp, &conn, &mut ctx);
331 assert_send(&fut);
332 drop(fut);
333 }
334
335 #[test]
341 fn middleware_kind_serde_round_trip_per_variant() {
342 for k in [
343 MiddlewareKind::L4Peek,
344 MiddlewareKind::L4Bytes,
345 MiddlewareKind::L7Request,
346 MiddlewareKind::L7Response,
347 ] {
348 let encoded = serde_json::to_string(&k).expect("serialize");
349 let decoded: MiddlewareKind = serde_json::from_str(&encoded).expect("deserialize");
350 assert_eq!(decoded, k);
351 }
352 }
353
354 #[test]
355 fn decision_and_shortcircuit_construct_per_variant() {
356 let _ = Decision::Continue;
357 let _ = Decision::Short(ShortCircuit::Close(CloseReason::Graceful));
358 let _ = ShortCircuit::Close(CloseReason::PolicyDenied("over quota".into()));
359 let _ = ShortCircuit::Close(CloseReason::ProtocolError("bad frame".into()));
360 }
361
362 #[test]
363 fn close_reason_construct_per_variant() {
364 let _ = CloseReason::Graceful;
365 let _ = CloseReason::PolicyDenied(std::borrow::Cow::Borrowed("over quota"));
366 let _ = CloseReason::ProtocolError(std::borrow::Cow::Owned(String::from("bad frame")));
367 let _ = CloseReason::Cancelled;
368 }
369
370 fn hash_of<T: Hash>(v: &T) -> u64 {
371 let mut h = DefaultHasher::new();
372 v.hash(&mut h);
373 h.finish()
374 }
375
376 fn sym_ref(args: serde_json::Value) -> SymbolicMiddlewareRef {
377 SymbolicMiddlewareRef {
378 name: Arc::from("rate_limit"),
379 args,
380 kind: MiddlewareKind::L7Request,
381 stateless: true,
382 needs_body: false,
383 on_error: None,
384 }
385 }
386
387 #[test]
388 fn symbolic_ref_args_hash_is_object_key_order_insensitive() {
389 let mut a = serde_json::Map::new();
392 a.insert("a".to_string(), json!(1));
393 a.insert("b".to_string(), json!(2));
394 let mut b = serde_json::Map::new();
395 b.insert("b".to_string(), json!(2));
396 b.insert("a".to_string(), json!(1));
397
398 let lhs = sym_ref(serde_json::Value::Object(a));
399 let rhs = sym_ref(serde_json::Value::Object(b));
400
401 assert_eq!(lhs, rhs);
402 assert_eq!(hash_of(&lhs), hash_of(&rhs));
403 }
404
405 #[test]
406 fn symbolic_ref_nested_object_key_order_is_ignored() {
407 let lhs = sym_ref(json!({ "outer": { "x": 1, "y": 2 } }));
408 let mut inner = serde_json::Map::new();
410 inner.insert("y".to_string(), json!(2));
411 inner.insert("x".to_string(), json!(1));
412 let mut outer = serde_json::Map::new();
413 outer.insert("outer".to_string(), serde_json::Value::Object(inner));
414 let rhs = sym_ref(serde_json::Value::Object(outer));
415
416 assert_eq!(lhs, rhs);
417 assert_eq!(hash_of(&lhs), hash_of(&rhs));
418 }
419
420 #[test]
421 fn symbolic_ref_arrays_are_order_sensitive() {
422 let lhs = sym_ref(json!({ "xs": [1, 2] }));
423 let rhs = sym_ref(json!({ "xs": [2, 1] }));
424 assert_ne!(lhs, rhs);
425 }
426
427 #[test]
428 fn symbolic_ref_differs_on_name() {
429 let a = sym_ref(json!({}));
430 let mut b = sym_ref(json!({}));
431 b.name = Arc::from("other");
432 assert_ne!(a, b);
433 }
434
435 #[test]
436 fn symbolic_ref_differs_on_kind() {
437 let a = sym_ref(json!({}));
438 let mut b = sym_ref(json!({}));
439 b.kind = MiddlewareKind::L4Peek;
440 assert_ne!(a, b);
441 }
442
443 #[test]
444 fn symbolic_ref_differs_on_stateless() {
445 let a = sym_ref(json!({}));
446 let mut b = sym_ref(json!({}));
447 b.stateless = false;
448 assert_ne!(a, b);
449 }
450
451 #[test]
452 fn symbolic_ref_differs_on_needs_body() {
453 let a = sym_ref(json!({}));
454 let mut b = sym_ref(json!({}));
455 b.needs_body = true;
456 assert_ne!(a, b);
457 }
458
459 #[test]
460 fn symbolic_ref_differs_on_on_error() {
461 let a = sym_ref(json!({}));
462 let mut b = sym_ref(json!({}));
463 b.on_error = Some(NodeId::new(3));
464 assert_ne!(a, b);
465 }
466
467 #[test]
468 fn symbolic_ref_same_name_but_distinct_args_are_unequal() {
469 let a = sym_ref(json!({ "limit": 100 }));
470 let b = sym_ref(json!({ "limit": 200 }));
471 assert_ne!(a, b);
472 }
473
474 #[test]
481 fn symbolic_middleware_ref_round_trip_preserves_all_fields() {
482 let m = SymbolicMiddlewareRef {
483 name: Arc::from("rate_limit"),
484 args: json!({ "rate": 100 }),
485 kind: MiddlewareKind::L7Request,
486 stateless: false,
487 needs_body: false,
488 on_error: Some(NodeId::new(5)),
489 };
490 let encoded = serde_json::to_string(&m).expect("serialize");
491 let decoded: SymbolicMiddlewareRef = serde_json::from_str(&encoded).expect("deserialize");
492 assert_eq!(decoded.name, m.name);
493 assert_eq!(decoded.kind, m.kind);
494 assert_eq!(decoded.stateless, m.stateless);
495 assert_eq!(decoded.needs_body, m.needs_body);
496 assert_eq!(decoded.on_error, m.on_error);
497 assert_eq!(decoded, m);
498 }
499
500 #[test]
501 fn symbolic_middleware_ref_round_trip_args_are_canonical_key_order_insensitive() {
502 let mut obj = serde_json::Map::new();
504 obj.insert("b".to_string(), json!(1));
505 obj.insert("a".to_string(), json!(2));
506 let m = SymbolicMiddlewareRef {
507 name: Arc::from("mw"),
508 args: serde_json::Value::Object(obj),
509 kind: MiddlewareKind::L7Request,
510 stateless: true,
511 needs_body: false,
512 on_error: None,
513 };
514 let encoded = serde_json::to_string(&m).expect("serialize");
515 let decoded: SymbolicMiddlewareRef = serde_json::from_str(&encoded).expect("deserialize");
516 assert_eq!(decoded, m);
519 }
520}