1use std::collections::HashMap;
11use std::sync::Arc;
12use std::time::Duration;
13
14use futures_util::{SinkExt, StreamExt};
15use parking_lot::RwLock;
16use tokio::sync::mpsc;
17use tokio::time::{interval, timeout};
18use tokio_tungstenite::tungstenite::Message as WsMessage;
19use uuid::Uuid;
20
21use crate::client::connector::OverlayAwareConnector;
22use crate::overlay::DynOverlayResolver;
23use crate::{Message, Result, ServiceConfig, ServiceProtocol, TunnelClientConfig, TunnelError};
24
25#[derive(Debug, Clone, Default, PartialEq, Eq)]
31pub enum AgentState {
32 #[default]
34 Disconnected,
35 Connecting,
37 Connected {
39 tunnel_id: Uuid,
41 },
42 Reconnecting {
44 attempt: u32,
46 },
47}
48
49#[derive(Debug, Clone, Default, PartialEq, Eq)]
55pub enum ServiceStatus {
56 #[default]
58 Pending,
59 Registered,
61 Failed(String),
63}
64
65#[derive(Debug, Clone)]
71pub struct RegisteredService {
72 pub config: ServiceConfig,
74 pub service_id: Option<Uuid>,
76 pub status: ServiceStatus,
78}
79
80impl RegisteredService {
81 #[must_use]
83 pub fn new(config: ServiceConfig) -> Self {
84 Self {
85 config,
86 service_id: None,
87 status: ServiceStatus::Pending,
88 }
89 }
90
91 #[must_use]
93 pub fn is_registered(&self) -> bool {
94 matches!(self.status, ServiceStatus::Registered)
95 }
96}
97
98#[derive(Debug, Clone)]
104pub enum ControlEvent {
105 Authenticated {
107 tunnel_id: Uuid,
109 },
110 ServiceRegistered {
112 name: String,
114 service_id: Uuid,
116 },
117 ServiceFailed {
119 name: String,
121 reason: String,
123 },
124 IncomingConnection {
126 service_id: Uuid,
128 connection_id: Uuid,
130 client_addr: String,
132 },
133 Heartbeat {
135 timestamp: u64,
137 },
138 Disconnected {
140 reason: String,
142 },
143 Error {
145 message: String,
147 },
148}
149
150#[derive(Debug, Clone)]
156pub enum ControlCommand {
157 Register {
159 name: String,
161 protocol: ServiceProtocol,
163 local_port: u16,
165 remote_port: u16,
167 },
168 Unregister {
170 service_id: Uuid,
172 },
173 ConnectAck {
175 connection_id: Uuid,
177 },
178 ConnectFail {
180 connection_id: Uuid,
182 reason: String,
184 },
185 Disconnect,
187}
188
189pub type ConnectionCallback = Arc<dyn Fn(Uuid, Uuid, String) -> bool + Send + Sync>;
198
199pub struct TunnelAgent {
236 config: TunnelClientConfig,
238 state: Arc<RwLock<AgentState>>,
240 services: Arc<RwLock<HashMap<String, RegisteredService>>>,
242 connection_callback: Option<ConnectionCallback>,
244 command_tx: Option<mpsc::Sender<ControlCommand>>,
246 event_tx: Option<mpsc::Sender<ControlEvent>>,
248 overlay_resolver: Option<DynOverlayResolver>,
250}
251
252impl TunnelAgent {
253 #[must_use]
258 pub fn new(config: TunnelClientConfig) -> Self {
259 let services: HashMap<String, RegisteredService> = config
261 .services
262 .iter()
263 .map(|s| (s.name.clone(), RegisteredService::new(s.clone())))
264 .collect();
265
266 Self {
267 config,
268 state: Arc::new(RwLock::new(AgentState::Disconnected)),
269 services: Arc::new(RwLock::new(services)),
270 connection_callback: None,
271 command_tx: None,
272 event_tx: None,
273 overlay_resolver: None,
274 }
275 }
276
277 #[must_use]
282 pub fn on_connection(mut self, callback: ConnectionCallback) -> Self {
283 self.connection_callback = Some(callback);
284 self
285 }
286
287 #[must_use]
291 pub fn with_event_channel(mut self, tx: mpsc::Sender<ControlEvent>) -> Self {
292 self.event_tx = Some(tx);
293 self
294 }
295
296 #[must_use]
298 pub fn with_overlay_resolver(mut self, resolver: DynOverlayResolver) -> Self {
299 self.overlay_resolver = Some(resolver);
300 self
301 }
302
303 #[must_use]
305 pub fn state(&self) -> AgentState {
306 self.state.read().clone()
307 }
308
309 #[must_use]
311 pub fn get_service(&self, name: &str) -> Option<RegisteredService> {
312 self.services.read().get(name).cloned()
313 }
314
315 #[must_use]
317 pub fn services(&self) -> Vec<RegisteredService> {
318 self.services.read().values().cloned().collect()
319 }
320
321 #[must_use]
323 pub fn is_connected(&self) -> bool {
324 matches!(*self.state.read(), AgentState::Connected { .. })
325 }
326
327 #[must_use]
329 pub fn tunnel_id(&self) -> Option<Uuid> {
330 match *self.state.read() {
331 AgentState::Connected { tunnel_id } => Some(tunnel_id),
332 _ => None,
333 }
334 }
335
336 pub async fn send_command(&self, command: ControlCommand) -> Result<()> {
342 let tx = self
343 .command_tx
344 .as_ref()
345 .ok_or_else(|| TunnelError::connection_msg("agent not running"))?;
346
347 tx.send(command)
348 .await
349 .map_err(|_| TunnelError::connection_msg("command channel closed"))
350 }
351
352 pub async fn run(&self) -> Result<()> {
364 self.config.validate().map_err(TunnelError::config)?;
366
367 let mut current_interval = self.config.reconnect_interval;
368 let mut attempt = 0u32;
369
370 loop {
371 attempt += 1;
372 *self.state.write() = AgentState::Reconnecting { attempt };
373
374 tracing::info!(
375 attempt = attempt,
376 interval_ms = current_interval.as_millis(),
377 "attempting to connect"
378 );
379
380 match self.run_once().await {
381 Ok(()) => {
382 tracing::info!("agent shutting down");
384 return Ok(());
385 }
386 Err(TunnelError::Shutdown) => {
387 tracing::info!("agent received shutdown signal");
389 return Ok(());
390 }
391 Err(e) => {
392 tracing::warn!(error = %e, "connection failed, will retry");
393
394 if let Some(ref tx) = self.event_tx {
396 let _ = tx
397 .send(ControlEvent::Disconnected {
398 reason: e.to_string(),
399 })
400 .await;
401 }
402 }
403 }
404
405 {
407 let mut services = self.services.write();
408 for service in services.values_mut() {
409 service.service_id = None;
410 service.status = ServiceStatus::Pending;
411 }
412 }
413
414 tokio::time::sleep(current_interval).await;
416
417 current_interval = std::cmp::min(
419 current_interval.saturating_mul(2),
420 self.config.max_reconnect_interval,
421 );
422 }
423 }
424
425 pub async fn run_once(&self) -> Result<()> {
437 *self.state.write() = AgentState::Connecting;
438
439 tracing::debug!(url = %self.config.server_url, "connecting to server");
441
442 let connector = OverlayAwareConnector::new(
443 &self.config.server_url,
444 self.config.overlay_server_url.as_deref(),
445 self.config.routing_mode,
446 self.overlay_resolver.clone(),
447 );
448 let (ws_stream, _response) = connector.connect().await?;
449
450 let (mut ws_sink, mut ws_stream) = ws_stream.split();
451
452 let client_id = Uuid::new_v4();
454
455 let auth_msg = Message::Auth {
457 token: self.config.token.clone(),
458 client_id,
459 };
460 ws_sink
461 .send(WsMessage::Binary(auth_msg.encode().into()))
462 .await
463 .map_err(TunnelError::connection)?;
464
465 let auth_timeout = Duration::from_secs(10);
467 let auth_response = timeout(auth_timeout, async {
468 while let Some(msg) = ws_stream.next().await {
469 match msg {
470 Ok(WsMessage::Binary(data)) => {
471 return Message::decode(&data).map(|(m, _)| m);
472 }
473 Ok(WsMessage::Close(frame)) => {
474 let reason = frame.map_or_else(
475 || "connection closed".to_string(),
476 |f| f.reason.to_string(),
477 );
478 return Err(TunnelError::connection_msg(reason));
479 }
480 Ok(_) => {} Err(e) => return Err(TunnelError::connection(e)),
482 }
483 }
484 Err(TunnelError::connection_msg("connection closed before auth"))
485 })
486 .await
487 .map_err(|_| TunnelError::timeout())??;
488
489 let tunnel_id = match auth_response {
491 Message::AuthOk { tunnel_id } => tunnel_id,
492 Message::AuthFail { reason } => {
493 return Err(TunnelError::auth(reason));
494 }
495 other => {
496 return Err(TunnelError::protocol(format!(
497 "expected AuthOk or AuthFail, got {:?}",
498 other.message_type()
499 )));
500 }
501 };
502
503 *self.state.write() = AgentState::Connected { tunnel_id };
504
505 tracing::info!(
506 tunnel_id = %tunnel_id,
507 client_id = %client_id,
508 "authenticated with server"
509 );
510
511 if let Some(ref tx) = self.event_tx {
513 let _ = tx.send(ControlEvent::Authenticated { tunnel_id }).await;
514 }
515
516 self.register_services(&mut ws_sink).await?;
518
519 self.run_message_loop(tunnel_id, &mut ws_sink, &mut ws_stream)
521 .await
522 }
523
524 async fn register_services<S>(&self, ws_sink: &mut S) -> Result<()>
526 where
527 S: SinkExt<WsMessage> + Unpin,
528 S::Error: std::error::Error,
529 {
530 let services: Vec<ServiceConfig> = {
531 self.services
532 .read()
533 .values()
534 .map(|s| s.config.clone())
535 .collect()
536 };
537
538 for service in services {
539 let register_msg = Message::Register {
540 name: service.name.clone(),
541 protocol: service.protocol,
542 local_port: service.local_port,
543 remote_port: service.remote_port,
544 };
545
546 tracing::debug!(
547 service_name = %service.name,
548 local_port = service.local_port,
549 "registering service"
550 );
551
552 ws_sink
553 .send(WsMessage::Binary(register_msg.encode().into()))
554 .await
555 .map_err(|e| TunnelError::connection_msg(e.to_string()))?;
556 }
557
558 Ok(())
559 }
560
561 async fn run_message_loop<Sink, Stream>(
563 &self,
564 tunnel_id: Uuid,
565 ws_sink: &mut Sink,
566 ws_stream: &mut Stream,
567 ) -> Result<()>
568 where
569 Sink: SinkExt<WsMessage> + Unpin,
570 Sink::Error: std::error::Error,
571 Stream: StreamExt<Item = std::result::Result<WsMessage, tokio_tungstenite::tungstenite::Error>>
572 + Unpin,
573 {
574 let (_command_tx, mut command_rx) = mpsc::channel::<ControlCommand>(256);
576
577 let mut pending_services: Vec<String> = { self.services.read().keys().cloned().collect() };
584
585 let mut check_interval = interval(Duration::from_secs(5));
587
588 loop {
589 tokio::select! {
590 _ = check_interval.tick() => {
592 }
594
595 Some(command) = command_rx.recv() => {
597 match command {
598 ControlCommand::Register { name, protocol, local_port, remote_port } => {
599 let msg = Message::Register {
600 name: name.clone(),
601 protocol,
602 local_port,
603 remote_port,
604 };
605 ws_sink
606 .send(WsMessage::Binary(msg.encode().into()))
607 .await
608 .map_err(|e| TunnelError::connection_msg(e.to_string()))?;
609 pending_services.push(name);
610 }
611 ControlCommand::Unregister { service_id } => {
612 let msg = Message::Unregister { service_id };
613 ws_sink
614 .send(WsMessage::Binary(msg.encode().into()))
615 .await
616 .map_err(|e| TunnelError::connection_msg(e.to_string()))?;
617 }
618 ControlCommand::ConnectAck { connection_id } => {
619 let msg = Message::ConnectAck { connection_id };
620 ws_sink
621 .send(WsMessage::Binary(msg.encode().into()))
622 .await
623 .map_err(|e| TunnelError::connection_msg(e.to_string()))?;
624 }
625 ControlCommand::ConnectFail { connection_id, reason } => {
626 let msg = Message::ConnectFail { connection_id, reason };
627 ws_sink
628 .send(WsMessage::Binary(msg.encode().into()))
629 .await
630 .map_err(|e| TunnelError::connection_msg(e.to_string()))?;
631 }
632 ControlCommand::Disconnect => {
633 tracing::info!("disconnect command received");
634 return Ok(());
635 }
636 }
637 }
638
639 Some(msg_result) = ws_stream.next() => {
641 match msg_result {
642 Ok(WsMessage::Binary(data)) => {
643 let (msg, _) = Message::decode(&data)?;
644 self.handle_server_message(
645 tunnel_id,
646 msg,
647 ws_sink,
648 &mut pending_services,
649 ).await?;
650 }
651 Ok(WsMessage::Close(frame)) => {
652 let reason = frame.map_or_else(
653 || "server closed connection".to_string(),
654 |f| f.reason.to_string(),
655 );
656 tracing::info!(reason = %reason, "server closed connection");
657 return Err(TunnelError::connection_msg(reason));
658 }
659 Ok(WsMessage::Ping(data)) => {
660 ws_sink
661 .send(WsMessage::Pong(data))
662 .await
663 .map_err(|e| TunnelError::connection_msg(e.to_string()))?;
664 }
665 Ok(_) => {} Err(e) => {
667 return Err(TunnelError::connection(e));
668 }
669 }
670 }
671
672 else => {
673 break;
675 }
676 }
677 }
678
679 Ok(())
680 }
681
682 async fn handle_server_message<S>(
684 &self,
685 tunnel_id: Uuid,
686 msg: Message,
687 ws_sink: &mut S,
688 pending_services: &mut Vec<String>,
689 ) -> Result<()>
690 where
691 S: SinkExt<WsMessage> + Unpin,
692 S::Error: std::error::Error,
693 {
694 match msg {
695 Message::RegisterOk { service_id } => {
696 self.handle_register_ok(service_id, pending_services).await;
697 }
698 Message::RegisterFail { reason } => {
699 self.handle_register_fail(reason, pending_services).await;
700 }
701 Message::Connect {
702 service_id,
703 connection_id,
704 client_addr,
705 } => {
706 self.handle_connect(service_id, connection_id, client_addr, ws_sink)
707 .await?;
708 }
709 Message::Heartbeat { timestamp } => {
710 self.handle_heartbeat(timestamp, ws_sink).await?;
711 }
712 Message::Disconnect { reason } => {
713 return self.handle_disconnect(reason).await;
714 }
715 Message::Auth { .. }
717 | Message::AuthOk { .. }
718 | Message::AuthFail { .. }
719 | Message::Register { .. }
720 | Message::Unregister { .. }
721 | Message::ConnectAck { .. }
722 | Message::ConnectFail { .. }
723 | Message::HeartbeatAck { .. } => {
724 tracing::warn!(
725 tunnel_id = %tunnel_id,
726 msg_type = ?msg.message_type(),
727 "unexpected message from server"
728 );
729 }
730 }
731 Ok(())
732 }
733
734 async fn handle_register_ok(&self, service_id: Uuid, pending_services: &mut Vec<String>) {
736 let name = match pending_services.first().cloned() {
737 Some(n) => {
738 pending_services.remove(0);
739 n
740 }
741 None => return,
742 };
743
744 {
746 let mut services = self.services.write();
747 if let Some(service) = services.get_mut(&name) {
748 service.service_id = Some(service_id);
749 service.status = ServiceStatus::Registered;
750 }
751 }
752
753 tracing::info!(
754 service_name = %name,
755 service_id = %service_id,
756 "service registered"
757 );
758
759 if let Some(ref tx) = self.event_tx {
761 let _ = tx
762 .send(ControlEvent::ServiceRegistered { name, service_id })
763 .await;
764 }
765 }
766
767 async fn handle_register_fail(&self, reason: String, pending_services: &mut Vec<String>) {
769 let name = match pending_services.first().cloned() {
770 Some(n) => {
771 pending_services.remove(0);
772 n
773 }
774 None => return,
775 };
776
777 {
779 let mut services = self.services.write();
780 if let Some(service) = services.get_mut(&name) {
781 service.status = ServiceStatus::Failed(reason.clone());
782 }
783 }
784
785 tracing::warn!(
786 service_name = %name,
787 reason = %reason,
788 "service registration failed"
789 );
790
791 if let Some(ref tx) = self.event_tx {
793 let _ = tx.send(ControlEvent::ServiceFailed { name, reason }).await;
794 }
795 }
796
797 async fn handle_connect<S>(
799 &self,
800 service_id: Uuid,
801 connection_id: Uuid,
802 client_addr: String,
803 ws_sink: &mut S,
804 ) -> Result<()>
805 where
806 S: SinkExt<WsMessage> + Unpin,
807 S::Error: std::error::Error,
808 {
809 tracing::debug!(
810 service_id = %service_id,
811 connection_id = %connection_id,
812 client_addr = %client_addr,
813 "incoming connection"
814 );
815
816 if let Some(ref tx) = self.event_tx {
818 let _ = tx
819 .send(ControlEvent::IncomingConnection {
820 service_id,
821 connection_id,
822 client_addr: client_addr.clone(),
823 })
824 .await;
825 }
826
827 let accepted = self
829 .connection_callback
830 .as_ref()
831 .is_none_or(|cb| cb(service_id, connection_id, client_addr.clone()));
832
833 let response = if accepted {
835 Message::ConnectAck { connection_id }
836 } else {
837 Message::ConnectFail {
838 connection_id,
839 reason: "connection rejected by client".to_string(),
840 }
841 };
842
843 ws_sink
844 .send(WsMessage::Binary(response.encode().into()))
845 .await
846 .map_err(|e| TunnelError::connection_msg(e.to_string()))?;
847
848 Ok(())
849 }
850
851 async fn handle_heartbeat<S>(&self, timestamp: u64, ws_sink: &mut S) -> Result<()>
853 where
854 S: SinkExt<WsMessage> + Unpin,
855 S::Error: std::error::Error,
856 {
857 tracing::trace!(timestamp = timestamp, "heartbeat received");
858
859 let ack = Message::HeartbeatAck { timestamp };
861 ws_sink
862 .send(WsMessage::Binary(ack.encode().into()))
863 .await
864 .map_err(|e| TunnelError::connection_msg(e.to_string()))?;
865
866 if let Some(ref tx) = self.event_tx {
868 let _ = tx.send(ControlEvent::Heartbeat { timestamp }).await;
869 }
870
871 Ok(())
872 }
873
874 async fn handle_disconnect(&self, reason: String) -> Result<()> {
876 tracing::info!(reason = %reason, "server requested disconnect");
877
878 if let Some(ref tx) = self.event_tx {
880 let _ = tx
881 .send(ControlEvent::Disconnected {
882 reason: reason.clone(),
883 })
884 .await;
885 }
886
887 Err(TunnelError::connection_msg(reason))
888 }
889
890 pub fn disconnect(&self) {
895 *self.state.write() = AgentState::Disconnected;
896
897 if let Some(ref tx) = self.command_tx {
899 let _ = tx.try_send(ControlCommand::Disconnect);
900 }
901 }
902}
903
904impl Clone for TunnelAgent {
905 fn clone(&self) -> Self {
906 Self {
907 config: self.config.clone(),
908 state: Arc::clone(&self.state),
909 services: Arc::clone(&self.services),
910 connection_callback: self.connection_callback.clone(),
911 command_tx: self.command_tx.clone(),
912 event_tx: self.event_tx.clone(),
913 overlay_resolver: self.overlay_resolver.clone(),
914 }
915 }
916}
917
918#[cfg(test)]
923mod tests {
924 use super::*;
925
926 fn create_test_config() -> TunnelClientConfig {
927 TunnelClientConfig::new("ws://localhost:8080/tunnel/v1", "test-token")
928 .with_service(ServiceConfig::tcp("ssh", 22).with_remote_port(2222))
929 .with_service(ServiceConfig::udp("game", 27015))
930 }
931
932 #[test]
933 fn test_agent_state_default() {
934 let state = AgentState::default();
935 assert_eq!(state, AgentState::Disconnected);
936 }
937
938 #[test]
939 fn test_agent_state_variants() {
940 let disconnected = AgentState::Disconnected;
941 let connecting = AgentState::Connecting;
942 let connected = AgentState::Connected {
943 tunnel_id: Uuid::new_v4(),
944 };
945 let reconnecting = AgentState::Reconnecting { attempt: 3 };
946
947 assert_ne!(disconnected, connecting);
949 assert_ne!(connecting, connected);
950 assert_ne!(connected, reconnecting);
951 }
952
953 #[test]
954 fn test_service_status_default() {
955 let status = ServiceStatus::default();
956 assert_eq!(status, ServiceStatus::Pending);
957 }
958
959 #[test]
960 fn test_service_status_variants() {
961 assert_eq!(ServiceStatus::Pending, ServiceStatus::Pending);
962 assert_eq!(ServiceStatus::Registered, ServiceStatus::Registered);
963 assert_eq!(
964 ServiceStatus::Failed("error".to_string()),
965 ServiceStatus::Failed("error".to_string())
966 );
967 assert_ne!(
968 ServiceStatus::Failed("error1".to_string()),
969 ServiceStatus::Failed("error2".to_string())
970 );
971 }
972
973 #[test]
974 fn test_registered_service_new() {
975 let config = ServiceConfig::tcp("ssh", 22);
976 let service = RegisteredService::new(config.clone());
977
978 assert_eq!(service.config.name, "ssh");
979 assert!(service.service_id.is_none());
980 assert_eq!(service.status, ServiceStatus::Pending);
981 assert!(!service.is_registered());
982 }
983
984 #[test]
985 fn test_registered_service_is_registered() {
986 let config = ServiceConfig::tcp("ssh", 22);
987 let mut service = RegisteredService::new(config);
988
989 assert!(!service.is_registered());
990
991 service.status = ServiceStatus::Registered;
992 assert!(service.is_registered());
993
994 service.status = ServiceStatus::Failed("error".to_string());
995 assert!(!service.is_registered());
996 }
997
998 #[test]
999 fn test_tunnel_agent_new() {
1000 let config = create_test_config();
1001 let agent = TunnelAgent::new(config);
1002
1003 assert_eq!(agent.state(), AgentState::Disconnected);
1004 assert!(!agent.is_connected());
1005 assert!(agent.tunnel_id().is_none());
1006
1007 let services = agent.services();
1008 assert_eq!(services.len(), 2);
1009 }
1010
1011 #[test]
1012 fn test_tunnel_agent_get_service() {
1013 let config = create_test_config();
1014 let agent = TunnelAgent::new(config);
1015
1016 let ssh = agent.get_service("ssh");
1017 assert!(ssh.is_some());
1018 assert_eq!(ssh.unwrap().config.local_port, 22);
1019
1020 let game = agent.get_service("game");
1021 assert!(game.is_some());
1022 assert_eq!(game.unwrap().config.protocol, ServiceProtocol::Udp);
1023
1024 let nonexistent = agent.get_service("nonexistent");
1025 assert!(nonexistent.is_none());
1026 }
1027
1028 #[test]
1029 fn test_tunnel_agent_on_connection() {
1030 let config = create_test_config();
1031 let callback_called = Arc::new(std::sync::atomic::AtomicBool::new(false));
1032 let callback_called_clone = Arc::clone(&callback_called);
1033
1034 let callback: ConnectionCallback = Arc::new(move |_service_id, _conn_id, _addr| {
1035 callback_called_clone.store(true, std::sync::atomic::Ordering::SeqCst);
1036 true
1037 });
1038
1039 let agent = TunnelAgent::new(config).on_connection(callback);
1040
1041 assert!(!callback_called.load(std::sync::atomic::Ordering::SeqCst));
1043 assert!(agent.connection_callback.is_some());
1044 }
1045
1046 #[test]
1047 fn test_tunnel_agent_clone() {
1048 let config = create_test_config();
1049 let agent = TunnelAgent::new(config);
1050
1051 let cloned = agent.clone();
1052
1053 assert_eq!(agent.state(), cloned.state());
1054 assert_eq!(agent.services().len(), cloned.services().len());
1055 }
1056
1057 #[test]
1058 fn test_tunnel_agent_disconnect() {
1059 let config = create_test_config();
1060 let agent = TunnelAgent::new(config);
1061
1062 *agent.state.write() = AgentState::Connected {
1064 tunnel_id: Uuid::new_v4(),
1065 };
1066 assert!(agent.is_connected());
1067
1068 agent.disconnect();
1070 assert_eq!(agent.state(), AgentState::Disconnected);
1071 assert!(!agent.is_connected());
1072 }
1073
1074 #[test]
1075 fn test_control_event_variants() {
1076 let _auth = ControlEvent::Authenticated {
1078 tunnel_id: Uuid::new_v4(),
1079 };
1080 let _registered = ControlEvent::ServiceRegistered {
1081 name: "ssh".to_string(),
1082 service_id: Uuid::new_v4(),
1083 };
1084 let _failed = ControlEvent::ServiceFailed {
1085 name: "ssh".to_string(),
1086 reason: "error".to_string(),
1087 };
1088 let _incoming = ControlEvent::IncomingConnection {
1089 service_id: Uuid::new_v4(),
1090 connection_id: Uuid::new_v4(),
1091 client_addr: "127.0.0.1:12345".to_string(),
1092 };
1093 let heartbeat = ControlEvent::Heartbeat { timestamp: 12345 };
1094 assert!(matches!(heartbeat, ControlEvent::Heartbeat { .. }));
1095 let _disconnected = ControlEvent::Disconnected {
1096 reason: "test".to_string(),
1097 };
1098 let _error = ControlEvent::Error {
1099 message: "test error".to_string(),
1100 };
1101 }
1102
1103 #[test]
1104 fn test_control_command_variants() {
1105 let _register = ControlCommand::Register {
1107 name: "ssh".to_string(),
1108 protocol: ServiceProtocol::Tcp,
1109 local_port: 22,
1110 remote_port: 2222,
1111 };
1112 let _unregister = ControlCommand::Unregister {
1113 service_id: Uuid::new_v4(),
1114 };
1115 let _ack = ControlCommand::ConnectAck {
1116 connection_id: Uuid::new_v4(),
1117 };
1118 let _fail = ControlCommand::ConnectFail {
1119 connection_id: Uuid::new_v4(),
1120 reason: "error".to_string(),
1121 };
1122 let disconnect = ControlCommand::Disconnect;
1123 assert!(matches!(disconnect, ControlCommand::Disconnect));
1124 }
1125
1126 #[test]
1127 fn test_tunnel_agent_with_event_channel() {
1128 let config = create_test_config();
1129 let (tx, _rx) = mpsc::channel(16);
1130
1131 let agent = TunnelAgent::new(config).with_event_channel(tx);
1132
1133 assert!(agent.event_tx.is_some());
1134 }
1135
1136 #[tokio::test]
1137 async fn test_send_command_not_running() {
1138 let config = create_test_config();
1139 let agent = TunnelAgent::new(config);
1140
1141 let result = agent.send_command(ControlCommand::Disconnect).await;
1142 assert!(result.is_err());
1143 assert!(result
1144 .unwrap_err()
1145 .to_string()
1146 .contains("agent not running"));
1147 }
1148}