Skip to main content

vane_core/
fetch.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4use tokio::io::{AsyncRead, AsyncWrite};
5use tokio::sync::oneshot;
6
7use crate::body::{Request, Response};
8use crate::conn_context::ConnContext;
9use crate::error::Error;
10use crate::flow_ctx::FlowCtx;
11use crate::l4::L4Conn;
12use crate::middleware::CloseReason;
13
14#[async_trait]
15pub trait L7Fetch: Send + Sync {
16	async fn fetch(
17		&self,
18		req: Request,
19		conn: &Arc<ConnContext>,
20		ctx: &mut FlowCtx<'_>,
21	) -> Result<L7FetchOutput, Error>;
22}
23
24#[async_trait]
25pub trait L4Fetch: Send + Sync {
26	async fn fetch(
27		&self,
28		l4: L4Conn,
29		conn: &Arc<ConnContext>,
30		ctx: &mut FlowCtx<'_>,
31	) -> Result<Tunnel, Error>;
32}
33
34pub enum L7FetchOutput {
35	Response(Response),
36	Tunnel(Tunnel),
37}
38
39pub struct Tunnel {
40	pub client: Box<dyn AsyncReadWrite + Send>,
41	pub upstream: Box<dyn AsyncReadWrite + Send>,
42	pub close_reason_tx: Option<oneshot::Sender<CloseReason>>,
43}
44
45// `Unpin` is in the trait bound so `tokio::io::copy_bidirectional`
46// (used by `Terminator::ByteTunnel` in the engine) can drive the streams
47// directly. `TcpStream` / `UnixStream` / `tokio::io::DuplexStream` /
48// `tokio_rustls::TlsStream<T: Unpin>` all satisfy it.
49pub trait AsyncReadWrite: AsyncRead + AsyncWrite + Unpin {}
50impl<T: AsyncRead + AsyncWrite + Unpin + ?Sized> AsyncReadWrite for T {}
51
52#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug, serde::Serialize, serde::Deserialize)]
53pub enum FetchKind {
54	HttpProxy,
55	HttpSynthesize,
56	WebSocketUpgrade,
57	L4Forward,
58}
59
60#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug, serde::Serialize, serde::Deserialize)]
61pub enum FetchPhase {
62	L4,
63	L7,
64}
65
66#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug, serde::Serialize, serde::Deserialize)]
67pub struct FetchOutputModes {
68	pub response: bool,
69	pub tunnel: bool,
70}
71
72#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
73pub struct SymbolicFetchRef {
74	pub kind: FetchKind,
75	pub args: serde_json::Value,
76}
77
78#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug, serde::Serialize, serde::Deserialize)]
79pub enum Terminator {
80	WriteHttpResponse,
81	ByteTunnel,
82	Close,
83}
84
85#[cfg(test)]
86mod tests {
87	use std::future::Future;
88	use std::io;
89	use std::net::SocketAddr;
90	use std::pin::Pin;
91	use std::task::{Context, Poll};
92	use std::time::Instant;
93
94	use parking_lot::Mutex;
95	use serde_json::json;
96	use tokio::io::ReadBuf;
97	use tokio_util::sync::CancellationToken;
98
99	use super::*;
100	use crate::body::{Body, Request, Response};
101	use crate::conn_context::{ConnId, Transport};
102	use crate::flow_log::{FlowLogEvent, FlowLogSink};
103
104	// A runtime-free `AsyncRead + AsyncWrite` witness. `UnixStream::pair` and
105	// `tokio::io::duplex` both require a running reactor; core tests
106	// deliberately do not spin one up (16-crate-layout.md: no async-runtime
107	// dep).
108	struct NoopStream;
109
110	impl AsyncRead for NoopStream {
111		fn poll_read(
112			self: Pin<&mut Self>,
113			_cx: &mut Context<'_>,
114			_buf: &mut ReadBuf<'_>,
115		) -> Poll<io::Result<()>> {
116			Poll::Ready(Ok(()))
117		}
118	}
119
120	impl AsyncWrite for NoopStream {
121		fn poll_write(
122			self: Pin<&mut Self>,
123			_cx: &mut Context<'_>,
124			buf: &[u8],
125		) -> Poll<io::Result<usize>> {
126			Poll::Ready(Ok(buf.len()))
127		}
128
129		fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
130			Poll::Ready(Ok(()))
131		}
132
133		fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
134			Poll::Ready(Ok(()))
135		}
136	}
137
138	struct NullSink;
139	impl FlowLogSink for NullSink {
140		fn emit(&self, _event: FlowLogEvent) {}
141	}
142
143	struct SynthOk;
144	#[async_trait]
145	impl L7Fetch for SynthOk {
146		async fn fetch(
147			&self,
148			_req: Request,
149			_conn: &Arc<ConnContext>,
150			_ctx: &mut FlowCtx<'_>,
151		) -> Result<L7FetchOutput, Error> {
152			let resp: Response = http::Response::builder().status(200).body(Body::Empty).expect("build");
153			Ok(L7FetchOutput::Response(resp))
154		}
155	}
156
157	struct L4Nop;
158	#[async_trait]
159	impl L4Fetch for L4Nop {
160		async fn fetch(
161			&self,
162			_l4: L4Conn,
163			_conn: &Arc<ConnContext>,
164			_ctx: &mut FlowCtx<'_>,
165		) -> Result<Tunnel, Error> {
166			let (tx, _rx) = oneshot::channel::<crate::middleware::CloseReason>();
167			Ok(Tunnel {
168				client: Box::new(NoopStream) as Box<dyn AsyncReadWrite + Send>,
169				upstream: Box::new(NoopStream) as Box<dyn AsyncReadWrite + Send>,
170				close_reason_tx: Some(tx),
171			})
172		}
173	}
174
175	fn assert_send<F: Send>(_: &F) {}
176
177	fn make_conn_context() -> Arc<ConnContext> {
178		let addr: SocketAddr = "127.0.0.1:0".parse().expect("parse addr");
179		Arc::new(ConnContext {
180			id: ConnId(0),
181			remote: addr,
182			local: addr,
183			transport: Transport::Tcp,
184			entered_at: Instant::now(),
185			tls: Mutex::new(None),
186			http_version: std::sync::OnceLock::new(),
187			user: Mutex::new(http::Extensions::new()),
188		})
189	}
190
191	#[test]
192	fn async_read_write_blanket_accepts_async_io_type() {
193		let _: Box<dyn AsyncReadWrite + Send> = Box::new(NoopStream);
194	}
195
196	#[test]
197	fn l7_fetch_output_response_variant_constructs() {
198		let resp: Response =
199			http::Response::builder().status(200).body(Body::Empty).expect("build response");
200		match L7FetchOutput::Response(resp) {
201			L7FetchOutput::Response(_) => {}
202			L7FetchOutput::Tunnel(_) => panic!("unexpected tunnel variant"),
203		}
204	}
205
206	#[test]
207	fn tunnel_builds_from_paired_async_io_streams() {
208		let (tx, _rx) = oneshot::channel::<crate::middleware::CloseReason>();
209		let tunnel = Tunnel {
210			client: Box::new(NoopStream) as Box<dyn AsyncReadWrite + Send>,
211			upstream: Box::new(NoopStream) as Box<dyn AsyncReadWrite + Send>,
212			close_reason_tx: Some(tx),
213		};
214		let _ = L7FetchOutput::Tunnel(tunnel);
215	}
216
217	// `async_trait` makes `L7Fetch` and `L4Fetch` dyn-compatible. `FetchInst`
218	// stores them as `Arc<dyn _>` per 05-terminator.md § _Trait surface_;
219	// constructing that exact shape from a concrete impl is the contract we
220	// guard here.
221
222	#[test]
223	fn l7_fetch_is_constructible_as_arc_dyn_send_sync() {
224		let f: Arc<dyn L7Fetch + Send + Sync> = Arc::new(SynthOk);
225		let _: Arc<dyn L7Fetch> = f;
226	}
227
228	#[test]
229	fn l4_fetch_is_constructible_as_arc_dyn_send_sync() {
230		let f: Arc<dyn L4Fetch + Send + Sync> = Arc::new(L4Nop);
231		let _: Arc<dyn L4Fetch> = f;
232	}
233
234	#[test]
235	fn l7_fetch_fetch_returns_send_future() {
236		let f: Arc<dyn L7Fetch> = Arc::new(SynthOk);
237		let conn = make_conn_context();
238		let mut sink = NullSink;
239		let mut span = tracing::Span::none();
240		let cancel = CancellationToken::new();
241		let mut ctx = FlowCtx {
242			span: &mut span,
243			log: &mut sink as &mut dyn FlowLogSink,
244			cancel: &cancel,
245			verbosity: crate::flow_log::FlowLogVerbosity::Trajectory,
246			trajectory: crate::flow_log::TrajectoryBuilder::new(conn.id, crate::ir::NodeId::new(0), 0),
247		};
248		let req: Request = http::Request::builder().uri("/").body(Body::Empty).expect("build req");
249		// Exact-type coercion — async_trait rewrites `fetch` to return
250		// `Pin<Box<dyn Future + Send>>`; this binding fails to compile if the
251		// future ever loses `Send`.
252		let fut: Pin<Box<dyn Future<Output = Result<L7FetchOutput, Error>> + Send + '_>> =
253			f.fetch(req, &conn, &mut ctx);
254		assert_send(&fut);
255		drop(fut);
256	}
257
258	#[test]
259	fn fetch_kind_serde_round_trip_per_variant() {
260		for k in [
261			FetchKind::HttpProxy,
262			FetchKind::HttpSynthesize,
263			FetchKind::WebSocketUpgrade,
264			FetchKind::L4Forward,
265		] {
266			let encoded = serde_json::to_string(&k).expect("serialize");
267			let decoded: FetchKind = serde_json::from_str(&encoded).expect("deserialize");
268			assert_eq!(decoded, k);
269		}
270	}
271
272	#[test]
273	fn fetch_phase_serde_round_trip_per_variant() {
274		for p in [FetchPhase::L4, FetchPhase::L7] {
275			let encoded = serde_json::to_string(&p).expect("serialize");
276			let decoded: FetchPhase = serde_json::from_str(&encoded).expect("deserialize");
277			assert_eq!(decoded, p);
278		}
279	}
280
281	#[test]
282	fn terminator_serde_round_trip_per_variant() {
283		for t in [Terminator::WriteHttpResponse, Terminator::ByteTunnel] {
284			let encoded = serde_json::to_string(&t).expect("serialize");
285			let decoded: Terminator = serde_json::from_str(&encoded).expect("deserialize");
286			assert_eq!(decoded, t);
287		}
288	}
289
290	#[test]
291	fn fetch_output_modes_serde_round_trip_http_shapes() {
292		// HttpProxy / HttpSynthesize: response-only.
293		let http_only = FetchOutputModes { response: true, tunnel: false };
294		// WebSocketUpgrade: both outputs, per the bi-outcome spec.
295		let ws = FetchOutputModes { response: true, tunnel: true };
296		// L4Forward: tunnel-only.
297		let l4 = FetchOutputModes { response: false, tunnel: true };
298		for modes in [http_only, ws, l4] {
299			let encoded = serde_json::to_string(&modes).expect("serialize");
300			let decoded: FetchOutputModes = serde_json::from_str(&encoded).expect("deserialize");
301			assert_eq!(decoded, modes);
302		}
303	}
304
305	#[test]
306	fn symbolic_fetch_ref_clone_preserves_fields() {
307		let r = SymbolicFetchRef {
308			kind: FetchKind::HttpProxy,
309			args: json!({ "upstream": "127.0.0.1:8080" }),
310		};
311		let cloned = r.clone();
312		assert_eq!(cloned.kind, r.kind);
313		assert_eq!(cloned.args, r.args);
314		// Debug must be derivable for diagnostics.
315		let _ = format!("{r:?}");
316	}
317
318	#[test]
319	fn symbolic_fetch_ref_accepts_each_kind() {
320		for kind in [
321			FetchKind::HttpProxy,
322			FetchKind::HttpSynthesize,
323			FetchKind::WebSocketUpgrade,
324			FetchKind::L4Forward,
325		] {
326			let _ = SymbolicFetchRef { kind, args: serde_json::Value::Null };
327		}
328	}
329
330	// Dry-run JSON wire-format contract: SymbolicFetchRef participates in
331	// the compiled-form JSON per 02-flow.md § _The compiled form_. Both the
332	// `kind` tag and the opaque `args` payload must round-trip.
333	#[test]
334	fn symbolic_fetch_ref_round_trip_preserves_kind_and_args() {
335		let r = SymbolicFetchRef {
336			kind: FetchKind::WebSocketUpgrade,
337			args: json!({ "upstream": "127.0.0.1:9000" }),
338		};
339		let encoded = serde_json::to_string(&r).expect("serialize");
340		let decoded: SymbolicFetchRef = serde_json::from_str(&encoded).expect("deserialize");
341		assert_eq!(decoded.kind, r.kind);
342		assert_eq!(decoded.args, r.args);
343	}
344}