Skip to main content

wireguard_netstack/
tunnel.rs

1//! High-level managed WireGuard tunnel.
2//!
3//! This module provides `ManagedTunnel`, a convenient abstraction that handles
4//! all the background tasks required to run a WireGuard tunnel.
5
6use crate::error::{Error, Result};
7use crate::netstack::NetStack;
8use crate::wireguard::{WireGuardConfig, WireGuardTunnel};
9use std::sync::Arc;
10use std::time::Duration;
11use tokio::task::JoinSet;
12
13/// A managed WireGuard tunnel that handles all background tasks automatically.
14///
15/// This is the main entry point for library users. It:
16/// - Creates and configures the WireGuard tunnel
17/// - Creates the userspace network stack
18/// - Spawns all required background tasks
19/// - Performs the WireGuard handshake
20/// - Provides access to the `NetStack` for making TCP connections
21///
22/// # Example
23///
24/// ```no_run
25/// use wireguard_netstack::{ManagedTunnel, WgConfigFile};
26///
27/// #[tokio::main]
28/// async fn main() -> Result<(), Box<dyn std::error::Error>> {
29///     // Load config and connect
30///     let config = WgConfigFile::from_file("wg.conf")?
31///         .into_wireguard_config()
32///         .await?;
33///     
34///     let tunnel = ManagedTunnel::connect(config).await?;
35///     
36///     // Use tunnel.netstack() to create TCP connections
37///     // ...
38///     
39///     // Graceful shutdown
40///     tunnel.shutdown().await;
41///     Ok(())
42/// }
43/// ```
44pub struct ManagedTunnel {
45    /// The underlying WireGuard tunnel.
46    wg_tunnel: Arc<WireGuardTunnel>,
47    /// The userspace network stack.
48    netstack: Arc<NetStack>,
49    /// Background task handles.
50    tasks: JoinSet<()>,
51}
52
53impl ManagedTunnel {
54    /// Connect to a WireGuard peer using the provided configuration.
55    ///
56    /// This will:
57    /// 1. Create the WireGuard tunnel
58    /// 2. Create the userspace network stack
59    /// 3. Spawn all background tasks
60    /// 4. Initiate and wait for the WireGuard handshake
61    ///
62    /// # Arguments
63    ///
64    /// * `config` - WireGuard configuration
65    ///
66    /// # Returns
67    ///
68    /// A `ManagedTunnel` ready to use for making TCP connections.
69    pub async fn connect(config: WireGuardConfig) -> Result<Self> {
70        Self::connect_with_timeout(config, Duration::from_secs(10)).await
71    }
72
73    /// Connect with a custom handshake timeout.
74    pub async fn connect_with_timeout(
75        config: WireGuardConfig,
76        handshake_timeout: Duration,
77    ) -> Result<Self> {
78        log::info!("Creating WireGuard tunnel...");
79        let wg_tunnel = WireGuardTunnel::new(config)
80            .await
81            .map_err(|e| Error::TunnelCreation(e.to_string()))?;
82
83        // Take the incoming receiver before starting tasks
84        let incoming_rx = wg_tunnel
85            .take_incoming_receiver()
86            .ok_or_else(|| Error::TunnelCreation("Failed to get incoming receiver".into()))?;
87
88        // Create the network stack
89        log::info!("Creating userspace network stack...");
90        let netstack = NetStack::new(wg_tunnel.clone());
91
92        // Spawn background tasks
93        log::info!("Starting background tasks...");
94        let mut tasks = JoinSet::new();
95
96        // WireGuard receive loop
97        let wg = wg_tunnel.clone();
98        tasks.spawn(async move {
99            if let Err(e) = wg.run_receive_loop().await {
100                log::error!("WireGuard receive loop error: {}", e);
101            }
102        });
103
104        // WireGuard send loop
105        let wg = wg_tunnel.clone();
106        tasks.spawn(async move {
107            if let Err(e) = wg.run_send_loop().await {
108                log::error!("WireGuard send loop error: {}", e);
109            }
110        });
111
112        // WireGuard timer loop
113        let wg = wg_tunnel.clone();
114        tasks.spawn(async move {
115            if let Err(e) = wg.run_timer_loop().await {
116                log::error!("WireGuard timer loop error: {}", e);
117            }
118        });
119
120        // Network stack poll loop
121        let ns = netstack.clone();
122        tasks.spawn(async move {
123            if let Err(e) = ns.run_poll_loop().await {
124                log::error!("Network stack poll loop error: {}", e);
125            }
126        });
127
128        // Network stack RX loop
129        let ns = netstack.clone();
130        tasks.spawn(async move {
131            if let Err(e) = ns.run_rx_loop(incoming_rx).await {
132                log::error!("Network stack RX loop error: {}", e);
133            }
134        });
135
136        // Give tasks time to start
137        tokio::time::sleep(Duration::from_millis(100)).await;
138
139        // Initiate handshake
140        log::info!("Initiating WireGuard handshake...");
141        wg_tunnel
142            .initiate_handshake()
143            .await
144            .map_err(|e| Error::TunnelCreation(e.to_string()))?;
145
146        // Wait for handshake
147        log::info!("Waiting for WireGuard handshake to complete...");
148        wg_tunnel.wait_for_handshake(handshake_timeout).await?;
149
150        log::info!("WireGuard tunnel established!");
151
152        Ok(Self {
153            wg_tunnel,
154            netstack,
155            tasks,
156        })
157    }
158
159    /// Get the network stack for creating TCP connections.
160    pub fn netstack(&self) -> Arc<NetStack> {
161        self.netstack.clone()
162    }
163
164    /// Get the underlying WireGuard tunnel.
165    pub fn wg_tunnel(&self) -> Arc<WireGuardTunnel> {
166        self.wg_tunnel.clone()
167    }
168
169    /// Returns the time elapsed since the last successful WireGuard handshake.
170    ///
171    /// Returns `Some(duration)` if a handshake has completed, or `None` if no
172    /// handshake has occurred yet. This is useful for health-checking the tunnel:
173    /// WireGuard re-handshakes every ~120s on an active session, so a value
174    /// exceeding ~180s typically indicates the tunnel is stale.
175    ///
176    /// # Example
177    ///
178    /// ```no_run
179    /// use std::time::Duration;
180    /// use wireguard_netstack::ManagedTunnel;
181    ///
182    /// fn check_health(tunnel: &ManagedTunnel) -> bool {
183    ///     match tunnel.time_since_last_handshake() {
184    ///         Some(elapsed) => elapsed < Duration::from_secs(180),
185    ///         None => false,
186    ///     }
187    /// }
188    /// ```
189    pub fn time_since_last_handshake(&self) -> Option<Duration> {
190        self.wg_tunnel.time_since_last_handshake()
191    }
192
193    /// Gracefully shutdown the tunnel.
194    ///
195    /// This aborts all background tasks and waits for them to complete.
196    pub async fn shutdown(mut self) {
197        log::info!("Shutting down WireGuard tunnel...");
198        self.tasks.abort_all();
199        while self.tasks.join_next().await.is_some() {}
200        log::info!("WireGuard tunnel shutdown complete.");
201    }
202}
203
204impl Drop for ManagedTunnel {
205    fn drop(&mut self) {
206        // Abort all tasks on drop
207        self.tasks.abort_all();
208    }
209}