1use std::cell::RefCell;
2use std::collections::{HashMap, HashSet};
3use std::convert::{TryFrom, TryInto};
4use std::path::PathBuf;
5use std::rc::Rc;
6use std::time::Duration;
7
8use futures::channel::mpsc;
9use futures::future::{Either, LocalBoxFuture};
10use futures::{Future, FutureExt, SinkExt, StreamExt, TryFutureExt};
11use smoltcp::iface::SocketHandle;
12use smoltcp::wire::IpEndpoint;
13use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};
14use tokio::task::spawn_local;
15use tokio::time::MissedTickBehavior;
16
17use crate::connection::{Connection, ConnectionMeta};
18use crate::packet::{
19 ip_ntoh, ArpField, ArpPacket, EtherFrame, IpPacket, PeekPacket, TcpPacket, UdpPacket,
20};
21use crate::protocol::Protocol;
22use crate::socket::{SocketDesc, SocketEndpoint, SocketExt, SocketMemory, SocketState};
23use crate::stack::Stack;
24use crate::{ChannelMetrics, Error, Result};
25
26use ya_relay_util::Payload;
27
28pub const PCAP_FILE_ENV_VAR: &str = "YA_NET_PCAP_FILE";
29pub const STACK_POLL_MS_ENV_VAR: &str = "YA_NET_STACK_POLL_MS";
30pub const STACK_POLL_SENT_ENV_VAR: &str = "YA_NET_STACK_POLL_SENT_BATCH";
31pub const STACK_POLL_RECV_ENV_VAR: &str = "YA_NET_STACK_POLL_RECV_BATCH";
32
33const DEFAULT_POLL_SENT_BATCH: usize = 16348;
34const DEFAULT_POLL_RECV_BATCH: usize = 32768;
35const MIN_STACK_POLL_SENT_BATCH: usize = 2048;
36const MIN_STACK_POLL_RECV_BATCH: usize = 4096;
37
38pub type IngressReceiver = UnboundedReceiver<IngressEvent>;
39pub type EgressReceiver = UnboundedReceiver<EgressEvent>;
40
41#[derive(Clone)]
42pub struct StackConfig {
43 pub pcap_path: Option<PathBuf>,
44 pub max_transmission_unit: usize,
45 pub max_send_batch: usize,
46 pub max_recv_batch: usize,
47 pub tcp_mem: SocketMemory,
48 pub udp_mem: SocketMemory,
49 pub icmp_mem: SocketMemory,
50 pub raw_mem: SocketMemory,
51}
52
53impl Default for StackConfig {
54 fn default() -> Self {
55 let max_send_batch = std::env::var(STACK_POLL_SENT_ENV_VAR)
56 .and_then(|s| {
57 s.parse::<usize>()
58 .map_err(|_| std::env::VarError::NotPresent)
59 })
60 .unwrap_or(DEFAULT_POLL_SENT_BATCH)
61 .max(MIN_STACK_POLL_SENT_BATCH);
62
63 let max_recv_batch = std::env::var(STACK_POLL_RECV_ENV_VAR)
64 .and_then(|s| {
65 s.parse::<usize>()
66 .map_err(|_| std::env::VarError::NotPresent)
67 })
68 .unwrap_or(DEFAULT_POLL_RECV_BATCH)
69 .max(MIN_STACK_POLL_RECV_BATCH);
70
71 Self {
72 pcap_path: std::env::var(PCAP_FILE_ENV_VAR).ok().map(PathBuf::from),
73 max_transmission_unit: 1400,
74 max_send_batch,
75 max_recv_batch,
76 tcp_mem: SocketMemory::default_tcp(),
77 udp_mem: SocketMemory::default_udp(),
78 icmp_mem: SocketMemory::default_icmp(),
79 raw_mem: SocketMemory::default_raw(),
80 }
81 }
82}
83
84#[derive(Clone)]
85pub struct Network {
86 pub name: Rc<String>,
87 pub config: Rc<StackConfig>,
88 pub stack: Stack<'static>,
89 is_tun: bool,
90 sender: StackSender,
91 poller: StackPoller,
92 pub bindings: Rc<RefCell<HashSet<SocketHandle>>>,
95 pub connections: Rc<RefCell<HashMap<ConnectionMeta, Connection>>>,
96 pub handles: Rc<RefCell<HashMap<SocketHandle, ConnectionMeta>>>,
97 ingress: Channel<IngressEvent>,
98 egress: Channel<EgressEvent>,
99}
100
101impl Network {
102 pub fn new(name: impl ToString, config: Rc<StackConfig>, stack: Stack<'static>) -> Self {
104 let is_tun = {
105 let iface_rfc = stack.iface();
106 let iface = iface_rfc.borrow();
107 iface.device().is_tun()
108 };
109
110 let network = Self {
111 name: Rc::new(name.to_string()),
112 config,
113 stack,
114 is_tun,
115 sender: Default::default(),
116 poller: Default::default(),
117 bindings: Default::default(),
118 connections: Default::default(),
119 handles: Default::default(),
120 ingress: Default::default(),
121 egress: Default::default(),
122 };
123
124 network.sender.net.borrow_mut().replace(network.clone());
125 network.poller.net.borrow_mut().replace(network.clone());
126 network
127 }
128
129 pub fn get_bound(
132 &self,
133 protocol: Protocol,
134 local_endpoint: impl Into<SocketEndpoint>,
135 ) -> Option<SocketHandle> {
136 let endpoint = local_endpoint.into();
137 let iface_rfc = self.stack.iface();
138 let iface = iface_rfc.borrow();
139 let mut sockets = iface.sockets();
140 sockets
141 .find(|(handle, s)| {
142 s.protocol() == protocol
143 && s.local_endpoint() == endpoint
144 && self.bindings.borrow().contains(handle)
147 })
148 .map(|(h, _)| h)
149 }
150
151 pub fn bind(
153 &self,
154 protocol: Protocol,
155 endpoint: impl Into<SocketEndpoint>,
156 ) -> Result<SocketHandle> {
157 let endpoint = endpoint.into();
158 let handle = self.stack.bind(protocol, endpoint)?;
159 self.bindings.borrow_mut().insert(handle);
160 Ok(handle)
161 }
162
163 pub fn unbind(&self, protocol: Protocol, endpoint: impl Into<SocketEndpoint>) -> Result<()> {
165 let endpoint = endpoint.into();
166 let handle = self.stack.unbind(protocol, endpoint)?;
167 self.bindings.borrow_mut().remove(&handle);
168 Ok(())
169 }
170
171 pub fn connect(
173 &self,
174 remote: impl Into<IpEndpoint>,
175 timeout: impl Into<Duration>,
176 ) -> LocalBoxFuture<Result<Connection>> {
177 let remote = remote.into();
178 let timeout = timeout.into();
179
180 let connect = match self.stack.connect(remote) {
181 Ok(fut) => fut,
182 Err(err) => return futures::future::err(err).boxed_local(),
183 };
184 self.poll();
185
186 let connections = self.connections.clone();
187 let handles = self.handles.clone();
188
189 async move {
190 let connection = match tokio::time::timeout(timeout, connect).await {
191 Ok(Ok(conn)) => conn,
192 Ok(Err(error)) => return Err(error),
193 _ => return Err(Error::ConnectionTimeout),
194 };
195 Self::add_connection_to(connection, &connections, &handles);
196 Ok(connection)
197 }
198 .boxed_local()
199 }
200
201 pub fn disconnect_all(
203 &self,
204 remote_ip: Box<[u8]>,
205 timeout: impl Into<Duration>,
206 ) -> LocalBoxFuture<()> {
207 let (handles, futs): (Vec<_>, Vec<_>) = {
208 let connections = self.connections.borrow();
209 connections
210 .values()
211 .filter(|conn| {
212 conn.meta.remote.addr.as_bytes() == remote_ip.as_ref()
213 && conn.meta.protocol == Protocol::Tcp
214 })
215 .map(|conn| (conn.handle, self.stack.disconnect(conn.handle)))
216 .unzip()
217 };
218
219 if futs.is_empty() {
220 return futures::future::ready(()).boxed_local();
221 }
222
223 self.poll();
224
225 let timeout = timeout.into();
226 let net = self.clone();
227
228 async move {
229 let pending = futures::future::join_all(futs);
230 let timeout = tokio::time::sleep(timeout).boxed_local();
231
232 if let Either::Right((_, pending)) = futures::future::select(pending, timeout).await {
233 handles.into_iter().for_each(|h| net.stack.abort(h));
234 net.poll();
235
236 let timeout = tokio::time::sleep(Duration::from_millis(500));
237 let _ = futures::future::select(pending, timeout.boxed_local()).await;
238 }
239 }
240 .boxed_local()
241 }
242
243 pub fn bindings(&self) -> core::cell::Ref<'_, HashSet<SocketHandle>> {
244 self.bindings.borrow()
245 }
246
247 pub fn handles(&self) -> core::cell::Ref<'_, HashMap<SocketHandle, ConnectionMeta>> {
248 self.handles.borrow()
249 }
250
251 pub fn connections(&self) -> core::cell::Ref<'_, HashMap<ConnectionMeta, Connection>> {
252 self.connections.borrow()
253 }
254
255 pub fn sockets(&self) -> Vec<(SocketDesc, SocketState<ChannelMetrics>)> {
256 let iface_rfc = self.stack.iface();
257 let iface = iface_rfc.borrow();
258 let metrics_rfc = self.stack.metrics();
259 let metrics = metrics_rfc.borrow();
260
261 iface
262 .sockets()
263 .map(|(_, s)| {
264 let desc = s.desc();
265 let metrics = metrics.get(&desc).cloned().unwrap_or_default();
266 let mut state = s.state();
267 state.set_inner(metrics);
268 (desc, state)
269 })
270 .collect()
271 }
272
273 pub fn sockets_meta(&self) -> Vec<(SocketHandle, SocketDesc, SocketState<ChannelMetrics>)> {
274 let iface_rfc = self.stack.iface();
275 let iface = iface_rfc.borrow();
276 let connections = self.handles.borrow();
277
278 iface
279 .sockets()
280 .map(|(handle, s)| {
281 (
282 handle,
283 connections
284 .get(&handle)
285 .cloned()
286 .map(|meta| meta.into())
287 .unwrap_or(s.desc()),
288 s.state(),
289 )
290 })
291 .collect()
292 }
293
294 pub fn metrics(&self) -> ChannelMetrics {
295 let iface_rfc = self.stack.iface();
296 let iface = iface_rfc.borrow();
297 iface.device().metrics()
298 }
299
300 #[inline(always)]
301 fn is_connected(&self, meta: &ConnectionMeta) -> bool {
302 self.connections.borrow().contains_key(meta)
303 }
304
305 #[inline(always)]
306 fn add_connection(&self, connection: Connection) {
307 Self::add_connection_to(connection, &self.connections, &self.handles);
308 }
309
310 fn add_connection_to(
311 connection: Connection,
312 connections: &Rc<RefCell<HashMap<ConnectionMeta, Connection>>>,
313 handles: &Rc<RefCell<HashMap<SocketHandle, ConnectionMeta>>>,
314 ) {
315 let handle = connection.handle;
316 let meta = connection.into();
317 connections.borrow_mut().insert(meta, connection);
318 handles.borrow_mut().insert(handle, meta);
319 }
320
321 #[inline(always)]
322 fn remove_connection(&self, meta: &ConnectionMeta, handle: SocketHandle) {
323 self.stack.remove(meta, handle);
324 self.handles.borrow_mut().remove(&handle);
325 self.sender.remove(&handle);
326
327 let ip_endpoint = smoltcp::wire::IpListenEndpoint::from(meta.remote);
328 if !ip_endpoint.is_specified() {
329 return;
330 }
331 self.connections.borrow_mut().remove(meta);
332 }
333
334 #[inline(always)]
336 pub fn send<'a>(
337 &self,
338 data: impl Into<Payload>,
339 connection: Connection,
340 ) -> impl Future<Output = Result<()>> + 'a {
341 self.sender.send(data.into(), connection)
342 }
343
344 #[inline(always)]
346 pub fn receive(&self, data: impl Into<Payload>) {
347 self.stack.receive(data)
348 }
349
350 pub fn spawn_local(&self) {
351 let interval = std::env::var(STACK_POLL_MS_ENV_VAR)
352 .and_then(|s| s.parse::<u64>().map_err(|_| std::env::VarError::NotPresent))
353 .and_then(|v| match v {
354 0 => Err(std::env::VarError::NotPresent),
355 v => Ok(v),
356 })
357 .unwrap_or(250);
358 self.poller.clone().spawn(Duration::from_millis(interval));
359 }
360
361 pub fn poll(&self) {
363 loop {
364 let finished = match (self.stack.poll(), self.is_tun) {
365 (true, _) | (_, false) => self.process_ingress() && self.process_egress(),
366 (false, _) => true,
367 };
368 if finished {
369 break;
370 }
371 }
372 }
373
374 #[inline(always)]
376 pub fn ingress_receiver(&self) -> Option<IngressReceiver> {
377 self.ingress.receiver()
378 }
379
380 #[inline(always)]
382 pub fn egress_receiver(&self) -> Option<EgressReceiver> {
383 self.egress.receiver()
384 }
385
386 fn process_ingress(&self) -> bool {
387 let mut finished = true;
388
389 let iface_rfc = self.stack.iface();
390 let mut iface = iface_rfc.borrow_mut();
391 let mut bindings = self.bindings.borrow_mut();
392 let mut events = Vec::new();
393 let mut remove = Vec::new();
394 let mut rebind = None;
395
396 for (handle, socket) in iface.sockets_mut() {
397 let mut desc = socket.desc();
398
399 if socket.is_closed() {
403 let meta = self
404 .handles
405 .borrow()
406 .get(&handle)
407 .copied()
408 .or(desc.try_into().ok());
409
410 if let Some(meta) = meta {
411 log::debug!("{}: closing socket [{handle}]: {desc} / {meta}", self.name);
413 events.push(IngressEvent::Disconnected { desc: meta.into() });
414 } else {
415 log::debug!("Removing socket {handle} with reset metadata");
419 }
420
421 remove.push((
422 meta.unwrap_or(ConnectionMeta::unspecified(desc.protocol)),
423 handle,
424 ));
425 }
426
427 let mut received = 0;
428
429 while socket.can_recv() {
430 let (remote, payload) = match socket.recv() {
431 Ok(Some(tuple)) => tuple,
432 Ok(None) => break,
433 Err(err) => {
434 log::debug!("{}: ingress packet error: {err}", self.name);
435 continue;
436 }
437 };
438
439 let len = payload.len();
440
441 received += len;
442 desc.remote = remote.into();
443
444 if let Ok(meta) = desc.try_into() {
445 if !self.is_connected(&meta) {
446 self.add_connection(Connection { handle, meta });
447 events.push(IngressEvent::InboundConnection { desc: meta.into() });
448 }
449 }
450
451 log::trace!("{}: ingress {len} B packet", self.name);
452
453 self.stack.on_received(&desc, len);
454 events.push(IngressEvent::Packet { desc, payload });
455
456 if received >= self.config.max_recv_batch {
457 finished = false;
458 break;
459 }
460 }
461
462 if bindings.contains(&handle) && socket.remote_endpoint().is_specified() {
463 bindings.remove(&handle);
464 rebind = Some((socket.protocol(), socket.local_endpoint()));
465
466 finished = false;
467 break;
468 }
469 }
470
471 drop(bindings);
472 drop(iface);
473
474 remove.into_iter().for_each(|(meta, handle)| {
475 self.remove_connection(&meta, handle);
476 });
477
478 if !events.is_empty() {
479 let ingress_tx = self.ingress.tx.clone();
480 for event in events {
481 if ingress_tx.send(event).is_err() {
482 log::debug!(
483 "{}: ingress channel closed, unable to receive packets",
484 self.name
485 );
486 break;
487 }
488 }
489 }
490
491 if let Some((p, ep)) = rebind {
492 if let Err(e) = self.bind(p, ep) {
493 log::warn!("{}: cannot bind socket {p} {ep:?}: {e}", self.name);
494 }
495 let _ = self.stack.poll();
496 return self.process_ingress();
497 }
498
499 finished
500 }
501
502 fn process_egress(&self) -> bool {
503 let mut sent = 0;
504 let mut finished = true;
505
506 let iface_rfc = self.stack.iface();
507 let mut iface = iface_rfc.borrow_mut();
508 let device = iface.device_mut();
509 let is_tun = device.is_tun();
510
511 while let Some(data) = device.next_phy_tx() {
512 match {
513 if is_tun {
514 EgressEvent::from_ip_packet(data)
515 } else {
516 EgressEvent::from_eth_frame(data)
517 }
518 } {
519 Ok(event) => {
520 sent += event.payload.len();
521
522 if let Some((desc, size)) = event.desc.as_ref() {
523 self.stack.on_sent(desc, *size);
524 }
525
526 if self.egress.tx.send(event).is_err() {
527 log::trace!(
528 "{}: egress channel closed, unable to send packets",
529 *self.name
530 );
531 break;
532 }
533 }
534 Err(err) => log::trace!("{}: egress packet error: {}", *self.name, err),
535 }
536
537 if sent >= self.config.max_send_batch {
538 finished = false;
539 continue;
540 }
541 }
542
543 finished
544 }
545}
546
547#[derive(Clone, Debug)]
548pub enum IngressEvent {
549 InboundConnection { desc: SocketDesc },
551 Disconnected { desc: SocketDesc },
553 Packet { desc: SocketDesc, payload: Vec<u8> },
555}
556
557#[derive(Clone, Debug)]
558pub struct EgressEvent {
559 pub remote: Box<[u8]>,
560 pub payload: Box<[u8]>,
561 pub desc: Option<(SocketDesc, usize)>,
562}
563
564impl EgressEvent {
565 pub fn from_eth_frame(data: Vec<u8>) -> Result<Self> {
566 let frame = EtherFrame::try_from(data)?;
567 let (desc, remote) = match &frame {
568 EtherFrame::Ip(_) => {
569 let data = frame.payload();
570 IpPacket::peek(data)?;
571
572 let packet = IpPacket::packet(data);
573 let remote = packet.dst_address().into();
574 let desc = Self::payload_desc(&packet);
575 (desc, remote)
576 }
577 EtherFrame::Arp(_) => {
578 let packet = ArpPacket::packet(frame.payload());
579 let remote = packet.get_field(ArpField::TPA).into();
580 (None, remote)
581 }
582 };
583
584 Ok(EgressEvent {
585 remote,
586 payload: frame.into(),
587 desc,
588 })
589 }
590
591 pub fn from_ip_packet(data: Vec<u8>) -> Result<Self> {
592 let (desc, remote) = {
593 IpPacket::peek(&data)?;
594 let packet = IpPacket::packet(&data);
595 let remote = packet.dst_address().into();
596 let desc = Self::payload_desc(&packet);
597
598 (desc, remote)
599 };
600
601 Ok(EgressEvent {
602 remote,
603 payload: data.into_boxed_slice(),
604 desc,
605 })
606 }
607
608 fn payload_desc(packet: &IpPacket) -> Option<(SocketDesc, usize)> {
609 let protocol = Protocol::try_from(packet.protocol()).ok()?;
610
611 let (local_port, remote_port, size) = match protocol {
612 Protocol::Tcp => {
613 TcpPacket::peek(packet.payload()).ok()?;
614 let tcp = TcpPacket::packet(packet.payload());
615 (tcp.src_port(), tcp.dst_port(), tcp.payload_size)
616 }
617 Protocol::Udp => {
618 UdpPacket::peek(packet.payload()).ok()?;
619 let udp = UdpPacket::packet(packet.payload());
620 (udp.src_port(), udp.dst_port(), udp.payload_size)
621 }
622 _ => return None,
623 };
624
625 let local_ip = ip_ntoh(packet.src_address())?;
626 let remote_ip = ip_ntoh(packet.dst_address())?;
627
628 let desc = SocketDesc {
629 protocol,
630 local: (local_ip, local_port).into(),
631 remote: (remote_ip, remote_port).into(),
632 };
633
634 Some((desc, size))
635 }
636}
637
638#[derive(Clone, Default)]
639struct StackSender {
640 inner: Rc<RefCell<StackSenderInner>>,
641 net: Rc<RefCell<Option<Network>>>,
642}
643
644impl StackSender {
645 #[inline]
646 pub fn send<'a>(
647 &self,
648 data: Payload,
649 conn: Connection,
650 ) -> impl Future<Output = Result<()>> + 'a {
651 let mut sender = {
652 match {
653 let inner = self.inner.borrow();
654 inner.map.get(&conn.handle).cloned()
655 } {
656 Some(sender) => sender,
657 None => self.spawn(conn.handle),
658 }
659 };
660 async move { sender.send((data, conn)).map_err(Error::from).await }
661 }
662
663 fn spawn(&self, handle: SocketHandle) -> mpsc::Sender<(Payload, Connection)> {
664 let net = self.net.borrow().clone().expect("Network not initialized");
665 let (tx, rx) = mpsc::channel(1);
666
667 spawn_local(async move {
668 rx.for_each(|(vec, conn)| {
669 let net = net.clone();
670 let stack = net.stack.clone();
671 async move {
672 let _ = stack.send(vec, conn, move || net.poll()).await;
673 }
674 })
675 .await;
676 });
677
678 let mut inner = self.inner.borrow_mut();
679 inner.map.insert(handle, tx.clone());
680
681 tx
682 }
683
684 pub fn remove(&self, handle: &SocketHandle) {
685 let mut inner = self.inner.borrow_mut();
686 if let Some(mut tx) = inner.map.remove(handle) {
687 spawn_local(async move {
688 let _ = tx.close().await;
689 });
690 }
691 }
692}
693
694#[derive(Default)]
695struct StackSenderInner {
696 map: HashMap<SocketHandle, mpsc::Sender<(Payload, Connection)>>,
697}
698
699#[derive(Clone, Default)]
700struct StackPoller {
701 net: Rc<RefCell<Option<Network>>>,
702}
703
704impl StackPoller {
705 pub fn spawn(&self, interval: Duration) {
706 let poller = self.clone();
707 spawn_local(async move {
708 let mut interval = tokio::time::interval(interval);
709 interval.set_missed_tick_behavior(MissedTickBehavior::Skip);
710 loop {
711 interval.tick().await;
712 poller.net.borrow().as_ref().unwrap().poll();
713 }
714 });
715 }
716}
717
718#[derive(Clone)]
719pub struct Channel<T> {
720 pub tx: UnboundedSender<T>,
721 rx: Rc<RefCell<Option<UnboundedReceiver<T>>>>,
722}
723
724impl<T> Channel<T> {
725 pub fn receiver(&self) -> Option<UnboundedReceiver<T>> {
726 self.rx.borrow_mut().take()
727 }
728}
729
730impl<T> Default for Channel<T> {
731 fn default() -> Self {
732 let (tx, rx) = unbounded_channel();
733 Self {
734 tx,
735 rx: Rc::new(RefCell::new(Some(rx))),
736 }
737 }
738}
739
740#[cfg(test)]
741mod tests {
742 use std::fmt::Debug;
743 use std::rc::Rc;
744 use std::time::Duration;
745
746 use futures::channel::{mpsc, oneshot};
747 use futures::{Sink, SinkExt, Stream, StreamExt};
748 use sha3::Digest;
749 use smoltcp::iface::Route;
750 use smoltcp::phy::Medium;
751 use smoltcp::wire::{IpAddress, IpCidr, Ipv4Address};
752 use tokio::task::spawn_local;
753 use tokio_stream::wrappers::UnboundedReceiverStream;
754
755 use crate::interface::{add_iface_address, add_iface_route, ip_to_mac, tap_iface, tun_iface};
756 use crate::{Connection, EgressEvent, IngressEvent, Network, Protocol, Stack, StackConfig};
757
758 const EXCHANGE_TIMEOUT: Duration = Duration::from_secs(30);
759
760 fn new_network(medium: Medium, ip: IpAddress, config: StackConfig) -> Network {
761 let config = Rc::new(config);
762 let cidr = IpCidr::new(ip, 16);
763 let route = match ip {
764 IpAddress::Ipv4(ipv4) => Route::new_ipv4_gateway(ipv4),
765 IpAddress::Ipv6(ipv6) => Route::new_ipv6_gateway(ipv6),
766 };
767
768 let mut iface = match medium {
769 Medium::Ethernet => tap_iface(ip_to_mac(ip), config.max_transmission_unit),
770 Medium::Ip => tun_iface(config.max_transmission_unit),
771 _ => panic!("unsupported medium: {:?}", medium),
772 };
773
774 add_iface_address(&mut iface, cidr);
775 add_iface_route(&mut iface, cidr, route);
776 Network::new(
777 format!("[{:?}] {}", medium, ip),
778 config.clone(),
779 Stack::new(iface, config),
780 )
781 }
782
783 fn produce_data<S, E>(
784 mut tx: S,
785 total: usize,
786 chunk_size: usize,
787 ) -> oneshot::Receiver<anyhow::Result<(S, Vec<u8>)>>
788 where
789 S: Sink<Vec<u8>, Error = E> + Unpin + 'static,
790 E: Into<anyhow::Error>,
791 {
792 let (dtx, drx) = oneshot::channel();
793
794 spawn_local(async move {
795 let mut digest = sha3::Sha3_224::new();
796 let mut sent = 0;
797 let mut err = None;
798
799 while sent < total {
800 let vec: Vec<u8> = (0..chunk_size.min(total - sent))
801 .map(|_| rand::random())
802 .collect();
803
804 digest.input(&vec);
805 sent += vec.len();
806
807 if let Err(e) = tx.send(vec).await {
808 err = Some(e);
809 break;
810 }
811 }
812
813 println!("Produced {} B", sent);
814 match err {
815 Some(e) => dtx.send(Err(e.into())),
816 None => dtx.send(Ok((tx, digest.result().to_vec()))),
817 }
818 });
819
820 drx
821 }
822
823 fn consume_data(
824 mut rx: mpsc::Receiver<Vec<u8>>,
825 total: usize,
826 ) -> oneshot::Receiver<anyhow::Result<Vec<u8>>> {
827 let (dtx, drx) = oneshot::channel();
828
829 spawn_local(async move {
830 let mut digest = sha3::Sha3_224::new();
831 let mut read: usize = 0;
832
833 while let Some(vec) = rx.next().await {
834 let len = vec.len();
835
836 read += len;
837 digest.input(&vec);
838
839 if read >= total {
840 break;
841 }
842 }
843
844 println!("Consumed {} B", read);
845 let _ = dtx.send(Ok(digest.result().to_vec()));
846 });
847
848 drx
849 }
850
851 fn net_inject<S>(rx: S, net: Network)
852 where
853 S: Stream<Item = Vec<u8>> + 'static,
854 {
855 spawn_local(async move {
856 rx.for_each(|vec| {
857 let net = net.clone();
858 async move {
859 net.receive(vec);
860 net.poll();
861 }
862 })
863 .await;
864 });
865 }
866
867 fn net_inject2<S>(rx: S, net1: Network, net2: Network)
868 where
869 S: Stream<Item = EgressEvent> + 'static,
870 {
871 let ip1 = net1
872 .stack
873 .address()
874 .unwrap()
875 .address()
876 .as_bytes()
877 .to_vec()
878 .into_boxed_slice();
879
880 spawn_local(async move {
881 rx.for_each(|event| {
882 let net = if event.remote == ip1 {
883 net1.clone()
884 } else {
885 net2.clone()
886 };
887 async move {
888 net.receive(event.payload);
889 net.poll();
890 }
891 })
892 .await;
893 });
894 }
895
896 fn net_send<S>(rx: S, net: Network, conn: Connection)
897 where
898 S: Stream<Item = Vec<u8>> + 'static,
899 {
900 spawn_local(async move {
901 let net = net.clone();
902 rx.for_each(|vec| async {
903 let _ = net
904 .send(vec, conn)
905 .await
906 .map_err(|e| eprintln!("failed to send packet: {}", e));
907 })
908 .await;
909 });
910 }
911
912 fn net_receive<Si, St, E>(tx: Si, rx: St)
913 where
914 Si: Sink<Vec<u8>, Error = E> + Clone + Unpin + 'static,
915 St: Stream<Item = IngressEvent> + 'static,
916 E: Into<anyhow::Error> + Debug,
917 {
918 spawn_local(async move {
919 rx.for_each(move |event| {
920 let mut tx = tx.clone();
921 async move {
922 match event {
923 IngressEvent::Packet { payload, .. } => {
924 if let Err(e) = tx.send(payload).await {
925 eprintln!("net send error: {:?}", e);
926 }
927 }
928 IngressEvent::Disconnected { desc } => {
929 println!("disconnected: {:?}", desc);
930 }
931 IngressEvent::InboundConnection { desc } => {
932 println!("inbound connection: {:?}", desc);
933 }
934 }
935 }
936 })
937 .await;
938 });
939 }
940
941 async fn net_exchange(medium: Medium, total: usize, chunk_size: usize) -> anyhow::Result<()> {
943 const MTU: usize = 65535;
944
945 println!(">> exchanging {} B in {} B chunks", total, chunk_size);
946
947 let ip1 = Ipv4Address::new(10, 0, 0, 1);
948 let ip2 = Ipv4Address::new(10, 0, 0, 2);
949
950 let config = StackConfig {
951 max_transmission_unit: MTU,
952 ..Default::default()
953 };
954
955 let net1 = new_network(medium, ip1.into(), config.clone());
956 let net2 = new_network(medium, ip2.into(), config.clone());
957
958 net1.spawn_local();
959 net2.spawn_local();
960
961 net1.bind(Protocol::Tcp, (ip1, 1))?;
962 net2.bind(Protocol::Tcp, (ip2, 1))?;
963
964 net_inject(
967 UnboundedReceiverStream::new(net2.egress_receiver().unwrap())
968 .map(|e| e.payload.into_vec()),
969 net1.clone(),
970 );
971 let (tx, rx) = mpsc::channel(1);
973 net_receive(
974 tx,
975 UnboundedReceiverStream::new(net1.ingress_receiver().unwrap()),
976 );
977 let consume1 = consume_data(rx, total);
979
980 net_inject(
983 UnboundedReceiverStream::new(net1.egress_receiver().unwrap())
984 .map(|e| e.payload.into_vec()),
985 net2.clone(),
986 );
987 let (tx, rx) = mpsc::channel(1);
989 net_receive(
990 tx,
991 UnboundedReceiverStream::new(net2.ingress_receiver().unwrap()),
992 );
993 let consume2 = consume_data(rx, total);
995
996 let conn1 = net1.connect((ip2, 1), Duration::from_secs(3)).await?;
997 let conn2 = net2.connect((ip1, 1), Duration::from_secs(3)).await?;
998
999 net1.poll();
1000 net2.poll();
1001
1002 let (tx, rx) = mpsc::channel(1);
1003 let produce1 = produce_data(tx, total, chunk_size);
1004 net_send(rx, net1.clone(), conn1);
1005
1006 let (tx, rx) = mpsc::channel(1);
1007 let produce2 = produce_data(tx, total, chunk_size);
1008 net_send(rx, net2.clone(), conn2);
1009
1010 let (f1, f2, f3, f4) = futures::future::join4(produce1, produce2, consume1, consume2).await;
1011
1012 let (mut ptx1, produced1) = f1??;
1013 let (mut ptx2, produced2) = f2??;
1014 let consumed1 = f3??;
1015 let consumed2 = f4??;
1016
1017 let _ = ptx1.close().await;
1018 let _ = ptx2.close().await;
1019
1020 assert_eq!(hex::encode(produced1), hex::encode(consumed2));
1021 assert_eq!(hex::encode(produced2), hex::encode(consumed1));
1022
1023 Ok(())
1024 }
1025
1026 async fn re_bind(medium: Medium, total: usize, chunk_size: usize) -> anyhow::Result<()> {
1027 const MTU: usize = 65535;
1028
1029 println!(">> exchanging {} B in {} B chunks", total, chunk_size);
1030
1031 let ip1 = Ipv4Address::new(10, 0, 0, 1);
1032 let ip2 = Ipv4Address::new(10, 0, 0, 2);
1033 let ip3 = Ipv4Address::new(10, 0, 0, 3);
1034
1035 let config = StackConfig {
1036 max_transmission_unit: MTU,
1037 ..Default::default()
1038 };
1039
1040 let net1 = new_network(medium, ip1.into(), config.clone());
1041 let net2 = new_network(medium, ip2.into(), config.clone());
1042 let net3 = new_network(medium, ip3.into(), config.clone());
1043
1044 net1.spawn_local();
1045 net2.spawn_local();
1046 net3.spawn_local();
1047
1048 net1.bind(Protocol::Tcp, (ip1, 1))?;
1049 net2.bind(Protocol::Tcp, (ip2, 1))?;
1050 net3.bind(Protocol::Tcp, (ip3, 1))?;
1051
1052 net_inject(
1055 UnboundedReceiverStream::new(net2.egress_receiver().unwrap())
1056 .map(|e| e.payload.into_vec()),
1057 net1.clone(),
1058 );
1059 net_inject(
1061 UnboundedReceiverStream::new(net3.egress_receiver().unwrap())
1062 .map(|e| e.payload.into_vec()),
1063 net1.clone(),
1064 );
1065 let (tx, rx) = mpsc::channel(1);
1067 net_receive(
1068 tx,
1069 UnboundedReceiverStream::new(net1.ingress_receiver().unwrap()),
1070 );
1071
1072 let _consume1 = spawn_local(rx.for_each(|e| async move { println!("consumer 1: {e:?}") }));
1073
1074 net_inject2(
1077 UnboundedReceiverStream::new(net1.egress_receiver().unwrap()),
1078 net2.clone(),
1079 net3.clone(),
1080 );
1081
1082 let (tx, rx) = mpsc::channel(1);
1084 net_receive(
1085 tx,
1086 UnboundedReceiverStream::new(net2.ingress_receiver().unwrap()),
1087 );
1088
1089 let _consume2 = spawn_local(rx.for_each(|e| async move { println!("consumer 2: {e:?}") }));
1090
1091 let (tx, rx) = mpsc::channel(1);
1093 net_receive(
1094 tx,
1095 UnboundedReceiverStream::new(net3.ingress_receiver().unwrap()),
1096 );
1097
1098 let _consume3 = spawn_local(rx.for_each(|e| async move { println!("consumer 3: {e:?}") }));
1099
1100 let conn1 = net2.connect((ip1, 1), Duration::from_secs(3));
1101 let conn2 = net3.connect((ip1, 1), Duration::from_secs(3));
1102
1103 let (f1, f2) = futures::future::join(conn1, conn2).await;
1104
1105 f1.expect("Connection failed!");
1106 f2.expect("Connection failed!");
1107
1108 Ok(())
1109 }
1110
1111 #[cfg(feature = "test-suite")]
1113 async fn establish_multiple_conn(
1114 medium: Medium,
1115 total: usize,
1116 chunk_size: usize,
1117 conn_num: u16,
1118 ) -> anyhow::Result<()> {
1119 use crate::error;
1120
1121 const MTU: usize = 65535;
1122
1123 println!(">> exchanging {} B in {} B chunks", total, chunk_size);
1124
1125 let ip1 = Ipv4Address::new(10, 0, 0, 1);
1126 let ip2 = Ipv4Address::new(10, 0, 0, 2);
1127
1128 let config = StackConfig {
1129 max_transmission_unit: MTU,
1130 ..Default::default()
1131 };
1132
1133 let net1 = new_network(medium, ip1.into(), config.clone());
1134 let net2 = new_network(medium, ip2.into(), config.clone());
1135
1136 net1.spawn_local();
1137 net2.spawn_local();
1138
1139 net1.bind(Protocol::Tcp, (ip1, 1))?;
1140 net2.bind(Protocol::Tcp, (ip2, 1))?;
1141
1142 net_inject(
1145 UnboundedReceiverStream::new(net2.egress_receiver().unwrap())
1146 .map(|e| e.payload.into_vec()),
1147 net1.clone(),
1148 );
1149
1150 let (tx, rx) = mpsc::channel(1);
1152 net_receive(
1153 tx,
1154 UnboundedReceiverStream::new(net1.ingress_receiver().unwrap()),
1155 );
1156
1157 let _consume1 = spawn_local(rx.for_each(|e| async move { println!("consumer 1: {e:?}") }));
1158
1159 net_inject(
1162 UnboundedReceiverStream::new(net1.egress_receiver().unwrap())
1163 .map(|e| e.payload.into_vec()),
1164 net2.clone(),
1165 );
1166
1167 let (tx, rx) = mpsc::channel(1);
1169 net_receive(
1170 tx,
1171 UnboundedReceiverStream::new(net2.ingress_receiver().unwrap()),
1172 );
1173
1174 let _consume2 = spawn_local(rx.for_each(|e| async move { println!("consumer 2: {e:?}") }));
1175
1176 for i in 1..=conn_num {
1177 let conn = net2.connect((ip1, 1), Duration::from_secs(3)).await;
1178 match conn {
1179 Ok(_) => println!("Connection({i}) successful"),
1180 Err(e) => {
1181 if i != u16::MAX {
1182 panic!("Connection failed! Error: {}", e);
1183 };
1184
1185 let expected = error::Error::Other("no ports available".into());
1186 assert_eq!(expected, e)
1187 }
1188 }
1189 }
1190
1191 Ok(())
1192 }
1193
1194 async fn spawn_exchange(medium: Medium, total: usize, chunk_size: usize) -> anyhow::Result<()> {
1195 tokio::task::LocalSet::new()
1196 .run_until(tokio::time::timeout(
1197 EXCHANGE_TIMEOUT,
1198 net_exchange(medium, total, chunk_size),
1199 ))
1200 .await?
1201 }
1202
1203 async fn spawn_exchange_scenarios(medium: Medium) -> anyhow::Result<()> {
1204 spawn_exchange(medium, 1024, 1).await?;
1205 spawn_exchange(medium, 1024, 4).await?;
1206 spawn_exchange(medium, 1024, 7).await?;
1207 spawn_exchange(medium, 10240, 16).await?;
1208 spawn_exchange(medium, 1024000, 383).await?;
1209 spawn_exchange(medium, 1024000, 384).await?;
1210 spawn_exchange(medium, 1024000, 4096).await?;
1211 spawn_exchange(medium, 1024000, 40960).await?;
1212
1213 #[cfg(not(debug_assertions))]
1214 {
1215 spawn_exchange(medium, 10240000, 40960).await?;
1216 spawn_exchange(medium, 10240000, 131070).await?;
1217 spawn_exchange(medium, 10240000, 1024000).await?;
1218 }
1219
1220 Ok(())
1221 }
1222
1223 #[tokio::test]
1224 async fn tap_exchange() -> anyhow::Result<()> {
1225 spawn_exchange_scenarios(Medium::Ethernet).await
1226 }
1227
1228 #[tokio::test]
1229 async fn tun_exchange() -> anyhow::Result<()> {
1230 spawn_exchange_scenarios(Medium::Ip).await
1231 }
1232
1233 #[tokio::test]
1234 async fn socket_re_binding() -> anyhow::Result<()> {
1235 tokio::task::LocalSet::new()
1236 .run_until(tokio::time::timeout(
1237 EXCHANGE_TIMEOUT,
1238 re_bind(Medium::Ip, 0, 0),
1239 ))
1240 .await?
1241 }
1242
1243 #[cfg(feature = "test-suite")]
1245 #[tokio::test]
1246 async fn multiple_conn() -> anyhow::Result<()> {
1247 tokio::task::LocalSet::new()
1248 .run_until(establish_multiple_conn(Medium::Ip, 0, 0, u16::MAX - 1))
1249 .await
1250 }
1251
1252 #[cfg(feature = "test-suite")]
1254 #[tokio::test]
1255 async fn overload_conn() -> anyhow::Result<()> {
1256 tokio::task::LocalSet::new()
1257 .run_until(establish_multiple_conn(Medium::Ip, 0, 0, u16::MAX))
1258 .await
1259 }
1260}