Skip to main content

vane_core/
middleware.rs

1use std::hash::Hash;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5
6use crate::body::{Request, Response};
7use crate::conn_context::ConnContext;
8use crate::error::Error;
9use crate::flow_ctx::FlowCtx;
10use crate::ir::NodeId;
11use crate::l4::L4Conn;
12
13#[async_trait]
14pub trait L4PeekMiddleware: Send + Sync {
15	async fn run(
16		&self,
17		peek: &[u8],
18		conn: &Arc<ConnContext>,
19		ctx: &mut FlowCtx,
20	) -> Result<Decision, Error>;
21}
22
23#[async_trait]
24pub trait L4BytesMiddleware: Send + Sync {
25	async fn run(
26		&self,
27		l4: &mut L4Conn,
28		conn: &Arc<ConnContext>,
29		ctx: &mut FlowCtx,
30	) -> Result<Decision, Error>;
31}
32
33#[async_trait]
34pub trait L7RequestMiddleware: Send + Sync {
35	async fn run(
36		&self,
37		req: &mut Request,
38		conn: &Arc<ConnContext>,
39		ctx: &mut FlowCtx,
40	) -> Result<Decision, Error>;
41
42	fn needs_body(&self) -> bool {
43		false
44	}
45}
46
47#[async_trait]
48pub trait L7ResponseMiddleware: Send + Sync {
49	async fn run(
50		&self,
51		resp: &mut Response,
52		conn: &Arc<ConnContext>,
53		ctx: &mut FlowCtx,
54	) -> Result<Decision, Error>;
55
56	fn needs_body(&self) -> bool {
57		false
58	}
59}
60
61pub enum Decision {
62	Continue,
63	Short(ShortCircuit),
64}
65
66pub enum ShortCircuit {
67	Response(Response),
68	Close(CloseReason),
69}
70
71#[derive(Clone, Debug)]
72pub enum CloseReason {
73	Graceful,
74	PolicyDenied(std::borrow::Cow<'static, str>),
75	ProtocolError(std::borrow::Cow<'static, str>),
76	/// Daemon-initiated cancellation — listener `force_cancel` fired during
77	/// shutdown drain (01-topology.md § _Listener lifecycle_), or any other
78	/// `ctx.cancel.cancelled()` propagation. Distinct from `Graceful` so
79	/// management observers can distinguish "client EOF'd" from "daemon
80	/// pulled the plug while in-flight."
81	Cancelled,
82}
83
84#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug, serde::Serialize, serde::Deserialize)]
85pub enum MiddlewareKind {
86	L4Peek,
87	L4Bytes,
88	L7Request,
89	L7Response,
90}
91
92#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
93pub struct SymbolicMiddlewareRef {
94	pub name: Arc<str>,
95	pub args: serde_json::Value,
96	pub kind: MiddlewareKind,
97	pub stateless: bool,
98	pub needs_body: bool,
99	pub on_error: Option<NodeId>,
100}
101
102impl PartialEq for SymbolicMiddlewareRef {
103	fn eq(&self, other: &Self) -> bool {
104		self.name == other.name
105			&& self.kind == other.kind
106			&& self.stateless == other.stateless
107			&& self.needs_body == other.needs_body
108			&& self.on_error == other.on_error
109			&& canonical_json_eq(&self.args, &other.args)
110	}
111}
112
113impl Eq for SymbolicMiddlewareRef {}
114
115impl std::hash::Hash for SymbolicMiddlewareRef {
116	fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
117		self.name.hash(state);
118		self.kind.hash(state);
119		self.stateless.hash(state);
120		self.needs_body.hash(state);
121		self.on_error.hash(state);
122		hash_canonical_json(&self.args, state);
123	}
124}
125
126fn canonical_json_eq(a: &serde_json::Value, b: &serde_json::Value) -> bool {
127	use serde_json::Value;
128	match (a, b) {
129		(Value::Null, Value::Null) => true,
130		(Value::Bool(x), Value::Bool(y)) => x == y,
131		(Value::Number(x), Value::Number(y)) => x == y,
132		(Value::String(x), Value::String(y)) => x == y,
133		(Value::Array(xs), Value::Array(ys)) => {
134			xs.len() == ys.len() && xs.iter().zip(ys).all(|(x, y)| canonical_json_eq(x, y))
135		}
136		(Value::Object(xs), Value::Object(ys)) if xs.len() == ys.len() => {
137			xs.iter().all(|(k, v)| ys.get(k).is_some_and(|w| canonical_json_eq(v, w)))
138		}
139		_ => false,
140	}
141}
142
143fn hash_canonical_json<H: std::hash::Hasher>(v: &serde_json::Value, state: &mut H) {
144	use serde_json::Value;
145	match v {
146		Value::Null => 0u8.hash(state),
147		Value::Bool(b) => {
148			1u8.hash(state);
149			b.hash(state);
150		}
151		Value::Number(n) => {
152			2u8.hash(state);
153			n.to_string().hash(state);
154		}
155		Value::String(s) => {
156			3u8.hash(state);
157			s.hash(state);
158		}
159		Value::Array(xs) => {
160			4u8.hash(state);
161			xs.len().hash(state);
162			for x in xs {
163				hash_canonical_json(x, state);
164			}
165		}
166		Value::Object(xs) => {
167			5u8.hash(state);
168			let mut keys: Vec<&String> = xs.keys().collect();
169			keys.sort();
170			keys.len().hash(state);
171			for k in keys {
172				k.hash(state);
173				hash_canonical_json(&xs[k], state);
174			}
175		}
176	}
177}
178
179#[cfg(test)]
180mod tests {
181	use std::collections::hash_map::DefaultHasher;
182	use std::future::Future;
183	use std::hash::{Hash, Hasher};
184	use std::net::SocketAddr;
185	use std::pin::Pin;
186	use std::time::Instant;
187
188	use parking_lot::Mutex;
189	use serde_json::json;
190	use tokio_util::sync::CancellationToken;
191
192	use super::*;
193	use crate::conn_context::{ConnId, Transport};
194	use crate::flow_log::{FlowLogEvent, FlowLogSink};
195
196	struct PassPeek;
197	#[async_trait]
198	impl L4PeekMiddleware for PassPeek {
199		async fn run(
200			&self,
201			_peek: &[u8],
202			_conn: &Arc<ConnContext>,
203			_ctx: &mut FlowCtx,
204		) -> Result<Decision, Error> {
205			Ok(Decision::Continue)
206		}
207	}
208
209	struct PassBytes;
210	#[async_trait]
211	impl L4BytesMiddleware for PassBytes {
212		async fn run(
213			&self,
214			_l4: &mut L4Conn,
215			_conn: &Arc<ConnContext>,
216			_ctx: &mut FlowCtx,
217		) -> Result<Decision, Error> {
218			Ok(Decision::Continue)
219		}
220	}
221
222	struct PassReq;
223	#[async_trait]
224	impl L7RequestMiddleware for PassReq {
225		async fn run(
226			&self,
227			_req: &mut Request,
228			_conn: &Arc<ConnContext>,
229			_ctx: &mut FlowCtx,
230		) -> Result<Decision, Error> {
231			Ok(Decision::Continue)
232		}
233	}
234
235	struct PassResp;
236	#[async_trait]
237	impl L7ResponseMiddleware for PassResp {
238		async fn run(
239			&self,
240			_resp: &mut Response,
241			_conn: &Arc<ConnContext>,
242			_ctx: &mut FlowCtx,
243		) -> Result<Decision, Error> {
244			Ok(Decision::Continue)
245		}
246	}
247
248	// Compile-time assertion helper: the type `F` must be `Send`. `async_trait`
249	// rewrites `async fn run(...)` to return `Pin<Box<dyn Future + Send>>`, so
250	// every `run` invocation's future must satisfy this bound — that is the
251	// load-bearing contract per 04-middleware.md § _Async Send via async_trait_.
252	fn assert_send<F: Send>(_: &F) {}
253
254	struct NullSink;
255	impl FlowLogSink for NullSink {
256		fn emit(&self, _event: FlowLogEvent) {}
257	}
258
259	fn make_conn_context() -> Arc<ConnContext> {
260		let addr: SocketAddr = "127.0.0.1:0".parse().expect("parse addr");
261		Arc::new(ConnContext {
262			id: ConnId(0),
263			remote: addr,
264			local: addr,
265			transport: Transport::Tcp,
266			entered_at: Instant::now(),
267			tls: Mutex::new(None),
268			http_version: std::sync::OnceLock::new(),
269			user: Mutex::new(http::Extensions::new()),
270		})
271	}
272
273	// `async_trait` makes these traits dyn-compatible. `MiddlewareInst` stores
274	// each variant as `Arc<dyn Trait>` per 04-middleware.md § _Symbolic forms_
275	// and § _Async Send via async_trait_; constructing that exact shape from a
276	// concrete impl is the contract we guard.
277
278	#[test]
279	fn l4_peek_is_constructible_as_arc_dyn_send_sync() {
280		let m: Arc<dyn L4PeekMiddleware + Send + Sync> = Arc::new(PassPeek);
281		// The trait-object Arc coerces to the bare `Arc<dyn Trait>` shape used
282		// by `MiddlewareInst::L4Peek(Arc<dyn L4PeekMiddleware>)` in engine.
283		let _: Arc<dyn L4PeekMiddleware> = m;
284	}
285
286	#[test]
287	fn l4_bytes_is_constructible_as_arc_dyn_send_sync() {
288		let m: Arc<dyn L4BytesMiddleware + Send + Sync> = Arc::new(PassBytes);
289		let _: Arc<dyn L4BytesMiddleware> = m;
290	}
291
292	#[test]
293	fn l7_request_is_constructible_as_arc_dyn_send_sync() {
294		let m: Arc<dyn L7RequestMiddleware + Send + Sync> = Arc::new(PassReq);
295		let _: Arc<dyn L7RequestMiddleware> = m;
296	}
297
298	#[test]
299	fn l7_response_is_constructible_as_arc_dyn_send_sync() {
300		let m: Arc<dyn L7ResponseMiddleware + Send + Sync> = Arc::new(PassResp);
301		let _: Arc<dyn L7ResponseMiddleware> = m;
302	}
303
304	fn make_flow_ctx(conn_id: ConnId) -> FlowCtx {
305		FlowCtx {
306			span: tracing::Span::none(),
307			log: Arc::new(NullSink),
308			cancel: CancellationToken::new(),
309			verbosity: crate::flow_log::FlowLogVerbosity::Trajectory,
310			trajectory: crate::flow_log::TrajectoryBuilder::new(conn_id, crate::ir::NodeId::new(0), 0),
311		}
312	}
313
314	#[test]
315	fn l4_peek_run_returns_send_future() {
316		let m: Arc<dyn L4PeekMiddleware> = Arc::new(PassPeek);
317		let conn = make_conn_context();
318		let mut ctx = make_flow_ctx(conn.id);
319		let peek: &[u8] = &[];
320		// Exact-type coercion into `Pin<Box<dyn Future + Send>>` — the async_trait
321		// signature. Fails to compile if a future becomes `!Send`.
322		let fut: Pin<Box<dyn Future<Output = Result<Decision, Error>> + Send + '_>> =
323			m.run(peek, &conn, &mut ctx);
324		assert_send(&fut);
325		drop(fut);
326	}
327
328	#[test]
329	fn l7_request_run_returns_send_future() {
330		let m: Arc<dyn L7RequestMiddleware> = Arc::new(PassReq);
331		let conn = make_conn_context();
332		let mut ctx = make_flow_ctx(conn.id);
333		let mut req: Request =
334			http::Request::builder().uri("/").body(crate::body::Body::Empty).expect("build req");
335		let fut: Pin<Box<dyn Future<Output = Result<Decision, Error>> + Send + '_>> =
336			m.run(&mut req, &conn, &mut ctx);
337		assert_send(&fut);
338		drop(fut);
339	}
340
341	#[test]
342	fn l7_response_run_returns_send_future() {
343		let m: Arc<dyn L7ResponseMiddleware> = Arc::new(PassResp);
344		let conn = make_conn_context();
345		let mut ctx = make_flow_ctx(conn.id);
346		let mut resp: Response =
347			http::Response::builder().status(200).body(crate::body::Body::Empty).expect("build resp");
348		let fut: Pin<Box<dyn Future<Output = Result<Decision, Error>> + Send + '_>> =
349			m.run(&mut resp, &conn, &mut ctx);
350		assert_send(&fut);
351		drop(fut);
352	}
353
354	#[test]
355	fn l7_request_needs_body_defaults_to_false() {
356		assert!(!L7RequestMiddleware::needs_body(&PassReq));
357	}
358
359	#[test]
360	fn l7_response_needs_body_defaults_to_false() {
361		assert!(!L7ResponseMiddleware::needs_body(&PassResp));
362	}
363
364	#[test]
365	fn middleware_kind_serde_round_trip_per_variant() {
366		for k in [
367			MiddlewareKind::L4Peek,
368			MiddlewareKind::L4Bytes,
369			MiddlewareKind::L7Request,
370			MiddlewareKind::L7Response,
371		] {
372			let encoded = serde_json::to_string(&k).expect("serialize");
373			let decoded: MiddlewareKind = serde_json::from_str(&encoded).expect("deserialize");
374			assert_eq!(decoded, k);
375		}
376	}
377
378	#[test]
379	fn decision_and_shortcircuit_construct_per_variant() {
380		let _ = Decision::Continue;
381		let _ = Decision::Short(ShortCircuit::Close(CloseReason::Graceful));
382		let _ = ShortCircuit::Close(CloseReason::PolicyDenied("over quota".into()));
383		let _ = ShortCircuit::Close(CloseReason::ProtocolError("bad frame".into()));
384	}
385
386	#[test]
387	fn close_reason_construct_per_variant() {
388		let _ = CloseReason::Graceful;
389		let _ = CloseReason::PolicyDenied(std::borrow::Cow::Borrowed("over quota"));
390		let _ = CloseReason::ProtocolError(std::borrow::Cow::Owned(String::from("bad frame")));
391		let _ = CloseReason::Cancelled;
392	}
393
394	fn hash_of<T: Hash>(v: &T) -> u64 {
395		let mut h = DefaultHasher::new();
396		v.hash(&mut h);
397		h.finish()
398	}
399
400	fn sym_ref(args: serde_json::Value) -> SymbolicMiddlewareRef {
401		SymbolicMiddlewareRef {
402			name: Arc::from("rate_limit"),
403			args,
404			kind: MiddlewareKind::L7Request,
405			stateless: true,
406			needs_body: false,
407			on_error: None,
408		}
409	}
410
411	#[test]
412	fn symbolic_ref_args_hash_is_object_key_order_insensitive() {
413		// Manually build both maps with opposite insertion orders to defeat
414		// serde_json::from_str's preserve-insertion-order backend.
415		let mut a = serde_json::Map::new();
416		a.insert("a".to_string(), json!(1));
417		a.insert("b".to_string(), json!(2));
418		let mut b = serde_json::Map::new();
419		b.insert("b".to_string(), json!(2));
420		b.insert("a".to_string(), json!(1));
421
422		let lhs = sym_ref(serde_json::Value::Object(a));
423		let rhs = sym_ref(serde_json::Value::Object(b));
424
425		assert_eq!(lhs, rhs);
426		assert_eq!(hash_of(&lhs), hash_of(&rhs));
427	}
428
429	#[test]
430	fn symbolic_ref_nested_object_key_order_is_ignored() {
431		let lhs = sym_ref(json!({ "outer": { "x": 1, "y": 2 } }));
432		// Build the inner map with swapped order by hand.
433		let mut inner = serde_json::Map::new();
434		inner.insert("y".to_string(), json!(2));
435		inner.insert("x".to_string(), json!(1));
436		let mut outer = serde_json::Map::new();
437		outer.insert("outer".to_string(), serde_json::Value::Object(inner));
438		let rhs = sym_ref(serde_json::Value::Object(outer));
439
440		assert_eq!(lhs, rhs);
441		assert_eq!(hash_of(&lhs), hash_of(&rhs));
442	}
443
444	#[test]
445	fn symbolic_ref_arrays_are_order_sensitive() {
446		let lhs = sym_ref(json!({ "xs": [1, 2] }));
447		let rhs = sym_ref(json!({ "xs": [2, 1] }));
448		assert_ne!(lhs, rhs);
449	}
450
451	#[test]
452	fn symbolic_ref_differs_on_name() {
453		let a = sym_ref(json!({}));
454		let mut b = sym_ref(json!({}));
455		b.name = Arc::from("other");
456		assert_ne!(a, b);
457	}
458
459	#[test]
460	fn symbolic_ref_differs_on_kind() {
461		let a = sym_ref(json!({}));
462		let mut b = sym_ref(json!({}));
463		b.kind = MiddlewareKind::L4Peek;
464		assert_ne!(a, b);
465	}
466
467	#[test]
468	fn symbolic_ref_differs_on_stateless() {
469		let a = sym_ref(json!({}));
470		let mut b = sym_ref(json!({}));
471		b.stateless = false;
472		assert_ne!(a, b);
473	}
474
475	#[test]
476	fn symbolic_ref_differs_on_needs_body() {
477		let a = sym_ref(json!({}));
478		let mut b = sym_ref(json!({}));
479		b.needs_body = true;
480		assert_ne!(a, b);
481	}
482
483	#[test]
484	fn symbolic_ref_differs_on_on_error() {
485		let a = sym_ref(json!({}));
486		let mut b = sym_ref(json!({}));
487		b.on_error = Some(NodeId::new(3));
488		assert_ne!(a, b);
489	}
490
491	#[test]
492	fn symbolic_ref_same_name_but_distinct_args_are_unequal() {
493		let a = sym_ref(json!({ "limit": 100 }));
494		let b = sym_ref(json!({ "limit": 200 }));
495		assert_ne!(a, b);
496	}
497
498	// Dry-run JSON wire-format contract: SymbolicMiddlewareRef participates
499	// in the compiled-form JSON per 02-flow.md § _The compiled form_. The
500	// whole struct uses derive(Serialize/Deserialize); all fields round-trip.
501	// PartialEq uses canonical-json equality on `args`, so key-order
502	// perturbation must still compare equal after a round-trip.
503
504	#[test]
505	fn symbolic_middleware_ref_round_trip_preserves_all_fields() {
506		let m = SymbolicMiddlewareRef {
507			name: Arc::from("rate_limit"),
508			args: json!({ "rate": 100 }),
509			kind: MiddlewareKind::L7Request,
510			stateless: false,
511			needs_body: false,
512			on_error: Some(NodeId::new(5)),
513		};
514		let encoded = serde_json::to_string(&m).expect("serialize");
515		let decoded: SymbolicMiddlewareRef = serde_json::from_str(&encoded).expect("deserialize");
516		assert_eq!(decoded.name, m.name);
517		assert_eq!(decoded.kind, m.kind);
518		assert_eq!(decoded.stateless, m.stateless);
519		assert_eq!(decoded.needs_body, m.needs_body);
520		assert_eq!(decoded.on_error, m.on_error);
521		assert_eq!(decoded, m);
522	}
523
524	#[test]
525	fn symbolic_middleware_ref_round_trip_args_are_canonical_key_order_insensitive() {
526		// Build an args value whose serialized form has a deliberate key order.
527		let mut obj = serde_json::Map::new();
528		obj.insert("b".to_string(), json!(1));
529		obj.insert("a".to_string(), json!(2));
530		let m = SymbolicMiddlewareRef {
531			name: Arc::from("mw"),
532			args: serde_json::Value::Object(obj),
533			kind: MiddlewareKind::L7Request,
534			stateless: true,
535			needs_body: false,
536			on_error: None,
537		};
538		let encoded = serde_json::to_string(&m).expect("serialize");
539		let decoded: SymbolicMiddlewareRef = serde_json::from_str(&encoded).expect("deserialize");
540		// PartialEq on SymbolicMiddlewareRef uses canonical-json equality on args,
541		// so any post-round-trip key reshuffling remains == to the original.
542		assert_eq!(decoded, m);
543	}
544}