Skip to main content

vane_core/
middleware.rs

1use std::hash::Hash;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5
6use crate::body::{Request, Response};
7use crate::conn_context::ConnContext;
8use crate::error::Error;
9use crate::flow_ctx::FlowCtx;
10use crate::ir::NodeId;
11use crate::l4::L4Conn;
12
13#[async_trait]
14pub trait L4PeekMiddleware: Send + Sync {
15	async fn run(
16		&self,
17		peek: &[u8],
18		conn: &Arc<ConnContext>,
19		ctx: &mut FlowCtx,
20	) -> Result<Decision, Error>;
21}
22
23#[async_trait]
24pub trait L4BytesMiddleware: Send + Sync {
25	async fn run(
26		&self,
27		l4: &mut L4Conn,
28		conn: &Arc<ConnContext>,
29		ctx: &mut FlowCtx,
30	) -> Result<Decision, Error>;
31}
32
33#[async_trait]
34pub trait L7RequestMiddleware: Send + Sync {
35	async fn run(
36		&self,
37		req: &mut Request,
38		conn: &Arc<ConnContext>,
39		ctx: &mut FlowCtx,
40	) -> Result<Decision, Error>;
41}
42
43#[async_trait]
44pub trait L7ResponseMiddleware: Send + Sync {
45	async fn run(
46		&self,
47		resp: &mut Response,
48		conn: &Arc<ConnContext>,
49		ctx: &mut FlowCtx,
50	) -> Result<Decision, Error>;
51}
52
53/// Middleware verdict. `#[non_exhaustive]` so future verdicts (e.g.
54/// a future suspend/replay shape) can land in vane-core without a
55/// downstream-breaking match.
56#[non_exhaustive]
57pub enum Decision {
58	Continue,
59	Short(ShortCircuit),
60}
61
62/// Short-circuit branch for [`Decision::Short`]. `#[non_exhaustive]`
63/// keeps the engine's switch resilient to new short-circuit shapes
64/// (synthetic 1xx interim responses, in-flow rewrites) being added
65/// without forcing every consumer to recompile.
66#[non_exhaustive]
67pub enum ShortCircuit {
68	Response(Response),
69	Close(CloseReason),
70}
71
72/// Reason a connection closed. `#[non_exhaustive]` so new operator-
73/// visible reasons (e.g. quota-hit, rate-limit-trip) can be added
74/// without breaking downstream observers that pattern-match on the
75/// enum.
76#[derive(Clone, Debug)]
77#[non_exhaustive]
78pub enum CloseReason {
79	Graceful,
80	PolicyDenied(std::borrow::Cow<'static, str>),
81	ProtocolError(std::borrow::Cow<'static, str>),
82	/// Daemon-initiated cancellation — listener `force_cancel` fired during
83	/// shutdown drain (spec/topology.md § _Listener lifecycle_), or any other
84	/// `ctx.cancel.cancelled()` propagation. Distinct from `Graceful` so
85	/// management observers can distinguish "client EOF'd" from "daemon
86	/// pulled the plug while in-flight."
87	Cancelled,
88}
89
90#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug, serde::Serialize, serde::Deserialize)]
91pub enum MiddlewareKind {
92	L4Peek,
93	L4Bytes,
94	L7Request,
95	L7Response,
96}
97
98#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
99pub struct SymbolicMiddlewareRef {
100	pub name: Arc<str>,
101	pub args: serde_json::Value,
102	pub kind: MiddlewareKind,
103	pub stateless: bool,
104	pub needs_body: bool,
105	pub on_error: Option<NodeId>,
106}
107
108impl PartialEq for SymbolicMiddlewareRef {
109	fn eq(&self, other: &Self) -> bool {
110		self.name == other.name
111			&& self.kind == other.kind
112			&& self.stateless == other.stateless
113			&& self.needs_body == other.needs_body
114			&& self.on_error == other.on_error
115			&& canonical_json_eq(&self.args, &other.args)
116	}
117}
118
119impl Eq for SymbolicMiddlewareRef {}
120
121impl std::hash::Hash for SymbolicMiddlewareRef {
122	fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
123		self.name.hash(state);
124		self.kind.hash(state);
125		self.stateless.hash(state);
126		self.needs_body.hash(state);
127		self.on_error.hash(state);
128		hash_canonical_json(&self.args, state);
129	}
130}
131
132fn canonical_json_eq(a: &serde_json::Value, b: &serde_json::Value) -> bool {
133	use serde_json::Value;
134	match (a, b) {
135		(Value::Null, Value::Null) => true,
136		(Value::Bool(x), Value::Bool(y)) => x == y,
137		(Value::Number(x), Value::Number(y)) => x == y,
138		(Value::String(x), Value::String(y)) => x == y,
139		(Value::Array(xs), Value::Array(ys)) => {
140			xs.len() == ys.len() && xs.iter().zip(ys).all(|(x, y)| canonical_json_eq(x, y))
141		}
142		(Value::Object(xs), Value::Object(ys)) if xs.len() == ys.len() => {
143			xs.iter().all(|(k, v)| ys.get(k).is_some_and(|w| canonical_json_eq(v, w)))
144		}
145		_ => false,
146	}
147}
148
149/// Hash the canonical byte form of `v` (produced by the workspace
150/// `crate::canonical` serializer). Routing every hash-cons through the
151/// same canonicalizer keeps middleware/fetch arg dedup in lockstep
152/// with `FlowGraphMeta::version_hash`.
153fn hash_canonical_json<H: std::hash::Hasher>(v: &serde_json::Value, state: &mut H) {
154	let mut buf = String::new();
155	crate::canonical::write_into_lossy(&mut buf, v);
156	buf.hash(state);
157}
158
159#[cfg(test)]
160mod tests {
161	use std::collections::hash_map::DefaultHasher;
162	use std::future::Future;
163	use std::hash::{Hash, Hasher};
164	use std::net::SocketAddr;
165	use std::pin::Pin;
166	use std::time::Instant;
167
168	use parking_lot::Mutex;
169	use serde_json::json;
170	use tokio_util::sync::CancellationToken;
171
172	use super::*;
173	use crate::conn_context::{ConnId, Transport};
174	use crate::flow_log::{FlowLogEvent, FlowLogSink};
175
176	struct PassPeek;
177	#[async_trait]
178	impl L4PeekMiddleware for PassPeek {
179		async fn run(
180			&self,
181			_peek: &[u8],
182			_conn: &Arc<ConnContext>,
183			_ctx: &mut FlowCtx,
184		) -> Result<Decision, Error> {
185			Ok(Decision::Continue)
186		}
187	}
188
189	struct PassBytes;
190	#[async_trait]
191	impl L4BytesMiddleware for PassBytes {
192		async fn run(
193			&self,
194			_l4: &mut L4Conn,
195			_conn: &Arc<ConnContext>,
196			_ctx: &mut FlowCtx,
197		) -> Result<Decision, Error> {
198			Ok(Decision::Continue)
199		}
200	}
201
202	struct PassReq;
203	#[async_trait]
204	impl L7RequestMiddleware for PassReq {
205		async fn run(
206			&self,
207			_req: &mut Request,
208			_conn: &Arc<ConnContext>,
209			_ctx: &mut FlowCtx,
210		) -> Result<Decision, Error> {
211			Ok(Decision::Continue)
212		}
213	}
214
215	struct PassResp;
216	#[async_trait]
217	impl L7ResponseMiddleware for PassResp {
218		async fn run(
219			&self,
220			_resp: &mut Response,
221			_conn: &Arc<ConnContext>,
222			_ctx: &mut FlowCtx,
223		) -> Result<Decision, Error> {
224			Ok(Decision::Continue)
225		}
226	}
227
228	// Compile-time assertion helper: the type `F` must be `Send`. `async_trait`
229	// rewrites `async fn run(...)` to return `Pin<Box<dyn Future + Send>>`, so
230	// every `run` invocation's future must satisfy this bound — that is the
231	// load-bearing contract per spec/crates/engine.md § _Middleware_.
232	fn assert_send<F: Send>(_: &F) {}
233
234	struct NullSink;
235	impl FlowLogSink for NullSink {
236		fn emit(&self, _event: FlowLogEvent) {}
237	}
238
239	fn make_conn_context() -> Arc<ConnContext> {
240		let addr: SocketAddr = "127.0.0.1:0".parse().expect("parse addr");
241		Arc::new(ConnContext {
242			id: ConnId(0),
243			remote: addr,
244			local: addr,
245			transport: Transport::Tcp,
246			entered_at: Instant::now(),
247			tls: Mutex::new(None),
248			http_version: std::sync::OnceLock::new(),
249			user: Mutex::new(http::Extensions::new()),
250		})
251	}
252
253	// `async_trait` makes these traits dyn-compatible. `MiddlewareInst` stores
254	// each variant as `Arc<dyn Trait>` per `spec/crates/engine.md` § _Middleware_;
255	// constructing that exact shape from a concrete impl is the contract we
256	// guard.
257
258	#[test]
259	fn l4_peek_is_constructible_as_arc_dyn_send_sync() {
260		let m: Arc<dyn L4PeekMiddleware + Send + Sync> = Arc::new(PassPeek);
261		// The trait-object Arc coerces to the bare `Arc<dyn Trait>` shape used
262		// by `MiddlewareInst::L4Peek(Arc<dyn L4PeekMiddleware>)` in engine.
263		let _: Arc<dyn L4PeekMiddleware> = m;
264	}
265
266	#[test]
267	fn l4_bytes_is_constructible_as_arc_dyn_send_sync() {
268		let m: Arc<dyn L4BytesMiddleware + Send + Sync> = Arc::new(PassBytes);
269		let _: Arc<dyn L4BytesMiddleware> = m;
270	}
271
272	#[test]
273	fn l7_request_is_constructible_as_arc_dyn_send_sync() {
274		let m: Arc<dyn L7RequestMiddleware + Send + Sync> = Arc::new(PassReq);
275		let _: Arc<dyn L7RequestMiddleware> = m;
276	}
277
278	#[test]
279	fn l7_response_is_constructible_as_arc_dyn_send_sync() {
280		let m: Arc<dyn L7ResponseMiddleware + Send + Sync> = Arc::new(PassResp);
281		let _: Arc<dyn L7ResponseMiddleware> = m;
282	}
283
284	fn make_flow_ctx(conn_id: ConnId) -> FlowCtx {
285		FlowCtx {
286			span: tracing::Span::none(),
287			log: Arc::new(NullSink),
288			cancel: CancellationToken::new(),
289			accept_cancel: CancellationToken::new(),
290			verbosity: crate::flow_log::FlowLogVerbosity::Trajectory,
291			trajectory: crate::flow_log::TrajectoryBuilder::new(conn_id, crate::ir::NodeId::new(0), 0),
292		}
293	}
294
295	#[test]
296	fn l4_peek_run_returns_send_future() {
297		let m: Arc<dyn L4PeekMiddleware> = Arc::new(PassPeek);
298		let conn = make_conn_context();
299		let mut ctx = make_flow_ctx(conn.id);
300		let peek: &[u8] = &[];
301		// Exact-type coercion into `Pin<Box<dyn Future + Send>>` — the async_trait
302		// signature. Fails to compile if a future becomes `!Send`.
303		let fut: Pin<Box<dyn Future<Output = Result<Decision, Error>> + Send + '_>> =
304			m.run(peek, &conn, &mut ctx);
305		assert_send(&fut);
306		drop(fut);
307	}
308
309	#[test]
310	fn l7_request_run_returns_send_future() {
311		let m: Arc<dyn L7RequestMiddleware> = Arc::new(PassReq);
312		let conn = make_conn_context();
313		let mut ctx = make_flow_ctx(conn.id);
314		let mut req: Request =
315			http::Request::builder().uri("/").body(crate::body::Body::Empty).expect("build req");
316		let fut: Pin<Box<dyn Future<Output = Result<Decision, Error>> + Send + '_>> =
317			m.run(&mut req, &conn, &mut ctx);
318		assert_send(&fut);
319		drop(fut);
320	}
321
322	#[test]
323	fn l7_response_run_returns_send_future() {
324		let m: Arc<dyn L7ResponseMiddleware> = Arc::new(PassResp);
325		let conn = make_conn_context();
326		let mut ctx = make_flow_ctx(conn.id);
327		let mut resp: Response =
328			http::Response::builder().status(200).body(crate::body::Body::Empty).expect("build resp");
329		let fut: Pin<Box<dyn Future<Output = Result<Decision, Error>> + Send + '_>> =
330			m.run(&mut resp, &conn, &mut ctx);
331		assert_send(&fut);
332		drop(fut);
333	}
334
335	// `needs_body` no longer lives on the middleware trait — the
336	// `MiddlewareMetadata` side is the single source of truth (used by
337	// the compile pass to inject the right `collect_body_before`
338	// edges). Tests against the trait method are deleted along with it.
339
340	#[test]
341	fn middleware_kind_serde_round_trip_per_variant() {
342		for k in [
343			MiddlewareKind::L4Peek,
344			MiddlewareKind::L4Bytes,
345			MiddlewareKind::L7Request,
346			MiddlewareKind::L7Response,
347		] {
348			let encoded = serde_json::to_string(&k).expect("serialize");
349			let decoded: MiddlewareKind = serde_json::from_str(&encoded).expect("deserialize");
350			assert_eq!(decoded, k);
351		}
352	}
353
354	#[test]
355	fn decision_and_shortcircuit_construct_per_variant() {
356		let _ = Decision::Continue;
357		let _ = Decision::Short(ShortCircuit::Close(CloseReason::Graceful));
358		let _ = ShortCircuit::Close(CloseReason::PolicyDenied("over quota".into()));
359		let _ = ShortCircuit::Close(CloseReason::ProtocolError("bad frame".into()));
360	}
361
362	#[test]
363	fn close_reason_construct_per_variant() {
364		let _ = CloseReason::Graceful;
365		let _ = CloseReason::PolicyDenied(std::borrow::Cow::Borrowed("over quota"));
366		let _ = CloseReason::ProtocolError(std::borrow::Cow::Owned(String::from("bad frame")));
367		let _ = CloseReason::Cancelled;
368	}
369
370	fn hash_of<T: Hash>(v: &T) -> u64 {
371		let mut h = DefaultHasher::new();
372		v.hash(&mut h);
373		h.finish()
374	}
375
376	fn sym_ref(args: serde_json::Value) -> SymbolicMiddlewareRef {
377		SymbolicMiddlewareRef {
378			name: Arc::from("rate_limit"),
379			args,
380			kind: MiddlewareKind::L7Request,
381			stateless: true,
382			needs_body: false,
383			on_error: None,
384		}
385	}
386
387	#[test]
388	fn symbolic_ref_args_hash_is_object_key_order_insensitive() {
389		// Manually build both maps with opposite insertion orders to defeat
390		// serde_json::from_str's preserve-insertion-order backend.
391		let mut a = serde_json::Map::new();
392		a.insert("a".to_string(), json!(1));
393		a.insert("b".to_string(), json!(2));
394		let mut b = serde_json::Map::new();
395		b.insert("b".to_string(), json!(2));
396		b.insert("a".to_string(), json!(1));
397
398		let lhs = sym_ref(serde_json::Value::Object(a));
399		let rhs = sym_ref(serde_json::Value::Object(b));
400
401		assert_eq!(lhs, rhs);
402		assert_eq!(hash_of(&lhs), hash_of(&rhs));
403	}
404
405	#[test]
406	fn symbolic_ref_nested_object_key_order_is_ignored() {
407		let lhs = sym_ref(json!({ "outer": { "x": 1, "y": 2 } }));
408		// Build the inner map with swapped order by hand.
409		let mut inner = serde_json::Map::new();
410		inner.insert("y".to_string(), json!(2));
411		inner.insert("x".to_string(), json!(1));
412		let mut outer = serde_json::Map::new();
413		outer.insert("outer".to_string(), serde_json::Value::Object(inner));
414		let rhs = sym_ref(serde_json::Value::Object(outer));
415
416		assert_eq!(lhs, rhs);
417		assert_eq!(hash_of(&lhs), hash_of(&rhs));
418	}
419
420	#[test]
421	fn symbolic_ref_arrays_are_order_sensitive() {
422		let lhs = sym_ref(json!({ "xs": [1, 2] }));
423		let rhs = sym_ref(json!({ "xs": [2, 1] }));
424		assert_ne!(lhs, rhs);
425	}
426
427	#[test]
428	fn symbolic_ref_differs_on_name() {
429		let a = sym_ref(json!({}));
430		let mut b = sym_ref(json!({}));
431		b.name = Arc::from("other");
432		assert_ne!(a, b);
433	}
434
435	#[test]
436	fn symbolic_ref_differs_on_kind() {
437		let a = sym_ref(json!({}));
438		let mut b = sym_ref(json!({}));
439		b.kind = MiddlewareKind::L4Peek;
440		assert_ne!(a, b);
441	}
442
443	#[test]
444	fn symbolic_ref_differs_on_stateless() {
445		let a = sym_ref(json!({}));
446		let mut b = sym_ref(json!({}));
447		b.stateless = false;
448		assert_ne!(a, b);
449	}
450
451	#[test]
452	fn symbolic_ref_differs_on_needs_body() {
453		let a = sym_ref(json!({}));
454		let mut b = sym_ref(json!({}));
455		b.needs_body = true;
456		assert_ne!(a, b);
457	}
458
459	#[test]
460	fn symbolic_ref_differs_on_on_error() {
461		let a = sym_ref(json!({}));
462		let mut b = sym_ref(json!({}));
463		b.on_error = Some(NodeId::new(3));
464		assert_ne!(a, b);
465	}
466
467	#[test]
468	fn symbolic_ref_same_name_but_distinct_args_are_unequal() {
469		let a = sym_ref(json!({ "limit": 100 }));
470		let b = sym_ref(json!({ "limit": 200 }));
471		assert_ne!(a, b);
472	}
473
474	// Dry-run JSON wire-format contract: SymbolicMiddlewareRef participates
475	// in the compiled-form JSON per spec/flow-model.md § _The compiled form_. The
476	// whole struct uses derive(Serialize/Deserialize); all fields round-trip.
477	// PartialEq uses canonical-json equality on `args`, so key-order
478	// perturbation must still compare equal after a round-trip.
479
480	#[test]
481	fn symbolic_middleware_ref_round_trip_preserves_all_fields() {
482		let m = SymbolicMiddlewareRef {
483			name: Arc::from("rate_limit"),
484			args: json!({ "rate": 100 }),
485			kind: MiddlewareKind::L7Request,
486			stateless: false,
487			needs_body: false,
488			on_error: Some(NodeId::new(5)),
489		};
490		let encoded = serde_json::to_string(&m).expect("serialize");
491		let decoded: SymbolicMiddlewareRef = serde_json::from_str(&encoded).expect("deserialize");
492		assert_eq!(decoded.name, m.name);
493		assert_eq!(decoded.kind, m.kind);
494		assert_eq!(decoded.stateless, m.stateless);
495		assert_eq!(decoded.needs_body, m.needs_body);
496		assert_eq!(decoded.on_error, m.on_error);
497		assert_eq!(decoded, m);
498	}
499
500	#[test]
501	fn symbolic_middleware_ref_round_trip_args_are_canonical_key_order_insensitive() {
502		// Build an args value whose serialized form has a deliberate key order.
503		let mut obj = serde_json::Map::new();
504		obj.insert("b".to_string(), json!(1));
505		obj.insert("a".to_string(), json!(2));
506		let m = SymbolicMiddlewareRef {
507			name: Arc::from("mw"),
508			args: serde_json::Value::Object(obj),
509			kind: MiddlewareKind::L7Request,
510			stateless: true,
511			needs_body: false,
512			on_error: None,
513		};
514		let encoded = serde_json::to_string(&m).expect("serialize");
515		let decoded: SymbolicMiddlewareRef = serde_json::from_str(&encoded).expect("deserialize");
516		// PartialEq on SymbolicMiddlewareRef uses canonical-json equality on args,
517		// so any post-round-trip key reshuffling remains == to the original.
518		assert_eq!(decoded, m);
519	}
520}