1use std::hash::Hash;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5
6use crate::body::{Request, Response};
7use crate::conn_context::ConnContext;
8use crate::error::Error;
9use crate::flow_ctx::FlowCtx;
10use crate::ir::NodeId;
11use crate::l4::L4Conn;
12
13#[async_trait]
14pub trait L4PeekMiddleware: Send + Sync {
15 async fn run(
16 &self,
17 peek: &[u8],
18 conn: &Arc<ConnContext>,
19 ctx: &mut FlowCtx<'_>,
20 ) -> Result<Decision, Error>;
21}
22
23#[async_trait]
24pub trait L4BytesMiddleware: Send + Sync {
25 async fn run(
26 &self,
27 l4: &mut L4Conn,
28 conn: &Arc<ConnContext>,
29 ctx: &mut FlowCtx<'_>,
30 ) -> Result<Decision, Error>;
31}
32
33#[async_trait]
34pub trait L7RequestMiddleware: Send + Sync {
35 async fn run(
36 &self,
37 req: &mut Request,
38 conn: &Arc<ConnContext>,
39 ctx: &mut FlowCtx<'_>,
40 ) -> Result<Decision, Error>;
41
42 fn needs_body(&self) -> bool {
43 false
44 }
45}
46
47#[async_trait]
48pub trait L7ResponseMiddleware: Send + Sync {
49 async fn run(
50 &self,
51 resp: &mut Response,
52 conn: &Arc<ConnContext>,
53 ctx: &mut FlowCtx<'_>,
54 ) -> Result<Decision, Error>;
55
56 fn needs_body(&self) -> bool {
57 false
58 }
59}
60
61pub enum Decision {
62 Continue,
63 Short(ShortCircuit),
64}
65
66pub enum ShortCircuit {
67 Response(Response),
68 Close(CloseReason),
69}
70
71#[derive(Clone, Debug)]
72pub enum CloseReason {
73 Graceful,
74 PolicyDenied(std::borrow::Cow<'static, str>),
75 ProtocolError(std::borrow::Cow<'static, str>),
76 Cancelled,
82}
83
84#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug, serde::Serialize, serde::Deserialize)]
85pub enum MiddlewareKind {
86 L4Peek,
87 L4Bytes,
88 L7Request,
89 L7Response,
90}
91
92#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
93pub struct SymbolicMiddlewareRef {
94 pub name: Arc<str>,
95 pub args: serde_json::Value,
96 pub kind: MiddlewareKind,
97 pub stateless: bool,
98 pub needs_body: bool,
99 pub on_error: Option<NodeId>,
100}
101
102impl PartialEq for SymbolicMiddlewareRef {
103 fn eq(&self, other: &Self) -> bool {
104 self.name == other.name
105 && self.kind == other.kind
106 && self.stateless == other.stateless
107 && self.needs_body == other.needs_body
108 && self.on_error == other.on_error
109 && canonical_json_eq(&self.args, &other.args)
110 }
111}
112
113impl Eq for SymbolicMiddlewareRef {}
114
115impl std::hash::Hash for SymbolicMiddlewareRef {
116 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
117 self.name.hash(state);
118 self.kind.hash(state);
119 self.stateless.hash(state);
120 self.needs_body.hash(state);
121 self.on_error.hash(state);
122 hash_canonical_json(&self.args, state);
123 }
124}
125
126fn canonical_json_eq(a: &serde_json::Value, b: &serde_json::Value) -> bool {
127 use serde_json::Value;
128 match (a, b) {
129 (Value::Null, Value::Null) => true,
130 (Value::Bool(x), Value::Bool(y)) => x == y,
131 (Value::Number(x), Value::Number(y)) => x == y,
132 (Value::String(x), Value::String(y)) => x == y,
133 (Value::Array(xs), Value::Array(ys)) => {
134 xs.len() == ys.len() && xs.iter().zip(ys).all(|(x, y)| canonical_json_eq(x, y))
135 }
136 (Value::Object(xs), Value::Object(ys)) if xs.len() == ys.len() => {
137 xs.iter().all(|(k, v)| ys.get(k).is_some_and(|w| canonical_json_eq(v, w)))
138 }
139 _ => false,
140 }
141}
142
143fn hash_canonical_json<H: std::hash::Hasher>(v: &serde_json::Value, state: &mut H) {
144 use serde_json::Value;
145 match v {
146 Value::Null => 0u8.hash(state),
147 Value::Bool(b) => {
148 1u8.hash(state);
149 b.hash(state);
150 }
151 Value::Number(n) => {
152 2u8.hash(state);
153 n.to_string().hash(state);
154 }
155 Value::String(s) => {
156 3u8.hash(state);
157 s.hash(state);
158 }
159 Value::Array(xs) => {
160 4u8.hash(state);
161 xs.len().hash(state);
162 for x in xs {
163 hash_canonical_json(x, state);
164 }
165 }
166 Value::Object(xs) => {
167 5u8.hash(state);
168 let mut keys: Vec<&String> = xs.keys().collect();
169 keys.sort();
170 keys.len().hash(state);
171 for k in keys {
172 k.hash(state);
173 hash_canonical_json(&xs[k], state);
174 }
175 }
176 }
177}
178
179#[cfg(test)]
180mod tests {
181 use std::collections::hash_map::DefaultHasher;
182 use std::future::Future;
183 use std::hash::{Hash, Hasher};
184 use std::net::SocketAddr;
185 use std::pin::Pin;
186 use std::time::Instant;
187
188 use parking_lot::Mutex;
189 use serde_json::json;
190 use tokio_util::sync::CancellationToken;
191
192 use super::*;
193 use crate::conn_context::{ConnId, Transport};
194 use crate::flow_log::{FlowLogEvent, FlowLogSink};
195
196 struct PassPeek;
197 #[async_trait]
198 impl L4PeekMiddleware for PassPeek {
199 async fn run(
200 &self,
201 _peek: &[u8],
202 _conn: &Arc<ConnContext>,
203 _ctx: &mut FlowCtx<'_>,
204 ) -> Result<Decision, Error> {
205 Ok(Decision::Continue)
206 }
207 }
208
209 struct PassBytes;
210 #[async_trait]
211 impl L4BytesMiddleware for PassBytes {
212 async fn run(
213 &self,
214 _l4: &mut L4Conn,
215 _conn: &Arc<ConnContext>,
216 _ctx: &mut FlowCtx<'_>,
217 ) -> Result<Decision, Error> {
218 Ok(Decision::Continue)
219 }
220 }
221
222 struct PassReq;
223 #[async_trait]
224 impl L7RequestMiddleware for PassReq {
225 async fn run(
226 &self,
227 _req: &mut Request,
228 _conn: &Arc<ConnContext>,
229 _ctx: &mut FlowCtx<'_>,
230 ) -> Result<Decision, Error> {
231 Ok(Decision::Continue)
232 }
233 }
234
235 struct PassResp;
236 #[async_trait]
237 impl L7ResponseMiddleware for PassResp {
238 async fn run(
239 &self,
240 _resp: &mut Response,
241 _conn: &Arc<ConnContext>,
242 _ctx: &mut FlowCtx<'_>,
243 ) -> Result<Decision, Error> {
244 Ok(Decision::Continue)
245 }
246 }
247
248 fn assert_send<F: Send>(_: &F) {}
253
254 struct NullSink;
255 impl FlowLogSink for NullSink {
256 fn emit(&self, _event: FlowLogEvent) {}
257 }
258
259 fn make_conn_context() -> Arc<ConnContext> {
260 let addr: SocketAddr = "127.0.0.1:0".parse().expect("parse addr");
261 Arc::new(ConnContext {
262 id: ConnId(0),
263 remote: addr,
264 local: addr,
265 transport: Transport::Tcp,
266 entered_at: Instant::now(),
267 tls: Mutex::new(None),
268 http_version: std::sync::OnceLock::new(),
269 user: Mutex::new(http::Extensions::new()),
270 })
271 }
272
273 #[test]
279 fn l4_peek_is_constructible_as_arc_dyn_send_sync() {
280 let m: Arc<dyn L4PeekMiddleware + Send + Sync> = Arc::new(PassPeek);
281 let _: Arc<dyn L4PeekMiddleware> = m;
284 }
285
286 #[test]
287 fn l4_bytes_is_constructible_as_arc_dyn_send_sync() {
288 let m: Arc<dyn L4BytesMiddleware + Send + Sync> = Arc::new(PassBytes);
289 let _: Arc<dyn L4BytesMiddleware> = m;
290 }
291
292 #[test]
293 fn l7_request_is_constructible_as_arc_dyn_send_sync() {
294 let m: Arc<dyn L7RequestMiddleware + Send + Sync> = Arc::new(PassReq);
295 let _: Arc<dyn L7RequestMiddleware> = m;
296 }
297
298 #[test]
299 fn l7_response_is_constructible_as_arc_dyn_send_sync() {
300 let m: Arc<dyn L7ResponseMiddleware + Send + Sync> = Arc::new(PassResp);
301 let _: Arc<dyn L7ResponseMiddleware> = m;
302 }
303
304 #[test]
305 fn l4_peek_run_returns_send_future() {
306 let m: Arc<dyn L4PeekMiddleware> = Arc::new(PassPeek);
307 let conn = make_conn_context();
308 let mut sink = NullSink;
309 let mut span = tracing::Span::none();
310 let cancel = CancellationToken::new();
311 let mut ctx = FlowCtx {
312 span: &mut span,
313 log: &mut sink as &mut dyn FlowLogSink,
314 cancel: &cancel,
315 verbosity: crate::flow_log::FlowLogVerbosity::Trajectory,
316 trajectory: crate::flow_log::TrajectoryBuilder::new(conn.id, crate::ir::NodeId::new(0), 0),
317 };
318 let peek: &[u8] = &[];
319 let fut: Pin<Box<dyn Future<Output = Result<Decision, Error>> + Send + '_>> =
322 m.run(peek, &conn, &mut ctx);
323 assert_send(&fut);
324 drop(fut);
325 }
326
327 #[test]
328 fn l7_request_run_returns_send_future() {
329 let m: Arc<dyn L7RequestMiddleware> = Arc::new(PassReq);
330 let conn = make_conn_context();
331 let mut sink = NullSink;
332 let mut span = tracing::Span::none();
333 let cancel = CancellationToken::new();
334 let mut ctx = FlowCtx {
335 span: &mut span,
336 log: &mut sink as &mut dyn FlowLogSink,
337 cancel: &cancel,
338 verbosity: crate::flow_log::FlowLogVerbosity::Trajectory,
339 trajectory: crate::flow_log::TrajectoryBuilder::new(conn.id, crate::ir::NodeId::new(0), 0),
340 };
341 let mut req: Request =
342 http::Request::builder().uri("/").body(crate::body::Body::Empty).expect("build req");
343 let fut: Pin<Box<dyn Future<Output = Result<Decision, Error>> + Send + '_>> =
344 m.run(&mut req, &conn, &mut ctx);
345 assert_send(&fut);
346 drop(fut);
347 }
348
349 #[test]
350 fn l7_response_run_returns_send_future() {
351 let m: Arc<dyn L7ResponseMiddleware> = Arc::new(PassResp);
352 let conn = make_conn_context();
353 let mut sink = NullSink;
354 let mut span = tracing::Span::none();
355 let cancel = CancellationToken::new();
356 let mut ctx = FlowCtx {
357 span: &mut span,
358 log: &mut sink as &mut dyn FlowLogSink,
359 cancel: &cancel,
360 verbosity: crate::flow_log::FlowLogVerbosity::Trajectory,
361 trajectory: crate::flow_log::TrajectoryBuilder::new(conn.id, crate::ir::NodeId::new(0), 0),
362 };
363 let mut resp: Response =
364 http::Response::builder().status(200).body(crate::body::Body::Empty).expect("build resp");
365 let fut: Pin<Box<dyn Future<Output = Result<Decision, Error>> + Send + '_>> =
366 m.run(&mut resp, &conn, &mut ctx);
367 assert_send(&fut);
368 drop(fut);
369 }
370
371 #[test]
372 fn l7_request_needs_body_defaults_to_false() {
373 assert!(!L7RequestMiddleware::needs_body(&PassReq));
374 }
375
376 #[test]
377 fn l7_response_needs_body_defaults_to_false() {
378 assert!(!L7ResponseMiddleware::needs_body(&PassResp));
379 }
380
381 #[test]
382 fn middleware_kind_serde_round_trip_per_variant() {
383 for k in [
384 MiddlewareKind::L4Peek,
385 MiddlewareKind::L4Bytes,
386 MiddlewareKind::L7Request,
387 MiddlewareKind::L7Response,
388 ] {
389 let encoded = serde_json::to_string(&k).expect("serialize");
390 let decoded: MiddlewareKind = serde_json::from_str(&encoded).expect("deserialize");
391 assert_eq!(decoded, k);
392 }
393 }
394
395 #[test]
396 fn decision_and_shortcircuit_construct_per_variant() {
397 let _ = Decision::Continue;
398 let _ = Decision::Short(ShortCircuit::Close(CloseReason::Graceful));
399 let _ = ShortCircuit::Close(CloseReason::PolicyDenied("over quota".into()));
400 let _ = ShortCircuit::Close(CloseReason::ProtocolError("bad frame".into()));
401 }
402
403 #[test]
404 fn close_reason_construct_per_variant() {
405 let _ = CloseReason::Graceful;
406 let _ = CloseReason::PolicyDenied(std::borrow::Cow::Borrowed("over quota"));
407 let _ = CloseReason::ProtocolError(std::borrow::Cow::Owned(String::from("bad frame")));
408 let _ = CloseReason::Cancelled;
409 }
410
411 fn hash_of<T: Hash>(v: &T) -> u64 {
412 let mut h = DefaultHasher::new();
413 v.hash(&mut h);
414 h.finish()
415 }
416
417 fn sym_ref(args: serde_json::Value) -> SymbolicMiddlewareRef {
418 SymbolicMiddlewareRef {
419 name: Arc::from("rate_limit"),
420 args,
421 kind: MiddlewareKind::L7Request,
422 stateless: true,
423 needs_body: false,
424 on_error: None,
425 }
426 }
427
428 #[test]
429 fn symbolic_ref_args_hash_is_object_key_order_insensitive() {
430 let mut a = serde_json::Map::new();
433 a.insert("a".to_string(), json!(1));
434 a.insert("b".to_string(), json!(2));
435 let mut b = serde_json::Map::new();
436 b.insert("b".to_string(), json!(2));
437 b.insert("a".to_string(), json!(1));
438
439 let lhs = sym_ref(serde_json::Value::Object(a));
440 let rhs = sym_ref(serde_json::Value::Object(b));
441
442 assert_eq!(lhs, rhs);
443 assert_eq!(hash_of(&lhs), hash_of(&rhs));
444 }
445
446 #[test]
447 fn symbolic_ref_nested_object_key_order_is_ignored() {
448 let lhs = sym_ref(json!({ "outer": { "x": 1, "y": 2 } }));
449 let mut inner = serde_json::Map::new();
451 inner.insert("y".to_string(), json!(2));
452 inner.insert("x".to_string(), json!(1));
453 let mut outer = serde_json::Map::new();
454 outer.insert("outer".to_string(), serde_json::Value::Object(inner));
455 let rhs = sym_ref(serde_json::Value::Object(outer));
456
457 assert_eq!(lhs, rhs);
458 assert_eq!(hash_of(&lhs), hash_of(&rhs));
459 }
460
461 #[test]
462 fn symbolic_ref_arrays_are_order_sensitive() {
463 let lhs = sym_ref(json!({ "xs": [1, 2] }));
464 let rhs = sym_ref(json!({ "xs": [2, 1] }));
465 assert_ne!(lhs, rhs);
466 }
467
468 #[test]
469 fn symbolic_ref_differs_on_name() {
470 let a = sym_ref(json!({}));
471 let mut b = sym_ref(json!({}));
472 b.name = Arc::from("other");
473 assert_ne!(a, b);
474 }
475
476 #[test]
477 fn symbolic_ref_differs_on_kind() {
478 let a = sym_ref(json!({}));
479 let mut b = sym_ref(json!({}));
480 b.kind = MiddlewareKind::L4Peek;
481 assert_ne!(a, b);
482 }
483
484 #[test]
485 fn symbolic_ref_differs_on_stateless() {
486 let a = sym_ref(json!({}));
487 let mut b = sym_ref(json!({}));
488 b.stateless = false;
489 assert_ne!(a, b);
490 }
491
492 #[test]
493 fn symbolic_ref_differs_on_needs_body() {
494 let a = sym_ref(json!({}));
495 let mut b = sym_ref(json!({}));
496 b.needs_body = true;
497 assert_ne!(a, b);
498 }
499
500 #[test]
501 fn symbolic_ref_differs_on_on_error() {
502 let a = sym_ref(json!({}));
503 let mut b = sym_ref(json!({}));
504 b.on_error = Some(NodeId::new(3));
505 assert_ne!(a, b);
506 }
507
508 #[test]
509 fn symbolic_ref_same_name_but_distinct_args_are_unequal() {
510 let a = sym_ref(json!({ "limit": 100 }));
511 let b = sym_ref(json!({ "limit": 200 }));
512 assert_ne!(a, b);
513 }
514
515 #[test]
522 fn symbolic_middleware_ref_round_trip_preserves_all_fields() {
523 let m = SymbolicMiddlewareRef {
524 name: Arc::from("rate_limit"),
525 args: json!({ "rate": 100 }),
526 kind: MiddlewareKind::L7Request,
527 stateless: false,
528 needs_body: false,
529 on_error: Some(NodeId::new(5)),
530 };
531 let encoded = serde_json::to_string(&m).expect("serialize");
532 let decoded: SymbolicMiddlewareRef = serde_json::from_str(&encoded).expect("deserialize");
533 assert_eq!(decoded.name, m.name);
534 assert_eq!(decoded.kind, m.kind);
535 assert_eq!(decoded.stateless, m.stateless);
536 assert_eq!(decoded.needs_body, m.needs_body);
537 assert_eq!(decoded.on_error, m.on_error);
538 assert_eq!(decoded, m);
539 }
540
541 #[test]
542 fn symbolic_middleware_ref_round_trip_args_are_canonical_key_order_insensitive() {
543 let mut obj = serde_json::Map::new();
545 obj.insert("b".to_string(), json!(1));
546 obj.insert("a".to_string(), json!(2));
547 let m = SymbolicMiddlewareRef {
548 name: Arc::from("mw"),
549 args: serde_json::Value::Object(obj),
550 kind: MiddlewareKind::L7Request,
551 stateless: true,
552 needs_body: false,
553 on_error: None,
554 };
555 let encoded = serde_json::to_string(&m).expect("serialize");
556 let decoded: SymbolicMiddlewareRef = serde_json::from_str(&encoded).expect("deserialize");
557 assert_eq!(decoded, m);
560 }
561}