1use std::sync::Arc;
2
3use async_trait::async_trait;
4use tokio::io::{AsyncRead, AsyncWrite};
5use tokio::sync::oneshot;
6
7use crate::body::{Request, Response};
8use crate::conn_context::ConnContext;
9use crate::error::Error;
10use crate::flow_ctx::FlowCtx;
11use crate::l4::L4Conn;
12use crate::middleware::CloseReason;
13
14#[async_trait]
15pub trait L7Fetch: Send + Sync {
16 async fn fetch(
17 &self,
18 req: Request,
19 conn: &Arc<ConnContext>,
20 ctx: &mut FlowCtx<'_>,
21 ) -> Result<L7FetchOutput, Error>;
22}
23
24#[async_trait]
25pub trait L4Fetch: Send + Sync {
26 async fn fetch(
27 &self,
28 l4: L4Conn,
29 conn: &Arc<ConnContext>,
30 ctx: &mut FlowCtx<'_>,
31 ) -> Result<Tunnel, Error>;
32}
33
34pub enum L7FetchOutput {
35 Response(Response),
36 Tunnel(Tunnel),
37}
38
39pub struct Tunnel {
40 pub client: Box<dyn AsyncReadWrite + Send>,
41 pub upstream: Box<dyn AsyncReadWrite + Send>,
42 pub close_reason_tx: Option<oneshot::Sender<CloseReason>>,
43}
44
45pub trait AsyncReadWrite: AsyncRead + AsyncWrite + Unpin {}
50impl<T: AsyncRead + AsyncWrite + Unpin + ?Sized> AsyncReadWrite for T {}
51
52#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug, serde::Serialize, serde::Deserialize)]
53pub enum FetchKind {
54 HttpProxy,
55 HttpSynthesize,
56 WebSocketUpgrade,
57 L4Forward,
58}
59
60#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug, serde::Serialize, serde::Deserialize)]
61pub enum FetchPhase {
62 L4,
63 L7,
64}
65
66#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug, serde::Serialize, serde::Deserialize)]
67pub struct FetchOutputModes {
68 pub response: bool,
69 pub tunnel: bool,
70}
71
72#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
73pub struct SymbolicFetchRef {
74 pub kind: FetchKind,
75 pub args: serde_json::Value,
76}
77
78#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug, serde::Serialize, serde::Deserialize)]
79pub enum Terminator {
80 WriteHttpResponse,
81 ByteTunnel,
82 Close,
83}
84
85#[cfg(test)]
86mod tests {
87 use std::future::Future;
88 use std::io;
89 use std::net::SocketAddr;
90 use std::pin::Pin;
91 use std::task::{Context, Poll};
92 use std::time::Instant;
93
94 use parking_lot::Mutex;
95 use serde_json::json;
96 use tokio::io::ReadBuf;
97 use tokio_util::sync::CancellationToken;
98
99 use super::*;
100 use crate::body::{Body, Request, Response};
101 use crate::conn_context::{ConnId, Transport};
102 use crate::flow_log::{FlowLogEvent, FlowLogSink};
103
104 struct NoopStream;
109
110 impl AsyncRead for NoopStream {
111 fn poll_read(
112 self: Pin<&mut Self>,
113 _cx: &mut Context<'_>,
114 _buf: &mut ReadBuf<'_>,
115 ) -> Poll<io::Result<()>> {
116 Poll::Ready(Ok(()))
117 }
118 }
119
120 impl AsyncWrite for NoopStream {
121 fn poll_write(
122 self: Pin<&mut Self>,
123 _cx: &mut Context<'_>,
124 buf: &[u8],
125 ) -> Poll<io::Result<usize>> {
126 Poll::Ready(Ok(buf.len()))
127 }
128
129 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
130 Poll::Ready(Ok(()))
131 }
132
133 fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
134 Poll::Ready(Ok(()))
135 }
136 }
137
138 struct NullSink;
139 impl FlowLogSink for NullSink {
140 fn emit(&self, _event: FlowLogEvent) {}
141 }
142
143 struct SynthOk;
144 #[async_trait]
145 impl L7Fetch for SynthOk {
146 async fn fetch(
147 &self,
148 _req: Request,
149 _conn: &Arc<ConnContext>,
150 _ctx: &mut FlowCtx<'_>,
151 ) -> Result<L7FetchOutput, Error> {
152 let resp: Response = http::Response::builder().status(200).body(Body::Empty).expect("build");
153 Ok(L7FetchOutput::Response(resp))
154 }
155 }
156
157 struct L4Nop;
158 #[async_trait]
159 impl L4Fetch for L4Nop {
160 async fn fetch(
161 &self,
162 _l4: L4Conn,
163 _conn: &Arc<ConnContext>,
164 _ctx: &mut FlowCtx<'_>,
165 ) -> Result<Tunnel, Error> {
166 let (tx, _rx) = oneshot::channel::<crate::middleware::CloseReason>();
167 Ok(Tunnel {
168 client: Box::new(NoopStream) as Box<dyn AsyncReadWrite + Send>,
169 upstream: Box::new(NoopStream) as Box<dyn AsyncReadWrite + Send>,
170 close_reason_tx: Some(tx),
171 })
172 }
173 }
174
175 fn assert_send<F: Send>(_: &F) {}
176
177 fn make_conn_context() -> Arc<ConnContext> {
178 let addr: SocketAddr = "127.0.0.1:0".parse().expect("parse addr");
179 Arc::new(ConnContext {
180 id: ConnId(0),
181 remote: addr,
182 local: addr,
183 transport: Transport::Tcp,
184 entered_at: Instant::now(),
185 tls: Mutex::new(None),
186 http_version: std::sync::OnceLock::new(),
187 user: Mutex::new(http::Extensions::new()),
188 })
189 }
190
191 #[test]
192 fn async_read_write_blanket_accepts_async_io_type() {
193 let _: Box<dyn AsyncReadWrite + Send> = Box::new(NoopStream);
194 }
195
196 #[test]
197 fn l7_fetch_output_response_variant_constructs() {
198 let resp: Response =
199 http::Response::builder().status(200).body(Body::Empty).expect("build response");
200 match L7FetchOutput::Response(resp) {
201 L7FetchOutput::Response(_) => {}
202 L7FetchOutput::Tunnel(_) => panic!("unexpected tunnel variant"),
203 }
204 }
205
206 #[test]
207 fn tunnel_builds_from_paired_async_io_streams() {
208 let (tx, _rx) = oneshot::channel::<crate::middleware::CloseReason>();
209 let tunnel = Tunnel {
210 client: Box::new(NoopStream) as Box<dyn AsyncReadWrite + Send>,
211 upstream: Box::new(NoopStream) as Box<dyn AsyncReadWrite + Send>,
212 close_reason_tx: Some(tx),
213 };
214 let _ = L7FetchOutput::Tunnel(tunnel);
215 }
216
217 #[test]
223 fn l7_fetch_is_constructible_as_arc_dyn_send_sync() {
224 let f: Arc<dyn L7Fetch + Send + Sync> = Arc::new(SynthOk);
225 let _: Arc<dyn L7Fetch> = f;
226 }
227
228 #[test]
229 fn l4_fetch_is_constructible_as_arc_dyn_send_sync() {
230 let f: Arc<dyn L4Fetch + Send + Sync> = Arc::new(L4Nop);
231 let _: Arc<dyn L4Fetch> = f;
232 }
233
234 #[test]
235 fn l7_fetch_fetch_returns_send_future() {
236 let f: Arc<dyn L7Fetch> = Arc::new(SynthOk);
237 let conn = make_conn_context();
238 let mut sink = NullSink;
239 let mut span = tracing::Span::none();
240 let cancel = CancellationToken::new();
241 let mut ctx = FlowCtx {
242 span: &mut span,
243 log: &mut sink as &mut dyn FlowLogSink,
244 cancel: &cancel,
245 verbosity: crate::flow_log::FlowLogVerbosity::Trajectory,
246 trajectory: crate::flow_log::TrajectoryBuilder::new(conn.id, crate::ir::NodeId::new(0), 0),
247 };
248 let req: Request = http::Request::builder().uri("/").body(Body::Empty).expect("build req");
249 let fut: Pin<Box<dyn Future<Output = Result<L7FetchOutput, Error>> + Send + '_>> =
253 f.fetch(req, &conn, &mut ctx);
254 assert_send(&fut);
255 drop(fut);
256 }
257
258 #[test]
259 fn fetch_kind_serde_round_trip_per_variant() {
260 for k in [
261 FetchKind::HttpProxy,
262 FetchKind::HttpSynthesize,
263 FetchKind::WebSocketUpgrade,
264 FetchKind::L4Forward,
265 ] {
266 let encoded = serde_json::to_string(&k).expect("serialize");
267 let decoded: FetchKind = serde_json::from_str(&encoded).expect("deserialize");
268 assert_eq!(decoded, k);
269 }
270 }
271
272 #[test]
273 fn fetch_phase_serde_round_trip_per_variant() {
274 for p in [FetchPhase::L4, FetchPhase::L7] {
275 let encoded = serde_json::to_string(&p).expect("serialize");
276 let decoded: FetchPhase = serde_json::from_str(&encoded).expect("deserialize");
277 assert_eq!(decoded, p);
278 }
279 }
280
281 #[test]
282 fn terminator_serde_round_trip_per_variant() {
283 for t in [Terminator::WriteHttpResponse, Terminator::ByteTunnel] {
284 let encoded = serde_json::to_string(&t).expect("serialize");
285 let decoded: Terminator = serde_json::from_str(&encoded).expect("deserialize");
286 assert_eq!(decoded, t);
287 }
288 }
289
290 #[test]
291 fn fetch_output_modes_serde_round_trip_http_shapes() {
292 let http_only = FetchOutputModes { response: true, tunnel: false };
294 let ws = FetchOutputModes { response: true, tunnel: true };
296 let l4 = FetchOutputModes { response: false, tunnel: true };
298 for modes in [http_only, ws, l4] {
299 let encoded = serde_json::to_string(&modes).expect("serialize");
300 let decoded: FetchOutputModes = serde_json::from_str(&encoded).expect("deserialize");
301 assert_eq!(decoded, modes);
302 }
303 }
304
305 #[test]
306 fn symbolic_fetch_ref_clone_preserves_fields() {
307 let r = SymbolicFetchRef {
308 kind: FetchKind::HttpProxy,
309 args: json!({ "upstream": "127.0.0.1:8080" }),
310 };
311 let cloned = r.clone();
312 assert_eq!(cloned.kind, r.kind);
313 assert_eq!(cloned.args, r.args);
314 let _ = format!("{r:?}");
316 }
317
318 #[test]
319 fn symbolic_fetch_ref_accepts_each_kind() {
320 for kind in [
321 FetchKind::HttpProxy,
322 FetchKind::HttpSynthesize,
323 FetchKind::WebSocketUpgrade,
324 FetchKind::L4Forward,
325 ] {
326 let _ = SymbolicFetchRef { kind, args: serde_json::Value::Null };
327 }
328 }
329
330 #[test]
334 fn symbolic_fetch_ref_round_trip_preserves_kind_and_args() {
335 let r = SymbolicFetchRef {
336 kind: FetchKind::WebSocketUpgrade,
337 args: json!({ "upstream": "127.0.0.1:9000" }),
338 };
339 let encoded = serde_json::to_string(&r).expect("serialize");
340 let decoded: SymbolicFetchRef = serde_json::from_str(&encoded).expect("deserialize");
341 assert_eq!(decoded.kind, r.kind);
342 assert_eq!(decoded.args, r.args);
343 }
344}