zlayer_tunnel/client/
proxy.rs1use std::sync::Arc;
10use std::time::Duration;
11
12use dashmap::DashMap;
13use tokio::io::{AsyncReadExt, AsyncWriteExt};
14use tokio::net::TcpStream;
15use uuid::Uuid;
16
17use crate::{Result, ServiceConfig, TunnelError};
18
19pub struct LocalProxy {
21 services: DashMap<Uuid, ServiceConfig>,
23
24 connections: Arc<DashMap<Uuid, ConnectionHandle>>,
26
27 connect_timeout: Duration,
29}
30
31struct ConnectionHandle {
32 abort_handle: tokio::task::AbortHandle,
34}
35
36impl LocalProxy {
37 #[must_use]
39 pub fn new(connect_timeout: Duration) -> Self {
40 Self {
41 services: DashMap::new(),
42 connections: Arc::new(DashMap::new()),
43 connect_timeout,
44 }
45 }
46
47 pub fn register_service(&self, service_id: Uuid, config: ServiceConfig) {
49 self.services.insert(service_id, config);
50 }
51
52 pub fn unregister_service(&self, service_id: Uuid) {
54 self.services.remove(&service_id);
55 }
56
57 pub async fn handle_connection(
69 &self,
70 service_id: Uuid,
71 connection_id: Uuid,
72 ) -> Result<TcpStream> {
73 let config = self
74 .services
75 .get(&service_id)
76 .ok_or_else(|| TunnelError::registry(format!("Unknown service: {service_id}")))?
77 .clone();
78
79 let addr = format!("127.0.0.1:{}", config.local_port);
80
81 let local_stream = tokio::time::timeout(self.connect_timeout, TcpStream::connect(&addr))
82 .await
83 .map_err(|_| TunnelError::timeout())?
84 .map_err(|e| TunnelError::Connection { source: e })?;
85
86 tracing::debug!(
87 service_id = %service_id,
88 connection_id = %connection_id,
89 local_addr = %addr,
90 "Connected to local service"
91 );
92
93 Ok(local_stream)
94 }
95
96 pub async fn proxy_streams(local: TcpStream, remote: TcpStream) -> Result<(u64, u64)> {
106 let (mut local_read, mut local_write) = local.into_split();
107 let (mut remote_read, mut remote_write) = remote.into_split();
108
109 let local_to_remote = async {
110 let mut buf = vec![0u8; 8192];
111 let mut total = 0u64;
112 loop {
113 match local_read.read(&mut buf).await {
114 Ok(0) | Err(_) => break,
115 Ok(n) => {
116 if remote_write.write_all(&buf[..n]).await.is_err() {
117 break;
118 }
119 total += n as u64;
120 }
121 }
122 }
123 let _ = remote_write.shutdown().await;
124 total
125 };
126
127 let remote_to_local = async {
128 let mut buf = vec![0u8; 8192];
129 let mut total = 0u64;
130 loop {
131 match remote_read.read(&mut buf).await {
132 Ok(0) | Err(_) => break,
133 Ok(n) => {
134 if local_write.write_all(&buf[..n]).await.is_err() {
135 break;
136 }
137 total += n as u64;
138 }
139 }
140 }
141 let _ = local_write.shutdown().await;
142 total
143 };
144
145 let (sent, received) = tokio::join!(local_to_remote, remote_to_local);
146
147 Ok((sent, received))
148 }
149
150 pub fn start_proxy(
155 &self,
156 connection_id: Uuid,
157 local: TcpStream,
158 remote: TcpStream,
159 ) -> tokio::task::JoinHandle<Result<(u64, u64)>> {
160 let connections = Arc::clone(&self.connections);
161
162 let handle = tokio::spawn(async move {
163 let result = Self::proxy_streams(local, remote).await;
164 connections.remove(&connection_id);
165 result
166 });
167
168 self.connections.insert(
169 connection_id,
170 ConnectionHandle {
171 abort_handle: handle.abort_handle(),
172 },
173 );
174
175 handle
176 }
177
178 pub fn cancel_connection(&self, connection_id: Uuid) {
180 if let Some((_, handle)) = self.connections.remove(&connection_id) {
181 handle.abort_handle.abort();
182 }
183 }
184
185 #[must_use]
187 pub fn connection_count(&self) -> usize {
188 self.connections.len()
189 }
190
191 #[must_use]
193 pub fn service_count(&self) -> usize {
194 self.services.len()
195 }
196
197 pub fn shutdown(&self) {
199 for item in self.connections.iter() {
200 item.abort_handle.abort();
201 }
202 self.connections.clear();
203 }
204}
205
206impl Drop for LocalProxy {
207 fn drop(&mut self) {
208 self.shutdown();
209 }
210}
211
212#[cfg(test)]
213mod tests {
214 use super::*;
215
216 #[test]
217 fn test_register_and_unregister_services() {
218 let proxy = LocalProxy::new(Duration::from_secs(5));
219
220 let service_id1 = Uuid::new_v4();
221 let service_id2 = Uuid::new_v4();
222 let config1 = ServiceConfig::tcp("ssh", 22);
223 let config2 = ServiceConfig::tcp("postgres", 5432);
224
225 proxy.register_service(service_id1, config1);
227 proxy.register_service(service_id2, config2);
228 assert_eq!(proxy.service_count(), 2);
229
230 assert!(proxy.services.contains_key(&service_id1));
232 assert!(proxy.services.contains_key(&service_id2));
233
234 proxy.unregister_service(service_id1);
236 assert_eq!(proxy.service_count(), 1);
237 assert!(!proxy.services.contains_key(&service_id1));
238 assert!(proxy.services.contains_key(&service_id2));
239
240 proxy.unregister_service(service_id2);
242 assert_eq!(proxy.service_count(), 0);
243 }
244
245 #[test]
246 fn test_service_and_connection_counts() {
247 let proxy = LocalProxy::new(Duration::from_secs(5));
248
249 assert_eq!(proxy.service_count(), 0);
250 assert_eq!(proxy.connection_count(), 0);
251
252 let service_id = Uuid::new_v4();
253 proxy.register_service(service_id, ServiceConfig::tcp("test", 8080));
254 assert_eq!(proxy.service_count(), 1);
255 assert_eq!(proxy.connection_count(), 0);
256
257 proxy.unregister_service(service_id);
258 assert_eq!(proxy.service_count(), 0);
259 }
260
261 #[test]
262 fn test_cancel_connection() {
263 let proxy = LocalProxy::new(Duration::from_secs(5));
264 let connection_id = Uuid::new_v4();
265
266 proxy.cancel_connection(connection_id);
268 assert_eq!(proxy.connection_count(), 0);
269 }
270
271 #[test]
272 fn test_shutdown_clears_all() {
273 let proxy = LocalProxy::new(Duration::from_secs(5));
274
275 for _ in 0..5 {
277 proxy.register_service(Uuid::new_v4(), ServiceConfig::tcp("test", 8080));
278 }
279 assert_eq!(proxy.service_count(), 5);
280
281 proxy.shutdown();
283 assert_eq!(proxy.connection_count(), 0);
284 assert_eq!(proxy.service_count(), 5);
286 }
287
288 #[tokio::test]
289 async fn test_handle_connection_unknown_service() {
290 let proxy = LocalProxy::new(Duration::from_secs(5));
291 let service_id = Uuid::new_v4();
292 let connection_id = Uuid::new_v4();
293
294 let result = proxy.handle_connection(service_id, connection_id).await;
295 assert!(result.is_err());
296
297 let err = result.unwrap_err();
298 assert!(matches!(err, TunnelError::Registry { .. }));
299 }
300
301 #[tokio::test]
302 async fn test_handle_connection_timeout() {
303 let proxy = LocalProxy::new(Duration::from_millis(100));
304 let service_id = Uuid::new_v4();
305 let connection_id = Uuid::new_v4();
306
307 proxy.register_service(service_id, ServiceConfig::tcp("test", 65432));
309
310 let result = proxy.handle_connection(service_id, connection_id).await;
311 assert!(result.is_err());
312
313 let err = result.unwrap_err();
315 assert!(matches!(
316 err,
317 TunnelError::Timeout | TunnelError::Connection { .. }
318 ));
319 }
320
321 #[test]
322 fn test_unregister_nonexistent_service() {
323 let proxy = LocalProxy::new(Duration::from_secs(5));
324 let service_id = Uuid::new_v4();
325
326 proxy.unregister_service(service_id);
328 assert_eq!(proxy.service_count(), 0);
329 }
330
331 #[test]
332 fn test_register_overwrites_existing() {
333 let proxy = LocalProxy::new(Duration::from_secs(5));
334 let service_id = Uuid::new_v4();
335
336 proxy.register_service(service_id, ServiceConfig::tcp("first", 8080));
337 assert_eq!(proxy.service_count(), 1);
338
339 proxy.register_service(service_id, ServiceConfig::tcp("second", 9090));
341 assert_eq!(proxy.service_count(), 1);
342
343 let config = proxy.services.get(&service_id).unwrap();
345 assert_eq!(config.name, "second");
346 assert_eq!(config.local_port, 9090);
347 }
348}