1use std::future::Future;
2use std::pin::Pin;
3use std::sync::Arc;
4
5use async_trait::async_trait;
6use tokio::io::{AsyncRead, AsyncWrite};
7use tokio::sync::oneshot;
8use tokio_util::sync::CancellationToken;
9
10use crate::body::{Request, Response};
11use crate::conn_context::ConnContext;
12use crate::error::Error;
13use crate::flow_ctx::FlowCtx;
14use crate::l4::L4Conn;
15use crate::middleware::CloseReason;
16
17#[async_trait]
18pub trait L7Fetch: Send + Sync {
19 async fn fetch(
20 &self,
21 req: Request,
22 conn: &Arc<ConnContext>,
23 ctx: &mut FlowCtx,
24 ) -> Result<L7FetchOutput, Error>;
25}
26
27#[async_trait]
28pub trait L4Fetch: Send + Sync {
29 async fn fetch(
30 &self,
31 l4: L4Conn,
32 conn: &Arc<ConnContext>,
33 ctx: &mut FlowCtx,
34 ) -> Result<Tunnel, Error>;
35}
36
37pub enum L7FetchOutput {
38 Response(Response),
39 Tunnel(Tunnel),
40}
41
42pub enum Tunnel {
53 Bidi {
54 client: Box<dyn AsyncReadWrite + Send>,
55 upstream: Box<dyn AsyncReadWrite + Send>,
56 close_reason_tx: Option<oneshot::Sender<CloseReason>>,
57 },
58 Udp(UdpTunnel),
59}
60
61pub struct UdpTunnel {
71 pub join: Pin<Box<dyn Future<Output = CloseReason> + Send>>,
72 pub cancel: CancellationToken,
73}
74
75pub trait AsyncReadWrite: AsyncRead + AsyncWrite + Unpin {}
80impl<T: AsyncRead + AsyncWrite + Unpin + ?Sized> AsyncReadWrite for T {}
81
82#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug, serde::Serialize, serde::Deserialize)]
83pub enum FetchKind {
84 HttpProxy,
85 HttpSynthesize,
86 WebSocketUpgrade,
87 L4Forward,
88}
89
90impl FetchKind {
91 #[must_use]
97 pub const fn phase(self) -> FetchPhase {
98 match self {
99 Self::L4Forward => FetchPhase::L4,
100 Self::HttpProxy | Self::HttpSynthesize | Self::WebSocketUpgrade => FetchPhase::L7,
101 }
102 }
103}
104
105#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug, serde::Serialize, serde::Deserialize)]
106pub enum FetchPhase {
107 L4,
108 L7,
109}
110
111#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug, serde::Serialize, serde::Deserialize)]
112pub struct FetchOutputModes {
113 pub response: bool,
114 pub tunnel: bool,
115}
116
117#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
118pub struct SymbolicFetchRef {
119 pub kind: FetchKind,
120 pub args: serde_json::Value,
121 #[serde(default)]
127 pub retry_buffer_required: bool,
128}
129
130#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug, serde::Serialize, serde::Deserialize)]
131pub enum Terminator {
132 WriteHttpResponse,
133 ByteTunnel,
134 Close,
135}
136
137#[derive(Clone, Debug)]
144pub struct HttpFetchLimits {
145 pub max_body_bytes: u64,
146 pub timeout_ms: Option<u32>,
147 pub follow_redirects: Option<u32>,
148 pub allow_insecure: bool,
149}
150
151impl Default for HttpFetchLimits {
152 fn default() -> Self {
153 Self {
154 max_body_bytes: 1024 * 1024,
155 timeout_ms: None,
156 follow_redirects: Some(5),
157 allow_insecure: false,
158 }
159 }
160}
161
162#[derive(Debug)]
166pub struct HttpFetchRequest {
167 pub method: String,
168 pub url: String,
169 pub headers: Vec<(String, String)>,
170 pub body: Vec<u8>,
171 pub timeout_ms: Option<u32>,
172 pub follow_redirects: Option<u32>,
173 pub verify_tls: Option<bool>,
174}
175
176#[derive(Debug)]
178pub struct HttpFetchResponse {
179 pub status: u16,
180 pub headers: Vec<(String, String)>,
181 pub body: Vec<u8>,
182}
183
184#[derive(Debug, thiserror::Error)]
188pub enum HttpFetchError {
189 #[error("dns failure: {0}")]
190 DnsFailure(String),
191 #[error("connection refused")]
192 ConnectionRefused,
193 #[error("timeout")]
194 Timeout,
195 #[error("tls error: {0}")]
196 TlsError(String),
197 #[error("pool exhausted")]
198 PoolExhausted,
199 #[error("body too large")]
200 BodyTooLarge,
201 #[error("not allowed: {0}")]
202 NotAllowed(String),
203 #[error("insecure rejected")]
204 InsecureRejected,
205 #[error("internal: {0}")]
206 Internal(String),
207}
208
209#[async_trait]
215pub trait HttpFetchBackend: Send + Sync {
216 async fn fetch(
217 &self,
218 req: HttpFetchRequest,
219 limits: HttpFetchLimits,
220 ) -> Result<HttpFetchResponse, HttpFetchError>;
221}
222
223#[cfg(test)]
224mod tests {
225 use std::future::Future;
226 use std::io;
227 use std::net::SocketAddr;
228 use std::pin::Pin;
229 use std::task::{Context, Poll};
230 use std::time::Instant;
231
232 use parking_lot::Mutex;
233 use serde_json::json;
234 use tokio::io::ReadBuf;
235 use tokio_util::sync::CancellationToken;
236
237 use super::*;
238 use crate::body::{Body, Request, Response};
239 use crate::conn_context::{ConnId, Transport};
240 use crate::flow_log::{FlowLogEvent, FlowLogSink};
241
242 struct NoopStream;
247
248 impl AsyncRead for NoopStream {
249 fn poll_read(
250 self: Pin<&mut Self>,
251 _cx: &mut Context<'_>,
252 _buf: &mut ReadBuf<'_>,
253 ) -> Poll<io::Result<()>> {
254 Poll::Ready(Ok(()))
255 }
256 }
257
258 impl AsyncWrite for NoopStream {
259 fn poll_write(
260 self: Pin<&mut Self>,
261 _cx: &mut Context<'_>,
262 buf: &[u8],
263 ) -> Poll<io::Result<usize>> {
264 Poll::Ready(Ok(buf.len()))
265 }
266
267 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
268 Poll::Ready(Ok(()))
269 }
270
271 fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
272 Poll::Ready(Ok(()))
273 }
274 }
275
276 struct NullSink;
277 impl FlowLogSink for NullSink {
278 fn emit(&self, _event: FlowLogEvent) {}
279 }
280
281 struct SynthOk;
282 #[async_trait]
283 impl L7Fetch for SynthOk {
284 async fn fetch(
285 &self,
286 _req: Request,
287 _conn: &Arc<ConnContext>,
288 _ctx: &mut FlowCtx,
289 ) -> Result<L7FetchOutput, Error> {
290 let resp: Response = http::Response::builder().status(200).body(Body::Empty).expect("build");
291 Ok(L7FetchOutput::Response(resp))
292 }
293 }
294
295 struct L4Nop;
296 #[async_trait]
297 impl L4Fetch for L4Nop {
298 async fn fetch(
299 &self,
300 _l4: L4Conn,
301 _conn: &Arc<ConnContext>,
302 _ctx: &mut FlowCtx,
303 ) -> Result<Tunnel, Error> {
304 let (tx, _rx) = oneshot::channel::<crate::middleware::CloseReason>();
305 Ok(Tunnel::Bidi {
306 client: Box::new(NoopStream) as Box<dyn AsyncReadWrite + Send>,
307 upstream: Box::new(NoopStream) as Box<dyn AsyncReadWrite + Send>,
308 close_reason_tx: Some(tx),
309 })
310 }
311 }
312
313 fn assert_send<F: Send>(_: &F) {}
314
315 fn make_conn_context() -> Arc<ConnContext> {
316 let addr: SocketAddr = "127.0.0.1:0".parse().expect("parse addr");
317 Arc::new(ConnContext {
318 id: ConnId(0),
319 remote: addr,
320 local: addr,
321 transport: Transport::Tcp,
322 entered_at: Instant::now(),
323 tls: Mutex::new(None),
324 http_version: std::sync::OnceLock::new(),
325 user: Mutex::new(http::Extensions::new()),
326 })
327 }
328
329 #[test]
330 fn async_read_write_blanket_accepts_async_io_type() {
331 let _: Box<dyn AsyncReadWrite + Send> = Box::new(NoopStream);
332 }
333
334 #[test]
335 fn l7_fetch_output_response_variant_constructs() {
336 let resp: Response =
337 http::Response::builder().status(200).body(Body::Empty).expect("build response");
338 match L7FetchOutput::Response(resp) {
339 L7FetchOutput::Response(_) => {}
340 L7FetchOutput::Tunnel(_) => panic!("unexpected tunnel variant"),
341 }
342 }
343
344 #[test]
345 fn tunnel_bidi_builds_from_paired_async_io_streams() {
346 let (tx, _rx) = oneshot::channel::<crate::middleware::CloseReason>();
347 let tunnel = Tunnel::Bidi {
348 client: Box::new(NoopStream) as Box<dyn AsyncReadWrite + Send>,
349 upstream: Box::new(NoopStream) as Box<dyn AsyncReadWrite + Send>,
350 close_reason_tx: Some(tx),
351 };
352 let _ = L7FetchOutput::Tunnel(tunnel);
353 }
354
355 #[test]
356 fn tunnel_udp_builds_from_join_future_and_cancel_token() {
357 let cancel = CancellationToken::new();
358 let join: Pin<Box<dyn Future<Output = CloseReason> + Send>> =
359 Box::pin(async move { CloseReason::Graceful });
360 let tunnel = Tunnel::Udp(UdpTunnel { join, cancel });
361 let _ = L7FetchOutput::Tunnel(tunnel);
362 }
363
364 #[test]
370 fn l7_fetch_is_constructible_as_arc_dyn_send_sync() {
371 let f: Arc<dyn L7Fetch + Send + Sync> = Arc::new(SynthOk);
372 let _: Arc<dyn L7Fetch> = f;
373 }
374
375 #[test]
376 fn l4_fetch_is_constructible_as_arc_dyn_send_sync() {
377 let f: Arc<dyn L4Fetch + Send + Sync> = Arc::new(L4Nop);
378 let _: Arc<dyn L4Fetch> = f;
379 }
380
381 #[test]
382 fn l7_fetch_fetch_returns_send_future() {
383 let f: Arc<dyn L7Fetch> = Arc::new(SynthOk);
384 let conn = make_conn_context();
385 let mut ctx = FlowCtx {
386 span: tracing::Span::none(),
387 log: Arc::new(NullSink),
388 cancel: CancellationToken::new(),
389 verbosity: crate::flow_log::FlowLogVerbosity::Trajectory,
390 trajectory: crate::flow_log::TrajectoryBuilder::new(conn.id, crate::ir::NodeId::new(0), 0),
391 };
392 let req: Request = http::Request::builder().uri("/").body(Body::Empty).expect("build req");
393 let fut: Pin<Box<dyn Future<Output = Result<L7FetchOutput, Error>> + Send + '_>> =
397 f.fetch(req, &conn, &mut ctx);
398 assert_send(&fut);
399 drop(fut);
400 }
401
402 #[test]
403 fn fetch_kind_serde_round_trip_per_variant() {
404 for k in [
405 FetchKind::HttpProxy,
406 FetchKind::HttpSynthesize,
407 FetchKind::WebSocketUpgrade,
408 FetchKind::L4Forward,
409 ] {
410 let encoded = serde_json::to_string(&k).expect("serialize");
411 let decoded: FetchKind = serde_json::from_str(&encoded).expect("deserialize");
412 assert_eq!(decoded, k);
413 }
414 }
415
416 #[test]
417 fn fetch_phase_serde_round_trip_per_variant() {
418 for p in [FetchPhase::L4, FetchPhase::L7] {
419 let encoded = serde_json::to_string(&p).expect("serialize");
420 let decoded: FetchPhase = serde_json::from_str(&encoded).expect("deserialize");
421 assert_eq!(decoded, p);
422 }
423 }
424
425 #[test]
426 fn terminator_serde_round_trip_per_variant() {
427 for t in [Terminator::WriteHttpResponse, Terminator::ByteTunnel] {
428 let encoded = serde_json::to_string(&t).expect("serialize");
429 let decoded: Terminator = serde_json::from_str(&encoded).expect("deserialize");
430 assert_eq!(decoded, t);
431 }
432 }
433
434 #[test]
435 fn fetch_output_modes_serde_round_trip_http_shapes() {
436 let http_only = FetchOutputModes { response: true, tunnel: false };
438 let ws = FetchOutputModes { response: true, tunnel: true };
440 let l4 = FetchOutputModes { response: false, tunnel: true };
442 for modes in [http_only, ws, l4] {
443 let encoded = serde_json::to_string(&modes).expect("serialize");
444 let decoded: FetchOutputModes = serde_json::from_str(&encoded).expect("deserialize");
445 assert_eq!(decoded, modes);
446 }
447 }
448
449 #[test]
450 fn symbolic_fetch_ref_clone_preserves_fields() {
451 let r = SymbolicFetchRef {
452 kind: FetchKind::HttpProxy,
453 args: json!({ "upstream": "127.0.0.1:8080" }),
454 retry_buffer_required: false,
455 };
456 let cloned = r.clone();
457 assert_eq!(cloned.kind, r.kind);
458 assert_eq!(cloned.args, r.args);
459 let _ = format!("{r:?}");
461 }
462
463 #[test]
464 fn symbolic_fetch_ref_accepts_each_kind() {
465 for kind in [
466 FetchKind::HttpProxy,
467 FetchKind::HttpSynthesize,
468 FetchKind::WebSocketUpgrade,
469 FetchKind::L4Forward,
470 ] {
471 let _ =
472 SymbolicFetchRef { kind, args: serde_json::Value::Null, retry_buffer_required: false };
473 }
474 }
475
476 #[test]
480 fn symbolic_fetch_ref_round_trip_preserves_kind_and_args() {
481 let r = SymbolicFetchRef {
482 kind: FetchKind::WebSocketUpgrade,
483 args: json!({ "upstream": "127.0.0.1:9000" }),
484 retry_buffer_required: false,
485 };
486 let encoded = serde_json::to_string(&r).expect("serialize");
487 let decoded: SymbolicFetchRef = serde_json::from_str(&encoded).expect("deserialize");
488 assert_eq!(decoded.kind, r.kind);
489 assert_eq!(decoded.args, r.args);
490 }
491}