wireguard_netstack/
wireguard.rs1use bytes::BytesMut;
7use gotatun::noise::{Tunn, TunnResult};
8use gotatun::packet::Packet;
9use gotatun::x25519::{PublicKey, StaticSecret};
10use parking_lot::Mutex;
11use zerocopy::IntoBytes;
12use std::net::{Ipv4Addr, SocketAddr};
13use std::sync::Arc;
14use std::time::Duration;
15use tokio::net::UdpSocket;
16use tokio::sync::mpsc;
17
18use crate::error::{Error, Result};
19
20#[derive(Clone)]
22pub struct WireGuardConfig {
23 pub private_key: [u8; 32],
25 pub peer_public_key: [u8; 32],
27 pub peer_endpoint: SocketAddr,
29 pub tunnel_ip: Ipv4Addr,
31 pub preshared_key: Option<[u8; 32]>,
33 pub keepalive_seconds: Option<u16>,
35}
36
37pub struct WireGuardTunnel {
39 tunn: Mutex<Tunn>,
41 udp_socket: Arc<UdpSocket>,
43 peer_endpoint: SocketAddr,
45 tunnel_ip: Ipv4Addr,
47 incoming_tx: mpsc::Sender<BytesMut>,
49 outgoing_rx: tokio::sync::Mutex<mpsc::Receiver<BytesMut>>,
51 incoming_rx: Mutex<Option<mpsc::Receiver<BytesMut>>>,
53 outgoing_tx: mpsc::Sender<BytesMut>,
55}
56
57impl WireGuardTunnel {
58 pub async fn new(config: WireGuardConfig) -> Result<Arc<Self>> {
60 let private_key = StaticSecret::from(config.private_key);
62 let peer_public_key = PublicKey::from(config.peer_public_key);
63
64 let tunn = Tunn::new(
66 private_key,
67 peer_public_key,
68 config.preshared_key,
69 config.keepalive_seconds,
70 rand::random::<u32>() >> 8, None, );
73
74 let udp_socket = UdpSocket::bind("0.0.0.0:0").await?;
76
77 let sock_ref = socket2::SockRef::from(&udp_socket);
79 if let Err(e) = sock_ref.set_recv_buffer_size(1024 * 1024) {
80 log::warn!("Failed to set UDP recv buffer size: {}", e);
82 }
83 if let Err(e) = sock_ref.set_send_buffer_size(1024 * 1024) {
84 log::warn!("Failed to set UDP send buffer size: {}", e);
86 }
87 log::info!("UDP recv buffer size: {:?}", sock_ref.recv_buffer_size());
88 log::info!("UDP send buffer size: {:?}", sock_ref.send_buffer_size());
89
90 log::info!(
91 "WireGuard UDP socket bound to {}",
92 udp_socket.local_addr()?
93 );
94
95 let (incoming_tx, incoming_rx) = mpsc::channel(256);
99 let (outgoing_tx, outgoing_rx) = mpsc::channel(256);
100
101 let tunnel = Arc::new(Self {
102 tunn: Mutex::new(tunn),
103 udp_socket: Arc::new(udp_socket),
104 peer_endpoint: config.peer_endpoint,
105 tunnel_ip: config.tunnel_ip,
106 incoming_tx,
107 incoming_rx: Mutex::new(Some(incoming_rx)),
108 outgoing_tx,
109 outgoing_rx: tokio::sync::Mutex::new(outgoing_rx),
110 });
111
112 Ok(tunnel)
113 }
114
115 pub fn tunnel_ip(&self) -> Ipv4Addr {
117 self.tunnel_ip
118 }
119
120 pub fn outgoing_sender(&self) -> mpsc::Sender<BytesMut> {
122 self.outgoing_tx.clone()
123 }
124
125 pub fn take_incoming_receiver(&self) -> Option<mpsc::Receiver<BytesMut>> {
127 self.incoming_rx.lock().take()
128 }
129
130 pub async fn initiate_handshake(&self) -> Result<()> {
132 log::info!("Initiating WireGuard handshake...");
133
134 let handshake_init = {
135 let mut tunn = self.tunn.lock();
136 tunn.format_handshake_initiation(false)
137 };
138
139 if let Some(packet) = handshake_init {
140 let data = packet.as_bytes();
142 self.udp_socket.send_to(data, self.peer_endpoint).await?;
143 log::debug!("Sent handshake initiation ({} bytes)", data.len());
144 }
145
146 Ok(())
147 }
148
149 pub async fn send_ip_packet(&self, packet: BytesMut) -> Result<()> {
151 let encrypted = {
152 let mut tunn = self.tunn.lock();
153 let pkt = Packet::from_bytes(packet);
154 tunn.handle_outgoing_packet(pkt)
155 };
156
157 if let Some(wg_packet) = encrypted {
158 let pkt: Packet = wg_packet.into();
160 let data = pkt.as_bytes();
161 self.udp_socket.send_to(data, self.peer_endpoint).await?;
162 log::trace!("Sent encrypted packet ({} bytes)", data.len());
163 }
164
165 Ok(())
166 }
167
168 fn process_incoming_udp(&self, data: &[u8]) -> Option<BytesMut> {
170 let packet = Packet::from_bytes(BytesMut::from(data));
171 let wg_packet = match packet.try_into_wg() {
172 Ok(wg) => wg,
173 Err(_) => {
174 log::warn!("Received non-WireGuard packet");
175 return None;
176 }
177 };
178
179 let mut tunn = self.tunn.lock();
180 match tunn.handle_incoming_packet(wg_packet) {
181 TunnResult::Done => {
182 log::trace!("WG: Packet processed (no output)");
183 None
184 }
185 TunnResult::Err(e) => {
186 log::warn!("WG error: {:?}", e);
187 None
188 }
189 TunnResult::WriteToNetwork(response) => {
190 log::trace!("WG: Sending response packet");
191 let pkt: Packet = response.into();
193 let data = BytesMut::from(pkt.as_bytes());
194 let socket = self.udp_socket.clone();
195 let endpoint = self.peer_endpoint;
196 tokio::spawn(async move {
197 if let Err(e) = socket.send_to(&data, endpoint).await {
198 log::error!("Failed to send response: {}", e);
199 }
200 });
201
202 while let Some(queued) = tunn.next_queued_packet() {
204 let pkt: Packet = queued.into();
205 let data = BytesMut::from(pkt.as_bytes());
206 let socket = self.udp_socket.clone();
207 let endpoint = self.peer_endpoint;
208 tokio::spawn(async move {
209 if let Err(e) = socket.send_to(&data, endpoint).await {
210 log::error!("Failed to send queued packet: {}", e);
211 }
212 });
213 }
214
215 None
216 }
217 TunnResult::WriteToTunnel(decrypted) => {
218 if decrypted.is_empty() {
219 log::trace!("WG: Received keepalive");
220 return None;
221 }
222 let bytes = BytesMut::from(decrypted.as_bytes());
223 log::trace!("WG: Decrypted {} bytes", bytes.len());
224 Some(bytes)
225 }
226 }
227 }
228
229 pub async fn run_receive_loop(self: &Arc<Self>) -> Result<()> {
231 let mut buf = vec![0u8; 65535];
232
233 loop {
234 match self.udp_socket.recv_from(&mut buf).await {
235 Ok((len, from)) => {
236 if from != self.peer_endpoint {
237 log::warn!("Received packet from unknown peer: {}", from);
238 continue;
239 }
240
241 log::trace!("Received UDP packet ({} bytes) from {}", len, from);
242
243 if let Some(ip_packet) = self.process_incoming_udp(&buf[..len]) {
244 if self.incoming_tx.send(ip_packet).await.is_err() {
245 log::error!("Incoming channel closed");
246 break;
247 }
248 }
249 }
250 Err(e) => {
251 log::error!("UDP receive error: {}", e);
252 break;
253 }
254 }
255 }
256
257 Ok(())
258 }
259
260 pub async fn run_send_loop(self: &Arc<Self>) -> Result<()> {
262 let mut outgoing_rx = self.outgoing_rx.lock().await;
263
264 while let Some(packet) = outgoing_rx.recv().await {
265 if let Err(e) = self.send_ip_packet(packet).await {
266 log::error!("Failed to send packet: {}", e);
267 }
268 }
269
270 Ok(())
271 }
272
273 pub async fn run_timer_loop(self: &Arc<Self>) -> Result<()> {
275 let mut interval = tokio::time::interval(Duration::from_millis(250));
276
277 loop {
278 interval.tick().await;
279
280 let packets_to_send: Vec<Vec<u8>> = {
281 let mut tunn = self.tunn.lock();
282 match tunn.update_timers() {
283 Ok(Some(packet)) => {
284 let pkt: Packet = packet.into();
285 vec![pkt.as_bytes().to_vec()]
286 }
287 Ok(None) => vec![],
288 Err(e) => {
289 log::trace!("Timer error (may be normal): {:?}", e);
290 vec![]
291 }
292 }
293 };
294
295 for packet in packets_to_send {
296 if let Err(e) = self.udp_socket.send_to(&packet, self.peer_endpoint).await {
297 log::error!("Failed to send timer packet: {}", e);
298 }
299 }
300 }
301 }
302
303 pub async fn wait_for_handshake(&self, timeout_duration: Duration) -> Result<()> {
305 let start = std::time::Instant::now();
306
307 loop {
308 {
309 let tunn = self.tunn.lock();
310 let (time_since_handshake, _tx_bytes, _rx_bytes, _, _) = tunn.stats();
312 if time_since_handshake.is_some() {
313 log::info!("WireGuard handshake completed!");
314 return Ok(());
315 }
316 }
317
318 if start.elapsed() > timeout_duration {
319 return Err(Error::HandshakeTimeout(timeout_duration));
320 }
321
322 tokio::time::sleep(Duration::from_millis(50)).await;
323 }
324 }
325}