1use std::sync::Arc;
2
3use async_trait::async_trait;
4use tokio::io::{AsyncRead, AsyncWrite};
5use tokio::sync::oneshot;
6
7use crate::body::{Request, Response};
8use crate::conn_context::ConnContext;
9use crate::error::Error;
10use crate::flow_ctx::FlowCtx;
11use crate::l4::L4Conn;
12use crate::middleware::CloseReason;
13
14#[async_trait]
15pub trait L7Fetch: Send + Sync {
16 async fn fetch(
17 &self,
18 req: Request,
19 conn: &Arc<ConnContext>,
20 ctx: &mut FlowCtx,
21 ) -> Result<L7FetchOutput, Error>;
22}
23
24#[async_trait]
25pub trait L4Fetch: Send + Sync {
26 async fn fetch(
27 &self,
28 l4: L4Conn,
29 conn: &Arc<ConnContext>,
30 ctx: &mut FlowCtx,
31 ) -> Result<Tunnel, Error>;
32}
33
34pub enum L7FetchOutput {
35 Response(Response),
36 Tunnel(Tunnel),
37}
38
39pub struct Tunnel {
40 pub client: Box<dyn AsyncReadWrite + Send>,
41 pub upstream: Box<dyn AsyncReadWrite + Send>,
42 pub close_reason_tx: Option<oneshot::Sender<CloseReason>>,
43}
44
45pub trait AsyncReadWrite: AsyncRead + AsyncWrite + Unpin {}
50impl<T: AsyncRead + AsyncWrite + Unpin + ?Sized> AsyncReadWrite for T {}
51
52#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug, serde::Serialize, serde::Deserialize)]
53pub enum FetchKind {
54 HttpProxy,
55 HttpSynthesize,
56 WebSocketUpgrade,
57 L4Forward,
58}
59
60#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug, serde::Serialize, serde::Deserialize)]
61pub enum FetchPhase {
62 L4,
63 L7,
64}
65
66#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug, serde::Serialize, serde::Deserialize)]
67pub struct FetchOutputModes {
68 pub response: bool,
69 pub tunnel: bool,
70}
71
72#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
73pub struct SymbolicFetchRef {
74 pub kind: FetchKind,
75 pub args: serde_json::Value,
76}
77
78#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug, serde::Serialize, serde::Deserialize)]
79pub enum Terminator {
80 WriteHttpResponse,
81 ByteTunnel,
82 Close,
83}
84
85#[cfg(test)]
86mod tests {
87 use std::future::Future;
88 use std::io;
89 use std::net::SocketAddr;
90 use std::pin::Pin;
91 use std::task::{Context, Poll};
92 use std::time::Instant;
93
94 use parking_lot::Mutex;
95 use serde_json::json;
96 use tokio::io::ReadBuf;
97 use tokio_util::sync::CancellationToken;
98
99 use super::*;
100 use crate::body::{Body, Request, Response};
101 use crate::conn_context::{ConnId, Transport};
102 use crate::flow_log::{FlowLogEvent, FlowLogSink};
103
104 struct NoopStream;
109
110 impl AsyncRead for NoopStream {
111 fn poll_read(
112 self: Pin<&mut Self>,
113 _cx: &mut Context<'_>,
114 _buf: &mut ReadBuf<'_>,
115 ) -> Poll<io::Result<()>> {
116 Poll::Ready(Ok(()))
117 }
118 }
119
120 impl AsyncWrite for NoopStream {
121 fn poll_write(
122 self: Pin<&mut Self>,
123 _cx: &mut Context<'_>,
124 buf: &[u8],
125 ) -> Poll<io::Result<usize>> {
126 Poll::Ready(Ok(buf.len()))
127 }
128
129 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
130 Poll::Ready(Ok(()))
131 }
132
133 fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
134 Poll::Ready(Ok(()))
135 }
136 }
137
138 struct NullSink;
139 impl FlowLogSink for NullSink {
140 fn emit(&self, _event: FlowLogEvent) {}
141 }
142
143 struct SynthOk;
144 #[async_trait]
145 impl L7Fetch for SynthOk {
146 async fn fetch(
147 &self,
148 _req: Request,
149 _conn: &Arc<ConnContext>,
150 _ctx: &mut FlowCtx,
151 ) -> Result<L7FetchOutput, Error> {
152 let resp: Response = http::Response::builder().status(200).body(Body::Empty).expect("build");
153 Ok(L7FetchOutput::Response(resp))
154 }
155 }
156
157 struct L4Nop;
158 #[async_trait]
159 impl L4Fetch for L4Nop {
160 async fn fetch(
161 &self,
162 _l4: L4Conn,
163 _conn: &Arc<ConnContext>,
164 _ctx: &mut FlowCtx,
165 ) -> Result<Tunnel, Error> {
166 let (tx, _rx) = oneshot::channel::<crate::middleware::CloseReason>();
167 Ok(Tunnel {
168 client: Box::new(NoopStream) as Box<dyn AsyncReadWrite + Send>,
169 upstream: Box::new(NoopStream) as Box<dyn AsyncReadWrite + Send>,
170 close_reason_tx: Some(tx),
171 })
172 }
173 }
174
175 fn assert_send<F: Send>(_: &F) {}
176
177 fn make_conn_context() -> Arc<ConnContext> {
178 let addr: SocketAddr = "127.0.0.1:0".parse().expect("parse addr");
179 Arc::new(ConnContext {
180 id: ConnId(0),
181 remote: addr,
182 local: addr,
183 transport: Transport::Tcp,
184 entered_at: Instant::now(),
185 tls: Mutex::new(None),
186 http_version: std::sync::OnceLock::new(),
187 user: Mutex::new(http::Extensions::new()),
188 })
189 }
190
191 #[test]
192 fn async_read_write_blanket_accepts_async_io_type() {
193 let _: Box<dyn AsyncReadWrite + Send> = Box::new(NoopStream);
194 }
195
196 #[test]
197 fn l7_fetch_output_response_variant_constructs() {
198 let resp: Response =
199 http::Response::builder().status(200).body(Body::Empty).expect("build response");
200 match L7FetchOutput::Response(resp) {
201 L7FetchOutput::Response(_) => {}
202 L7FetchOutput::Tunnel(_) => panic!("unexpected tunnel variant"),
203 }
204 }
205
206 #[test]
207 fn tunnel_builds_from_paired_async_io_streams() {
208 let (tx, _rx) = oneshot::channel::<crate::middleware::CloseReason>();
209 let tunnel = Tunnel {
210 client: Box::new(NoopStream) as Box<dyn AsyncReadWrite + Send>,
211 upstream: Box::new(NoopStream) as Box<dyn AsyncReadWrite + Send>,
212 close_reason_tx: Some(tx),
213 };
214 let _ = L7FetchOutput::Tunnel(tunnel);
215 }
216
217 #[test]
223 fn l7_fetch_is_constructible_as_arc_dyn_send_sync() {
224 let f: Arc<dyn L7Fetch + Send + Sync> = Arc::new(SynthOk);
225 let _: Arc<dyn L7Fetch> = f;
226 }
227
228 #[test]
229 fn l4_fetch_is_constructible_as_arc_dyn_send_sync() {
230 let f: Arc<dyn L4Fetch + Send + Sync> = Arc::new(L4Nop);
231 let _: Arc<dyn L4Fetch> = f;
232 }
233
234 #[test]
235 fn l7_fetch_fetch_returns_send_future() {
236 let f: Arc<dyn L7Fetch> = Arc::new(SynthOk);
237 let conn = make_conn_context();
238 let mut ctx = FlowCtx {
239 span: tracing::Span::none(),
240 log: Arc::new(NullSink),
241 cancel: CancellationToken::new(),
242 verbosity: crate::flow_log::FlowLogVerbosity::Trajectory,
243 trajectory: crate::flow_log::TrajectoryBuilder::new(conn.id, crate::ir::NodeId::new(0), 0),
244 };
245 let req: Request = http::Request::builder().uri("/").body(Body::Empty).expect("build req");
246 let fut: Pin<Box<dyn Future<Output = Result<L7FetchOutput, Error>> + Send + '_>> =
250 f.fetch(req, &conn, &mut ctx);
251 assert_send(&fut);
252 drop(fut);
253 }
254
255 #[test]
256 fn fetch_kind_serde_round_trip_per_variant() {
257 for k in [
258 FetchKind::HttpProxy,
259 FetchKind::HttpSynthesize,
260 FetchKind::WebSocketUpgrade,
261 FetchKind::L4Forward,
262 ] {
263 let encoded = serde_json::to_string(&k).expect("serialize");
264 let decoded: FetchKind = serde_json::from_str(&encoded).expect("deserialize");
265 assert_eq!(decoded, k);
266 }
267 }
268
269 #[test]
270 fn fetch_phase_serde_round_trip_per_variant() {
271 for p in [FetchPhase::L4, FetchPhase::L7] {
272 let encoded = serde_json::to_string(&p).expect("serialize");
273 let decoded: FetchPhase = serde_json::from_str(&encoded).expect("deserialize");
274 assert_eq!(decoded, p);
275 }
276 }
277
278 #[test]
279 fn terminator_serde_round_trip_per_variant() {
280 for t in [Terminator::WriteHttpResponse, Terminator::ByteTunnel] {
281 let encoded = serde_json::to_string(&t).expect("serialize");
282 let decoded: Terminator = serde_json::from_str(&encoded).expect("deserialize");
283 assert_eq!(decoded, t);
284 }
285 }
286
287 #[test]
288 fn fetch_output_modes_serde_round_trip_http_shapes() {
289 let http_only = FetchOutputModes { response: true, tunnel: false };
291 let ws = FetchOutputModes { response: true, tunnel: true };
293 let l4 = FetchOutputModes { response: false, tunnel: true };
295 for modes in [http_only, ws, l4] {
296 let encoded = serde_json::to_string(&modes).expect("serialize");
297 let decoded: FetchOutputModes = serde_json::from_str(&encoded).expect("deserialize");
298 assert_eq!(decoded, modes);
299 }
300 }
301
302 #[test]
303 fn symbolic_fetch_ref_clone_preserves_fields() {
304 let r = SymbolicFetchRef {
305 kind: FetchKind::HttpProxy,
306 args: json!({ "upstream": "127.0.0.1:8080" }),
307 };
308 let cloned = r.clone();
309 assert_eq!(cloned.kind, r.kind);
310 assert_eq!(cloned.args, r.args);
311 let _ = format!("{r:?}");
313 }
314
315 #[test]
316 fn symbolic_fetch_ref_accepts_each_kind() {
317 for kind in [
318 FetchKind::HttpProxy,
319 FetchKind::HttpSynthesize,
320 FetchKind::WebSocketUpgrade,
321 FetchKind::L4Forward,
322 ] {
323 let _ = SymbolicFetchRef { kind, args: serde_json::Value::Null };
324 }
325 }
326
327 #[test]
331 fn symbolic_fetch_ref_round_trip_preserves_kind_and_args() {
332 let r = SymbolicFetchRef {
333 kind: FetchKind::WebSocketUpgrade,
334 args: json!({ "upstream": "127.0.0.1:9000" }),
335 };
336 let encoded = serde_json::to_string(&r).expect("serialize");
337 let decoded: SymbolicFetchRef = serde_json::from_str(&encoded).expect("deserialize");
338 assert_eq!(decoded.kind, r.kind);
339 assert_eq!(decoded.args, r.args);
340 }
341}