wireguard_netstack/
wireguard.rs1use bytes::BytesMut;
7use gotatun::noise::{Tunn, TunnResult};
8use gotatun::packet::Packet;
9use gotatun::x25519::{PublicKey, StaticSecret};
10use parking_lot::Mutex;
11use zerocopy::IntoBytes;
12use std::net::{Ipv4Addr, SocketAddr};
13use std::sync::Arc;
14use std::time::Duration;
15use tokio::net::UdpSocket;
16use tokio::sync::mpsc;
17
18use crate::error::{Error, Result};
19
20#[derive(Clone)]
22pub struct WireGuardConfig {
23 pub private_key: [u8; 32],
25 pub peer_public_key: [u8; 32],
27 pub peer_endpoint: SocketAddr,
29 pub tunnel_ip: Ipv4Addr,
31 pub preshared_key: Option<[u8; 32]>,
33 pub keepalive_seconds: Option<u16>,
35 pub mtu: Option<u16>,
37}
38
39pub struct WireGuardTunnel {
41 tunn: Mutex<Tunn>,
43 udp_socket: Arc<UdpSocket>,
45 peer_endpoint: SocketAddr,
47 tunnel_ip: Ipv4Addr,
49 mtu: u16,
51 incoming_tx: mpsc::Sender<BytesMut>,
53 outgoing_rx: tokio::sync::Mutex<mpsc::Receiver<BytesMut>>,
55 incoming_rx: Mutex<Option<mpsc::Receiver<BytesMut>>>,
57 outgoing_tx: mpsc::Sender<BytesMut>,
59}
60
61impl WireGuardTunnel {
62 pub async fn new(config: WireGuardConfig) -> Result<Arc<Self>> {
64 let private_key = StaticSecret::from(config.private_key);
66 let peer_public_key = PublicKey::from(config.peer_public_key);
67
68 let tunn = Tunn::new(
70 private_key,
71 peer_public_key,
72 config.preshared_key,
73 config.keepalive_seconds,
74 rand::random::<u32>() >> 8, None, );
77
78 let udp_socket = UdpSocket::bind("0.0.0.0:0").await?;
80
81 let sock_ref = socket2::SockRef::from(&udp_socket);
83 if let Err(e) = sock_ref.set_recv_buffer_size(1024 * 1024) {
84 log::warn!("Failed to set UDP recv buffer size: {}", e);
86 }
87 if let Err(e) = sock_ref.set_send_buffer_size(1024 * 1024) {
88 log::warn!("Failed to set UDP send buffer size: {}", e);
90 }
91 log::info!("UDP recv buffer size: {:?}", sock_ref.recv_buffer_size());
92 log::info!("UDP send buffer size: {:?}", sock_ref.send_buffer_size());
93
94 log::info!(
95 "WireGuard UDP socket bound to {}",
96 udp_socket.local_addr()?
97 );
98
99 let (incoming_tx, incoming_rx) = mpsc::channel(256);
103 let (outgoing_tx, outgoing_rx) = mpsc::channel(256);
104
105 let tunnel = Arc::new(Self {
106 tunn: Mutex::new(tunn),
107 udp_socket: Arc::new(udp_socket),
108 peer_endpoint: config.peer_endpoint,
109 tunnel_ip: config.tunnel_ip,
110 mtu: config.mtu.unwrap_or(460), incoming_tx,
112 incoming_rx: Mutex::new(Some(incoming_rx)),
113 outgoing_tx,
114 outgoing_rx: tokio::sync::Mutex::new(outgoing_rx),
115 });
116
117 Ok(tunnel)
118 }
119
120 pub fn tunnel_ip(&self) -> Ipv4Addr {
122 self.tunnel_ip
123 }
124
125 pub fn mtu(&self) -> u16 {
127 self.mtu
128 }
129
130 pub fn outgoing_sender(&self) -> mpsc::Sender<BytesMut> {
132 self.outgoing_tx.clone()
133 }
134
135 pub fn take_incoming_receiver(&self) -> Option<mpsc::Receiver<BytesMut>> {
137 self.incoming_rx.lock().take()
138 }
139
140 pub async fn initiate_handshake(&self) -> Result<()> {
142 log::info!("Initiating WireGuard handshake...");
143
144 let handshake_init = {
145 let mut tunn = self.tunn.lock();
146 tunn.format_handshake_initiation(false)
147 };
148
149 if let Some(packet) = handshake_init {
150 let data = packet.as_bytes();
152 self.udp_socket.send_to(data, self.peer_endpoint).await?;
153 log::debug!("Sent handshake initiation ({} bytes)", data.len());
154 }
155
156 Ok(())
157 }
158
159 pub async fn send_ip_packet(&self, packet: BytesMut) -> Result<()> {
161 let encrypted = {
162 let mut tunn = self.tunn.lock();
163 let pkt = Packet::from_bytes(packet);
164 tunn.handle_outgoing_packet(pkt)
165 };
166
167 if let Some(wg_packet) = encrypted {
168 let pkt: Packet = wg_packet.into();
170 let data = pkt.as_bytes();
171 self.udp_socket.send_to(data, self.peer_endpoint).await?;
172 log::trace!("Sent encrypted packet ({} bytes)", data.len());
173 }
174
175 Ok(())
176 }
177
178 fn process_incoming_udp(&self, data: &[u8]) -> Option<BytesMut> {
180 let packet = Packet::from_bytes(BytesMut::from(data));
181 let wg_packet = match packet.try_into_wg() {
182 Ok(wg) => wg,
183 Err(_) => {
184 log::warn!("Received non-WireGuard packet");
185 return None;
186 }
187 };
188
189 let mut tunn = self.tunn.lock();
190 match tunn.handle_incoming_packet(wg_packet) {
191 TunnResult::Done => {
192 log::trace!("WG: Packet processed (no output)");
193 None
194 }
195 TunnResult::Err(e) => {
196 log::warn!("WG error: {:?}", e);
197 None
198 }
199 TunnResult::WriteToNetwork(response) => {
200 log::trace!("WG: Sending response packet");
201 let pkt: Packet = response.into();
203 let data = BytesMut::from(pkt.as_bytes());
204 let socket = self.udp_socket.clone();
205 let endpoint = self.peer_endpoint;
206 tokio::spawn(async move {
207 if let Err(e) = socket.send_to(&data, endpoint).await {
208 log::error!("Failed to send response: {}", e);
209 }
210 });
211
212 while let Some(queued) = tunn.next_queued_packet() {
214 let pkt: Packet = queued.into();
215 let data = BytesMut::from(pkt.as_bytes());
216 let socket = self.udp_socket.clone();
217 let endpoint = self.peer_endpoint;
218 tokio::spawn(async move {
219 if let Err(e) = socket.send_to(&data, endpoint).await {
220 log::error!("Failed to send queued packet: {}", e);
221 }
222 });
223 }
224
225 None
226 }
227 TunnResult::WriteToTunnel(decrypted) => {
228 if decrypted.is_empty() {
229 log::trace!("WG: Received keepalive");
230 return None;
231 }
232 let bytes = BytesMut::from(decrypted.as_bytes());
233 log::trace!("WG: Decrypted {} bytes", bytes.len());
234 Some(bytes)
235 }
236 }
237 }
238
239 pub async fn run_receive_loop(self: &Arc<Self>) -> Result<()> {
241 let mut buf = vec![0u8; 65535];
242
243 loop {
244 match self.udp_socket.recv_from(&mut buf).await {
245 Ok((len, from)) => {
246 if from != self.peer_endpoint {
247 log::warn!("Received packet from unknown peer: {}", from);
248 continue;
249 }
250
251 log::trace!("Received UDP packet ({} bytes) from {}", len, from);
252
253 if let Some(ip_packet) = self.process_incoming_udp(&buf[..len]) {
254 if self.incoming_tx.send(ip_packet).await.is_err() {
255 log::error!("Incoming channel closed");
256 break;
257 }
258 }
259 }
260 Err(e) => {
261 log::error!("UDP receive error: {}", e);
262 break;
263 }
264 }
265 }
266
267 Ok(())
268 }
269
270 pub async fn run_send_loop(self: &Arc<Self>) -> Result<()> {
272 let mut outgoing_rx = self.outgoing_rx.lock().await;
273
274 while let Some(packet) = outgoing_rx.recv().await {
275 if let Err(e) = self.send_ip_packet(packet).await {
276 log::error!("Failed to send packet: {}", e);
277 }
278 }
279
280 Ok(())
281 }
282
283 pub async fn run_timer_loop(self: &Arc<Self>) -> Result<()> {
285 let mut interval = tokio::time::interval(Duration::from_millis(250));
286
287 loop {
288 interval.tick().await;
289
290 let packets_to_send: Vec<Vec<u8>> = {
291 let mut tunn = self.tunn.lock();
292 match tunn.update_timers() {
293 Ok(Some(packet)) => {
294 let pkt: Packet = packet.into();
295 vec![pkt.as_bytes().to_vec()]
296 }
297 Ok(None) => vec![],
298 Err(e) => {
299 log::trace!("Timer error (may be normal): {:?}", e);
300 vec![]
301 }
302 }
303 };
304
305 for packet in packets_to_send {
306 if let Err(e) = self.udp_socket.send_to(&packet, self.peer_endpoint).await {
307 log::error!("Failed to send timer packet: {}", e);
308 }
309 }
310 }
311 }
312
313 pub async fn wait_for_handshake(&self, timeout_duration: Duration) -> Result<()> {
315 let start = std::time::Instant::now();
316
317 loop {
318 {
319 let tunn = self.tunn.lock();
320 let (time_since_handshake, _tx_bytes, _rx_bytes, _, _) = tunn.stats();
322 if time_since_handshake.is_some() {
323 log::info!("WireGuard handshake completed!");
324 return Ok(());
325 }
326 }
327
328 if start.elapsed() > timeout_duration {
329 return Err(Error::HandshakeTimeout(timeout_duration));
330 }
331
332 tokio::time::sleep(Duration::from_millis(50)).await;
333 }
334 }
335}