Skip to main content

zlayer_tunnel/client/
agent.rs

1//! Tunnel client agent for connecting to tunnel servers
2//!
3//! The [`TunnelAgent`] manages the lifecycle of a tunnel client connection, including:
4//! - WebSocket connection to the tunnel server
5//! - Authentication and service registration
6//! - Heartbeat handling and timeout detection
7//! - Automatic reconnection with exponential backoff
8//! - Incoming connection notification
9
10use std::collections::HashMap;
11use std::sync::Arc;
12use std::time::Duration;
13
14use futures_util::{SinkExt, StreamExt};
15use parking_lot::RwLock;
16use tokio::sync::mpsc;
17use tokio::time::{interval, timeout};
18use tokio_tungstenite::tungstenite::Message as WsMessage;
19use uuid::Uuid;
20
21use crate::client::connector::OverlayAwareConnector;
22use crate::overlay::DynOverlayResolver;
23use crate::{Message, Result, ServiceConfig, ServiceProtocol, TunnelClientConfig, TunnelError};
24
25// =============================================================================
26// Agent State
27// =============================================================================
28
29/// Current state of the tunnel agent
30#[derive(Debug, Clone, Default, PartialEq, Eq)]
31pub enum AgentState {
32    /// Not connected to the server
33    #[default]
34    Disconnected,
35    /// Attempting to connect
36    Connecting,
37    /// Successfully connected and authenticated
38    Connected {
39        /// The assigned tunnel ID from the server
40        tunnel_id: Uuid,
41    },
42    /// Reconnecting after a disconnection
43    Reconnecting {
44        /// Current reconnection attempt number
45        attempt: u32,
46    },
47}
48
49// =============================================================================
50// Service Status
51// =============================================================================
52
53/// Status of a registered service
54#[derive(Debug, Clone, Default, PartialEq, Eq)]
55pub enum ServiceStatus {
56    /// Service registration is pending
57    #[default]
58    Pending,
59    /// Service is registered and active
60    Registered,
61    /// Service registration failed
62    Failed(String),
63}
64
65// =============================================================================
66// Registered Service
67// =============================================================================
68
69/// A service being exposed through the tunnel
70#[derive(Debug, Clone)]
71pub struct RegisteredService {
72    /// The service configuration
73    pub config: ServiceConfig,
74    /// Server-assigned service ID (if registered)
75    pub service_id: Option<Uuid>,
76    /// Current status of the service
77    pub status: ServiceStatus,
78}
79
80impl RegisteredService {
81    /// Create a new registered service from a config
82    #[must_use]
83    pub fn new(config: ServiceConfig) -> Self {
84        Self {
85            config,
86            service_id: None,
87            status: ServiceStatus::Pending,
88        }
89    }
90
91    /// Check if the service is successfully registered
92    #[must_use]
93    pub fn is_registered(&self) -> bool {
94        matches!(self.status, ServiceStatus::Registered)
95    }
96}
97
98// =============================================================================
99// Control Events
100// =============================================================================
101
102/// Events received from the tunnel server
103#[derive(Debug, Clone)]
104pub enum ControlEvent {
105    /// Successfully authenticated with the server
106    Authenticated {
107        /// The assigned tunnel ID
108        tunnel_id: Uuid,
109    },
110    /// A service was successfully registered
111    ServiceRegistered {
112        /// Service name
113        name: String,
114        /// Server-assigned service ID
115        service_id: Uuid,
116    },
117    /// A service registration failed
118    ServiceFailed {
119        /// Service name
120        name: String,
121        /// Failure reason
122        reason: String,
123    },
124    /// An incoming connection is being established
125    IncomingConnection {
126        /// Service ID receiving the connection
127        service_id: Uuid,
128        /// Unique connection ID
129        connection_id: Uuid,
130        /// Remote client address
131        client_addr: String,
132    },
133    /// Heartbeat received from the server
134    Heartbeat {
135        /// Server timestamp
136        timestamp: u64,
137    },
138    /// Disconnected from the server
139    Disconnected {
140        /// Reason for disconnection
141        reason: String,
142    },
143    /// An error occurred
144    Error {
145        /// Error message
146        message: String,
147    },
148}
149
150// =============================================================================
151// Control Commands
152// =============================================================================
153
154/// Commands to send to the tunnel server
155#[derive(Debug, Clone)]
156pub enum ControlCommand {
157    /// Register a new service
158    Register {
159        /// Service name
160        name: String,
161        /// Protocol type
162        protocol: ServiceProtocol,
163        /// Local port
164        local_port: u16,
165        /// Requested remote port (0 = auto-assign)
166        remote_port: u16,
167    },
168    /// Unregister an existing service
169    Unregister {
170        /// Service ID to unregister
171        service_id: Uuid,
172    },
173    /// Acknowledge an incoming connection
174    ConnectAck {
175        /// Connection ID to acknowledge
176        connection_id: Uuid,
177    },
178    /// Reject an incoming connection
179    ConnectFail {
180        /// Connection ID to reject
181        connection_id: Uuid,
182        /// Failure reason
183        reason: String,
184    },
185    /// Gracefully disconnect from the server
186    Disconnect,
187}
188
189// =============================================================================
190// Connection Callback
191// =============================================================================
192
193/// Type alias for connection callback function
194///
195/// The callback receives the service ID, connection ID, and client address,
196/// and returns whether the connection was accepted.
197pub type ConnectionCallback = Arc<dyn Fn(Uuid, Uuid, String) -> bool + Send + Sync>;
198
199// =============================================================================
200// Tunnel Agent
201// =============================================================================
202
203/// Tunnel client agent that connects to a tunnel server
204///
205/// The `TunnelAgent` manages the complete lifecycle of a tunnel client connection,
206/// including authentication, service registration, heartbeat handling, and
207/// automatic reconnection on failure.
208///
209/// # Example
210///
211/// ```rust,no_run
212/// use zlayer_tunnel::{TunnelAgent, TunnelClientConfig, ServiceConfig};
213/// use std::sync::Arc;
214///
215/// #[tokio::main]
216/// async fn main() {
217///     let config = TunnelClientConfig::new(
218///         "wss://tunnel.example.com/tunnel/v1",
219///         "my-auth-token"
220///     )
221///     .with_service(ServiceConfig::tcp("ssh", 22).with_remote_port(2222));
222///
223///     let agent = TunnelAgent::new(config)
224///         .on_connection(Arc::new(|service_id, conn_id, client_addr| {
225///             println!("Incoming connection from {} to service {}", client_addr, service_id);
226///             true // Accept the connection
227///         }));
228///
229///     // Run the agent with auto-reconnect
230///     if let Err(e) = agent.run().await {
231///         eprintln!("Agent error: {}", e);
232///     }
233/// }
234/// ```
235pub struct TunnelAgent {
236    /// Client configuration
237    config: TunnelClientConfig,
238    /// Current agent state
239    state: Arc<RwLock<AgentState>>,
240    /// Registered services by name
241    services: Arc<RwLock<HashMap<String, RegisteredService>>>,
242    /// Callback for incoming connections
243    connection_callback: Option<ConnectionCallback>,
244    /// Channel to send commands to the agent loop
245    command_tx: Option<mpsc::Sender<ControlCommand>>,
246    /// Event channel for external listeners
247    event_tx: Option<mpsc::Sender<ControlEvent>>,
248    /// Optional overlay resolver for routing through overlay network
249    overlay_resolver: Option<DynOverlayResolver>,
250}
251
252impl TunnelAgent {
253    /// Create a new tunnel agent with the given configuration
254    ///
255    /// The agent will attempt to register all services specified in the config
256    /// when it connects to the server.
257    #[must_use]
258    pub fn new(config: TunnelClientConfig) -> Self {
259        // Initialize services from config
260        let services: HashMap<String, RegisteredService> = config
261            .services
262            .iter()
263            .map(|s| (s.name.clone(), RegisteredService::new(s.clone())))
264            .collect();
265
266        Self {
267            config,
268            state: Arc::new(RwLock::new(AgentState::Disconnected)),
269            services: Arc::new(RwLock::new(services)),
270            connection_callback: None,
271            command_tx: None,
272            event_tx: None,
273            overlay_resolver: None,
274        }
275    }
276
277    /// Set the connection callback (builder pattern)
278    ///
279    /// The callback is invoked when an incoming connection is received.
280    /// It should return `true` to accept the connection or `false` to reject it.
281    #[must_use]
282    pub fn on_connection(mut self, callback: ConnectionCallback) -> Self {
283        self.connection_callback = Some(callback);
284        self
285    }
286
287    /// Set an event channel for receiving control events
288    ///
289    /// Events will be sent to this channel for external processing.
290    #[must_use]
291    pub fn with_event_channel(mut self, tx: mpsc::Sender<ControlEvent>) -> Self {
292        self.event_tx = Some(tx);
293        self
294    }
295
296    /// Set the overlay resolver for routing connections through the overlay network
297    #[must_use]
298    pub fn with_overlay_resolver(mut self, resolver: DynOverlayResolver) -> Self {
299        self.overlay_resolver = Some(resolver);
300        self
301    }
302
303    /// Get the current agent state
304    #[must_use]
305    pub fn state(&self) -> AgentState {
306        self.state.read().clone()
307    }
308
309    /// Get a service by name
310    #[must_use]
311    pub fn get_service(&self, name: &str) -> Option<RegisteredService> {
312        self.services.read().get(name).cloned()
313    }
314
315    /// Get all registered services
316    #[must_use]
317    pub fn services(&self) -> Vec<RegisteredService> {
318        self.services.read().values().cloned().collect()
319    }
320
321    /// Check if the agent is connected
322    #[must_use]
323    pub fn is_connected(&self) -> bool {
324        matches!(*self.state.read(), AgentState::Connected { .. })
325    }
326
327    /// Get the tunnel ID if connected
328    #[must_use]
329    pub fn tunnel_id(&self) -> Option<Uuid> {
330        match *self.state.read() {
331            AgentState::Connected { tunnel_id } => Some(tunnel_id),
332            _ => None,
333        }
334    }
335
336    /// Send a command to the agent (if running)
337    ///
338    /// # Errors
339    ///
340    /// Returns an error if the agent is not running or the command channel is full.
341    pub async fn send_command(&self, command: ControlCommand) -> Result<()> {
342        let tx = self
343            .command_tx
344            .as_ref()
345            .ok_or_else(|| TunnelError::connection_msg("agent not running"))?;
346
347        tx.send(command)
348            .await
349            .map_err(|_| TunnelError::connection_msg("command channel closed"))
350    }
351
352    /// Run the agent with automatic reconnection
353    ///
354    /// This method will attempt to connect to the server and maintain the connection.
355    /// If the connection is lost, it will automatically reconnect with exponential backoff.
356    ///
357    /// # Errors
358    ///
359    /// Returns an error if:
360    /// - The configuration is invalid
361    /// - The connection fails and cannot be recovered
362    /// - A shutdown signal is received
363    pub async fn run(&self) -> Result<()> {
364        // Validate config
365        self.config.validate().map_err(TunnelError::config)?;
366
367        let mut current_interval = self.config.reconnect_interval;
368        let mut attempt = 0u32;
369
370        loop {
371            attempt += 1;
372            *self.state.write() = AgentState::Reconnecting { attempt };
373
374            tracing::info!(
375                attempt = attempt,
376                interval_ms = current_interval.as_millis(),
377                "attempting to connect"
378            );
379
380            match self.run_once().await {
381                Ok(()) => {
382                    // Clean shutdown requested
383                    tracing::info!("agent shutting down");
384                    return Ok(());
385                }
386                Err(TunnelError::Shutdown) => {
387                    // Clean shutdown
388                    tracing::info!("agent received shutdown signal");
389                    return Ok(());
390                }
391                Err(e) => {
392                    tracing::warn!(error = %e, "connection failed, will retry");
393
394                    // Notify of disconnection
395                    if let Some(ref tx) = self.event_tx {
396                        let _ = tx
397                            .send(ControlEvent::Disconnected {
398                                reason: e.to_string(),
399                            })
400                            .await;
401                    }
402                }
403            }
404
405            // Reset services to pending state
406            {
407                let mut services = self.services.write();
408                for service in services.values_mut() {
409                    service.service_id = None;
410                    service.status = ServiceStatus::Pending;
411                }
412            }
413
414            // Wait before reconnecting
415            tokio::time::sleep(current_interval).await;
416
417            // Exponential backoff
418            current_interval = std::cmp::min(
419                current_interval.saturating_mul(2),
420                self.config.max_reconnect_interval,
421            );
422        }
423    }
424
425    /// Run a single connection attempt
426    ///
427    /// This method connects to the server, authenticates, registers services,
428    /// and runs the message loop until disconnection.
429    ///
430    /// # Errors
431    ///
432    /// Returns an error if:
433    /// - WebSocket connection fails
434    /// - Authentication fails
435    /// - Connection is lost
436    pub async fn run_once(&self) -> Result<()> {
437        *self.state.write() = AgentState::Connecting;
438
439        // Connect to WebSocket server via overlay-aware connector
440        tracing::debug!(url = %self.config.server_url, "connecting to server");
441
442        let connector = OverlayAwareConnector::new(
443            &self.config.server_url,
444            self.config.overlay_server_url.as_deref(),
445            self.config.routing_mode,
446            self.overlay_resolver.clone(),
447        );
448        let (ws_stream, _response) = connector.connect().await?;
449
450        let (mut ws_sink, mut ws_stream) = ws_stream.split();
451
452        // Generate a client ID for this connection
453        let client_id = Uuid::new_v4();
454
455        // Send AUTH message
456        let auth_msg = Message::Auth {
457            token: self.config.token.clone(),
458            client_id,
459        };
460        ws_sink
461            .send(WsMessage::Binary(auth_msg.encode().into()))
462            .await
463            .map_err(TunnelError::connection)?;
464
465        // Wait for AUTH_OK with timeout (10 seconds)
466        let auth_timeout = Duration::from_secs(10);
467        let auth_response = timeout(auth_timeout, async {
468            while let Some(msg) = ws_stream.next().await {
469                match msg {
470                    Ok(WsMessage::Binary(data)) => {
471                        return Message::decode(&data).map(|(m, _)| m);
472                    }
473                    Ok(WsMessage::Close(frame)) => {
474                        let reason = frame.map_or_else(
475                            || "connection closed".to_string(),
476                            |f| f.reason.to_string(),
477                        );
478                        return Err(TunnelError::connection_msg(reason));
479                    }
480                    Ok(_) => {} // Ignore text, ping, pong
481                    Err(e) => return Err(TunnelError::connection(e)),
482                }
483            }
484            Err(TunnelError::connection_msg("connection closed before auth"))
485        })
486        .await
487        .map_err(|_| TunnelError::timeout())??;
488
489        // Handle auth response
490        let tunnel_id = match auth_response {
491            Message::AuthOk { tunnel_id } => tunnel_id,
492            Message::AuthFail { reason } => {
493                return Err(TunnelError::auth(reason));
494            }
495            other => {
496                return Err(TunnelError::protocol(format!(
497                    "expected AuthOk or AuthFail, got {:?}",
498                    other.message_type()
499                )));
500            }
501        };
502
503        *self.state.write() = AgentState::Connected { tunnel_id };
504
505        tracing::info!(
506            tunnel_id = %tunnel_id,
507            client_id = %client_id,
508            "authenticated with server"
509        );
510
511        // Notify of authentication
512        if let Some(ref tx) = self.event_tx {
513            let _ = tx.send(ControlEvent::Authenticated { tunnel_id }).await;
514        }
515
516        // Register all services
517        self.register_services(&mut ws_sink).await?;
518
519        // Run the main message loop
520        self.run_message_loop(tunnel_id, &mut ws_sink, &mut ws_stream)
521            .await
522    }
523
524    /// Register all services from the config
525    async fn register_services<S>(&self, ws_sink: &mut S) -> Result<()>
526    where
527        S: SinkExt<WsMessage> + Unpin,
528        S::Error: std::error::Error,
529    {
530        let services: Vec<ServiceConfig> = {
531            self.services
532                .read()
533                .values()
534                .map(|s| s.config.clone())
535                .collect()
536        };
537
538        for service in services {
539            let register_msg = Message::Register {
540                name: service.name.clone(),
541                protocol: service.protocol,
542                local_port: service.local_port,
543                remote_port: service.remote_port,
544            };
545
546            tracing::debug!(
547                service_name = %service.name,
548                local_port = service.local_port,
549                "registering service"
550            );
551
552            ws_sink
553                .send(WsMessage::Binary(register_msg.encode().into()))
554                .await
555                .map_err(|e| TunnelError::connection_msg(e.to_string()))?;
556        }
557
558        Ok(())
559    }
560
561    /// Run the main message loop
562    async fn run_message_loop<Sink, Stream>(
563        &self,
564        tunnel_id: Uuid,
565        ws_sink: &mut Sink,
566        ws_stream: &mut Stream,
567    ) -> Result<()>
568    where
569        Sink: SinkExt<WsMessage> + Unpin,
570        Sink::Error: std::error::Error,
571        Stream: StreamExt<Item = std::result::Result<WsMessage, tokio_tungstenite::tungstenite::Error>>
572            + Unpin,
573    {
574        // Create command channel
575        let (_command_tx, mut command_rx) = mpsc::channel::<ControlCommand>(256);
576
577        // Store the sender so external code can send commands
578        // Note: This is a bit awkward since we're in a method, but we need to
579        // allow external commands during the message loop
580        // For now, we'll handle this differently
581
582        // Track pending service registrations by name
583        let mut pending_services: Vec<String> = { self.services.read().keys().cloned().collect() };
584
585        // Heartbeat interval (we respond to server heartbeats, not send our own)
586        let mut check_interval = interval(Duration::from_secs(5));
587
588        loop {
589            tokio::select! {
590                // Check interval for health/status
591                _ = check_interval.tick() => {
592                    // Just a periodic check, we respond to server heartbeats
593                }
594
595                // Commands from external code
596                Some(command) = command_rx.recv() => {
597                    match command {
598                        ControlCommand::Register { name, protocol, local_port, remote_port } => {
599                            let msg = Message::Register {
600                                name: name.clone(),
601                                protocol,
602                                local_port,
603                                remote_port,
604                            };
605                            ws_sink
606                                .send(WsMessage::Binary(msg.encode().into()))
607                                .await
608                                .map_err(|e| TunnelError::connection_msg(e.to_string()))?;
609                            pending_services.push(name);
610                        }
611                        ControlCommand::Unregister { service_id } => {
612                            let msg = Message::Unregister { service_id };
613                            ws_sink
614                                .send(WsMessage::Binary(msg.encode().into()))
615                                .await
616                                .map_err(|e| TunnelError::connection_msg(e.to_string()))?;
617                        }
618                        ControlCommand::ConnectAck { connection_id } => {
619                            let msg = Message::ConnectAck { connection_id };
620                            ws_sink
621                                .send(WsMessage::Binary(msg.encode().into()))
622                                .await
623                                .map_err(|e| TunnelError::connection_msg(e.to_string()))?;
624                        }
625                        ControlCommand::ConnectFail { connection_id, reason } => {
626                            let msg = Message::ConnectFail { connection_id, reason };
627                            ws_sink
628                                .send(WsMessage::Binary(msg.encode().into()))
629                                .await
630                                .map_err(|e| TunnelError::connection_msg(e.to_string()))?;
631                        }
632                        ControlCommand::Disconnect => {
633                            tracing::info!("disconnect command received");
634                            return Ok(());
635                        }
636                    }
637                }
638
639                // Incoming WebSocket messages
640                Some(msg_result) = ws_stream.next() => {
641                    match msg_result {
642                        Ok(WsMessage::Binary(data)) => {
643                            let (msg, _) = Message::decode(&data)?;
644                            self.handle_server_message(
645                                tunnel_id,
646                                msg,
647                                ws_sink,
648                                &mut pending_services,
649                            ).await?;
650                        }
651                        Ok(WsMessage::Close(frame)) => {
652                            let reason = frame.map_or_else(
653                                || "server closed connection".to_string(),
654                                |f| f.reason.to_string(),
655                            );
656                            tracing::info!(reason = %reason, "server closed connection");
657                            return Err(TunnelError::connection_msg(reason));
658                        }
659                        Ok(WsMessage::Ping(data)) => {
660                            ws_sink
661                                .send(WsMessage::Pong(data))
662                                .await
663                                .map_err(|e| TunnelError::connection_msg(e.to_string()))?;
664                        }
665                        Ok(_) => {} // Ignore other message types
666                        Err(e) => {
667                            return Err(TunnelError::connection(e));
668                        }
669                    }
670                }
671
672                else => {
673                    // All channels closed
674                    break;
675                }
676            }
677        }
678
679        Ok(())
680    }
681
682    /// Handle a message from the server
683    async fn handle_server_message<S>(
684        &self,
685        tunnel_id: Uuid,
686        msg: Message,
687        ws_sink: &mut S,
688        pending_services: &mut Vec<String>,
689    ) -> Result<()>
690    where
691        S: SinkExt<WsMessage> + Unpin,
692        S::Error: std::error::Error,
693    {
694        match msg {
695            Message::RegisterOk { service_id } => {
696                self.handle_register_ok(service_id, pending_services).await;
697            }
698            Message::RegisterFail { reason } => {
699                self.handle_register_fail(reason, pending_services).await;
700            }
701            Message::Connect {
702                service_id,
703                connection_id,
704                client_addr,
705            } => {
706                self.handle_connect(service_id, connection_id, client_addr, ws_sink)
707                    .await?;
708            }
709            Message::Heartbeat { timestamp } => {
710                self.handle_heartbeat(timestamp, ws_sink).await?;
711            }
712            Message::Disconnect { reason } => {
713                return self.handle_disconnect(reason).await;
714            }
715            // Client shouldn't receive these messages
716            Message::Auth { .. }
717            | Message::AuthOk { .. }
718            | Message::AuthFail { .. }
719            | Message::Register { .. }
720            | Message::Unregister { .. }
721            | Message::ConnectAck { .. }
722            | Message::ConnectFail { .. }
723            | Message::HeartbeatAck { .. } => {
724                tracing::warn!(
725                    tunnel_id = %tunnel_id,
726                    msg_type = ?msg.message_type(),
727                    "unexpected message from server"
728                );
729            }
730        }
731        Ok(())
732    }
733
734    /// Handle `RegisterOk` message
735    async fn handle_register_ok(&self, service_id: Uuid, pending_services: &mut Vec<String>) {
736        let name = match pending_services.first().cloned() {
737            Some(n) => {
738                pending_services.remove(0);
739                n
740            }
741            None => return,
742        };
743
744        // Update service state without holding lock across await
745        {
746            let mut services = self.services.write();
747            if let Some(service) = services.get_mut(&name) {
748                service.service_id = Some(service_id);
749                service.status = ServiceStatus::Registered;
750            }
751        }
752
753        tracing::info!(
754            service_name = %name,
755            service_id = %service_id,
756            "service registered"
757        );
758
759        // Notify (after releasing lock)
760        if let Some(ref tx) = self.event_tx {
761            let _ = tx
762                .send(ControlEvent::ServiceRegistered { name, service_id })
763                .await;
764        }
765    }
766
767    /// Handle `RegisterFail` message
768    async fn handle_register_fail(&self, reason: String, pending_services: &mut Vec<String>) {
769        let name = match pending_services.first().cloned() {
770            Some(n) => {
771                pending_services.remove(0);
772                n
773            }
774            None => return,
775        };
776
777        // Update service state without holding lock across await
778        {
779            let mut services = self.services.write();
780            if let Some(service) = services.get_mut(&name) {
781                service.status = ServiceStatus::Failed(reason.clone());
782            }
783        }
784
785        tracing::warn!(
786            service_name = %name,
787            reason = %reason,
788            "service registration failed"
789        );
790
791        // Notify (after releasing lock)
792        if let Some(ref tx) = self.event_tx {
793            let _ = tx.send(ControlEvent::ServiceFailed { name, reason }).await;
794        }
795    }
796
797    /// Handle Connect message (incoming connection)
798    async fn handle_connect<S>(
799        &self,
800        service_id: Uuid,
801        connection_id: Uuid,
802        client_addr: String,
803        ws_sink: &mut S,
804    ) -> Result<()>
805    where
806        S: SinkExt<WsMessage> + Unpin,
807        S::Error: std::error::Error,
808    {
809        tracing::debug!(
810            service_id = %service_id,
811            connection_id = %connection_id,
812            client_addr = %client_addr,
813            "incoming connection"
814        );
815
816        // Notify via event channel
817        if let Some(ref tx) = self.event_tx {
818            let _ = tx
819                .send(ControlEvent::IncomingConnection {
820                    service_id,
821                    connection_id,
822                    client_addr: client_addr.clone(),
823                })
824                .await;
825        }
826
827        // Call the connection callback
828        let accepted = self
829            .connection_callback
830            .as_ref()
831            .is_none_or(|cb| cb(service_id, connection_id, client_addr.clone()));
832
833        // Send response
834        let response = if accepted {
835            Message::ConnectAck { connection_id }
836        } else {
837            Message::ConnectFail {
838                connection_id,
839                reason: "connection rejected by client".to_string(),
840            }
841        };
842
843        ws_sink
844            .send(WsMessage::Binary(response.encode().into()))
845            .await
846            .map_err(|e| TunnelError::connection_msg(e.to_string()))?;
847
848        Ok(())
849    }
850
851    /// Handle Heartbeat message
852    async fn handle_heartbeat<S>(&self, timestamp: u64, ws_sink: &mut S) -> Result<()>
853    where
854        S: SinkExt<WsMessage> + Unpin,
855        S::Error: std::error::Error,
856    {
857        tracing::trace!(timestamp = timestamp, "heartbeat received");
858
859        // Respond with heartbeat ack
860        let ack = Message::HeartbeatAck { timestamp };
861        ws_sink
862            .send(WsMessage::Binary(ack.encode().into()))
863            .await
864            .map_err(|e| TunnelError::connection_msg(e.to_string()))?;
865
866        // Notify
867        if let Some(ref tx) = self.event_tx {
868            let _ = tx.send(ControlEvent::Heartbeat { timestamp }).await;
869        }
870
871        Ok(())
872    }
873
874    /// Handle Disconnect message
875    async fn handle_disconnect(&self, reason: String) -> Result<()> {
876        tracing::info!(reason = %reason, "server requested disconnect");
877
878        // Notify
879        if let Some(ref tx) = self.event_tx {
880            let _ = tx
881                .send(ControlEvent::Disconnected {
882                    reason: reason.clone(),
883                })
884                .await;
885        }
886
887        Err(TunnelError::connection_msg(reason))
888    }
889
890    /// Gracefully disconnect from the server
891    ///
892    /// This signals the agent to stop and disconnect. If the agent is running
893    /// with auto-reconnect (`run()`), it will stop reconnecting.
894    pub fn disconnect(&self) {
895        *self.state.write() = AgentState::Disconnected;
896
897        // Send disconnect command if the agent is running
898        if let Some(ref tx) = self.command_tx {
899            let _ = tx.try_send(ControlCommand::Disconnect);
900        }
901    }
902}
903
904impl Clone for TunnelAgent {
905    fn clone(&self) -> Self {
906        Self {
907            config: self.config.clone(),
908            state: Arc::clone(&self.state),
909            services: Arc::clone(&self.services),
910            connection_callback: self.connection_callback.clone(),
911            command_tx: self.command_tx.clone(),
912            event_tx: self.event_tx.clone(),
913            overlay_resolver: self.overlay_resolver.clone(),
914        }
915    }
916}
917
918// =============================================================================
919// Tests
920// =============================================================================
921
922#[cfg(test)]
923mod tests {
924    use super::*;
925
926    fn create_test_config() -> TunnelClientConfig {
927        TunnelClientConfig::new("ws://localhost:8080/tunnel/v1", "test-token")
928            .with_service(ServiceConfig::tcp("ssh", 22).with_remote_port(2222))
929            .with_service(ServiceConfig::udp("game", 27015))
930    }
931
932    #[test]
933    fn test_agent_state_default() {
934        let state = AgentState::default();
935        assert_eq!(state, AgentState::Disconnected);
936    }
937
938    #[test]
939    fn test_agent_state_variants() {
940        let disconnected = AgentState::Disconnected;
941        let connecting = AgentState::Connecting;
942        let connected = AgentState::Connected {
943            tunnel_id: Uuid::new_v4(),
944        };
945        let reconnecting = AgentState::Reconnecting { attempt: 3 };
946
947        // Just verify they're different
948        assert_ne!(disconnected, connecting);
949        assert_ne!(connecting, connected);
950        assert_ne!(connected, reconnecting);
951    }
952
953    #[test]
954    fn test_service_status_default() {
955        let status = ServiceStatus::default();
956        assert_eq!(status, ServiceStatus::Pending);
957    }
958
959    #[test]
960    fn test_service_status_variants() {
961        assert_eq!(ServiceStatus::Pending, ServiceStatus::Pending);
962        assert_eq!(ServiceStatus::Registered, ServiceStatus::Registered);
963        assert_eq!(
964            ServiceStatus::Failed("error".to_string()),
965            ServiceStatus::Failed("error".to_string())
966        );
967        assert_ne!(
968            ServiceStatus::Failed("error1".to_string()),
969            ServiceStatus::Failed("error2".to_string())
970        );
971    }
972
973    #[test]
974    fn test_registered_service_new() {
975        let config = ServiceConfig::tcp("ssh", 22);
976        let service = RegisteredService::new(config.clone());
977
978        assert_eq!(service.config.name, "ssh");
979        assert!(service.service_id.is_none());
980        assert_eq!(service.status, ServiceStatus::Pending);
981        assert!(!service.is_registered());
982    }
983
984    #[test]
985    fn test_registered_service_is_registered() {
986        let config = ServiceConfig::tcp("ssh", 22);
987        let mut service = RegisteredService::new(config);
988
989        assert!(!service.is_registered());
990
991        service.status = ServiceStatus::Registered;
992        assert!(service.is_registered());
993
994        service.status = ServiceStatus::Failed("error".to_string());
995        assert!(!service.is_registered());
996    }
997
998    #[test]
999    fn test_tunnel_agent_new() {
1000        let config = create_test_config();
1001        let agent = TunnelAgent::new(config);
1002
1003        assert_eq!(agent.state(), AgentState::Disconnected);
1004        assert!(!agent.is_connected());
1005        assert!(agent.tunnel_id().is_none());
1006
1007        let services = agent.services();
1008        assert_eq!(services.len(), 2);
1009    }
1010
1011    #[test]
1012    fn test_tunnel_agent_get_service() {
1013        let config = create_test_config();
1014        let agent = TunnelAgent::new(config);
1015
1016        let ssh = agent.get_service("ssh");
1017        assert!(ssh.is_some());
1018        assert_eq!(ssh.unwrap().config.local_port, 22);
1019
1020        let game = agent.get_service("game");
1021        assert!(game.is_some());
1022        assert_eq!(game.unwrap().config.protocol, ServiceProtocol::Udp);
1023
1024        let nonexistent = agent.get_service("nonexistent");
1025        assert!(nonexistent.is_none());
1026    }
1027
1028    #[test]
1029    fn test_tunnel_agent_on_connection() {
1030        let config = create_test_config();
1031        let callback_called = Arc::new(std::sync::atomic::AtomicBool::new(false));
1032        let callback_called_clone = Arc::clone(&callback_called);
1033
1034        let callback: ConnectionCallback = Arc::new(move |_service_id, _conn_id, _addr| {
1035            callback_called_clone.store(true, std::sync::atomic::Ordering::SeqCst);
1036            true
1037        });
1038
1039        let agent = TunnelAgent::new(config).on_connection(callback);
1040
1041        // Callback is set but not called yet
1042        assert!(!callback_called.load(std::sync::atomic::Ordering::SeqCst));
1043        assert!(agent.connection_callback.is_some());
1044    }
1045
1046    #[test]
1047    fn test_tunnel_agent_clone() {
1048        let config = create_test_config();
1049        let agent = TunnelAgent::new(config);
1050
1051        let cloned = agent.clone();
1052
1053        assert_eq!(agent.state(), cloned.state());
1054        assert_eq!(agent.services().len(), cloned.services().len());
1055    }
1056
1057    #[test]
1058    fn test_tunnel_agent_disconnect() {
1059        let config = create_test_config();
1060        let agent = TunnelAgent::new(config);
1061
1062        // Set state to connected
1063        *agent.state.write() = AgentState::Connected {
1064            tunnel_id: Uuid::new_v4(),
1065        };
1066        assert!(agent.is_connected());
1067
1068        // Disconnect
1069        agent.disconnect();
1070        assert_eq!(agent.state(), AgentState::Disconnected);
1071        assert!(!agent.is_connected());
1072    }
1073
1074    #[test]
1075    fn test_control_event_variants() {
1076        // Just verify the variants exist and can be created
1077        let _auth = ControlEvent::Authenticated {
1078            tunnel_id: Uuid::new_v4(),
1079        };
1080        let _registered = ControlEvent::ServiceRegistered {
1081            name: "ssh".to_string(),
1082            service_id: Uuid::new_v4(),
1083        };
1084        let _failed = ControlEvent::ServiceFailed {
1085            name: "ssh".to_string(),
1086            reason: "error".to_string(),
1087        };
1088        let _incoming = ControlEvent::IncomingConnection {
1089            service_id: Uuid::new_v4(),
1090            connection_id: Uuid::new_v4(),
1091            client_addr: "127.0.0.1:12345".to_string(),
1092        };
1093        let heartbeat = ControlEvent::Heartbeat { timestamp: 12345 };
1094        assert!(matches!(heartbeat, ControlEvent::Heartbeat { .. }));
1095        let _disconnected = ControlEvent::Disconnected {
1096            reason: "test".to_string(),
1097        };
1098        let _error = ControlEvent::Error {
1099            message: "test error".to_string(),
1100        };
1101    }
1102
1103    #[test]
1104    fn test_control_command_variants() {
1105        // Just verify the variants exist and can be created
1106        let _register = ControlCommand::Register {
1107            name: "ssh".to_string(),
1108            protocol: ServiceProtocol::Tcp,
1109            local_port: 22,
1110            remote_port: 2222,
1111        };
1112        let _unregister = ControlCommand::Unregister {
1113            service_id: Uuid::new_v4(),
1114        };
1115        let _ack = ControlCommand::ConnectAck {
1116            connection_id: Uuid::new_v4(),
1117        };
1118        let _fail = ControlCommand::ConnectFail {
1119            connection_id: Uuid::new_v4(),
1120            reason: "error".to_string(),
1121        };
1122        let disconnect = ControlCommand::Disconnect;
1123        assert!(matches!(disconnect, ControlCommand::Disconnect));
1124    }
1125
1126    #[test]
1127    fn test_tunnel_agent_with_event_channel() {
1128        let config = create_test_config();
1129        let (tx, _rx) = mpsc::channel(16);
1130
1131        let agent = TunnelAgent::new(config).with_event_channel(tx);
1132
1133        assert!(agent.event_tx.is_some());
1134    }
1135
1136    #[tokio::test]
1137    async fn test_send_command_not_running() {
1138        let config = create_test_config();
1139        let agent = TunnelAgent::new(config);
1140
1141        let result = agent.send_command(ControlCommand::Disconnect).await;
1142        assert!(result.is_err());
1143        assert!(result
1144            .unwrap_err()
1145            .to_string()
1146            .contains("agent not running"));
1147    }
1148}