Skip to main content

zlayer_tunnel/server/
control.rs

1//! WebSocket control channel handler for tunnel server
2//!
3//! This module implements the server-side control channel that handles
4//! WebSocket connections from tunnel clients. It manages authentication,
5//! service registration, heartbeat monitoring, and message routing.
6
7use std::net::{Ipv4Addr, SocketAddr};
8use std::sync::Arc;
9use std::time::Duration;
10
11use futures_util::{SinkExt, StreamExt};
12use sha2::{Digest, Sha256};
13use tokio::net::TcpStream;
14use tokio::sync::mpsc;
15use tokio::time::{interval, timeout};
16use tokio_tungstenite::{accept_async, tungstenite::Message as WsMessage, WebSocketStream};
17use uuid::Uuid;
18
19use crate::overlay::DynTunnelDnsRegistrar;
20use crate::{
21    ControlMessage, Message, Result, ServiceProtocol, TunnelError, TunnelRegistry,
22    TunnelServerConfig,
23};
24
25/// Type alias for token validator function
26pub type TokenValidator = Arc<dyn Fn(&str) -> Result<()> + Send + Sync>;
27
28/// Get current timestamp in milliseconds since Unix epoch
29///
30/// Returns a u64 timestamp. In practice, timestamps won't overflow u64 for billions of years,
31/// so we use saturating conversion for safety. The value is explicitly capped to `u64::MAX`
32/// before casting, making truncation impossible.
33#[inline]
34#[allow(clippy::cast_possible_truncation)]
35fn current_timestamp_ms() -> u64 {
36    std::time::SystemTime::now()
37        .duration_since(std::time::UNIX_EPOCH)
38        .map(|d| d.as_millis())
39        .unwrap_or(0)
40        .min(u128::from(u64::MAX)) as u64
41}
42
43/// Control channel handler for a single tunnel connection
44///
45/// The `ControlHandler` manages the lifecycle of a tunnel connection, including:
46/// - WebSocket upgrade and authentication
47/// - Service registration and unregistration
48/// - Heartbeat monitoring
49/// - Message routing between server and client
50///
51/// # Example
52///
53/// ```rust,no_run
54/// use std::sync::Arc;
55/// use zlayer_tunnel::{TunnelRegistry, TunnelServerConfig, ControlHandler, accept_all_tokens};
56///
57/// async fn handle_client(stream: tokio::net::TcpStream, addr: std::net::SocketAddr) {
58///     let registry = Arc::new(TunnelRegistry::default());
59///     let config = TunnelServerConfig::default();
60///     let validator = Arc::new(accept_all_tokens);
61///
62///     let handler = ControlHandler::new(registry, config, validator);
63///     if let Err(e) = handler.handle_connection(stream, addr).await {
64///         tracing::error!("Connection error: {}", e);
65///     }
66/// }
67/// ```
68pub struct ControlHandler {
69    registry: Arc<TunnelRegistry>,
70    config: TunnelServerConfig,
71
72    /// Token validator function (returns `Ok(())` if token is valid, `Err` with reason if not)
73    token_validator: TokenValidator,
74
75    /// Optional DNS registrar for registering tunnel services in overlay DNS
76    dns_registrar: Option<DynTunnelDnsRegistrar>,
77
78    /// Local overlay IP for DNS registration
79    local_overlay_ip: Option<Ipv4Addr>,
80}
81
82impl ControlHandler {
83    /// Create a new control handler
84    ///
85    /// # Arguments
86    ///
87    /// * `registry` - The tunnel registry for managing tunnel state
88    /// * `config` - Server configuration
89    /// * `token_validator` - Function to validate authentication tokens
90    #[must_use]
91    pub fn new(
92        registry: Arc<TunnelRegistry>,
93        config: TunnelServerConfig,
94        token_validator: TokenValidator,
95    ) -> Self {
96        Self {
97            registry,
98            config,
99            token_validator,
100            dns_registrar: None,
101            local_overlay_ip: None,
102        }
103    }
104
105    /// Set the DNS registrar for registering tunnel services in overlay DNS
106    #[must_use]
107    pub fn with_dns_registrar(mut self, registrar: DynTunnelDnsRegistrar) -> Self {
108        self.dns_registrar = Some(registrar);
109        self
110    }
111
112    /// Set the local overlay IP for DNS registration
113    #[must_use]
114    pub fn with_local_overlay_ip(mut self, ip: Ipv4Addr) -> Self {
115        self.local_overlay_ip = Some(ip);
116        self
117    }
118
119    /// Handle a new WebSocket connection
120    ///
121    /// This is the main entry point for each tunnel client connection.
122    /// It performs the following steps:
123    /// 1. Upgrade TCP connection to WebSocket
124    /// 2. Wait for and validate AUTH message
125    /// 3. Register the tunnel in the registry
126    /// 4. Run the main message loop until disconnection
127    /// 5. Clean up the tunnel on disconnect
128    ///
129    /// # Arguments
130    ///
131    /// * `stream` - The TCP stream to upgrade
132    /// * `client_addr` - The client's socket address
133    ///
134    /// # Errors
135    ///
136    /// Returns an error if:
137    /// - WebSocket upgrade fails
138    /// - Authentication times out or fails
139    /// - Token is invalid or already in use
140    /// - Connection is closed unexpectedly
141    pub async fn handle_connection(
142        &self,
143        stream: TcpStream,
144        client_addr: SocketAddr,
145    ) -> Result<()> {
146        // Upgrade to WebSocket
147        let ws_stream = accept_async(stream)
148            .await
149            .map_err(TunnelError::connection)?;
150
151        let (mut ws_sink, mut ws_stream) = ws_stream.split();
152
153        // Wait for AUTH message with timeout (10 seconds)
154        let auth_timeout = Duration::from_secs(10);
155        let auth_msg = timeout(auth_timeout, async {
156            while let Some(msg) = ws_stream.next().await {
157                match msg {
158                    Ok(WsMessage::Binary(data)) => {
159                        return Message::decode(&data).map(|(m, _)| m);
160                    }
161                    Ok(WsMessage::Close(_)) => {
162                        return Err(TunnelError::connection_msg("Client closed connection"));
163                    }
164                    Ok(_) => {} // Ignore text, ping, pong
165                    Err(e) => return Err(TunnelError::connection(e)),
166                }
167            }
168            Err(TunnelError::connection_msg("Connection closed before auth"))
169        })
170        .await
171        .map_err(|_| TunnelError::timeout())??;
172
173        // Validate AUTH message
174        let Message::Auth {
175            token,
176            client_id: _,
177        } = auth_msg
178        else {
179            let fail = Message::AuthFail {
180                reason: "Expected AUTH message".to_string(),
181            };
182            let _ = ws_sink.send(WsMessage::Binary(fail.encode().into())).await;
183            return Err(TunnelError::auth("Expected AUTH message"));
184        };
185
186        // Validate token
187        if let Err(e) = (self.token_validator)(&token) {
188            let fail = Message::AuthFail {
189                reason: e.to_string(),
190            };
191            let _ = ws_sink.send(WsMessage::Binary(fail.encode().into())).await;
192            return Err(e);
193        }
194
195        // Hash token for storage (don't store raw token)
196        let token_hash = hash_token(&token);
197
198        // Check if token already connected
199        if self.registry.token_exists(&token_hash) {
200            let fail = Message::AuthFail {
201                reason: "Token already in use".to_string(),
202            };
203            let _ = ws_sink.send(WsMessage::Binary(fail.encode().into())).await;
204            return Err(TunnelError::auth("Token already in use"));
205        }
206
207        // Create control message channel
208        let (control_tx, mut control_rx) = mpsc::channel::<ControlMessage>(256);
209
210        // Register tunnel
211        let tunnel = self.registry.register_tunnel(
212            token_hash.clone(),
213            None, // Name can be set later via REGISTER
214            control_tx,
215            Some(client_addr),
216        )?;
217
218        let tunnel_id = tunnel.id;
219
220        // Send AUTH_OK
221        let auth_ok = Message::AuthOk { tunnel_id };
222        ws_sink
223            .send(WsMessage::Binary(auth_ok.encode().into()))
224            .await
225            .map_err(TunnelError::connection)?;
226
227        tracing::info!(
228            tunnel_id = %tunnel_id,
229            client_addr = %client_addr,
230            "Tunnel authenticated"
231        );
232
233        // Main message loop
234        let result = self
235            .run_message_loop(tunnel_id, &mut ws_sink, &mut ws_stream, &mut control_rx)
236            .await;
237
238        // Cleanup on disconnect
239        self.registry.unregister_tunnel(tunnel_id);
240
241        tracing::info!(tunnel_id = %tunnel_id, "Tunnel disconnected");
242
243        result
244    }
245
246    /// Run the main message loop for a connected tunnel
247    ///
248    /// This handles:
249    /// - Heartbeat sending and timeout detection
250    /// - Control messages from the registry
251    /// - Incoming WebSocket messages from the client
252    async fn run_message_loop(
253        &self,
254        tunnel_id: Uuid,
255        ws_sink: &mut futures_util::stream::SplitSink<WebSocketStream<TcpStream>, WsMessage>,
256        ws_stream: &mut futures_util::stream::SplitStream<WebSocketStream<TcpStream>>,
257        control_rx: &mut mpsc::Receiver<ControlMessage>,
258    ) -> Result<()> {
259        let mut heartbeat_interval = interval(self.config.heartbeat_interval);
260        let heartbeat_timeout = self.config.heartbeat_timeout;
261        let mut last_heartbeat_ack = std::time::Instant::now();
262
263        loop {
264            tokio::select! {
265                // Heartbeat timer
266                _ = heartbeat_interval.tick() => {
267                    // Check if we've received a heartbeat ack recently
268                    if last_heartbeat_ack.elapsed() > heartbeat_timeout {
269                        tracing::warn!(tunnel_id = %tunnel_id, "Heartbeat timeout");
270                        return Err(TunnelError::timeout());
271                    }
272
273                    // Send heartbeat
274                    let timestamp = current_timestamp_ms();
275                    let hb = Message::Heartbeat { timestamp };
276                    ws_sink
277                        .send(WsMessage::Binary(hb.encode().into()))
278                        .await
279                        .map_err(TunnelError::connection)?;
280                }
281
282                // Control messages from registry
283                Some(ctrl_msg) = control_rx.recv() => {
284                    let msg = match ctrl_msg {
285                        ControlMessage::Connect {
286                            service_id,
287                            connection_id,
288                            client_addr,
289                        } => Message::Connect {
290                            service_id,
291                            connection_id,
292                            client_addr: client_addr.to_string(),
293                        },
294                        ControlMessage::Heartbeat { timestamp } => {
295                            Message::Heartbeat { timestamp }
296                        }
297                        ControlMessage::Disconnect { reason } => {
298                            let _ = ws_sink
299                                .send(WsMessage::Binary(
300                                    Message::Disconnect { reason }.encode().into(),
301                                ))
302                                .await;
303                            return Ok(());
304                        }
305                    };
306                    ws_sink
307                        .send(WsMessage::Binary(msg.encode().into()))
308                        .await
309                        .map_err(TunnelError::connection)?;
310                }
311
312                // Incoming WebSocket messages
313                Some(msg_result) = ws_stream.next() => {
314                    match msg_result {
315                        Ok(WsMessage::Binary(data)) => {
316                            let (msg, _) = Message::decode(&data)?;
317
318                            // Check for heartbeat ack before handling (so we update even if handling fails)
319                            if matches!(msg, Message::HeartbeatAck { .. }) {
320                                last_heartbeat_ack = std::time::Instant::now();
321                            }
322
323                            self.handle_client_message(tunnel_id, msg, ws_sink).await?;
324
325                            // Update activity
326                            self.registry.touch_tunnel(tunnel_id);
327                        }
328                        Ok(WsMessage::Close(_)) => {
329                            return Ok(());
330                        }
331                        Ok(WsMessage::Ping(data)) => {
332                            ws_sink
333                                .send(WsMessage::Pong(data))
334                                .await
335                                .map_err(TunnelError::connection)?;
336                        }
337                        Ok(_) => {} // Ignore other message types
338                        Err(e) => {
339                            return Err(TunnelError::connection(e));
340                        }
341                    }
342                }
343
344                else => break,
345            }
346        }
347
348        Ok(())
349    }
350
351    /// Handle a message from the tunnel client
352    #[allow(clippy::too_many_lines)]
353    async fn handle_client_message(
354        &self,
355        tunnel_id: Uuid,
356        msg: Message,
357        ws_sink: &mut futures_util::stream::SplitSink<WebSocketStream<TcpStream>, WsMessage>,
358    ) -> Result<()> {
359        match msg {
360            Message::Register {
361                name,
362                protocol,
363                local_port,
364                remote_port,
365            } => {
366                self.handle_register(tunnel_id, &name, protocol, local_port, remote_port, ws_sink)
367                    .await?;
368            }
369
370            Message::Unregister { service_id } => {
371                if let Err(e) = self.registry.remove_service(tunnel_id, service_id) {
372                    tracing::warn!(
373                        tunnel_id = %tunnel_id,
374                        service_id = %service_id,
375                        error = %e,
376                        "Service unregistration failed"
377                    );
378                } else {
379                    tracing::info!(
380                        tunnel_id = %tunnel_id,
381                        service_id = %service_id,
382                        "Service unregistered"
383                    );
384                }
385            }
386
387            Message::ConnectAck { connection_id } => {
388                tracing::debug!(
389                    tunnel_id = %tunnel_id,
390                    connection_id = %connection_id,
391                    "Connection acknowledged"
392                );
393                // The listener handles this via a callback
394            }
395
396            Message::ConnectFail {
397                connection_id,
398                reason,
399            } => {
400                tracing::warn!(
401                    tunnel_id = %tunnel_id,
402                    connection_id = %connection_id,
403                    reason = %reason,
404                    "Connection failed"
405                );
406            }
407
408            Message::HeartbeatAck { timestamp } => {
409                let now = current_timestamp_ms();
410                let latency_ms = now.saturating_sub(timestamp);
411                tracing::trace!(
412                    tunnel_id = %tunnel_id,
413                    latency_ms = latency_ms,
414                    "Heartbeat ack received"
415                );
416            }
417
418            // Client shouldn't send these - they're server-to-client messages
419            Message::Auth { .. }
420            | Message::AuthOk { .. }
421            | Message::AuthFail { .. }
422            | Message::RegisterOk { .. }
423            | Message::RegisterFail { .. }
424            | Message::Connect { .. }
425            | Message::Heartbeat { .. }
426            | Message::Disconnect { .. } => {
427                tracing::warn!(
428                    tunnel_id = %tunnel_id,
429                    msg_type = ?msg.message_type(),
430                    "Unexpected message from client"
431                );
432            }
433        }
434
435        Ok(())
436    }
437
438    /// Handle a service registration request
439    async fn handle_register(
440        &self,
441        tunnel_id: Uuid,
442        name: &str,
443        protocol: ServiceProtocol,
444        local_port: u16,
445        remote_port: u16,
446        ws_sink: &mut futures_util::stream::SplitSink<WebSocketStream<TcpStream>, WsMessage>,
447    ) -> Result<()> {
448        let result = self
449            .registry
450            .add_service(tunnel_id, name, protocol, local_port, remote_port);
451
452        let response = match result {
453            Ok(service) => {
454                let assigned_port = service.assigned_port.unwrap_or(remote_port);
455                tracing::info!(
456                    tunnel_id = %tunnel_id,
457                    service_name = %name,
458                    local_port = local_port,
459                    remote_port = assigned_port,
460                    "Service registered"
461                );
462
463                // Optionally register in overlay DNS
464                if let (Some(ref registrar), Some(overlay_ip)) =
465                    (&self.dns_registrar, self.local_overlay_ip)
466                {
467                    let dns_name = format!("tun-{name}");
468                    if let Err(e) = registrar
469                        .register_service(&dns_name, overlay_ip, assigned_port)
470                        .await
471                    {
472                        tracing::warn!(
473                            service_name = %name,
474                            dns_name = %dns_name,
475                            error = %e,
476                            "Failed to register service in overlay DNS"
477                        );
478                    } else {
479                        tracing::debug!(
480                            dns_name = %dns_name,
481                            overlay_ip = %overlay_ip,
482                            port = assigned_port,
483                            "Registered service in overlay DNS"
484                        );
485                    }
486                }
487
488                Message::RegisterOk {
489                    service_id: service.id,
490                }
491            }
492            Err(e) => {
493                tracing::warn!(
494                    tunnel_id = %tunnel_id,
495                    service_name = %name,
496                    error = %e,
497                    "Service registration failed"
498                );
499                Message::RegisterFail {
500                    reason: e.to_string(),
501                }
502            }
503        };
504
505        ws_sink
506            .send(WsMessage::Binary(response.encode().into()))
507            .await
508            .map_err(TunnelError::connection)?;
509
510        Ok(())
511    }
512}
513
514/// Hash a token for storage (SHA256, hex encoded)
515///
516/// This function takes a raw authentication token and produces a secure hash
517/// suitable for storage and comparison. The raw token should never be stored.
518///
519/// # Example
520///
521/// ```rust
522/// use zlayer_tunnel::hash_token;
523///
524/// let hash = hash_token("my-secret-token");
525/// assert_eq!(hash.len(), 64); // SHA256 produces 32 bytes = 64 hex chars
526/// ```
527#[must_use]
528pub fn hash_token(token: &str) -> String {
529    let mut hasher = Sha256::new();
530    hasher.update(token.as_bytes());
531    hex::encode(hasher.finalize())
532}
533
534/// Simple token validator that accepts any non-empty token
535///
536/// This is a basic validator suitable for development or testing.
537/// In production, you should implement proper token validation against
538/// your authentication system.
539///
540/// # Errors
541///
542/// Returns an error if the token is empty.
543///
544/// # Example
545///
546/// ```rust
547/// use zlayer_tunnel::accept_all_tokens;
548///
549/// assert!(accept_all_tokens("valid-token").is_ok());
550/// assert!(accept_all_tokens("").is_err());
551/// ```
552pub fn accept_all_tokens(token: &str) -> Result<()> {
553    if token.is_empty() {
554        return Err(TunnelError::auth("Token cannot be empty"));
555    }
556    Ok(())
557}
558
559#[cfg(test)]
560mod tests {
561    use super::*;
562
563    #[test]
564    fn test_hash_token_consistent() {
565        let token = "my-secret-token";
566        let hash1 = hash_token(token);
567        let hash2 = hash_token(token);
568
569        assert_eq!(hash1, hash2);
570        assert_eq!(hash1.len(), 64); // SHA256 produces 32 bytes = 64 hex chars
571    }
572
573    #[test]
574    fn test_hash_token_different_tokens() {
575        let hash1 = hash_token("token1");
576        let hash2 = hash_token("token2");
577
578        assert_ne!(hash1, hash2);
579    }
580
581    #[test]
582    fn test_hash_token_empty() {
583        let hash = hash_token("");
584        // Even empty string produces a valid hash
585        assert_eq!(hash.len(), 64);
586    }
587
588    #[test]
589    fn test_hash_token_known_value() {
590        // Verify against a known SHA256 hash
591        let hash = hash_token("test");
592        // SHA256("test") = 9f86d081884c7d659a2feaa0c55ad015a3bf4f1b2b0b822cd15d6c15b0f00a08
593        assert_eq!(
594            hash,
595            "9f86d081884c7d659a2feaa0c55ad015a3bf4f1b2b0b822cd15d6c15b0f00a08"
596        );
597    }
598
599    #[test]
600    fn test_accept_all_tokens_valid() {
601        assert!(accept_all_tokens("valid-token").is_ok());
602        assert!(accept_all_tokens("a").is_ok());
603        assert!(accept_all_tokens("very-long-token-with-many-characters").is_ok());
604    }
605
606    #[test]
607    fn test_accept_all_tokens_empty() {
608        let result = accept_all_tokens("");
609        assert!(result.is_err());
610        assert!(result.unwrap_err().to_string().contains("cannot be empty"));
611    }
612
613    #[test]
614    fn test_control_handler_creation() {
615        let registry = Arc::new(TunnelRegistry::default());
616        let config = TunnelServerConfig::default();
617        let validator = Arc::new(accept_all_tokens);
618
619        let handler = ControlHandler::new(registry.clone(), config, validator);
620
621        // Just verify it creates without panic
622        assert!(Arc::strong_count(&handler.registry) >= 1);
623    }
624
625    #[test]
626    fn test_hash_token_unicode() {
627        // Ensure unicode tokens work correctly
628        let hash = hash_token("token-with-unicode-\u{1F600}");
629        assert_eq!(hash.len(), 64);
630    }
631
632    #[test]
633    fn test_hash_token_special_chars() {
634        let hash = hash_token("token!@#$%^&*()");
635        assert_eq!(hash.len(), 64);
636    }
637}