1use std::net::SocketAddr;
2use std::sync::Arc;
3
4use anyhow::Result;
5use bytes::BytesMut;
6use cuid2::CuidConstructor;
7use tokio::io::{AsyncReadExt, AsyncWriteExt};
8use tokio::net::{TcpListener, TcpStream};
9use tokio::sync::mpsc;
10use tokio_rustls::TlsAcceptor;
11use tokio_util::codec::{Decoder, Encoder};
12
13use siphon_protocol::{ClientMessage, ServerMessage, TunnelCodec, TunnelType};
14
15use crate::dns_provider::DnsProvider;
16use crate::router::{Router, TunnelHandle};
17use crate::state::{HttpResponseData, ResponseRegistry, TcpConnectionRegistry};
18use crate::tcp_plane::TcpPlane;
19
20pub struct ControlPlane {
22 router: Arc<Router>,
23 tls_acceptor: TlsAcceptor,
24 dns_provider: Arc<dyn DnsProvider>,
25 base_domain: String,
26 response_registry: ResponseRegistry,
27 tcp_plane: Arc<TcpPlane>,
28 tcp_registry: TcpConnectionRegistry,
29}
30
31impl ControlPlane {
32 pub fn new(
33 router: Arc<Router>,
34 tls_acceptor: TlsAcceptor,
35 dns_provider: Arc<dyn DnsProvider>,
36 base_domain: String,
37 response_registry: ResponseRegistry,
38 tcp_plane: Arc<TcpPlane>,
39 tcp_registry: TcpConnectionRegistry,
40 ) -> Arc<Self> {
41 Arc::new(Self {
42 router,
43 tls_acceptor,
44 dns_provider,
45 base_domain,
46 response_registry,
47 tcp_plane,
48 tcp_registry,
49 })
50 }
51
52 pub async fn run(self: Arc<Self>, addr: SocketAddr) -> Result<()> {
54 let listener = TcpListener::bind(addr).await?;
55 tracing::info!("Control plane listening on {}", addr);
56 self.run_with_listener(listener).await
57 }
58
59 pub async fn run_with_listener(self: Arc<Self>, listener: TcpListener) -> Result<()> {
64 loop {
65 let (stream, peer_addr) = listener.accept().await?;
66 let this = self.clone();
67
68 tokio::spawn(async move {
69 if let Err(e) = this.handle_connection(stream, peer_addr).await {
70 tracing::error!("Connection error from {}: {}", peer_addr, e);
71 }
72 });
73 }
74 }
75
76 async fn handle_connection(
77 self: Arc<Self>,
78 stream: TcpStream,
79 peer_addr: SocketAddr,
80 ) -> Result<()> {
81 tracing::info!("New connection from {}", peer_addr);
82
83 let tls_stream = self.tls_acceptor.accept(stream).await?;
85 tracing::info!("TLS handshake complete with {}", peer_addr);
86
87 let client_id = extract_client_id(&tls_stream);
89 tracing::info!("Client identified as: {}", client_id);
90
91 let (read_half, write_half) = tokio::io::split(tls_stream);
93
94 let (tx, mut rx) = mpsc::channel::<ServerMessage>(32);
96
97 let router = self.router.clone();
99 let dns_provider = self.dns_provider.clone();
100 let base_domain = self.base_domain.clone();
101 let client_id_clone = client_id.clone();
102 let response_registry = self.response_registry.clone();
103 let tcp_plane = self.tcp_plane.clone();
104 let _tcp_registry = self.tcp_registry.clone();
105
106 let mut codec = TunnelCodec::<ClientMessage>::new();
107 let mut read_buf = BytesMut::with_capacity(8192);
108
109 let mut assigned_subdomain: Option<String> = None;
111 let mut assigned_tcp_port: Option<u16> = None;
112
113 let write_handle = tokio::spawn(async move {
115 let mut write_half = write_half;
116 let mut codec = TunnelCodec::<ServerMessage>::new();
117 let mut write_buf = BytesMut::with_capacity(8192);
118
119 while let Some(msg) = rx.recv().await {
120 write_buf.clear();
121 if let Err(e) = codec.encode(msg, &mut write_buf) {
122 tracing::error!("Failed to encode message: {}", e);
123 break;
124 }
125 if let Err(e) = write_half.write_all(&write_buf).await {
126 tracing::error!("Failed to write message: {}", e);
127 break;
128 }
129 }
130 });
131
132 let mut read_half = read_half;
134 loop {
135 match read_half.read_buf(&mut read_buf).await {
137 Ok(0) => {
138 tracing::info!("Client {} disconnected", peer_addr);
139 break;
140 }
141 Ok(_) => {}
142 Err(e) => {
143 tracing::error!("Read error: {}", e);
144 break;
145 }
146 };
147
148 loop {
150 match codec.decode(&mut read_buf) {
151 Ok(Some(msg)) => {
152 match msg {
153 ClientMessage::RequestTunnel {
154 subdomain,
155 tunnel_type,
156 local_port,
157 } => {
158 tracing::info!(
159 "Tunnel request from {}: subdomain={:?}, type={:?}, local_port={}",
160 client_id_clone,
161 subdomain,
162 tunnel_type,
163 local_port
164 );
165
166 let subdomain = subdomain.unwrap_or_else(|| {
168 CuidConstructor::new().with_length(8).create_id()
170 });
171
172 if !is_valid_subdomain(&subdomain) {
174 let _ = tx
175 .send(ServerMessage::TunnelDenied {
176 reason: "Invalid subdomain format".to_string(),
177 })
178 .await;
179 continue;
180 }
181
182 if !router.is_available(&subdomain) {
184 let _ = tx
185 .send(ServerMessage::TunnelDenied {
186 reason: "Subdomain already in use".to_string(),
187 })
188 .await;
189 continue;
190 }
191
192 let tcp_port = if tunnel_type == TunnelType::Tcp {
194 match tcp_plane
195 .clone()
196 .allocate_and_listen(subdomain.clone())
197 .await
198 {
199 Ok(port) => Some(port),
200 Err(e) => {
201 tracing::error!("Failed to allocate TCP port: {}", e);
202 let _ = tx
203 .send(ServerMessage::TunnelDenied {
204 reason: format!(
205 "TCP port allocation failed: {}",
206 e
207 ),
208 })
209 .await;
210 continue;
211 }
212 }
213 } else {
214 None
215 };
216
217 let proxied = tunnel_type == TunnelType::Http;
219 match dns_provider.create_record(&subdomain, proxied).await {
220 Ok(record_id) => {
221 let handle = TunnelHandle {
223 sender: tx.clone(),
224 client_id: client_id_clone.clone(),
225 tunnel_type: tunnel_type.clone(),
226 dns_record_id: Some(record_id),
227 };
228
229 if let Err(e) =
231 router.register(subdomain.clone(), handle, tcp_port)
232 {
233 tracing::error!("Failed to register tunnel: {}", e);
234 if let Some(port) = tcp_port {
236 tcp_plane.release_port(port);
237 }
238 let _ = tx
239 .send(ServerMessage::TunnelDenied {
240 reason: format!("Registration failed: {}", e),
241 })
242 .await;
243 continue;
244 }
245
246 assigned_subdomain = Some(subdomain.clone());
247 assigned_tcp_port = tcp_port;
248
249 let (full_url, response_port) = if tunnel_type
250 == TunnelType::Http
251 {
252 (format!("https://{}.{}", subdomain, base_domain), None)
253 } else {
254 (format!("{}.{}", subdomain, base_domain), tcp_port)
255 };
256
257 tracing::info!(
258 "Tunnel established: {} -> {} (port: {:?})",
259 full_url,
260 local_port,
261 response_port
262 );
263
264 let _ = tx
265 .send(ServerMessage::TunnelEstablished {
266 subdomain: subdomain.clone(),
267 url: full_url,
268 port: response_port,
269 })
270 .await;
271 }
272 Err(e) => {
273 tracing::error!("Failed to create DNS record: {}", e);
274 if let Some(port) = tcp_port {
276 tcp_plane.release_port(port);
277 }
278 let _ = tx
279 .send(ServerMessage::TunnelDenied {
280 reason: format!("DNS error: {}", e),
281 })
282 .await;
283 }
284 }
285 }
286 ClientMessage::HttpResponse {
287 stream_id,
288 status,
289 headers,
290 body,
291 } => {
292 tracing::debug!(
294 "Received HTTP response for stream {}: status={}",
295 stream_id,
296 status
297 );
298
299 if let Some((_, sender)) = response_registry.remove(&stream_id) {
301 let response = HttpResponseData {
302 status,
303 headers,
304 body,
305 };
306 if sender.send(response).is_err() {
307 tracing::warn!(
308 "Failed to send response for stream {} (receiver dropped)",
309 stream_id
310 );
311 }
312 } else {
313 tracing::warn!(
314 "No pending request for stream {} (may have timed out)",
315 stream_id
316 );
317 }
318 }
319 ClientMessage::TcpData { stream_id, data } => {
320 tracing::debug!(
321 "Received TCP data for stream {}: {} bytes",
322 stream_id,
323 data.len()
324 );
325 if let Some(writer) = tcp_plane.get_writer(stream_id) {
327 if let Err(e) = writer.send(data).await {
328 tracing::error!(
329 "Failed to forward TCP data to stream {}: {}",
330 stream_id,
331 e
332 );
333 }
334 } else {
335 tracing::warn!(
336 "No TCP connection for stream {} (may have been closed)",
337 stream_id
338 );
339 }
340 }
341 ClientMessage::TcpClose { stream_id } => {
342 tracing::debug!("TCP connection {} closed by client", stream_id);
343 tcp_plane.close_connection(stream_id);
345 }
346 ClientMessage::Ping { timestamp } => {
347 let _ = tx.send(ServerMessage::Pong { timestamp }).await;
348 }
349 }
350 }
351 Ok(None) => break, Err(e) => {
353 tracing::error!("Decode error: {}", e);
354 break;
355 }
356 }
357 }
358 }
359
360 tracing::info!("Cleaning up connection for {}", client_id);
362
363 if let Some(subdomain) = &assigned_subdomain {
365 if let Some(handle) = router.unregister(subdomain) {
366 if let Some(record_id) = handle.dns_record_id {
368 if let Err(e) = dns_provider.delete_record(&record_id).await {
369 tracing::error!("Failed to delete DNS record: {}", e);
370 }
371 }
372 }
373 }
374
375 if let Some(port) = assigned_tcp_port {
377 tcp_plane.release_port(port);
378 }
379
380 write_handle.abort();
381 Ok(())
382 }
383}
384
385fn extract_client_id<S>(tls_stream: &tokio_rustls::server::TlsStream<S>) -> String {
387 let (_, server_conn) = tls_stream.get_ref();
390
391 if let Some(certs) = server_conn.peer_certificates() {
392 if let Some(cert) = certs.first() {
393 use std::collections::hash_map::DefaultHasher;
395 use std::hash::{Hash, Hasher};
396 let mut hasher = DefaultHasher::new();
397 cert.as_ref().hash(&mut hasher);
398 return format!("client-{:016x}", hasher.finish());
399 }
400 }
401
402 format!(
403 "unknown-{}",
404 CuidConstructor::new().with_length(8).create_id()
405 )
406}
407
408fn is_valid_subdomain(subdomain: &str) -> bool {
410 if subdomain.is_empty() || subdomain.len() > 63 {
411 return false;
412 }
413
414 let chars: Vec<char> = subdomain.chars().collect();
416 if !chars.first().map(|c| c.is_alphanumeric()).unwrap_or(false) {
417 return false;
418 }
419 if !chars.last().map(|c| c.is_alphanumeric()).unwrap_or(false) {
420 return false;
421 }
422
423 subdomain.chars().all(|c| c.is_alphanumeric() || c == '-')
425}
426
427#[cfg(test)]
428mod tests {
429 use super::*;
430
431 #[test]
432 fn test_valid_subdomains() {
433 assert!(is_valid_subdomain("myapp"));
434 assert!(is_valid_subdomain("my-app"));
435 assert!(is_valid_subdomain("my-app-123"));
436 assert!(is_valid_subdomain("a"));
437 assert!(is_valid_subdomain("123"));
438 }
439
440 #[test]
441 fn test_invalid_subdomains() {
442 assert!(!is_valid_subdomain(""));
443 assert!(!is_valid_subdomain("-myapp"));
444 assert!(!is_valid_subdomain("myapp-"));
445 assert!(!is_valid_subdomain("my_app"));
446 assert!(!is_valid_subdomain("my.app"));
447 assert!(!is_valid_subdomain(&"a".repeat(64)));
448 }
449}