1use crate::error::{Error, Result};
7use crate::wireguard::WireGuardTunnel;
8use bytes::BytesMut;
9use parking_lot::Mutex;
10use smoltcp::iface::{Config, Interface, PollResult, SocketHandle, SocketSet};
11use smoltcp::phy::{Device, DeviceCapabilities, Medium, RxToken, TxToken};
12use smoltcp::socket::tcp::{Socket as TcpSocket, SocketBuffer, State as TcpState};
13use smoltcp::time::Instant;
14use smoltcp::wire::{HardwareAddress, IpAddress, IpCidr, Ipv4Address, Ipv4Packet, TcpPacket};
15use std::collections::VecDeque;
16use std::net::{SocketAddr, SocketAddrV4};
17use std::sync::Arc;
18use std::time::Duration;
19use tokio::sync::mpsc;
20
21const MTU: usize = 460;
27
28const TCP_BUFFER_SIZE: usize = 65535;
30
31struct VirtualDevice {
33 rx_queue: VecDeque<BytesMut>,
35 tx_queue: VecDeque<BytesMut>,
37}
38
39impl VirtualDevice {
40 fn new() -> Self {
41 Self {
42 rx_queue: VecDeque::new(),
43 tx_queue: VecDeque::new(),
44 }
45 }
46
47 fn push_rx(&mut self, packet: BytesMut) {
49 self.rx_queue.push_back(packet);
50 }
51
52 fn drain_tx(&mut self) -> Vec<BytesMut> {
54 self.tx_queue.drain(..).collect()
55 }
56}
57
58struct VirtualRxToken {
60 buffer: BytesMut,
61}
62
63impl RxToken for VirtualRxToken {
64 fn consume<R, F>(self, f: F) -> R
65 where
66 F: FnOnce(&[u8]) -> R,
67 {
68 f(&self.buffer)
69 }
70}
71
72struct VirtualTxToken<'a> {
74 tx_queue: &'a mut VecDeque<BytesMut>,
75}
76
77impl<'a> TxToken for VirtualTxToken<'a> {
78 fn consume<R, F>(self, len: usize, f: F) -> R
79 where
80 F: FnOnce(&mut [u8]) -> R,
81 {
82 let mut buffer = BytesMut::zeroed(len);
83 let result = f(&mut buffer);
84 self.tx_queue.push_back(buffer);
85 result
86 }
87
88 fn set_meta(&mut self, _meta: smoltcp::phy::PacketMeta) {
89 }
91}
92
93impl Device for VirtualDevice {
94 type RxToken<'a> = VirtualRxToken;
95 type TxToken<'a> = VirtualTxToken<'a>;
96
97 fn receive(&mut self, _timestamp: Instant) -> Option<(Self::RxToken<'_>, Self::TxToken<'_>)> {
98 if let Some(buffer) = self.rx_queue.pop_front() {
99 Some((
100 VirtualRxToken { buffer },
101 VirtualTxToken {
102 tx_queue: &mut self.tx_queue,
103 },
104 ))
105 } else {
106 None
107 }
108 }
109
110 fn transmit(&mut self, _timestamp: Instant) -> Option<Self::TxToken<'_>> {
111 Some(VirtualTxToken {
112 tx_queue: &mut self.tx_queue,
113 })
114 }
115
116 fn capabilities(&self) -> DeviceCapabilities {
117 let mut caps = DeviceCapabilities::default();
118 caps.medium = Medium::Ip;
119 caps.max_transmission_unit = MTU;
120 caps
121 }
122}
123
124struct NetStackInner {
126 interface: Interface,
127 device: VirtualDevice,
128 sockets: SocketSet<'static>,
129}
130
131pub struct NetStack {
133 inner: Mutex<NetStackInner>,
134 wg_tunnel: Arc<WireGuardTunnel>,
135 wg_tx: mpsc::Sender<BytesMut>,
137}
138
139impl NetStack {
140 pub fn new(wg_tunnel: Arc<WireGuardTunnel>) -> Arc<Self> {
142 let tunnel_ip = wg_tunnel.tunnel_ip();
143 let wg_tx = wg_tunnel.outgoing_sender();
144
145 let mut device = VirtualDevice::new();
147
148 let config = Config::new(HardwareAddress::Ip);
150
151 let mut interface = Interface::new(config, &mut device, Instant::now());
153
154 interface.update_ip_addrs(|addrs| {
156 addrs
157 .push(IpCidr::new(
158 IpAddress::v4(
159 tunnel_ip.octets()[0],
160 tunnel_ip.octets()[1],
161 tunnel_ip.octets()[2],
162 tunnel_ip.octets()[3],
163 ),
164 32,
165 ))
166 .unwrap();
167 });
168
169 interface
171 .routes_mut()
172 .add_default_ipv4_route(Ipv4Address::new(0, 0, 0, 0))
173 .unwrap();
174
175 let sockets = SocketSet::new(vec![]);
177
178 let inner = NetStackInner {
179 interface,
180 device,
181 sockets,
182 };
183
184 Arc::new(Self {
185 inner: Mutex::new(inner),
186 wg_tunnel,
187 wg_tx,
188 })
189 }
190
191 pub fn create_tcp_socket(&self) -> SocketHandle {
193 let mut inner = self.inner.lock();
194
195 let rx_buffer = SocketBuffer::new(vec![0u8; TCP_BUFFER_SIZE]);
196 let tx_buffer = SocketBuffer::new(vec![0u8; TCP_BUFFER_SIZE]);
197 let socket = TcpSocket::new(rx_buffer, tx_buffer);
198
199 inner.sockets.add(socket)
200 }
201
202 pub fn connect(&self, handle: SocketHandle, addr: SocketAddr) -> Result<()> {
204 let mut inner = self.inner.lock();
205
206 let local_port = 49152 + (rand::random::<u16>() % 16384);
207 let local_addr = SocketAddrV4::new(self.wg_tunnel.tunnel_ip(), local_port);
208
209 let remote = match addr {
210 SocketAddr::V4(v4) => smoltcp::wire::IpEndpoint::new(
211 IpAddress::v4(
212 v4.ip().octets()[0],
213 v4.ip().octets()[1],
214 v4.ip().octets()[2],
215 v4.ip().octets()[3],
216 ),
217 v4.port(),
218 ),
219 SocketAddr::V6(_) => return Err(Error::Ipv6NotSupported),
220 };
221
222 let local = smoltcp::wire::IpEndpoint::new(
223 IpAddress::v4(
224 local_addr.ip().octets()[0],
225 local_addr.ip().octets()[1],
226 local_addr.ip().octets()[2],
227 local_addr.ip().octets()[3],
228 ),
229 local_addr.port(),
230 );
231
232 let NetStackInner {
234 ref mut interface,
235 ref mut sockets,
236 ..
237 } = *inner;
238 let cx = interface.context();
239 let socket = sockets.get_mut::<TcpSocket>(handle);
240 socket
241 .connect(cx, remote, local)
242 .map_err(|e| Error::TcpConnectGeneric(format!("TCP connect failed: {}", e)))?;
243
244 log::debug!("TCP socket connecting to {} from {}", addr, local_addr);
245
246 Ok(())
247 }
248
249 pub fn is_connected(&self, handle: SocketHandle) -> bool {
251 let inner = self.inner.lock();
252 let socket = inner.sockets.get::<TcpSocket>(handle);
253 socket.state() == TcpState::Established
254 }
255
256 pub fn can_send(&self, handle: SocketHandle) -> bool {
258 let inner = self.inner.lock();
259 let socket = inner.sockets.get::<TcpSocket>(handle);
260 socket.can_send()
261 }
262
263 pub fn can_recv(&self, handle: SocketHandle) -> bool {
265 let inner = self.inner.lock();
266 let socket = inner.sockets.get::<TcpSocket>(handle);
267 let can = socket.can_recv();
268 let recv_queue = socket.recv_queue();
269 if recv_queue > 0 {
270 log::debug!(
271 "Socket can_recv={}, recv_queue={}, state={:?}",
272 can,
273 recv_queue,
274 socket.state()
275 );
276 }
277 can
278 }
279
280 pub fn may_send(&self, handle: SocketHandle) -> bool {
282 let inner = self.inner.lock();
283 let socket = inner.sockets.get::<TcpSocket>(handle);
284 socket.may_send()
285 }
286
287 pub fn may_recv(&self, handle: SocketHandle) -> bool {
289 let inner = self.inner.lock();
290 let socket = inner.sockets.get::<TcpSocket>(handle);
291 socket.may_recv()
292 }
293
294 pub fn socket_state(&self, handle: SocketHandle) -> TcpState {
296 let inner = self.inner.lock();
297 let socket = inner.sockets.get::<TcpSocket>(handle);
298 socket.state()
299 }
300
301 pub fn send(&self, handle: SocketHandle, data: &[u8]) -> Result<usize> {
303 let mut inner = self.inner.lock();
304 let socket = inner.sockets.get_mut::<TcpSocket>(handle);
305
306 socket
307 .send_slice(data)
308 .map_err(|e| Error::TcpSend(e.to_string()))
309 }
310
311 pub fn recv(&self, handle: SocketHandle, buffer: &mut [u8]) -> Result<usize> {
313 let mut inner = self.inner.lock();
314 let socket = inner.sockets.get_mut::<TcpSocket>(handle);
315
316 socket
317 .recv_slice(buffer)
318 .map_err(|e| Error::TcpRecv(e.to_string()))
319 }
320
321 pub fn close(&self, handle: SocketHandle) {
323 let mut inner = self.inner.lock();
324 let socket = inner.sockets.get_mut::<TcpSocket>(handle);
325 socket.close();
326 }
327
328 pub fn remove_socket(&self, handle: SocketHandle) {
330 let mut inner = self.inner.lock();
331 inner.sockets.remove(handle);
332 }
333
334 pub fn poll(&self) -> bool {
337 let mut inner = self.inner.lock();
338
339 let timestamp = Instant::now();
340
341 let NetStackInner {
343 ref mut interface,
344 ref mut device,
345 ref mut sockets,
346 } = *inner;
347
348 let rx_queue_len = device.rx_queue.len();
350 if rx_queue_len > 0 {
351 log::trace!("NetStack poll: {} packets in rx_queue", rx_queue_len);
352 }
353
354 let poll_result = interface.poll(timestamp, device, sockets);
356 let processed = poll_result != PollResult::None;
357
358 if processed {
359 log::trace!("NetStack poll processed packets");
360 }
361
362 let tx_packets = device.drain_tx();
364 let tx_count = tx_packets.len();
365 drop(inner); if tx_count > 0 {
368 log::trace!("NetStack poll sending {} packets", tx_count);
369 }
370
371 for packet in tx_packets {
372 if log::log_enabled!(log::Level::Debug) {
374 if let Ok(ip_packet) = Ipv4Packet::new_checked(&packet) {
375 let protocol = ip_packet.next_header();
376 if protocol == smoltcp::wire::IpProtocol::Tcp {
377 if let Ok(tcp_packet) = TcpPacket::new_checked(ip_packet.payload()) {
378 let dst_port = tcp_packet.dst_port();
379 let payload_len = tcp_packet.payload().len();
380
381 let mut flags = String::new();
382 if tcp_packet.syn() {
383 flags.push_str("SYN ");
384 }
385 if tcp_packet.ack() {
386 flags.push_str("ACK ");
387 }
388 if tcp_packet.fin() {
389 flags.push_str("FIN ");
390 }
391 if tcp_packet.rst() {
392 flags.push_str("RST ");
393 }
394 if tcp_packet.psh() {
395 flags.push_str("PSH ");
396 }
397
398 log::debug!(
399 "TX: {}:{} [{}] {} bytes",
400 ip_packet.dst_addr(),
401 dst_port,
402 flags.trim(),
403 payload_len
404 );
405 }
406 }
407 }
408 }
409
410 let tx = self.wg_tx.clone();
411 tokio::spawn(async move {
412 if let Err(e) = tx.send(packet).await {
413 log::error!("Failed to queue packet for WireGuard: {}", e);
414 }
415 });
416 }
417
418 processed
419 }
420
421 pub fn push_rx_packet(&self, packet: BytesMut) {
423 if log::log_enabled!(log::Level::Debug) {
425 if let Ok(ip_packet) = Ipv4Packet::new_checked(&packet) {
426 let protocol = ip_packet.next_header();
427 if protocol == smoltcp::wire::IpProtocol::Tcp {
428 if let Ok(tcp_packet) = TcpPacket::new_checked(ip_packet.payload()) {
429 let src_port = tcp_packet.src_port();
430 let payload_len = tcp_packet.payload().len();
431
432 let mut flags = String::new();
433 if tcp_packet.syn() {
434 flags.push_str("SYN ");
435 }
436 if tcp_packet.ack() {
437 flags.push_str("ACK ");
438 }
439 if tcp_packet.fin() {
440 flags.push_str("FIN ");
441 }
442 if tcp_packet.rst() {
443 flags.push_str("RST ");
444 }
445 if tcp_packet.psh() {
446 flags.push_str("PSH ");
447 }
448
449 log::debug!(
450 "RX: {}:{} [{}] {} bytes",
451 ip_packet.src_addr(),
452 src_port,
453 flags.trim(),
454 payload_len
455 );
456 }
457 }
458 }
459 }
460
461 let mut inner = self.inner.lock();
462 inner.device.push_rx(packet);
463 }
464
465 pub async fn run_poll_loop(self: &Arc<Self>) -> Result<()> {
467 let mut interval = tokio::time::interval(Duration::from_millis(1));
468
469 loop {
470 interval.tick().await;
471 self.poll();
472 }
473 }
474
475 pub async fn run_rx_loop(self: &Arc<Self>, mut rx: mpsc::Receiver<BytesMut>) -> Result<()> {
477 while let Some(packet) = rx.recv().await {
478 log::debug!("NetStack received packet ({} bytes)", packet.len());
479 self.push_rx_packet(packet);
480 self.poll();
481 }
482
483 Ok(())
484 }
485}
486
487pub struct TcpConnection {
489 pub netstack: Arc<NetStack>,
491 pub handle: SocketHandle,
493}
494
495impl TcpConnection {
496 pub async fn connect(netstack: Arc<NetStack>, addr: SocketAddr) -> Result<Self> {
498 let handle = netstack.create_tcp_socket();
499 netstack.connect(handle, addr)?;
500
501 let start = std::time::Instant::now();
503 let timeout = Duration::from_secs(30);
504
505 loop {
506 netstack.poll();
507
508 let state = netstack.socket_state(handle);
509 log::trace!("TCP state: {:?}", state);
510
511 if state == TcpState::Established {
512 log::info!("TCP connection established to {}", addr);
513 return Ok(Self { netstack, handle });
514 }
515
516 if state == TcpState::Closed || state == TcpState::TimeWait {
517 netstack.remove_socket(handle);
518 return Err(Error::TcpConnect {
519 addr,
520 message: format!("Connection failed (state: {:?})", state),
521 });
522 }
523
524 if start.elapsed() > timeout {
525 netstack.remove_socket(handle);
526 return Err(Error::TcpTimeout);
527 }
528
529 tokio::time::sleep(Duration::from_millis(1)).await;
530 }
531 }
532
533 pub async fn read(&self, buf: &mut [u8]) -> Result<usize> {
535 let timeout = Duration::from_secs(30);
536 let start = std::time::Instant::now();
537
538 loop {
539 self.netstack.poll();
540
541 if self.netstack.can_recv(self.handle) {
542 match self.netstack.recv(self.handle, buf) {
543 Ok(n) if n > 0 => return Ok(n),
544 Ok(_) => {}
545 Err(e) => return Err(e),
546 }
547 }
548
549 if !self.netstack.may_recv(self.handle) {
550 return Ok(0);
552 }
553
554 if start.elapsed() > timeout {
555 return Err(Error::ReadTimeout);
556 }
557
558 tokio::time::sleep(Duration::from_millis(1)).await;
559 }
560 }
561
562 pub async fn write(&self, data: &[u8]) -> Result<usize> {
564 let timeout = Duration::from_secs(30);
565 let start = std::time::Instant::now();
566
567 let mut written = 0;
568
569 while written < data.len() {
570 self.netstack.poll();
571
572 if self.netstack.can_send(self.handle) {
573 match self.netstack.send(self.handle, &data[written..]) {
574 Ok(n) => {
575 written += n;
576 log::trace!("Wrote {} bytes (total: {})", n, written);
577 }
578 Err(e) => return Err(e),
579 }
580 }
581
582 if !self.netstack.may_send(self.handle) {
583 return Err(Error::ConnectionClosed);
585 }
586
587 if start.elapsed() > timeout {
588 return Err(Error::WriteTimeout);
589 }
590
591 if written < data.len() {
592 tokio::time::sleep(Duration::from_millis(1)).await;
593 }
594 }
595
596 self.netstack.poll();
597 Ok(written)
598 }
599
600 pub async fn write_all(&self, data: &[u8]) -> Result<()> {
602 let n = self.write(data).await?;
603 if n != data.len() {
604 return Err(Error::ShortWrite {
605 written: n,
606 expected: data.len(),
607 });
608 }
609 Ok(())
610 }
611
612 pub fn shutdown(&self) {
614 self.netstack.close(self.handle);
615 }
616
617 pub fn handle(&self) -> SocketHandle {
619 self.handle
620 }
621}
622
623impl Drop for TcpConnection {
624 fn drop(&mut self) {
625 self.netstack.close(self.handle);
626 self.netstack.poll();
628 }
629}