1use std::future::Future;
2use std::pin::Pin;
3use std::sync::Arc;
4
5use async_trait::async_trait;
6use bytes::Bytes;
7use tokio::io::{AsyncRead, AsyncWrite};
8use tokio::sync::oneshot;
9use tokio_util::sync::CancellationToken;
10
11use crate::body::{Request, Response};
12use crate::conn_context::ConnContext;
13use crate::error::Error;
14use crate::flow_ctx::FlowCtx;
15use crate::l4::L4Conn;
16use crate::middleware::CloseReason;
17
18#[async_trait]
19pub trait L7Fetch: Send + Sync {
20 async fn fetch(
21 &self,
22 req: Request,
23 conn: &Arc<ConnContext>,
24 ctx: &mut FlowCtx,
25 ) -> Result<L7FetchOutput, Error>;
26}
27
28#[async_trait]
29pub trait L4Fetch: Send + Sync {
30 async fn fetch(
31 &self,
32 l4: L4Conn,
33 conn: &Arc<ConnContext>,
34 ctx: &mut FlowCtx,
35 ) -> Result<Tunnel, Error>;
36}
37
38pub enum L7FetchOutput {
39 Response(Response),
40 Tunnel(Tunnel),
41}
42
43pub enum Tunnel {
54 Bidi {
55 client: Box<dyn AsyncReadWrite + Send>,
56 upstream: Box<dyn AsyncReadWrite + Send>,
57 close_reason_tx: Option<oneshot::Sender<CloseReason>>,
58 },
59 SpliceBidi {
69 client: tokio::net::TcpStream,
70 upstream: tokio::net::TcpStream,
71 close_reason_tx: Option<oneshot::Sender<CloseReason>>,
72 },
73 Udp(UdpTunnel),
74}
75
76pub struct UdpTunnel {
86 pub join: Pin<Box<dyn Future<Output = CloseReason> + Send>>,
87 pub cancel: CancellationToken,
88}
89
90pub trait AsyncReadWrite: AsyncRead + AsyncWrite + Unpin {}
95impl<T: AsyncRead + AsyncWrite + Unpin + ?Sized> AsyncReadWrite for T {}
96
97#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug, serde::Serialize, serde::Deserialize)]
98pub enum FetchKind {
99 HttpProxy,
100 HttpSynthesize,
101 WebSocketUpgrade,
102 L4Forward,
103 AcmeChallenge,
111}
112
113impl FetchKind {
114 #[must_use]
120 pub const fn phase(self) -> FetchPhase {
121 match self {
122 Self::L4Forward => FetchPhase::L4,
123 Self::HttpProxy | Self::HttpSynthesize | Self::WebSocketUpgrade | Self::AcmeChallenge => {
124 FetchPhase::L7
125 }
126 }
127 }
128}
129
130#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug, serde::Serialize, serde::Deserialize)]
131pub enum FetchPhase {
132 L4,
133 L7,
134}
135
136#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug, serde::Serialize, serde::Deserialize)]
137pub struct FetchOutputModes {
138 pub response: bool,
139 pub tunnel: bool,
140}
141
142#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
143pub struct SymbolicFetchRef {
144 pub kind: FetchKind,
145 pub args: serde_json::Value,
146 #[serde(default)]
152 pub retry_buffer_required: bool,
153 #[serde(default)]
164 pub allow_zero_rtt: Option<bool>,
165}
166
167#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug, serde::Serialize, serde::Deserialize)]
168pub enum Terminator {
169 WriteHttpResponse,
170 ByteTunnel,
171 Close,
172}
173
174#[derive(Clone, Debug)]
181pub struct HttpFetchLimits {
182 pub max_body_bytes: u64,
183 pub timeout_ms: Option<u32>,
184 pub follow_redirects: Option<u32>,
185 pub allow_insecure: bool,
186}
187
188impl Default for HttpFetchLimits {
189 fn default() -> Self {
190 Self {
191 max_body_bytes: 1024 * 1024,
192 timeout_ms: None,
193 follow_redirects: Some(5),
194 allow_insecure: false,
195 }
196 }
197}
198
199#[derive(Debug)]
207pub struct HttpFetchRequest {
208 pub method: String,
209 pub url: String,
210 pub headers: Vec<(String, String)>,
211 pub body: Vec<u8>,
212 pub timeout_ms: Option<u32>,
213 pub follow_redirects: Option<u32>,
214 pub verify_tls: Option<bool>,
215}
216
217impl HttpFetchRequest {
218 #[must_use]
225 pub fn from_http_request(req: &http::Request<Bytes>) -> Self {
226 let headers = req
227 .headers()
228 .iter()
229 .filter_map(|(name, value)| {
230 value.to_str().ok().map(|v| (name.as_str().to_owned(), v.to_owned()))
231 })
232 .collect();
233 Self {
234 method: req.method().as_str().to_owned(),
235 url: req.uri().to_string(),
236 headers,
237 body: req.body().to_vec(),
238 timeout_ms: None,
239 follow_redirects: None,
240 verify_tls: None,
241 }
242 }
243}
244
245impl TryFrom<&HttpFetchRequest> for http::Request<Bytes> {
246 type Error = http::Error;
247
248 fn try_from(value: &HttpFetchRequest) -> Result<Self, Self::Error> {
253 let mut builder =
254 http::Request::builder().method(value.method.as_str()).uri(value.url.as_str());
255 for (name, val) in &value.headers {
256 builder = builder.header(name.as_str(), val.as_str());
257 }
258 builder.body(Bytes::from(value.body.clone()))
259 }
260}
261
262#[derive(Debug)]
264pub struct HttpFetchResponse {
265 pub status: u16,
266 pub headers: Vec<(String, String)>,
267 pub body: Vec<u8>,
268}
269
270impl HttpFetchResponse {
271 #[must_use]
275 pub fn from_http_response(resp: &http::Response<Bytes>) -> Self {
276 let headers = resp
277 .headers()
278 .iter()
279 .filter_map(|(name, value)| {
280 value.to_str().ok().map(|v| (name.as_str().to_owned(), v.to_owned()))
281 })
282 .collect();
283 Self { status: resp.status().as_u16(), headers, body: resp.body().to_vec() }
284 }
285}
286
287#[derive(Debug, thiserror::Error)]
291pub enum HttpFetchError {
292 #[error("dns failure: {0}")]
293 DnsFailure(String),
294 #[error("connection refused")]
295 ConnectionRefused,
296 #[error("timeout")]
297 Timeout,
298 #[error("tls error: {0}")]
299 TlsError(String),
300 #[error("pool exhausted")]
301 PoolExhausted,
302 #[error("body too large")]
303 BodyTooLarge,
304 #[error("not allowed: {0}")]
305 NotAllowed(String),
306 #[error("insecure rejected")]
307 InsecureRejected,
308 #[error("internal: {0}")]
309 Internal(String),
310}
311
312#[async_trait]
318pub trait HttpFetchBackend: Send + Sync {
319 async fn fetch(
320 &self,
321 req: HttpFetchRequest,
322 limits: HttpFetchLimits,
323 ) -> Result<HttpFetchResponse, HttpFetchError>;
324}
325
326#[cfg(test)]
327mod tests {
328 use std::future::Future;
329 use std::io;
330 use std::net::SocketAddr;
331 use std::pin::Pin;
332 use std::task::{Context, Poll};
333 use std::time::Instant;
334
335 use parking_lot::Mutex;
336 use serde_json::json;
337 use tokio::io::ReadBuf;
338 use tokio_util::sync::CancellationToken;
339
340 use super::*;
341 use crate::body::{Body, Request, Response};
342 use crate::conn_context::{ConnId, Transport};
343 use crate::flow_log::{FlowLogEvent, FlowLogSink};
344
345 struct NoopStream;
350
351 impl AsyncRead for NoopStream {
352 fn poll_read(
353 self: Pin<&mut Self>,
354 _cx: &mut Context<'_>,
355 _buf: &mut ReadBuf<'_>,
356 ) -> Poll<io::Result<()>> {
357 Poll::Ready(Ok(()))
358 }
359 }
360
361 impl AsyncWrite for NoopStream {
362 fn poll_write(
363 self: Pin<&mut Self>,
364 _cx: &mut Context<'_>,
365 buf: &[u8],
366 ) -> Poll<io::Result<usize>> {
367 Poll::Ready(Ok(buf.len()))
368 }
369
370 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
371 Poll::Ready(Ok(()))
372 }
373
374 fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
375 Poll::Ready(Ok(()))
376 }
377 }
378
379 struct NullSink;
380 impl FlowLogSink for NullSink {
381 fn emit(&self, _event: FlowLogEvent) {}
382 }
383
384 struct SynthOk;
385 #[async_trait]
386 impl L7Fetch for SynthOk {
387 async fn fetch(
388 &self,
389 _req: Request,
390 _conn: &Arc<ConnContext>,
391 _ctx: &mut FlowCtx,
392 ) -> Result<L7FetchOutput, Error> {
393 let resp: Response = http::Response::builder().status(200).body(Body::Empty).expect("build");
394 Ok(L7FetchOutput::Response(resp))
395 }
396 }
397
398 struct L4Nop;
399 #[async_trait]
400 impl L4Fetch for L4Nop {
401 async fn fetch(
402 &self,
403 _l4: L4Conn,
404 _conn: &Arc<ConnContext>,
405 _ctx: &mut FlowCtx,
406 ) -> Result<Tunnel, Error> {
407 let (tx, _rx) = oneshot::channel::<crate::middleware::CloseReason>();
408 Ok(Tunnel::Bidi {
409 client: Box::new(NoopStream) as Box<dyn AsyncReadWrite + Send>,
410 upstream: Box::new(NoopStream) as Box<dyn AsyncReadWrite + Send>,
411 close_reason_tx: Some(tx),
412 })
413 }
414 }
415
416 fn assert_send<F: Send>(_: &F) {}
417
418 fn make_conn_context() -> Arc<ConnContext> {
419 let addr: SocketAddr = "127.0.0.1:0".parse().expect("parse addr");
420 Arc::new(ConnContext {
421 id: ConnId(0),
422 remote: addr,
423 local: addr,
424 transport: Transport::Tcp,
425 entered_at: Instant::now(),
426 tls: Mutex::new(None),
427 http_version: std::sync::OnceLock::new(),
428 user: Mutex::new(http::Extensions::new()),
429 })
430 }
431
432 #[test]
433 fn async_read_write_blanket_accepts_async_io_type() {
434 let _: Box<dyn AsyncReadWrite + Send> = Box::new(NoopStream);
435 }
436
437 #[test]
438 fn l7_fetch_output_response_variant_constructs() {
439 let resp: Response =
440 http::Response::builder().status(200).body(Body::Empty).expect("build response");
441 match L7FetchOutput::Response(resp) {
442 L7FetchOutput::Response(_) => {}
443 L7FetchOutput::Tunnel(_) => panic!("unexpected tunnel variant"),
444 }
445 }
446
447 #[test]
448 fn tunnel_bidi_builds_from_paired_async_io_streams() {
449 let (tx, _rx) = oneshot::channel::<crate::middleware::CloseReason>();
450 let tunnel = Tunnel::Bidi {
451 client: Box::new(NoopStream) as Box<dyn AsyncReadWrite + Send>,
452 upstream: Box::new(NoopStream) as Box<dyn AsyncReadWrite + Send>,
453 close_reason_tx: Some(tx),
454 };
455 let _ = L7FetchOutput::Tunnel(tunnel);
456 }
457
458 #[test]
459 fn tunnel_udp_builds_from_join_future_and_cancel_token() {
460 let cancel = CancellationToken::new();
461 let join: Pin<Box<dyn Future<Output = CloseReason> + Send>> =
462 Box::pin(async move { CloseReason::Graceful });
463 let tunnel = Tunnel::Udp(UdpTunnel { join, cancel });
464 let _ = L7FetchOutput::Tunnel(tunnel);
465 }
466
467 #[test]
473 fn l7_fetch_is_constructible_as_arc_dyn_send_sync() {
474 let f: Arc<dyn L7Fetch + Send + Sync> = Arc::new(SynthOk);
475 let _: Arc<dyn L7Fetch> = f;
476 }
477
478 #[test]
479 fn l4_fetch_is_constructible_as_arc_dyn_send_sync() {
480 let f: Arc<dyn L4Fetch + Send + Sync> = Arc::new(L4Nop);
481 let _: Arc<dyn L4Fetch> = f;
482 }
483
484 #[test]
485 fn l7_fetch_fetch_returns_send_future() {
486 let f: Arc<dyn L7Fetch> = Arc::new(SynthOk);
487 let conn = make_conn_context();
488 let mut ctx = FlowCtx {
489 span: tracing::Span::none(),
490 log: Arc::new(NullSink),
491 cancel: CancellationToken::new(),
492 accept_cancel: CancellationToken::new(),
493 verbosity: crate::flow_log::FlowLogVerbosity::Trajectory,
494 trajectory: crate::flow_log::TrajectoryBuilder::new(conn.id, crate::ir::NodeId::new(0), 0),
495 };
496 let req: Request = http::Request::builder().uri("/").body(Body::Empty).expect("build req");
497 let fut: Pin<Box<dyn Future<Output = Result<L7FetchOutput, Error>> + Send + '_>> =
501 f.fetch(req, &conn, &mut ctx);
502 assert_send(&fut);
503 drop(fut);
504 }
505
506 #[test]
507 fn fetch_kind_serde_round_trip_per_variant() {
508 for k in [
509 FetchKind::HttpProxy,
510 FetchKind::HttpSynthesize,
511 FetchKind::WebSocketUpgrade,
512 FetchKind::L4Forward,
513 FetchKind::AcmeChallenge,
514 ] {
515 let encoded = serde_json::to_string(&k).expect("serialize");
516 let decoded: FetchKind = serde_json::from_str(&encoded).expect("deserialize");
517 assert_eq!(decoded, k);
518 }
519 }
520
521 #[test]
522 fn fetch_phase_serde_round_trip_per_variant() {
523 for p in [FetchPhase::L4, FetchPhase::L7] {
524 let encoded = serde_json::to_string(&p).expect("serialize");
525 let decoded: FetchPhase = serde_json::from_str(&encoded).expect("deserialize");
526 assert_eq!(decoded, p);
527 }
528 }
529
530 #[test]
531 fn terminator_serde_round_trip_per_variant() {
532 for t in [Terminator::WriteHttpResponse, Terminator::ByteTunnel] {
533 let encoded = serde_json::to_string(&t).expect("serialize");
534 let decoded: Terminator = serde_json::from_str(&encoded).expect("deserialize");
535 assert_eq!(decoded, t);
536 }
537 }
538
539 #[test]
540 fn fetch_output_modes_serde_round_trip_http_shapes() {
541 let http_only = FetchOutputModes { response: true, tunnel: false };
543 let ws = FetchOutputModes { response: true, tunnel: true };
545 let l4 = FetchOutputModes { response: false, tunnel: true };
547 for modes in [http_only, ws, l4] {
548 let encoded = serde_json::to_string(&modes).expect("serialize");
549 let decoded: FetchOutputModes = serde_json::from_str(&encoded).expect("deserialize");
550 assert_eq!(decoded, modes);
551 }
552 }
553
554 #[test]
555 fn symbolic_fetch_ref_clone_preserves_fields() {
556 let r = SymbolicFetchRef {
557 kind: FetchKind::HttpProxy,
558 args: json!({ "upstream": "127.0.0.1:8080" }),
559 retry_buffer_required: false,
560 allow_zero_rtt: None,
561 };
562 let cloned = r.clone();
563 assert_eq!(cloned.kind, r.kind);
564 assert_eq!(cloned.args, r.args);
565 let _ = format!("{r:?}");
567 }
568
569 #[test]
570 fn symbolic_fetch_ref_accepts_each_kind() {
571 for kind in [
572 FetchKind::HttpProxy,
573 FetchKind::HttpSynthesize,
574 FetchKind::WebSocketUpgrade,
575 FetchKind::L4Forward,
576 ] {
577 let _ = SymbolicFetchRef {
578 kind,
579 args: serde_json::Value::Null,
580 retry_buffer_required: false,
581 allow_zero_rtt: None,
582 };
583 }
584 }
585
586 #[test]
590 fn symbolic_fetch_ref_round_trip_preserves_kind_and_args() {
591 let r = SymbolicFetchRef {
592 kind: FetchKind::WebSocketUpgrade,
593 args: json!({ "upstream": "127.0.0.1:9000" }),
594 retry_buffer_required: false,
595 allow_zero_rtt: None,
596 };
597 let encoded = serde_json::to_string(&r).expect("serialize");
598 let decoded: SymbolicFetchRef = serde_json::from_str(&encoded).expect("deserialize");
599 assert_eq!(decoded.kind, r.kind);
600 assert_eq!(decoded.args, r.args);
601 }
602}