Skip to main content

zlayer_tunnel/client/
proxy.rs

1//! Local service proxy for tunnel client
2//!
3//! Handles proxying between the tunnel data channel and local services.
4//! When the tunnel client receives an `IncomingConnection` event, it needs to:
5//! 1. Connect to the local service (`127.0.0.1:local_port`)
6//! 2. Open a data channel back to the server
7//! 3. Proxy data bidirectionally between them
8
9use std::sync::Arc;
10use std::time::Duration;
11
12use dashmap::DashMap;
13use tokio::io::{AsyncReadExt, AsyncWriteExt};
14use tokio::net::TcpStream;
15use uuid::Uuid;
16
17use crate::{Result, ServiceConfig, TunnelError};
18
19/// Manages proxying connections to local services
20pub struct LocalProxy {
21    /// Service configurations by `service_id`
22    services: DashMap<Uuid, ServiceConfig>,
23
24    /// Active connections (wrapped in Arc for sharing with spawned tasks)
25    connections: Arc<DashMap<Uuid, ConnectionHandle>>,
26
27    /// Connection timeout
28    connect_timeout: Duration,
29}
30
31struct ConnectionHandle {
32    /// Abort handle for the proxy task
33    abort_handle: tokio::task::AbortHandle,
34}
35
36impl LocalProxy {
37    /// Create a new local proxy with the specified connection timeout
38    #[must_use]
39    pub fn new(connect_timeout: Duration) -> Self {
40        Self {
41            services: DashMap::new(),
42            connections: Arc::new(DashMap::new()),
43            connect_timeout,
44        }
45    }
46
47    /// Register a service for proxying
48    pub fn register_service(&self, service_id: Uuid, config: ServiceConfig) {
49        self.services.insert(service_id, config);
50    }
51
52    /// Unregister a service
53    pub fn unregister_service(&self, service_id: Uuid) {
54        self.services.remove(&service_id);
55    }
56
57    /// Handle an incoming connection request
58    ///
59    /// This connects to the local service and returns a channel for the data stream.
60    /// The caller is responsible for connecting the other end (server data channel).
61    ///
62    /// # Errors
63    ///
64    /// Returns an error if:
65    /// - The `service_id` is not registered
66    /// - Connection to the local service times out
67    /// - Connection to the local service fails
68    pub async fn handle_connection(
69        &self,
70        service_id: Uuid,
71        connection_id: Uuid,
72    ) -> Result<TcpStream> {
73        let config = self
74            .services
75            .get(&service_id)
76            .ok_or_else(|| TunnelError::registry(format!("Unknown service: {service_id}")))?
77            .clone();
78
79        let addr = format!("127.0.0.1:{}", config.local_port);
80
81        let local_stream = tokio::time::timeout(self.connect_timeout, TcpStream::connect(&addr))
82            .await
83            .map_err(|_| TunnelError::timeout())?
84            .map_err(|e| TunnelError::Connection { source: e })?;
85
86        tracing::debug!(
87            service_id = %service_id,
88            connection_id = %connection_id,
89            local_addr = %addr,
90            "Connected to local service"
91        );
92
93        Ok(local_stream)
94    }
95
96    /// Proxy data between two streams bidirectionally
97    ///
98    /// Returns when either side closes or an error occurs.
99    /// Returns the number of bytes (`sent_to_remote`, `received_from_remote`).
100    ///
101    /// # Errors
102    ///
103    /// This method does not return errors; it gracefully handles stream
104    /// closure and I/O errors by terminating the proxy loop.
105    pub async fn proxy_streams(local: TcpStream, remote: TcpStream) -> Result<(u64, u64)> {
106        let (mut local_read, mut local_write) = local.into_split();
107        let (mut remote_read, mut remote_write) = remote.into_split();
108
109        let local_to_remote = async {
110            let mut buf = vec![0u8; 8192];
111            let mut total = 0u64;
112            loop {
113                match local_read.read(&mut buf).await {
114                    Ok(0) | Err(_) => break,
115                    Ok(n) => {
116                        if remote_write.write_all(&buf[..n]).await.is_err() {
117                            break;
118                        }
119                        total += n as u64;
120                    }
121                }
122            }
123            let _ = remote_write.shutdown().await;
124            total
125        };
126
127        let remote_to_local = async {
128            let mut buf = vec![0u8; 8192];
129            let mut total = 0u64;
130            loop {
131                match remote_read.read(&mut buf).await {
132                    Ok(0) | Err(_) => break,
133                    Ok(n) => {
134                        if local_write.write_all(&buf[..n]).await.is_err() {
135                            break;
136                        }
137                        total += n as u64;
138                    }
139                }
140            }
141            let _ = local_write.shutdown().await;
142            total
143        };
144
145        let (sent, received) = tokio::join!(local_to_remote, remote_to_local);
146
147        Ok((sent, received))
148    }
149
150    /// Start a proxy task for a connection
151    ///
152    /// Spawns a background task that proxies data between the local and remote streams.
153    /// The task is automatically cleaned up when the proxy completes or is cancelled.
154    pub fn start_proxy(
155        &self,
156        connection_id: Uuid,
157        local: TcpStream,
158        remote: TcpStream,
159    ) -> tokio::task::JoinHandle<Result<(u64, u64)>> {
160        let connections = Arc::clone(&self.connections);
161
162        let handle = tokio::spawn(async move {
163            let result = Self::proxy_streams(local, remote).await;
164            connections.remove(&connection_id);
165            result
166        });
167
168        self.connections.insert(
169            connection_id,
170            ConnectionHandle {
171                abort_handle: handle.abort_handle(),
172            },
173        );
174
175        handle
176    }
177
178    /// Cancel a proxy connection
179    pub fn cancel_connection(&self, connection_id: Uuid) {
180        if let Some((_, handle)) = self.connections.remove(&connection_id) {
181            handle.abort_handle.abort();
182        }
183    }
184
185    /// Get count of active connections
186    #[must_use]
187    pub fn connection_count(&self) -> usize {
188        self.connections.len()
189    }
190
191    /// Get count of registered services
192    #[must_use]
193    pub fn service_count(&self) -> usize {
194        self.services.len()
195    }
196
197    /// Clean up all connections
198    pub fn shutdown(&self) {
199        for item in self.connections.iter() {
200            item.abort_handle.abort();
201        }
202        self.connections.clear();
203    }
204}
205
206impl Drop for LocalProxy {
207    fn drop(&mut self) {
208        self.shutdown();
209    }
210}
211
212#[cfg(test)]
213mod tests {
214    use super::*;
215
216    #[test]
217    fn test_register_and_unregister_services() {
218        let proxy = LocalProxy::new(Duration::from_secs(5));
219
220        let service_id1 = Uuid::new_v4();
221        let service_id2 = Uuid::new_v4();
222        let config1 = ServiceConfig::tcp("ssh", 22);
223        let config2 = ServiceConfig::tcp("postgres", 5432);
224
225        // Register services
226        proxy.register_service(service_id1, config1);
227        proxy.register_service(service_id2, config2);
228        assert_eq!(proxy.service_count(), 2);
229
230        // Verify services are registered
231        assert!(proxy.services.contains_key(&service_id1));
232        assert!(proxy.services.contains_key(&service_id2));
233
234        // Unregister one service
235        proxy.unregister_service(service_id1);
236        assert_eq!(proxy.service_count(), 1);
237        assert!(!proxy.services.contains_key(&service_id1));
238        assert!(proxy.services.contains_key(&service_id2));
239
240        // Unregister the other
241        proxy.unregister_service(service_id2);
242        assert_eq!(proxy.service_count(), 0);
243    }
244
245    #[test]
246    fn test_service_and_connection_counts() {
247        let proxy = LocalProxy::new(Duration::from_secs(5));
248
249        assert_eq!(proxy.service_count(), 0);
250        assert_eq!(proxy.connection_count(), 0);
251
252        let service_id = Uuid::new_v4();
253        proxy.register_service(service_id, ServiceConfig::tcp("test", 8080));
254        assert_eq!(proxy.service_count(), 1);
255        assert_eq!(proxy.connection_count(), 0);
256
257        proxy.unregister_service(service_id);
258        assert_eq!(proxy.service_count(), 0);
259    }
260
261    #[test]
262    fn test_cancel_connection() {
263        let proxy = LocalProxy::new(Duration::from_secs(5));
264        let connection_id = Uuid::new_v4();
265
266        // Cancelling a non-existent connection should not panic
267        proxy.cancel_connection(connection_id);
268        assert_eq!(proxy.connection_count(), 0);
269    }
270
271    #[test]
272    fn test_shutdown_clears_all() {
273        let proxy = LocalProxy::new(Duration::from_secs(5));
274
275        // Register some services
276        for _ in 0..5 {
277            proxy.register_service(Uuid::new_v4(), ServiceConfig::tcp("test", 8080));
278        }
279        assert_eq!(proxy.service_count(), 5);
280
281        // Shutdown clears connections (services remain registered)
282        proxy.shutdown();
283        assert_eq!(proxy.connection_count(), 0);
284        // Note: shutdown only clears connections, not services
285        assert_eq!(proxy.service_count(), 5);
286    }
287
288    #[tokio::test]
289    async fn test_handle_connection_unknown_service() {
290        let proxy = LocalProxy::new(Duration::from_secs(5));
291        let service_id = Uuid::new_v4();
292        let connection_id = Uuid::new_v4();
293
294        let result = proxy.handle_connection(service_id, connection_id).await;
295        assert!(result.is_err());
296
297        let err = result.unwrap_err();
298        assert!(matches!(err, TunnelError::Registry { .. }));
299    }
300
301    #[tokio::test]
302    async fn test_handle_connection_timeout() {
303        let proxy = LocalProxy::new(Duration::from_millis(100));
304        let service_id = Uuid::new_v4();
305        let connection_id = Uuid::new_v4();
306
307        // Register a service pointing to a port that's unlikely to be listening
308        proxy.register_service(service_id, ServiceConfig::tcp("test", 65432));
309
310        let result = proxy.handle_connection(service_id, connection_id).await;
311        assert!(result.is_err());
312
313        // Should be either timeout or connection refused
314        let err = result.unwrap_err();
315        assert!(matches!(
316            err,
317            TunnelError::Timeout | TunnelError::Connection { .. }
318        ));
319    }
320
321    #[test]
322    fn test_unregister_nonexistent_service() {
323        let proxy = LocalProxy::new(Duration::from_secs(5));
324        let service_id = Uuid::new_v4();
325
326        // Should not panic
327        proxy.unregister_service(service_id);
328        assert_eq!(proxy.service_count(), 0);
329    }
330
331    #[test]
332    fn test_register_overwrites_existing() {
333        let proxy = LocalProxy::new(Duration::from_secs(5));
334        let service_id = Uuid::new_v4();
335
336        proxy.register_service(service_id, ServiceConfig::tcp("first", 8080));
337        assert_eq!(proxy.service_count(), 1);
338
339        // Register with same ID should overwrite
340        proxy.register_service(service_id, ServiceConfig::tcp("second", 9090));
341        assert_eq!(proxy.service_count(), 1);
342
343        // Verify the config was updated
344        let config = proxy.services.get(&service_id).unwrap();
345        assert_eq!(config.name, "second");
346        assert_eq!(config.local_port, 9090);
347    }
348}