Skip to main content

vane_core/
middleware.rs

1use std::hash::Hash;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5
6use crate::body::{Request, Response};
7use crate::conn_context::ConnContext;
8use crate::error::Error;
9use crate::flow_ctx::FlowCtx;
10use crate::ir::NodeId;
11use crate::l4::L4Conn;
12
13#[async_trait]
14pub trait L4PeekMiddleware: Send + Sync {
15	async fn run(
16		&self,
17		peek: &[u8],
18		conn: &Arc<ConnContext>,
19		ctx: &mut FlowCtx<'_>,
20	) -> Result<Decision, Error>;
21}
22
23#[async_trait]
24pub trait L4BytesMiddleware: Send + Sync {
25	async fn run(
26		&self,
27		l4: &mut L4Conn,
28		conn: &Arc<ConnContext>,
29		ctx: &mut FlowCtx<'_>,
30	) -> Result<Decision, Error>;
31}
32
33#[async_trait]
34pub trait L7RequestMiddleware: Send + Sync {
35	async fn run(
36		&self,
37		req: &mut Request,
38		conn: &Arc<ConnContext>,
39		ctx: &mut FlowCtx<'_>,
40	) -> Result<Decision, Error>;
41
42	fn needs_body(&self) -> bool {
43		false
44	}
45}
46
47#[async_trait]
48pub trait L7ResponseMiddleware: Send + Sync {
49	async fn run(
50		&self,
51		resp: &mut Response,
52		conn: &Arc<ConnContext>,
53		ctx: &mut FlowCtx<'_>,
54	) -> Result<Decision, Error>;
55
56	fn needs_body(&self) -> bool {
57		false
58	}
59}
60
61pub enum Decision {
62	Continue,
63	Short(ShortCircuit),
64}
65
66pub enum ShortCircuit {
67	Response(Response),
68	Close(CloseReason),
69}
70
71#[derive(Clone, Debug)]
72pub enum CloseReason {
73	Graceful,
74	PolicyDenied(std::borrow::Cow<'static, str>),
75	ProtocolError(std::borrow::Cow<'static, str>),
76	/// Daemon-initiated cancellation — listener `force_cancel` fired during
77	/// shutdown drain (01-topology.md § _Listener lifecycle_), or any other
78	/// `ctx.cancel.cancelled()` propagation. Distinct from `Graceful` so
79	/// management observers can distinguish "client EOF'd" from "daemon
80	/// pulled the plug while in-flight."
81	Cancelled,
82}
83
84#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug, serde::Serialize, serde::Deserialize)]
85pub enum MiddlewareKind {
86	L4Peek,
87	L4Bytes,
88	L7Request,
89	L7Response,
90}
91
92#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
93pub struct SymbolicMiddlewareRef {
94	pub name: Arc<str>,
95	pub args: serde_json::Value,
96	pub kind: MiddlewareKind,
97	pub stateless: bool,
98	pub needs_body: bool,
99	pub on_error: Option<NodeId>,
100}
101
102impl PartialEq for SymbolicMiddlewareRef {
103	fn eq(&self, other: &Self) -> bool {
104		self.name == other.name
105			&& self.kind == other.kind
106			&& self.stateless == other.stateless
107			&& self.needs_body == other.needs_body
108			&& self.on_error == other.on_error
109			&& canonical_json_eq(&self.args, &other.args)
110	}
111}
112
113impl Eq for SymbolicMiddlewareRef {}
114
115impl std::hash::Hash for SymbolicMiddlewareRef {
116	fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
117		self.name.hash(state);
118		self.kind.hash(state);
119		self.stateless.hash(state);
120		self.needs_body.hash(state);
121		self.on_error.hash(state);
122		hash_canonical_json(&self.args, state);
123	}
124}
125
126fn canonical_json_eq(a: &serde_json::Value, b: &serde_json::Value) -> bool {
127	use serde_json::Value;
128	match (a, b) {
129		(Value::Null, Value::Null) => true,
130		(Value::Bool(x), Value::Bool(y)) => x == y,
131		(Value::Number(x), Value::Number(y)) => x == y,
132		(Value::String(x), Value::String(y)) => x == y,
133		(Value::Array(xs), Value::Array(ys)) => {
134			xs.len() == ys.len() && xs.iter().zip(ys).all(|(x, y)| canonical_json_eq(x, y))
135		}
136		(Value::Object(xs), Value::Object(ys)) if xs.len() == ys.len() => {
137			xs.iter().all(|(k, v)| ys.get(k).is_some_and(|w| canonical_json_eq(v, w)))
138		}
139		_ => false,
140	}
141}
142
143fn hash_canonical_json<H: std::hash::Hasher>(v: &serde_json::Value, state: &mut H) {
144	use serde_json::Value;
145	match v {
146		Value::Null => 0u8.hash(state),
147		Value::Bool(b) => {
148			1u8.hash(state);
149			b.hash(state);
150		}
151		Value::Number(n) => {
152			2u8.hash(state);
153			n.to_string().hash(state);
154		}
155		Value::String(s) => {
156			3u8.hash(state);
157			s.hash(state);
158		}
159		Value::Array(xs) => {
160			4u8.hash(state);
161			xs.len().hash(state);
162			for x in xs {
163				hash_canonical_json(x, state);
164			}
165		}
166		Value::Object(xs) => {
167			5u8.hash(state);
168			let mut keys: Vec<&String> = xs.keys().collect();
169			keys.sort();
170			keys.len().hash(state);
171			for k in keys {
172				k.hash(state);
173				hash_canonical_json(&xs[k], state);
174			}
175		}
176	}
177}
178
179#[cfg(test)]
180mod tests {
181	use std::collections::hash_map::DefaultHasher;
182	use std::future::Future;
183	use std::hash::{Hash, Hasher};
184	use std::net::SocketAddr;
185	use std::pin::Pin;
186	use std::time::Instant;
187
188	use parking_lot::Mutex;
189	use serde_json::json;
190	use tokio_util::sync::CancellationToken;
191
192	use super::*;
193	use crate::conn_context::{ConnId, Transport};
194	use crate::flow_log::{FlowLogEvent, FlowLogSink};
195
196	struct PassPeek;
197	#[async_trait]
198	impl L4PeekMiddleware for PassPeek {
199		async fn run(
200			&self,
201			_peek: &[u8],
202			_conn: &Arc<ConnContext>,
203			_ctx: &mut FlowCtx<'_>,
204		) -> Result<Decision, Error> {
205			Ok(Decision::Continue)
206		}
207	}
208
209	struct PassBytes;
210	#[async_trait]
211	impl L4BytesMiddleware for PassBytes {
212		async fn run(
213			&self,
214			_l4: &mut L4Conn,
215			_conn: &Arc<ConnContext>,
216			_ctx: &mut FlowCtx<'_>,
217		) -> Result<Decision, Error> {
218			Ok(Decision::Continue)
219		}
220	}
221
222	struct PassReq;
223	#[async_trait]
224	impl L7RequestMiddleware for PassReq {
225		async fn run(
226			&self,
227			_req: &mut Request,
228			_conn: &Arc<ConnContext>,
229			_ctx: &mut FlowCtx<'_>,
230		) -> Result<Decision, Error> {
231			Ok(Decision::Continue)
232		}
233	}
234
235	struct PassResp;
236	#[async_trait]
237	impl L7ResponseMiddleware for PassResp {
238		async fn run(
239			&self,
240			_resp: &mut Response,
241			_conn: &Arc<ConnContext>,
242			_ctx: &mut FlowCtx<'_>,
243		) -> Result<Decision, Error> {
244			Ok(Decision::Continue)
245		}
246	}
247
248	// Compile-time assertion helper: the type `F` must be `Send`. `async_trait`
249	// rewrites `async fn run(...)` to return `Pin<Box<dyn Future + Send>>`, so
250	// every `run` invocation's future must satisfy this bound — that is the
251	// load-bearing contract per 04-middleware.md § _Async Send via async_trait_.
252	fn assert_send<F: Send>(_: &F) {}
253
254	struct NullSink;
255	impl FlowLogSink for NullSink {
256		fn emit(&self, _event: FlowLogEvent) {}
257	}
258
259	fn make_conn_context() -> Arc<ConnContext> {
260		let addr: SocketAddr = "127.0.0.1:0".parse().expect("parse addr");
261		Arc::new(ConnContext {
262			id: ConnId(0),
263			remote: addr,
264			local: addr,
265			transport: Transport::Tcp,
266			entered_at: Instant::now(),
267			tls: Mutex::new(None),
268			http_version: std::sync::OnceLock::new(),
269			user: Mutex::new(http::Extensions::new()),
270		})
271	}
272
273	// `async_trait` makes these traits dyn-compatible. `MiddlewareInst` stores
274	// each variant as `Arc<dyn Trait>` per 04-middleware.md § _Symbolic forms_
275	// and § _Async Send via async_trait_; constructing that exact shape from a
276	// concrete impl is the contract we guard.
277
278	#[test]
279	fn l4_peek_is_constructible_as_arc_dyn_send_sync() {
280		let m: Arc<dyn L4PeekMiddleware + Send + Sync> = Arc::new(PassPeek);
281		// The trait-object Arc coerces to the bare `Arc<dyn Trait>` shape used
282		// by `MiddlewareInst::L4Peek(Arc<dyn L4PeekMiddleware>)` in engine.
283		let _: Arc<dyn L4PeekMiddleware> = m;
284	}
285
286	#[test]
287	fn l4_bytes_is_constructible_as_arc_dyn_send_sync() {
288		let m: Arc<dyn L4BytesMiddleware + Send + Sync> = Arc::new(PassBytes);
289		let _: Arc<dyn L4BytesMiddleware> = m;
290	}
291
292	#[test]
293	fn l7_request_is_constructible_as_arc_dyn_send_sync() {
294		let m: Arc<dyn L7RequestMiddleware + Send + Sync> = Arc::new(PassReq);
295		let _: Arc<dyn L7RequestMiddleware> = m;
296	}
297
298	#[test]
299	fn l7_response_is_constructible_as_arc_dyn_send_sync() {
300		let m: Arc<dyn L7ResponseMiddleware + Send + Sync> = Arc::new(PassResp);
301		let _: Arc<dyn L7ResponseMiddleware> = m;
302	}
303
304	#[test]
305	fn l4_peek_run_returns_send_future() {
306		let m: Arc<dyn L4PeekMiddleware> = Arc::new(PassPeek);
307		let conn = make_conn_context();
308		let mut sink = NullSink;
309		let mut span = tracing::Span::none();
310		let cancel = CancellationToken::new();
311		let mut ctx = FlowCtx {
312			span: &mut span,
313			log: &mut sink as &mut dyn FlowLogSink,
314			cancel: &cancel,
315			verbosity: crate::flow_log::FlowLogVerbosity::Trajectory,
316			trajectory: crate::flow_log::TrajectoryBuilder::new(conn.id, crate::ir::NodeId::new(0), 0),
317		};
318		let peek: &[u8] = &[];
319		// Exact-type coercion into `Pin<Box<dyn Future + Send>>` — the async_trait
320		// signature. Fails to compile if a future becomes `!Send`.
321		let fut: Pin<Box<dyn Future<Output = Result<Decision, Error>> + Send + '_>> =
322			m.run(peek, &conn, &mut ctx);
323		assert_send(&fut);
324		drop(fut);
325	}
326
327	#[test]
328	fn l7_request_run_returns_send_future() {
329		let m: Arc<dyn L7RequestMiddleware> = Arc::new(PassReq);
330		let conn = make_conn_context();
331		let mut sink = NullSink;
332		let mut span = tracing::Span::none();
333		let cancel = CancellationToken::new();
334		let mut ctx = FlowCtx {
335			span: &mut span,
336			log: &mut sink as &mut dyn FlowLogSink,
337			cancel: &cancel,
338			verbosity: crate::flow_log::FlowLogVerbosity::Trajectory,
339			trajectory: crate::flow_log::TrajectoryBuilder::new(conn.id, crate::ir::NodeId::new(0), 0),
340		};
341		let mut req: Request =
342			http::Request::builder().uri("/").body(crate::body::Body::Empty).expect("build req");
343		let fut: Pin<Box<dyn Future<Output = Result<Decision, Error>> + Send + '_>> =
344			m.run(&mut req, &conn, &mut ctx);
345		assert_send(&fut);
346		drop(fut);
347	}
348
349	#[test]
350	fn l7_response_run_returns_send_future() {
351		let m: Arc<dyn L7ResponseMiddleware> = Arc::new(PassResp);
352		let conn = make_conn_context();
353		let mut sink = NullSink;
354		let mut span = tracing::Span::none();
355		let cancel = CancellationToken::new();
356		let mut ctx = FlowCtx {
357			span: &mut span,
358			log: &mut sink as &mut dyn FlowLogSink,
359			cancel: &cancel,
360			verbosity: crate::flow_log::FlowLogVerbosity::Trajectory,
361			trajectory: crate::flow_log::TrajectoryBuilder::new(conn.id, crate::ir::NodeId::new(0), 0),
362		};
363		let mut resp: Response =
364			http::Response::builder().status(200).body(crate::body::Body::Empty).expect("build resp");
365		let fut: Pin<Box<dyn Future<Output = Result<Decision, Error>> + Send + '_>> =
366			m.run(&mut resp, &conn, &mut ctx);
367		assert_send(&fut);
368		drop(fut);
369	}
370
371	#[test]
372	fn l7_request_needs_body_defaults_to_false() {
373		assert!(!L7RequestMiddleware::needs_body(&PassReq));
374	}
375
376	#[test]
377	fn l7_response_needs_body_defaults_to_false() {
378		assert!(!L7ResponseMiddleware::needs_body(&PassResp));
379	}
380
381	#[test]
382	fn middleware_kind_serde_round_trip_per_variant() {
383		for k in [
384			MiddlewareKind::L4Peek,
385			MiddlewareKind::L4Bytes,
386			MiddlewareKind::L7Request,
387			MiddlewareKind::L7Response,
388		] {
389			let encoded = serde_json::to_string(&k).expect("serialize");
390			let decoded: MiddlewareKind = serde_json::from_str(&encoded).expect("deserialize");
391			assert_eq!(decoded, k);
392		}
393	}
394
395	#[test]
396	fn decision_and_shortcircuit_construct_per_variant() {
397		let _ = Decision::Continue;
398		let _ = Decision::Short(ShortCircuit::Close(CloseReason::Graceful));
399		let _ = ShortCircuit::Close(CloseReason::PolicyDenied("over quota".into()));
400		let _ = ShortCircuit::Close(CloseReason::ProtocolError("bad frame".into()));
401	}
402
403	#[test]
404	fn close_reason_construct_per_variant() {
405		let _ = CloseReason::Graceful;
406		let _ = CloseReason::PolicyDenied(std::borrow::Cow::Borrowed("over quota"));
407		let _ = CloseReason::ProtocolError(std::borrow::Cow::Owned(String::from("bad frame")));
408		let _ = CloseReason::Cancelled;
409	}
410
411	fn hash_of<T: Hash>(v: &T) -> u64 {
412		let mut h = DefaultHasher::new();
413		v.hash(&mut h);
414		h.finish()
415	}
416
417	fn sym_ref(args: serde_json::Value) -> SymbolicMiddlewareRef {
418		SymbolicMiddlewareRef {
419			name: Arc::from("rate_limit"),
420			args,
421			kind: MiddlewareKind::L7Request,
422			stateless: true,
423			needs_body: false,
424			on_error: None,
425		}
426	}
427
428	#[test]
429	fn symbolic_ref_args_hash_is_object_key_order_insensitive() {
430		// Manually build both maps with opposite insertion orders to defeat
431		// serde_json::from_str's preserve-insertion-order backend.
432		let mut a = serde_json::Map::new();
433		a.insert("a".to_string(), json!(1));
434		a.insert("b".to_string(), json!(2));
435		let mut b = serde_json::Map::new();
436		b.insert("b".to_string(), json!(2));
437		b.insert("a".to_string(), json!(1));
438
439		let lhs = sym_ref(serde_json::Value::Object(a));
440		let rhs = sym_ref(serde_json::Value::Object(b));
441
442		assert_eq!(lhs, rhs);
443		assert_eq!(hash_of(&lhs), hash_of(&rhs));
444	}
445
446	#[test]
447	fn symbolic_ref_nested_object_key_order_is_ignored() {
448		let lhs = sym_ref(json!({ "outer": { "x": 1, "y": 2 } }));
449		// Build the inner map with swapped order by hand.
450		let mut inner = serde_json::Map::new();
451		inner.insert("y".to_string(), json!(2));
452		inner.insert("x".to_string(), json!(1));
453		let mut outer = serde_json::Map::new();
454		outer.insert("outer".to_string(), serde_json::Value::Object(inner));
455		let rhs = sym_ref(serde_json::Value::Object(outer));
456
457		assert_eq!(lhs, rhs);
458		assert_eq!(hash_of(&lhs), hash_of(&rhs));
459	}
460
461	#[test]
462	fn symbolic_ref_arrays_are_order_sensitive() {
463		let lhs = sym_ref(json!({ "xs": [1, 2] }));
464		let rhs = sym_ref(json!({ "xs": [2, 1] }));
465		assert_ne!(lhs, rhs);
466	}
467
468	#[test]
469	fn symbolic_ref_differs_on_name() {
470		let a = sym_ref(json!({}));
471		let mut b = sym_ref(json!({}));
472		b.name = Arc::from("other");
473		assert_ne!(a, b);
474	}
475
476	#[test]
477	fn symbolic_ref_differs_on_kind() {
478		let a = sym_ref(json!({}));
479		let mut b = sym_ref(json!({}));
480		b.kind = MiddlewareKind::L4Peek;
481		assert_ne!(a, b);
482	}
483
484	#[test]
485	fn symbolic_ref_differs_on_stateless() {
486		let a = sym_ref(json!({}));
487		let mut b = sym_ref(json!({}));
488		b.stateless = false;
489		assert_ne!(a, b);
490	}
491
492	#[test]
493	fn symbolic_ref_differs_on_needs_body() {
494		let a = sym_ref(json!({}));
495		let mut b = sym_ref(json!({}));
496		b.needs_body = true;
497		assert_ne!(a, b);
498	}
499
500	#[test]
501	fn symbolic_ref_differs_on_on_error() {
502		let a = sym_ref(json!({}));
503		let mut b = sym_ref(json!({}));
504		b.on_error = Some(NodeId::new(3));
505		assert_ne!(a, b);
506	}
507
508	#[test]
509	fn symbolic_ref_same_name_but_distinct_args_are_unequal() {
510		let a = sym_ref(json!({ "limit": 100 }));
511		let b = sym_ref(json!({ "limit": 200 }));
512		assert_ne!(a, b);
513	}
514
515	// Dry-run JSON wire-format contract: SymbolicMiddlewareRef participates
516	// in the compiled-form JSON per 02-flow.md § _The compiled form_. The
517	// whole struct uses derive(Serialize/Deserialize); all fields round-trip.
518	// PartialEq uses canonical-json equality on `args`, so key-order
519	// perturbation must still compare equal after a round-trip.
520
521	#[test]
522	fn symbolic_middleware_ref_round_trip_preserves_all_fields() {
523		let m = SymbolicMiddlewareRef {
524			name: Arc::from("rate_limit"),
525			args: json!({ "rate": 100 }),
526			kind: MiddlewareKind::L7Request,
527			stateless: false,
528			needs_body: false,
529			on_error: Some(NodeId::new(5)),
530		};
531		let encoded = serde_json::to_string(&m).expect("serialize");
532		let decoded: SymbolicMiddlewareRef = serde_json::from_str(&encoded).expect("deserialize");
533		assert_eq!(decoded.name, m.name);
534		assert_eq!(decoded.kind, m.kind);
535		assert_eq!(decoded.stateless, m.stateless);
536		assert_eq!(decoded.needs_body, m.needs_body);
537		assert_eq!(decoded.on_error, m.on_error);
538		assert_eq!(decoded, m);
539	}
540
541	#[test]
542	fn symbolic_middleware_ref_round_trip_args_are_canonical_key_order_insensitive() {
543		// Build an args value whose serialized form has a deliberate key order.
544		let mut obj = serde_json::Map::new();
545		obj.insert("b".to_string(), json!(1));
546		obj.insert("a".to_string(), json!(2));
547		let m = SymbolicMiddlewareRef {
548			name: Arc::from("mw"),
549			args: serde_json::Value::Object(obj),
550			kind: MiddlewareKind::L7Request,
551			stateless: true,
552			needs_body: false,
553			on_error: None,
554		};
555		let encoded = serde_json::to_string(&m).expect("serialize");
556		let decoded: SymbolicMiddlewareRef = serde_json::from_str(&encoded).expect("deserialize");
557		// PartialEq on SymbolicMiddlewareRef uses canonical-json equality on args,
558		// so any post-round-trip key reshuffling remains == to the original.
559		assert_eq!(decoded, m);
560	}
561}