1use std::future::Future;
2use std::pin::Pin;
3use std::sync::Arc;
4
5use async_trait::async_trait;
6use tokio::io::{AsyncRead, AsyncWrite};
7use tokio::sync::oneshot;
8use tokio_util::sync::CancellationToken;
9
10use crate::body::{Request, Response};
11use crate::conn_context::ConnContext;
12use crate::error::Error;
13use crate::flow_ctx::FlowCtx;
14use crate::l4::L4Conn;
15use crate::middleware::CloseReason;
16
17#[async_trait]
18pub trait L7Fetch: Send + Sync {
19 async fn fetch(
20 &self,
21 req: Request,
22 conn: &Arc<ConnContext>,
23 ctx: &mut FlowCtx,
24 ) -> Result<L7FetchOutput, Error>;
25}
26
27#[async_trait]
28pub trait L4Fetch: Send + Sync {
29 async fn fetch(
30 &self,
31 l4: L4Conn,
32 conn: &Arc<ConnContext>,
33 ctx: &mut FlowCtx,
34 ) -> Result<Tunnel, Error>;
35}
36
37pub enum L7FetchOutput {
38 Response(Response),
39 Tunnel(Tunnel),
40}
41
42pub enum Tunnel {
53 Bidi {
54 client: Box<dyn AsyncReadWrite + Send>,
55 upstream: Box<dyn AsyncReadWrite + Send>,
56 close_reason_tx: Option<oneshot::Sender<CloseReason>>,
57 },
58 Udp(UdpTunnel),
59}
60
61pub struct UdpTunnel {
71 pub join: Pin<Box<dyn Future<Output = CloseReason> + Send>>,
72 pub cancel: CancellationToken,
73}
74
75pub trait AsyncReadWrite: AsyncRead + AsyncWrite + Unpin {}
80impl<T: AsyncRead + AsyncWrite + Unpin + ?Sized> AsyncReadWrite for T {}
81
82#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug, serde::Serialize, serde::Deserialize)]
83pub enum FetchKind {
84 HttpProxy,
85 HttpSynthesize,
86 WebSocketUpgrade,
87 L4Forward,
88 AcmeChallenge,
96}
97
98impl FetchKind {
99 #[must_use]
105 pub const fn phase(self) -> FetchPhase {
106 match self {
107 Self::L4Forward => FetchPhase::L4,
108 Self::HttpProxy | Self::HttpSynthesize | Self::WebSocketUpgrade | Self::AcmeChallenge => {
109 FetchPhase::L7
110 }
111 }
112 }
113}
114
115#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug, serde::Serialize, serde::Deserialize)]
116pub enum FetchPhase {
117 L4,
118 L7,
119}
120
121#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug, serde::Serialize, serde::Deserialize)]
122pub struct FetchOutputModes {
123 pub response: bool,
124 pub tunnel: bool,
125}
126
127#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
128pub struct SymbolicFetchRef {
129 pub kind: FetchKind,
130 pub args: serde_json::Value,
131 #[serde(default)]
137 pub retry_buffer_required: bool,
138 #[serde(default)]
149 pub allow_zero_rtt: Option<bool>,
150}
151
152#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug, serde::Serialize, serde::Deserialize)]
153pub enum Terminator {
154 WriteHttpResponse,
155 ByteTunnel,
156 Close,
157}
158
159#[derive(Clone, Debug)]
166pub struct HttpFetchLimits {
167 pub max_body_bytes: u64,
168 pub timeout_ms: Option<u32>,
169 pub follow_redirects: Option<u32>,
170 pub allow_insecure: bool,
171}
172
173impl Default for HttpFetchLimits {
174 fn default() -> Self {
175 Self {
176 max_body_bytes: 1024 * 1024,
177 timeout_ms: None,
178 follow_redirects: Some(5),
179 allow_insecure: false,
180 }
181 }
182}
183
184#[derive(Debug)]
188pub struct HttpFetchRequest {
189 pub method: String,
190 pub url: String,
191 pub headers: Vec<(String, String)>,
192 pub body: Vec<u8>,
193 pub timeout_ms: Option<u32>,
194 pub follow_redirects: Option<u32>,
195 pub verify_tls: Option<bool>,
196}
197
198#[derive(Debug)]
200pub struct HttpFetchResponse {
201 pub status: u16,
202 pub headers: Vec<(String, String)>,
203 pub body: Vec<u8>,
204}
205
206#[derive(Debug, thiserror::Error)]
210pub enum HttpFetchError {
211 #[error("dns failure: {0}")]
212 DnsFailure(String),
213 #[error("connection refused")]
214 ConnectionRefused,
215 #[error("timeout")]
216 Timeout,
217 #[error("tls error: {0}")]
218 TlsError(String),
219 #[error("pool exhausted")]
220 PoolExhausted,
221 #[error("body too large")]
222 BodyTooLarge,
223 #[error("not allowed: {0}")]
224 NotAllowed(String),
225 #[error("insecure rejected")]
226 InsecureRejected,
227 #[error("internal: {0}")]
228 Internal(String),
229}
230
231#[async_trait]
237pub trait HttpFetchBackend: Send + Sync {
238 async fn fetch(
239 &self,
240 req: HttpFetchRequest,
241 limits: HttpFetchLimits,
242 ) -> Result<HttpFetchResponse, HttpFetchError>;
243}
244
245#[cfg(test)]
246mod tests {
247 use std::future::Future;
248 use std::io;
249 use std::net::SocketAddr;
250 use std::pin::Pin;
251 use std::task::{Context, Poll};
252 use std::time::Instant;
253
254 use parking_lot::Mutex;
255 use serde_json::json;
256 use tokio::io::ReadBuf;
257 use tokio_util::sync::CancellationToken;
258
259 use super::*;
260 use crate::body::{Body, Request, Response};
261 use crate::conn_context::{ConnId, Transport};
262 use crate::flow_log::{FlowLogEvent, FlowLogSink};
263
264 struct NoopStream;
269
270 impl AsyncRead for NoopStream {
271 fn poll_read(
272 self: Pin<&mut Self>,
273 _cx: &mut Context<'_>,
274 _buf: &mut ReadBuf<'_>,
275 ) -> Poll<io::Result<()>> {
276 Poll::Ready(Ok(()))
277 }
278 }
279
280 impl AsyncWrite for NoopStream {
281 fn poll_write(
282 self: Pin<&mut Self>,
283 _cx: &mut Context<'_>,
284 buf: &[u8],
285 ) -> Poll<io::Result<usize>> {
286 Poll::Ready(Ok(buf.len()))
287 }
288
289 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
290 Poll::Ready(Ok(()))
291 }
292
293 fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
294 Poll::Ready(Ok(()))
295 }
296 }
297
298 struct NullSink;
299 impl FlowLogSink for NullSink {
300 fn emit(&self, _event: FlowLogEvent) {}
301 }
302
303 struct SynthOk;
304 #[async_trait]
305 impl L7Fetch for SynthOk {
306 async fn fetch(
307 &self,
308 _req: Request,
309 _conn: &Arc<ConnContext>,
310 _ctx: &mut FlowCtx,
311 ) -> Result<L7FetchOutput, Error> {
312 let resp: Response = http::Response::builder().status(200).body(Body::Empty).expect("build");
313 Ok(L7FetchOutput::Response(resp))
314 }
315 }
316
317 struct L4Nop;
318 #[async_trait]
319 impl L4Fetch for L4Nop {
320 async fn fetch(
321 &self,
322 _l4: L4Conn,
323 _conn: &Arc<ConnContext>,
324 _ctx: &mut FlowCtx,
325 ) -> Result<Tunnel, Error> {
326 let (tx, _rx) = oneshot::channel::<crate::middleware::CloseReason>();
327 Ok(Tunnel::Bidi {
328 client: Box::new(NoopStream) as Box<dyn AsyncReadWrite + Send>,
329 upstream: Box::new(NoopStream) as Box<dyn AsyncReadWrite + Send>,
330 close_reason_tx: Some(tx),
331 })
332 }
333 }
334
335 fn assert_send<F: Send>(_: &F) {}
336
337 fn make_conn_context() -> Arc<ConnContext> {
338 let addr: SocketAddr = "127.0.0.1:0".parse().expect("parse addr");
339 Arc::new(ConnContext {
340 id: ConnId(0),
341 remote: addr,
342 local: addr,
343 transport: Transport::Tcp,
344 entered_at: Instant::now(),
345 tls: Mutex::new(None),
346 http_version: std::sync::OnceLock::new(),
347 user: Mutex::new(http::Extensions::new()),
348 })
349 }
350
351 #[test]
352 fn async_read_write_blanket_accepts_async_io_type() {
353 let _: Box<dyn AsyncReadWrite + Send> = Box::new(NoopStream);
354 }
355
356 #[test]
357 fn l7_fetch_output_response_variant_constructs() {
358 let resp: Response =
359 http::Response::builder().status(200).body(Body::Empty).expect("build response");
360 match L7FetchOutput::Response(resp) {
361 L7FetchOutput::Response(_) => {}
362 L7FetchOutput::Tunnel(_) => panic!("unexpected tunnel variant"),
363 }
364 }
365
366 #[test]
367 fn tunnel_bidi_builds_from_paired_async_io_streams() {
368 let (tx, _rx) = oneshot::channel::<crate::middleware::CloseReason>();
369 let tunnel = Tunnel::Bidi {
370 client: Box::new(NoopStream) as Box<dyn AsyncReadWrite + Send>,
371 upstream: Box::new(NoopStream) as Box<dyn AsyncReadWrite + Send>,
372 close_reason_tx: Some(tx),
373 };
374 let _ = L7FetchOutput::Tunnel(tunnel);
375 }
376
377 #[test]
378 fn tunnel_udp_builds_from_join_future_and_cancel_token() {
379 let cancel = CancellationToken::new();
380 let join: Pin<Box<dyn Future<Output = CloseReason> + Send>> =
381 Box::pin(async move { CloseReason::Graceful });
382 let tunnel = Tunnel::Udp(UdpTunnel { join, cancel });
383 let _ = L7FetchOutput::Tunnel(tunnel);
384 }
385
386 #[test]
392 fn l7_fetch_is_constructible_as_arc_dyn_send_sync() {
393 let f: Arc<dyn L7Fetch + Send + Sync> = Arc::new(SynthOk);
394 let _: Arc<dyn L7Fetch> = f;
395 }
396
397 #[test]
398 fn l4_fetch_is_constructible_as_arc_dyn_send_sync() {
399 let f: Arc<dyn L4Fetch + Send + Sync> = Arc::new(L4Nop);
400 let _: Arc<dyn L4Fetch> = f;
401 }
402
403 #[test]
404 fn l7_fetch_fetch_returns_send_future() {
405 let f: Arc<dyn L7Fetch> = Arc::new(SynthOk);
406 let conn = make_conn_context();
407 let mut ctx = FlowCtx {
408 span: tracing::Span::none(),
409 log: Arc::new(NullSink),
410 cancel: CancellationToken::new(),
411 verbosity: crate::flow_log::FlowLogVerbosity::Trajectory,
412 trajectory: crate::flow_log::TrajectoryBuilder::new(conn.id, crate::ir::NodeId::new(0), 0),
413 };
414 let req: Request = http::Request::builder().uri("/").body(Body::Empty).expect("build req");
415 let fut: Pin<Box<dyn Future<Output = Result<L7FetchOutput, Error>> + Send + '_>> =
419 f.fetch(req, &conn, &mut ctx);
420 assert_send(&fut);
421 drop(fut);
422 }
423
424 #[test]
425 fn fetch_kind_serde_round_trip_per_variant() {
426 for k in [
427 FetchKind::HttpProxy,
428 FetchKind::HttpSynthesize,
429 FetchKind::WebSocketUpgrade,
430 FetchKind::L4Forward,
431 FetchKind::AcmeChallenge,
432 ] {
433 let encoded = serde_json::to_string(&k).expect("serialize");
434 let decoded: FetchKind = serde_json::from_str(&encoded).expect("deserialize");
435 assert_eq!(decoded, k);
436 }
437 }
438
439 #[test]
440 fn fetch_phase_serde_round_trip_per_variant() {
441 for p in [FetchPhase::L4, FetchPhase::L7] {
442 let encoded = serde_json::to_string(&p).expect("serialize");
443 let decoded: FetchPhase = serde_json::from_str(&encoded).expect("deserialize");
444 assert_eq!(decoded, p);
445 }
446 }
447
448 #[test]
449 fn terminator_serde_round_trip_per_variant() {
450 for t in [Terminator::WriteHttpResponse, Terminator::ByteTunnel] {
451 let encoded = serde_json::to_string(&t).expect("serialize");
452 let decoded: Terminator = serde_json::from_str(&encoded).expect("deserialize");
453 assert_eq!(decoded, t);
454 }
455 }
456
457 #[test]
458 fn fetch_output_modes_serde_round_trip_http_shapes() {
459 let http_only = FetchOutputModes { response: true, tunnel: false };
461 let ws = FetchOutputModes { response: true, tunnel: true };
463 let l4 = FetchOutputModes { response: false, tunnel: true };
465 for modes in [http_only, ws, l4] {
466 let encoded = serde_json::to_string(&modes).expect("serialize");
467 let decoded: FetchOutputModes = serde_json::from_str(&encoded).expect("deserialize");
468 assert_eq!(decoded, modes);
469 }
470 }
471
472 #[test]
473 fn symbolic_fetch_ref_clone_preserves_fields() {
474 let r = SymbolicFetchRef {
475 kind: FetchKind::HttpProxy,
476 args: json!({ "upstream": "127.0.0.1:8080" }),
477 retry_buffer_required: false,
478 allow_zero_rtt: None,
479 };
480 let cloned = r.clone();
481 assert_eq!(cloned.kind, r.kind);
482 assert_eq!(cloned.args, r.args);
483 let _ = format!("{r:?}");
485 }
486
487 #[test]
488 fn symbolic_fetch_ref_accepts_each_kind() {
489 for kind in [
490 FetchKind::HttpProxy,
491 FetchKind::HttpSynthesize,
492 FetchKind::WebSocketUpgrade,
493 FetchKind::L4Forward,
494 ] {
495 let _ = SymbolicFetchRef {
496 kind,
497 args: serde_json::Value::Null,
498 retry_buffer_required: false,
499 allow_zero_rtt: None,
500 };
501 }
502 }
503
504 #[test]
508 fn symbolic_fetch_ref_round_trip_preserves_kind_and_args() {
509 let r = SymbolicFetchRef {
510 kind: FetchKind::WebSocketUpgrade,
511 args: json!({ "upstream": "127.0.0.1:9000" }),
512 retry_buffer_required: false,
513 allow_zero_rtt: None,
514 };
515 let encoded = serde_json::to_string(&r).expect("serialize");
516 let decoded: SymbolicFetchRef = serde_json::from_str(&encoded).expect("deserialize");
517 assert_eq!(decoded.kind, r.kind);
518 assert_eq!(decoded.args, r.args);
519 }
520}