siphon_server/
control_plane.rs

1use std::net::SocketAddr;
2use std::sync::Arc;
3
4use anyhow::Result;
5use bytes::BytesMut;
6use cuid2::CuidConstructor;
7use tokio::io::{AsyncReadExt, AsyncWriteExt};
8use tokio::net::{TcpListener, TcpStream};
9use tokio::sync::mpsc;
10use tokio_rustls::TlsAcceptor;
11use tokio_util::codec::{Decoder, Encoder};
12
13use siphon_protocol::{ClientMessage, ServerMessage, TunnelCodec, TunnelType};
14
15use crate::dns_provider::DnsProvider;
16use crate::router::{Router, TunnelHandle};
17use crate::state::{HttpResponseData, ResponseRegistry, TcpConnectionRegistry};
18use crate::tcp_plane::TcpPlane;
19
20/// Control plane server that accepts tunnel client connections via mTLS
21pub struct ControlPlane {
22    router: Arc<Router>,
23    tls_acceptor: TlsAcceptor,
24    dns_provider: Arc<dyn DnsProvider>,
25    base_domain: String,
26    response_registry: ResponseRegistry,
27    tcp_plane: Arc<TcpPlane>,
28    tcp_registry: TcpConnectionRegistry,
29}
30
31impl ControlPlane {
32    pub fn new(
33        router: Arc<Router>,
34        tls_acceptor: TlsAcceptor,
35        dns_provider: Arc<dyn DnsProvider>,
36        base_domain: String,
37        response_registry: ResponseRegistry,
38        tcp_plane: Arc<TcpPlane>,
39        tcp_registry: TcpConnectionRegistry,
40    ) -> Arc<Self> {
41        Arc::new(Self {
42            router,
43            tls_acceptor,
44            dns_provider,
45            base_domain,
46            response_registry,
47            tcp_plane,
48            tcp_registry,
49        })
50    }
51
52    /// Start listening for tunnel client connections
53    pub async fn run(self: Arc<Self>, addr: SocketAddr) -> Result<()> {
54        let listener = TcpListener::bind(addr).await?;
55        tracing::info!("Control plane listening on {}", addr);
56        self.run_with_listener(listener).await
57    }
58
59    /// Start accepting connections from a pre-bound listener
60    ///
61    /// This is useful for testing where the caller wants to bind to an
62    /// ephemeral port and get the actual address before starting the server.
63    pub async fn run_with_listener(self: Arc<Self>, listener: TcpListener) -> Result<()> {
64        loop {
65            let (stream, peer_addr) = listener.accept().await?;
66            let this = self.clone();
67
68            tokio::spawn(async move {
69                if let Err(e) = this.handle_connection(stream, peer_addr).await {
70                    tracing::error!("Connection error from {}: {}", peer_addr, e);
71                }
72            });
73        }
74    }
75
76    async fn handle_connection(
77        self: Arc<Self>,
78        stream: TcpStream,
79        peer_addr: SocketAddr,
80    ) -> Result<()> {
81        tracing::info!("New connection from {}", peer_addr);
82
83        // Perform TLS handshake with client cert verification
84        let tls_stream = self.tls_acceptor.accept(stream).await?;
85        tracing::info!("TLS handshake complete with {}", peer_addr);
86
87        // Extract client identity from certificate
88        let client_id = extract_client_id(&tls_stream);
89        tracing::info!("Client identified as: {}", client_id);
90
91        // Split the stream for reading and writing
92        let (read_half, write_half) = tokio::io::split(tls_stream);
93
94        // Create channels for communication
95        let (tx, mut rx) = mpsc::channel::<ServerMessage>(32);
96
97        // Read loop: process incoming messages from client
98        let router = self.router.clone();
99        let dns_provider = self.dns_provider.clone();
100        let base_domain = self.base_domain.clone();
101        let client_id_clone = client_id.clone();
102        let response_registry = self.response_registry.clone();
103        let tcp_plane = self.tcp_plane.clone();
104        let _tcp_registry = self.tcp_registry.clone();
105
106        let mut codec = TunnelCodec::<ClientMessage>::new();
107        let mut read_buf = BytesMut::with_capacity(8192);
108
109        // State for this connection
110        let mut assigned_subdomain: Option<String> = None;
111        let mut assigned_tcp_port: Option<u16> = None;
112
113        // Spawn write task
114        let write_handle = tokio::spawn(async move {
115            let mut write_half = write_half;
116            let mut codec = TunnelCodec::<ServerMessage>::new();
117            let mut write_buf = BytesMut::with_capacity(8192);
118
119            while let Some(msg) = rx.recv().await {
120                write_buf.clear();
121                if let Err(e) = codec.encode(msg, &mut write_buf) {
122                    tracing::error!("Failed to encode message: {}", e);
123                    break;
124                }
125                if let Err(e) = write_half.write_all(&write_buf).await {
126                    tracing::error!("Failed to write message: {}", e);
127                    break;
128                }
129            }
130        });
131
132        // Read loop
133        let mut read_half = read_half;
134        loop {
135            // Read more data
136            match read_half.read_buf(&mut read_buf).await {
137                Ok(0) => {
138                    tracing::info!("Client {} disconnected", peer_addr);
139                    break;
140                }
141                Ok(_) => {}
142                Err(e) => {
143                    tracing::error!("Read error: {}", e);
144                    break;
145                }
146            };
147
148            // Try to decode messages
149            loop {
150                match codec.decode(&mut read_buf) {
151                    Ok(Some(msg)) => {
152                        match msg {
153                            ClientMessage::RequestTunnel {
154                                subdomain,
155                                tunnel_type,
156                                local_port,
157                            } => {
158                                tracing::info!(
159                                    "Tunnel request from {}: subdomain={:?}, type={:?}, local_port={}",
160                                    client_id_clone,
161                                    subdomain,
162                                    tunnel_type,
163                                    local_port
164                                );
165
166                                // Generate or validate subdomain
167                                let subdomain = subdomain.unwrap_or_else(|| {
168                                    // Generate random subdomain using cuid2 (always starts with a letter)
169                                    CuidConstructor::new().with_length(8).create_id()
170                                });
171
172                                // Validate subdomain format
173                                if !is_valid_subdomain(&subdomain) {
174                                    let _ = tx
175                                        .send(ServerMessage::TunnelDenied {
176                                            reason: "Invalid subdomain format".to_string(),
177                                        })
178                                        .await;
179                                    continue;
180                                }
181
182                                // Check availability
183                                if !router.is_available(&subdomain) {
184                                    let _ = tx
185                                        .send(ServerMessage::TunnelDenied {
186                                            reason: "Subdomain already in use".to_string(),
187                                        })
188                                        .await;
189                                    continue;
190                                }
191
192                                // For TCP tunnels, allocate a port first
193                                let tcp_port = if tunnel_type == TunnelType::Tcp {
194                                    match tcp_plane
195                                        .clone()
196                                        .allocate_and_listen(subdomain.clone())
197                                        .await
198                                    {
199                                        Ok(port) => Some(port),
200                                        Err(e) => {
201                                            tracing::error!("Failed to allocate TCP port: {}", e);
202                                            let _ = tx
203                                                .send(ServerMessage::TunnelDenied {
204                                                    reason: format!(
205                                                        "TCP port allocation failed: {}",
206                                                        e
207                                                    ),
208                                                })
209                                                .await;
210                                            continue;
211                                        }
212                                    }
213                                } else {
214                                    None
215                                };
216
217                                // Create DNS record
218                                let proxied = tunnel_type == TunnelType::Http;
219                                match dns_provider.create_record(&subdomain, proxied).await {
220                                    Ok(record_id) => {
221                                        // Create tunnel handle
222                                        let handle = TunnelHandle {
223                                            sender: tx.clone(),
224                                            client_id: client_id_clone.clone(),
225                                            tunnel_type: tunnel_type.clone(),
226                                            dns_record_id: Some(record_id),
227                                        };
228
229                                        // Register the tunnel
230                                        if let Err(e) =
231                                            router.register(subdomain.clone(), handle, tcp_port)
232                                        {
233                                            tracing::error!("Failed to register tunnel: {}", e);
234                                            // Release TCP port if allocated
235                                            if let Some(port) = tcp_port {
236                                                tcp_plane.release_port(port);
237                                            }
238                                            let _ = tx
239                                                .send(ServerMessage::TunnelDenied {
240                                                    reason: format!("Registration failed: {}", e),
241                                                })
242                                                .await;
243                                            continue;
244                                        }
245
246                                        assigned_subdomain = Some(subdomain.clone());
247                                        assigned_tcp_port = tcp_port;
248
249                                        let (full_url, response_port) = if tunnel_type
250                                            == TunnelType::Http
251                                        {
252                                            (format!("https://{}.{}", subdomain, base_domain), None)
253                                        } else {
254                                            (format!("{}.{}", subdomain, base_domain), tcp_port)
255                                        };
256
257                                        tracing::info!(
258                                            "Tunnel established: {} -> {} (port: {:?})",
259                                            full_url,
260                                            local_port,
261                                            response_port
262                                        );
263
264                                        let _ = tx
265                                            .send(ServerMessage::TunnelEstablished {
266                                                subdomain: subdomain.clone(),
267                                                url: full_url,
268                                                port: response_port,
269                                            })
270                                            .await;
271                                    }
272                                    Err(e) => {
273                                        tracing::error!("Failed to create DNS record: {}", e);
274                                        // Release TCP port if allocated
275                                        if let Some(port) = tcp_port {
276                                            tcp_plane.release_port(port);
277                                        }
278                                        let _ = tx
279                                            .send(ServerMessage::TunnelDenied {
280                                                reason: format!("DNS error: {}", e),
281                                            })
282                                            .await;
283                                    }
284                                }
285                            }
286                            ClientMessage::HttpResponse {
287                                stream_id,
288                                status,
289                                headers,
290                                body,
291                            } => {
292                                // Forward response to the waiting HTTP handler
293                                tracing::debug!(
294                                    "Received HTTP response for stream {}: status={}",
295                                    stream_id,
296                                    status
297                                );
298
299                                // Look up the pending response in the shared registry
300                                if let Some((_, sender)) = response_registry.remove(&stream_id) {
301                                    let response = HttpResponseData {
302                                        status,
303                                        headers,
304                                        body,
305                                    };
306                                    if sender.send(response).is_err() {
307                                        tracing::warn!(
308                                            "Failed to send response for stream {} (receiver dropped)",
309                                            stream_id
310                                        );
311                                    }
312                                } else {
313                                    tracing::warn!(
314                                        "No pending request for stream {} (may have timed out)",
315                                        stream_id
316                                    );
317                                }
318                            }
319                            ClientMessage::TcpData { stream_id, data } => {
320                                tracing::debug!(
321                                    "Received TCP data for stream {}: {} bytes",
322                                    stream_id,
323                                    data.len()
324                                );
325                                // Forward to TCP plane
326                                if let Some(writer) = tcp_plane.get_writer(stream_id) {
327                                    if let Err(e) = writer.send(data).await {
328                                        tracing::error!(
329                                            "Failed to forward TCP data to stream {}: {}",
330                                            stream_id,
331                                            e
332                                        );
333                                    }
334                                } else {
335                                    tracing::warn!(
336                                        "No TCP connection for stream {} (may have been closed)",
337                                        stream_id
338                                    );
339                                }
340                            }
341                            ClientMessage::TcpClose { stream_id } => {
342                                tracing::debug!("TCP connection {} closed by client", stream_id);
343                                // Close the TCP connection
344                                tcp_plane.close_connection(stream_id);
345                            }
346                            ClientMessage::Ping { timestamp } => {
347                                let _ = tx.send(ServerMessage::Pong { timestamp }).await;
348                            }
349                        }
350                    }
351                    Ok(None) => break, // Need more data
352                    Err(e) => {
353                        tracing::error!("Decode error: {}", e);
354                        break;
355                    }
356                }
357            }
358        }
359
360        // Cleanup
361        tracing::info!("Cleaning up connection for {}", client_id);
362
363        // Unregister tunnel
364        if let Some(subdomain) = &assigned_subdomain {
365            if let Some(handle) = router.unregister(subdomain) {
366                // Delete DNS record
367                if let Some(record_id) = handle.dns_record_id {
368                    if let Err(e) = dns_provider.delete_record(&record_id).await {
369                        tracing::error!("Failed to delete DNS record: {}", e);
370                    }
371                }
372            }
373        }
374
375        // Release TCP port if allocated
376        if let Some(port) = assigned_tcp_port {
377            tcp_plane.release_port(port);
378        }
379
380        write_handle.abort();
381        Ok(())
382    }
383}
384
385/// Extract client ID from TLS connection (certificate CN)
386fn extract_client_id<S>(tls_stream: &tokio_rustls::server::TlsStream<S>) -> String {
387    // In a full implementation, we would extract the CN from the client certificate
388    // For now, generate a unique ID
389    let (_, server_conn) = tls_stream.get_ref();
390
391    if let Some(certs) = server_conn.peer_certificates() {
392        if let Some(cert) = certs.first() {
393            // Hash the certificate for a stable ID
394            use std::collections::hash_map::DefaultHasher;
395            use std::hash::{Hash, Hasher};
396            let mut hasher = DefaultHasher::new();
397            cert.as_ref().hash(&mut hasher);
398            return format!("client-{:016x}", hasher.finish());
399        }
400    }
401
402    format!(
403        "unknown-{}",
404        CuidConstructor::new().with_length(8).create_id()
405    )
406}
407
408/// Validate subdomain format (alphanumeric and hyphens only)
409fn is_valid_subdomain(subdomain: &str) -> bool {
410    if subdomain.is_empty() || subdomain.len() > 63 {
411        return false;
412    }
413
414    // Must start and end with alphanumeric
415    let chars: Vec<char> = subdomain.chars().collect();
416    if !chars.first().map(|c| c.is_alphanumeric()).unwrap_or(false) {
417        return false;
418    }
419    if !chars.last().map(|c| c.is_alphanumeric()).unwrap_or(false) {
420        return false;
421    }
422
423    // Only alphanumeric and hyphens
424    subdomain.chars().all(|c| c.is_alphanumeric() || c == '-')
425}
426
427#[cfg(test)]
428mod tests {
429    use super::*;
430
431    #[test]
432    fn test_valid_subdomains() {
433        assert!(is_valid_subdomain("myapp"));
434        assert!(is_valid_subdomain("my-app"));
435        assert!(is_valid_subdomain("my-app-123"));
436        assert!(is_valid_subdomain("a"));
437        assert!(is_valid_subdomain("123"));
438    }
439
440    #[test]
441    fn test_invalid_subdomains() {
442        assert!(!is_valid_subdomain(""));
443        assert!(!is_valid_subdomain("-myapp"));
444        assert!(!is_valid_subdomain("myapp-"));
445        assert!(!is_valid_subdomain("my_app"));
446        assert!(!is_valid_subdomain("my.app"));
447        assert!(!is_valid_subdomain(&"a".repeat(64)));
448    }
449}