siphon_server/
tcp_plane.rs

1use std::net::SocketAddr;
2use std::sync::Arc;
3
4use anyhow::Result;
5use tokio::io::{AsyncReadExt, AsyncWriteExt};
6use tokio::net::{TcpListener, TcpStream};
7use tokio::sync::mpsc;
8
9use siphon_protocol::ServerMessage;
10
11use crate::router::Router;
12use crate::state::{PortAllocator, StreamIdGenerator, TcpConnectionHandle, TcpConnectionRegistry};
13
14/// TCP data plane for direct TCP tunnel connections
15pub struct TcpPlane {
16    router: Arc<Router>,
17    port_allocator: Arc<PortAllocator>,
18    tcp_registry: TcpConnectionRegistry,
19    stream_id_gen: Arc<StreamIdGenerator>,
20}
21
22impl TcpPlane {
23    pub fn new(
24        router: Arc<Router>,
25        port_allocator: Arc<PortAllocator>,
26        tcp_registry: TcpConnectionRegistry,
27        stream_id_gen: Arc<StreamIdGenerator>,
28    ) -> Arc<Self> {
29        Arc::new(Self {
30            router,
31            port_allocator,
32            tcp_registry,
33            stream_id_gen,
34        })
35    }
36
37    /// Allocate a port and start listening for TCP connections
38    pub async fn allocate_and_listen(self: Arc<Self>, subdomain: String) -> Result<u16> {
39        let port = self
40            .port_allocator
41            .allocate()
42            .ok_or_else(|| anyhow::anyhow!("No available ports"))?;
43
44        let addr: SocketAddr = format!("0.0.0.0:{}", port).parse()?;
45        let listener = TcpListener::bind(addr).await?;
46
47        tracing::info!(
48            "TCP plane listening on {} for subdomain {}",
49            addr,
50            subdomain
51        );
52
53        let this = self.clone();
54        let subdomain_clone = subdomain.clone();
55
56        // Spawn listener task
57        tokio::spawn(async move {
58            loop {
59                match listener.accept().await {
60                    Ok((stream, peer_addr)) => {
61                        tracing::info!(
62                            "TCP connection from {} for subdomain {}",
63                            peer_addr,
64                            subdomain_clone
65                        );
66                        let this = this.clone();
67                        let subdomain = subdomain_clone.clone();
68                        tokio::spawn(async move {
69                            if let Err(e) = this.handle_tcp_connection(stream, subdomain).await {
70                                tracing::error!("TCP connection error: {}", e);
71                            }
72                        });
73                    }
74                    Err(e) => {
75                        tracing::error!("TCP accept error: {}", e);
76                        break;
77                    }
78                }
79            }
80        });
81
82        Ok(port)
83    }
84
85    /// Handle an incoming TCP connection
86    async fn handle_tcp_connection(
87        self: Arc<Self>,
88        stream: TcpStream,
89        subdomain: String,
90    ) -> Result<()> {
91        let stream_id = self.stream_id_gen.next();
92        tracing::debug!("New TCP stream {} for subdomain {}", stream_id, subdomain);
93
94        // Get sender for this subdomain
95        let tunnel_sender = match self.router.get_sender(&subdomain) {
96            Some(s) => s,
97            None => {
98                tracing::warn!("No tunnel for subdomain: {}", subdomain);
99                return Ok(());
100            }
101        };
102
103        // Split the stream
104        let (mut read_half, mut write_half) = stream.into_split();
105
106        // Create channel for writing data back to this TCP connection
107        let (write_tx, mut write_rx) = mpsc::channel::<Vec<u8>>(32);
108
109        // Register this connection
110        self.tcp_registry.insert(
111            stream_id,
112            TcpConnectionHandle {
113                writer: write_tx,
114                subdomain: subdomain.clone(),
115            },
116        );
117
118        // Send TcpConnect to client
119        if let Err(e) = tunnel_sender
120            .send(ServerMessage::TcpConnect { stream_id })
121            .await
122        {
123            tracing::error!("Failed to send TcpConnect: {}", e);
124            self.tcp_registry.remove(&stream_id);
125            return Ok(());
126        }
127
128        // Spawn write task (receives data from tunnel client, writes to TCP)
129        let tcp_registry = self.tcp_registry.clone();
130        let tunnel_sender_clone = tunnel_sender.clone();
131        let write_task = tokio::spawn(async move {
132            while let Some(data) = write_rx.recv().await {
133                if let Err(e) = write_half.write_all(&data).await {
134                    tracing::error!("Failed to write to TCP stream {}: {}", stream_id, e);
135                    break;
136                }
137            }
138            // Connection closed, send TcpClose
139            let _ = tunnel_sender_clone
140                .send(ServerMessage::TcpClose { stream_id })
141                .await;
142            tcp_registry.remove(&stream_id);
143        });
144
145        // Read from TCP, send to tunnel
146        let mut buf = vec![0u8; 8192];
147        loop {
148            match read_half.read(&mut buf).await {
149                Ok(0) => {
150                    // EOF
151                    tracing::debug!("TCP stream {} closed by remote", stream_id);
152                    break;
153                }
154                Ok(n) => {
155                    let data = buf[..n].to_vec();
156                    if let Err(e) = tunnel_sender
157                        .send(ServerMessage::TcpData { stream_id, data })
158                        .await
159                    {
160                        tracing::error!("Failed to send TcpData: {}", e);
161                        break;
162                    }
163                }
164                Err(e) => {
165                    tracing::error!("TCP read error on stream {}: {}", stream_id, e);
166                    break;
167                }
168            }
169        }
170
171        // Clean up
172        self.tcp_registry.remove(&stream_id);
173        write_task.abort();
174
175        // Send TcpClose
176        let _ = tunnel_sender
177            .send(ServerMessage::TcpClose { stream_id })
178            .await;
179
180        Ok(())
181    }
182
183    /// Release a port when tunnel is closed
184    pub fn release_port(&self, port: u16) {
185        self.port_allocator.release(port);
186    }
187
188    /// Get write channel for a stream
189    pub fn get_writer(&self, stream_id: u64) -> Option<mpsc::Sender<Vec<u8>>> {
190        self.tcp_registry.get(&stream_id).map(|h| h.writer.clone())
191    }
192
193    /// Close a TCP connection
194    pub fn close_connection(&self, stream_id: u64) {
195        if let Some((_, handle)) = self.tcp_registry.remove(&stream_id) {
196            // Dropping the sender will cause the write task to exit
197            drop(handle);
198        }
199    }
200}