Skip to main content

zlayer_proxy/stream/
tcp.rs

1//! TCP stream proxy service
2//!
3//! Implements raw TCP proxying with a standalone `serve()` method.
4//! Provides bidirectional tunneling between clients and backends.
5//!
6//! Optionally terminates TLS at the proxy (driven by the endpoint's
7//! `stream.tls` config) and/or prepends a PROXY protocol v2 header to the
8//! upstream connection (driven by `stream.proxy_protocol`).
9
10use std::net::SocketAddr;
11use std::sync::Arc;
12use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
13use tokio::net::TcpListener;
14use tokio_rustls::TlsAcceptor;
15
16use super::registry::StreamRegistry;
17
18/// Build a [`TlsAcceptor`] for L4 TLS termination from a dynamic SNI cert
19/// resolver.
20///
21/// This reuses the same rustls server-config shape as the L7 HTTPS listener
22/// (`with_no_client_auth().with_cert_resolver(..)`), so an L4 TCP endpoint with
23/// `stream.tls = true` terminates TLS using the same shared certificate set
24/// (ACME / hot-loaded certs) as HTTPS endpoints. Pass the `ProxyManager`'s
25/// shared `Arc<SniCertResolver>` here.
26#[must_use]
27pub fn tls_acceptor_from_resolver(
28    resolver: Arc<dyn rustls::server::ResolvesServerCert>,
29) -> TlsAcceptor {
30    let config = rustls::ServerConfig::builder()
31        .with_no_client_auth()
32        .with_cert_resolver(resolver);
33    TlsAcceptor::from(Arc::new(config))
34}
35
36/// TCP stream proxy service
37///
38/// Listens on a port and proxies TCP connections to registered backends
39/// using round-robin load balancing.
40pub struct TcpStreamService {
41    registry: Arc<StreamRegistry>,
42    listen_port: u16,
43    /// When `Some`, the proxy terminates TLS on accepted client connections
44    /// using this acceptor (built from the shared SNI cert resolver) and
45    /// relays the decrypted plaintext to the backend.
46    tls_acceptor: Option<TlsAcceptor>,
47    /// When `true`, prepend a PROXY protocol v2 header to the upstream
48    /// connection so the backend can recover the real client address.
49    proxy_protocol: bool,
50    /// Listener local address, captured at `serve()` time. Used as the
51    /// "destination" address in the PROXY protocol header.
52    local_addr: std::sync::OnceLock<SocketAddr>,
53}
54
55impl TcpStreamService {
56    /// Create a new TCP stream service
57    #[must_use]
58    pub fn new(registry: Arc<StreamRegistry>, listen_port: u16) -> Self {
59        Self {
60            registry,
61            listen_port,
62            tls_acceptor: None,
63            proxy_protocol: false,
64            local_addr: std::sync::OnceLock::new(),
65        }
66    }
67
68    /// Enable TLS termination using the given acceptor (builder-style).
69    #[must_use]
70    pub fn with_tls_acceptor(mut self, acceptor: TlsAcceptor) -> Self {
71        self.tls_acceptor = Some(acceptor);
72        self
73    }
74
75    /// Enable PROXY protocol v2 toward the upstream backend (builder-style).
76    #[must_use]
77    pub fn with_proxy_protocol(mut self, enabled: bool) -> Self {
78        self.proxy_protocol = enabled;
79        self
80    }
81
82    /// Get the listen port
83    #[must_use]
84    pub fn port(&self) -> u16 {
85        self.listen_port
86    }
87
88    /// Get a reference to the registry
89    #[must_use]
90    pub fn registry(&self) -> &Arc<StreamRegistry> {
91        &self.registry
92    }
93
94    /// Run a standalone TCP accept loop on the given listener.
95    ///
96    /// For each accepted connection, resolves a backend from the registry and
97    /// spawns a task to perform bidirectional tunneling. This method runs
98    /// indefinitely until the listener encounters a fatal error.
99    pub async fn serve(self: Arc<Self>, listener: TcpListener) {
100        // Capture the listener local address once so the PROXY protocol header
101        // can report the correct destination address/port.
102        if let Ok(addr) = listener.local_addr() {
103            let _ = self.local_addr.set(addr);
104        }
105
106        tracing::info!(
107            port = self.listen_port,
108            tls = self.tls_acceptor.is_some(),
109            proxy_protocol = self.proxy_protocol,
110            "TCP stream proxy listening"
111        );
112
113        loop {
114            let (client_stream, client_addr) = match listener.accept().await {
115                Ok(conn) => conn,
116                Err(e) => {
117                    // Transient errors (too many open files, etc.) -- log and retry
118                    tracing::warn!(
119                        port = self.listen_port,
120                        error = %e,
121                        "TCP accept error, retrying"
122                    );
123                    tokio::time::sleep(std::time::Duration::from_millis(50)).await;
124                    continue;
125                }
126            };
127
128            let svc = Arc::clone(&self);
129            tokio::spawn(async move {
130                svc.handle_raw_connection(client_stream, client_addr).await;
131            });
132        }
133    }
134
135    /// Handle a single raw TCP connection (resolve backend, tunnel).
136    async fn handle_raw_connection(
137        &self,
138        client_stream: tokio::net::TcpStream,
139        client_addr: SocketAddr,
140    ) {
141        // Resolve service for this port
142        let Some(service) = self.registry.resolve_tcp(self.listen_port) else {
143            tracing::warn!(
144                port = self.listen_port,
145                client = %client_addr,
146                "No service registered for TCP port"
147            );
148            return;
149        };
150
151        // Select backend using round-robin
152        let Some(backend) = service.select_backend() else {
153            tracing::warn!(
154                port = self.listen_port,
155                service = %service.name,
156                client = %client_addr,
157                "No backends available for TCP service"
158            );
159            return;
160        };
161
162        tracing::debug!(
163            port = self.listen_port,
164            service = %service.name,
165            client = %client_addr,
166            backend = %backend,
167            "Proxying TCP connection"
168        );
169
170        // Connect to the upstream backend
171        let mut upstream = match tokio::net::TcpStream::connect(backend).await {
172            Ok(stream) => stream,
173            Err(e) => {
174                tracing::warn!(
175                    error = %e,
176                    backend = %backend,
177                    service = %service.name,
178                    client = %client_addr,
179                    "Failed to connect to TCP backend"
180                );
181                return;
182            }
183        };
184
185        // When PROXY protocol is enabled, emit a v2 header to the upstream so
186        // the backend can recover the real client address. The destination is
187        // the listener's local address (captured at serve()).
188        if self.proxy_protocol {
189            let dst = self
190                .local_addr
191                .get()
192                .copied()
193                .unwrap_or_else(|| SocketAddr::new(backend.ip(), self.listen_port));
194            let header = build_proxy_protocol_v2_header(client_addr, dst);
195            if let Err(e) = upstream.write_all(&header).await {
196                tracing::warn!(
197                    error = %e,
198                    backend = %backend,
199                    service = %service.name,
200                    client = %client_addr,
201                    "Failed to write PROXY protocol header to backend"
202                );
203                return;
204            }
205        }
206
207        // Terminate TLS if configured, then relay the resulting plaintext
208        // stream against the upstream; otherwise relay raw.
209        if let Some(acceptor) = &self.tls_acceptor {
210            match acceptor.accept(client_stream).await {
211                Ok(tls_stream) => {
212                    Self::duplex(tls_stream, upstream).await;
213                }
214                Err(e) => {
215                    tracing::warn!(
216                        error = %e,
217                        service = %service.name,
218                        client = %client_addr,
219                        "TLS handshake with client failed"
220                    );
221                }
222            }
223        } else {
224            Self::duplex(client_stream, upstream).await;
225        }
226    }
227
228    /// Bidirectional data copy between a downstream (client-facing) and an
229    /// upstream (backend-facing) stream.
230    ///
231    /// Generic over any `AsyncRead + AsyncWrite` so it can relay either a raw
232    /// `TcpStream` or a TLS-terminated `TlsStream` on the downstream side.
233    /// Uses `tokio::io::copy_bidirectional` for efficient proxying.
234    async fn duplex<D, U>(mut downstream: D, mut upstream: U)
235    where
236        D: AsyncRead + AsyncWrite + Unpin,
237        U: AsyncRead + AsyncWrite + Unpin,
238    {
239        match tokio::io::copy_bidirectional(&mut downstream, &mut upstream).await {
240            Ok((down_to_up, up_to_down)) => {
241                tracing::debug!(
242                    down_to_up = down_to_up,
243                    up_to_down = up_to_down,
244                    "TCP tunnel closed"
245                );
246            }
247            Err(e) => {
248                tracing::debug!(error = %e, "TCP tunnel error");
249            }
250        }
251    }
252
253    /// `pub(crate)` bidirectional splice between a downstream and upstream
254    /// stream, reusing the same [`copy_bidirectional`](tokio::io::copy_bidirectional)
255    /// machinery as [`Self::duplex`].
256    ///
257    /// Exposed so the HTTPS ingress (`server.rs`) can splice an unmanaged SNI
258    /// connection straight through to its real upstream without terminating TLS.
259    pub(crate) async fn splice<D, U>(downstream: D, upstream: U)
260    where
261        D: AsyncRead + AsyncWrite + Unpin,
262        U: AsyncRead + AsyncWrite + Unpin,
263    {
264        Self::duplex(downstream, upstream).await;
265    }
266}
267
268/// Build a PROXY protocol v2 header describing a proxied TCP connection from
269/// `src` (the real client) to `dst` (the proxy's listener address).
270///
271/// The header layout (see the `HAProxy` PROXY protocol v2 spec):
272/// - 12-byte signature `0D 0A 0D 0A 00 0D 0A 51 55 49 54 0A`
273/// - byte 13: version/command `0x21` (v2, PROXY command)
274/// - byte 14: address family + transport — `0x11` (`AF_INET` + STREAM) or
275///   `0x21` (`AF_INET6` + STREAM)
276/// - bytes 15-16: big-endian length of the following address block
277///   (12 for IPv4, 36 for IPv6)
278/// - address block: src IP, dst IP, src port, dst port (all big-endian)
279///
280/// The address family is chosen from the client (`src`) address. When `src`
281/// and `dst` families differ, both addresses are coerced to the client's
282/// family so the header stays internally consistent.
283#[must_use]
284pub fn build_proxy_protocol_v2_header(src: SocketAddr, dst: SocketAddr) -> Vec<u8> {
285    const SIG: [u8; 12] = [
286        0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A,
287    ];
288
289    let mut out = Vec::with_capacity(28);
290    out.extend_from_slice(&SIG);
291    out.push(0x21); // version 2 + PROXY command
292
293    match src {
294        SocketAddr::V4(src_v4) => {
295            out.push(0x11); // AF_INET + STREAM
296            out.extend_from_slice(&12u16.to_be_bytes()); // addr block len
297
298            let dst_ip = match dst {
299                SocketAddr::V4(d) => *d.ip(),
300                SocketAddr::V6(_) => std::net::Ipv4Addr::UNSPECIFIED,
301            };
302            out.extend_from_slice(&src_v4.ip().octets());
303            out.extend_from_slice(&dst_ip.octets());
304            out.extend_from_slice(&src_v4.port().to_be_bytes());
305            out.extend_from_slice(&dst.port().to_be_bytes());
306        }
307        SocketAddr::V6(src_v6) => {
308            out.push(0x21); // AF_INET6 + STREAM
309            out.extend_from_slice(&36u16.to_be_bytes()); // addr block len
310
311            let dst_ip = match dst {
312                SocketAddr::V6(d) => *d.ip(),
313                SocketAddr::V4(d) => d.ip().to_ipv6_mapped(),
314            };
315            out.extend_from_slice(&src_v6.ip().octets());
316            out.extend_from_slice(&dst_ip.octets());
317            out.extend_from_slice(&src_v6.port().to_be_bytes());
318            out.extend_from_slice(&dst.port().to_be_bytes());
319        }
320    }
321
322    out
323}
324
325#[cfg(test)]
326mod tests {
327    use super::*;
328    use std::net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6};
329
330    #[test]
331    fn proxy_protocol_v2_ipv4_exact_bytes() {
332        let src = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(192, 168, 1, 50), 0xABCD));
333        let dst = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(10, 0, 0, 1), 5432));
334        let hdr = build_proxy_protocol_v2_header(src, dst);
335
336        let expected: Vec<u8> = vec![
337            // 12-byte signature
338            0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A,
339            0x21, // v2 + PROXY
340            0x11, // AF_INET + STREAM
341            0x00, 0x0C, // addr block length = 12
342            192, 168, 1, 50, // src IP
343            10, 0, 0, 1, // dst IP
344            0xAB, 0xCD, // src port 0xABCD
345            0x15, 0x38, // dst port 5432
346        ];
347        assert_eq!(hdr, expected);
348        assert_eq!(hdr.len(), 16 + 12);
349    }
350
351    #[test]
352    fn proxy_protocol_v2_ipv6_shape() {
353        let src = SocketAddr::V6(SocketAddrV6::new(
354            Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 1),
355            7777,
356            0,
357            0,
358        ));
359        let dst = SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::LOCALHOST, 8888, 0, 0));
360        let hdr = build_proxy_protocol_v2_header(src, dst);
361
362        // 12 sig + 1 ver/cmd + 1 fam + 2 len + 36 addr block = 52 bytes
363        assert_eq!(hdr.len(), 16 + 36);
364        assert_eq!(
365            &hdr[..12],
366            &[0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A]
367        );
368        assert_eq!(hdr[12], 0x21); // v2 + PROXY
369        assert_eq!(hdr[13], 0x21); // AF_INET6 + STREAM
370        assert_eq!(&hdr[14..16], &36u16.to_be_bytes());
371        // src IP starts at byte 16
372        assert_eq!(
373            &hdr[16..32],
374            &src.ip().to_string().parse::<Ipv6Addr>().unwrap().octets()
375        );
376        // ports at the tail
377        assert_eq!(&hdr[48..50], &7777u16.to_be_bytes());
378        assert_eq!(&hdr[50..52], &8888u16.to_be_bytes());
379    }
380}