ya_relay_stack/
network.rs

1use std::cell::RefCell;
2use std::collections::{HashMap, HashSet};
3use std::convert::{TryFrom, TryInto};
4use std::path::PathBuf;
5use std::rc::Rc;
6use std::time::Duration;
7
8use futures::channel::mpsc;
9use futures::future::{Either, LocalBoxFuture};
10use futures::{Future, FutureExt, SinkExt, StreamExt, TryFutureExt};
11use smoltcp::iface::SocketHandle;
12use smoltcp::wire::IpEndpoint;
13use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};
14use tokio::task::spawn_local;
15use tokio::time::MissedTickBehavior;
16
17use crate::connection::{Connection, ConnectionMeta};
18use crate::packet::{
19    ip_ntoh, ArpField, ArpPacket, EtherFrame, IpPacket, PeekPacket, TcpPacket, UdpPacket,
20};
21use crate::protocol::Protocol;
22use crate::socket::{SocketDesc, SocketEndpoint, SocketExt, SocketMemory, SocketState};
23use crate::stack::Stack;
24use crate::{ChannelMetrics, Error, Result};
25
26use ya_relay_util::Payload;
27
28pub const PCAP_FILE_ENV_VAR: &str = "YA_NET_PCAP_FILE";
29pub const STACK_POLL_MS_ENV_VAR: &str = "YA_NET_STACK_POLL_MS";
30pub const STACK_POLL_SENT_ENV_VAR: &str = "YA_NET_STACK_POLL_SENT_BATCH";
31pub const STACK_POLL_RECV_ENV_VAR: &str = "YA_NET_STACK_POLL_RECV_BATCH";
32
33const DEFAULT_POLL_SENT_BATCH: usize = 16348;
34const DEFAULT_POLL_RECV_BATCH: usize = 32768;
35const MIN_STACK_POLL_SENT_BATCH: usize = 2048;
36const MIN_STACK_POLL_RECV_BATCH: usize = 4096;
37
38pub type IngressReceiver = UnboundedReceiver<IngressEvent>;
39pub type EgressReceiver = UnboundedReceiver<EgressEvent>;
40
41#[derive(Clone)]
42pub struct StackConfig {
43    pub pcap_path: Option<PathBuf>,
44    pub max_transmission_unit: usize,
45    pub max_send_batch: usize,
46    pub max_recv_batch: usize,
47    pub tcp_mem: SocketMemory,
48    pub udp_mem: SocketMemory,
49    pub icmp_mem: SocketMemory,
50    pub raw_mem: SocketMemory,
51}
52
53impl Default for StackConfig {
54    fn default() -> Self {
55        let max_send_batch = std::env::var(STACK_POLL_SENT_ENV_VAR)
56            .and_then(|s| {
57                s.parse::<usize>()
58                    .map_err(|_| std::env::VarError::NotPresent)
59            })
60            .unwrap_or(DEFAULT_POLL_SENT_BATCH)
61            .max(MIN_STACK_POLL_SENT_BATCH);
62
63        let max_recv_batch = std::env::var(STACK_POLL_RECV_ENV_VAR)
64            .and_then(|s| {
65                s.parse::<usize>()
66                    .map_err(|_| std::env::VarError::NotPresent)
67            })
68            .unwrap_or(DEFAULT_POLL_RECV_BATCH)
69            .max(MIN_STACK_POLL_RECV_BATCH);
70
71        Self {
72            pcap_path: std::env::var(PCAP_FILE_ENV_VAR).ok().map(PathBuf::from),
73            max_transmission_unit: 1400,
74            max_send_batch,
75            max_recv_batch,
76            tcp_mem: SocketMemory::default_tcp(),
77            udp_mem: SocketMemory::default_udp(),
78            icmp_mem: SocketMemory::default_icmp(),
79            raw_mem: SocketMemory::default_raw(),
80        }
81    }
82}
83
84#[derive(Clone)]
85pub struct Network {
86    pub name: Rc<String>,
87    pub config: Rc<StackConfig>,
88    pub stack: Stack<'static>,
89    is_tun: bool,
90    sender: StackSender,
91    poller: StackPoller,
92    /// Set of listening sockets. Socket is removed from this set, when connection is created.
93    /// Network stack will create new binding using new handle in place of previous.
94    pub bindings: Rc<RefCell<HashSet<SocketHandle>>>,
95    pub connections: Rc<RefCell<HashMap<ConnectionMeta, Connection>>>,
96    pub handles: Rc<RefCell<HashMap<SocketHandle, ConnectionMeta>>>,
97    ingress: Channel<IngressEvent>,
98    egress: Channel<EgressEvent>,
99}
100
101impl Network {
102    /// Creates a new Network instance
103    pub fn new(name: impl ToString, config: Rc<StackConfig>, stack: Stack<'static>) -> Self {
104        let is_tun = {
105            let iface_rfc = stack.iface();
106            let iface = iface_rfc.borrow();
107            iface.device().is_tun()
108        };
109
110        let network = Self {
111            name: Rc::new(name.to_string()),
112            config,
113            stack,
114            is_tun,
115            sender: Default::default(),
116            poller: Default::default(),
117            bindings: Default::default(),
118            connections: Default::default(),
119            handles: Default::default(),
120            ingress: Default::default(),
121            egress: Default::default(),
122        };
123
124        network.sender.net.borrow_mut().replace(network.clone());
125        network.poller.net.borrow_mut().replace(network.clone());
126        network
127    }
128
129    /// Returns a socket listening on an endpoint and ready for incoming
130    /// connections. Sockets already connected won't be returned.
131    pub fn get_bound(
132        &self,
133        protocol: Protocol,
134        local_endpoint: impl Into<SocketEndpoint>,
135    ) -> Option<SocketHandle> {
136        let endpoint = local_endpoint.into();
137        let iface_rfc = self.stack.iface();
138        let iface = iface_rfc.borrow();
139        let mut sockets = iface.sockets();
140        sockets
141            .find(|(handle, s)| {
142                s.protocol() == protocol
143                    && s.local_endpoint() == endpoint
144                    // This condition prevents from returning socket connection instead
145                    // of listening socket, unattached to connection.
146                    && self.bindings.borrow().contains(handle)
147            })
148            .map(|(h, _)| h)
149    }
150
151    /// Listen on a local endpoint
152    pub fn bind(
153        &self,
154        protocol: Protocol,
155        endpoint: impl Into<SocketEndpoint>,
156    ) -> Result<SocketHandle> {
157        let endpoint = endpoint.into();
158        let handle = self.stack.bind(protocol, endpoint)?;
159        self.bindings.borrow_mut().insert(handle);
160        Ok(handle)
161    }
162
163    /// Stop listening on a local endpoint
164    pub fn unbind(&self, protocol: Protocol, endpoint: impl Into<SocketEndpoint>) -> Result<()> {
165        let endpoint = endpoint.into();
166        let handle = self.stack.unbind(protocol, endpoint)?;
167        self.bindings.borrow_mut().remove(&handle);
168        Ok(())
169    }
170
171    /// Initiate a TCP connection
172    pub fn connect(
173        &self,
174        remote: impl Into<IpEndpoint>,
175        timeout: impl Into<Duration>,
176    ) -> LocalBoxFuture<Result<Connection>> {
177        let remote = remote.into();
178        let timeout = timeout.into();
179
180        let connect = match self.stack.connect(remote) {
181            Ok(fut) => fut,
182            Err(err) => return futures::future::err(err).boxed_local(),
183        };
184        self.poll();
185
186        let connections = self.connections.clone();
187        let handles = self.handles.clone();
188
189        async move {
190            let connection = match tokio::time::timeout(timeout, connect).await {
191                Ok(Ok(conn)) => conn,
192                Ok(Err(error)) => return Err(error),
193                _ => return Err(Error::ConnectionTimeout),
194            };
195            Self::add_connection_to(connection, &connections, &handles);
196            Ok(connection)
197        }
198        .boxed_local()
199    }
200
201    /// Close all TCP connections with a remote IP address
202    pub fn disconnect_all(
203        &self,
204        remote_ip: Box<[u8]>,
205        timeout: impl Into<Duration>,
206    ) -> LocalBoxFuture<()> {
207        let (handles, futs): (Vec<_>, Vec<_>) = {
208            let connections = self.connections.borrow();
209            connections
210                .values()
211                .filter(|conn| {
212                    conn.meta.remote.addr.as_bytes() == remote_ip.as_ref()
213                        && conn.meta.protocol == Protocol::Tcp
214                })
215                .map(|conn| (conn.handle, self.stack.disconnect(conn.handle)))
216                .unzip()
217        };
218
219        if futs.is_empty() {
220            return futures::future::ready(()).boxed_local();
221        }
222
223        self.poll();
224
225        let timeout = timeout.into();
226        let net = self.clone();
227
228        async move {
229            let pending = futures::future::join_all(futs);
230            let timeout = tokio::time::sleep(timeout).boxed_local();
231
232            if let Either::Right((_, pending)) = futures::future::select(pending, timeout).await {
233                handles.into_iter().for_each(|h| net.stack.abort(h));
234                net.poll();
235
236                let timeout = tokio::time::sleep(Duration::from_millis(500));
237                let _ = futures::future::select(pending, timeout.boxed_local()).await;
238            }
239        }
240        .boxed_local()
241    }
242
243    pub fn bindings(&self) -> core::cell::Ref<'_, HashSet<SocketHandle>> {
244        self.bindings.borrow()
245    }
246
247    pub fn handles(&self) -> core::cell::Ref<'_, HashMap<SocketHandle, ConnectionMeta>> {
248        self.handles.borrow()
249    }
250
251    pub fn connections(&self) -> core::cell::Ref<'_, HashMap<ConnectionMeta, Connection>> {
252        self.connections.borrow()
253    }
254
255    pub fn sockets(&self) -> Vec<(SocketDesc, SocketState<ChannelMetrics>)> {
256        let iface_rfc = self.stack.iface();
257        let iface = iface_rfc.borrow();
258        let metrics_rfc = self.stack.metrics();
259        let metrics = metrics_rfc.borrow();
260
261        iface
262            .sockets()
263            .map(|(_, s)| {
264                let desc = s.desc();
265                let metrics = metrics.get(&desc).cloned().unwrap_or_default();
266                let mut state = s.state();
267                state.set_inner(metrics);
268                (desc, state)
269            })
270            .collect()
271    }
272
273    pub fn sockets_meta(&self) -> Vec<(SocketHandle, SocketDesc, SocketState<ChannelMetrics>)> {
274        let iface_rfc = self.stack.iface();
275        let iface = iface_rfc.borrow();
276        let connections = self.handles.borrow();
277
278        iface
279            .sockets()
280            .map(|(handle, s)| {
281                (
282                    handle,
283                    connections
284                        .get(&handle)
285                        .cloned()
286                        .map(|meta| meta.into())
287                        .unwrap_or(s.desc()),
288                    s.state(),
289                )
290            })
291            .collect()
292    }
293
294    pub fn metrics(&self) -> ChannelMetrics {
295        let iface_rfc = self.stack.iface();
296        let iface = iface_rfc.borrow();
297        iface.device().metrics()
298    }
299
300    #[inline(always)]
301    fn is_connected(&self, meta: &ConnectionMeta) -> bool {
302        self.connections.borrow().contains_key(meta)
303    }
304
305    #[inline(always)]
306    fn add_connection(&self, connection: Connection) {
307        Self::add_connection_to(connection, &self.connections, &self.handles);
308    }
309
310    fn add_connection_to(
311        connection: Connection,
312        connections: &Rc<RefCell<HashMap<ConnectionMeta, Connection>>>,
313        handles: &Rc<RefCell<HashMap<SocketHandle, ConnectionMeta>>>,
314    ) {
315        let handle = connection.handle;
316        let meta = connection.into();
317        connections.borrow_mut().insert(meta, connection);
318        handles.borrow_mut().insert(handle, meta);
319    }
320
321    #[inline(always)]
322    fn remove_connection(&self, meta: &ConnectionMeta, handle: SocketHandle) {
323        self.stack.remove(meta, handle);
324        self.handles.borrow_mut().remove(&handle);
325        self.sender.remove(&handle);
326
327        let ip_endpoint = smoltcp::wire::IpListenEndpoint::from(meta.remote);
328        if !ip_endpoint.is_specified() {
329            return;
330        }
331        self.connections.borrow_mut().remove(meta);
332    }
333
334    /// Inject send data into the stack
335    #[inline(always)]
336    pub fn send<'a>(
337        &self,
338        data: impl Into<Payload>,
339        connection: Connection,
340    ) -> impl Future<Output = Result<()>> + 'a {
341        self.sender.send(data.into(), connection)
342    }
343
344    /// Inject received data into the stack
345    #[inline(always)]
346    pub fn receive(&self, data: impl Into<Payload>) {
347        self.stack.receive(data)
348    }
349
350    pub fn spawn_local(&self) {
351        let interval = std::env::var(STACK_POLL_MS_ENV_VAR)
352            .and_then(|s| s.parse::<u64>().map_err(|_| std::env::VarError::NotPresent))
353            .and_then(|v| match v {
354                0 => Err(std::env::VarError::NotPresent),
355                v => Ok(v),
356            })
357            .unwrap_or(250);
358        self.poller.clone().spawn(Duration::from_millis(interval));
359    }
360
361    /// Polls the inner network stack
362    pub fn poll(&self) {
363        loop {
364            let finished = match (self.stack.poll(), self.is_tun) {
365                (true, _) | (_, false) => self.process_ingress() && self.process_egress(),
366                (false, _) => true,
367            };
368            if finished {
369                break;
370            }
371        }
372    }
373
374    /// Take the ingress traffic receive channel
375    #[inline(always)]
376    pub fn ingress_receiver(&self) -> Option<IngressReceiver> {
377        self.ingress.receiver()
378    }
379
380    /// Take the egress traffic receive channel
381    #[inline(always)]
382    pub fn egress_receiver(&self) -> Option<EgressReceiver> {
383        self.egress.receiver()
384    }
385
386    fn process_ingress(&self) -> bool {
387        let mut finished = true;
388
389        let iface_rfc = self.stack.iface();
390        let mut iface = iface_rfc.borrow_mut();
391        let mut bindings = self.bindings.borrow_mut();
392        let mut events = Vec::new();
393        let mut remove = Vec::new();
394        let mut rebind = None;
395
396        for (handle, socket) in iface.sockets_mut() {
397            let mut desc = socket.desc();
398
399            // When socket is closing, smoltcp clears remote endpoint at some point
400            // to `Unspecified`. This is why SocketDesc and ConnectionMeta can differ here.
401            // We will try to use ConnectionMeta, because it conveys more information.
402            if socket.is_closed() {
403                let meta = self
404                    .handles
405                    .borrow()
406                    .get(&handle)
407                    .copied()
408                    .or(desc.try_into().ok());
409
410                if let Some(meta) = meta {
411                    // We had established connection with someone and it was closed.
412                    log::debug!("{}: closing socket [{handle}]: {desc} / {meta}", self.name);
413                    events.push(IngressEvent::Disconnected { desc: meta.into() });
414                } else {
415                    // Connection metadata was cleared.
416                    // Socket probably got RST packet and it's state was cleared,
417                    // or it was only listening socket that wasn't connected at any moment.
418                    log::debug!("Removing socket {handle} with reset metadata");
419                }
420
421                remove.push((
422                    meta.unwrap_or(ConnectionMeta::unspecified(desc.protocol)),
423                    handle,
424                ));
425            }
426
427            let mut received = 0;
428
429            while socket.can_recv() {
430                let (remote, payload) = match socket.recv() {
431                    Ok(Some(tuple)) => tuple,
432                    Ok(None) => break,
433                    Err(err) => {
434                        log::debug!("{}: ingress packet error: {err}", self.name);
435                        continue;
436                    }
437                };
438
439                let len = payload.len();
440
441                received += len;
442                desc.remote = remote.into();
443
444                if let Ok(meta) = desc.try_into() {
445                    if !self.is_connected(&meta) {
446                        self.add_connection(Connection { handle, meta });
447                        events.push(IngressEvent::InboundConnection { desc: meta.into() });
448                    }
449                }
450
451                log::trace!("{}: ingress {len} B packet", self.name);
452
453                self.stack.on_received(&desc, len);
454                events.push(IngressEvent::Packet { desc, payload });
455
456                if received >= self.config.max_recv_batch {
457                    finished = false;
458                    break;
459                }
460            }
461
462            if bindings.contains(&handle) && socket.remote_endpoint().is_specified() {
463                bindings.remove(&handle);
464                rebind = Some((socket.protocol(), socket.local_endpoint()));
465
466                finished = false;
467                break;
468            }
469        }
470
471        drop(bindings);
472        drop(iface);
473
474        remove.into_iter().for_each(|(meta, handle)| {
475            self.remove_connection(&meta, handle);
476        });
477
478        if !events.is_empty() {
479            let ingress_tx = self.ingress.tx.clone();
480            for event in events {
481                if ingress_tx.send(event).is_err() {
482                    log::debug!(
483                        "{}: ingress channel closed, unable to receive packets",
484                        self.name
485                    );
486                    break;
487                }
488            }
489        }
490
491        if let Some((p, ep)) = rebind {
492            if let Err(e) = self.bind(p, ep) {
493                log::warn!("{}: cannot bind socket {p} {ep:?}: {e}", self.name);
494            }
495            let _ = self.stack.poll();
496            return self.process_ingress();
497        }
498
499        finished
500    }
501
502    fn process_egress(&self) -> bool {
503        let mut sent = 0;
504        let mut finished = true;
505
506        let iface_rfc = self.stack.iface();
507        let mut iface = iface_rfc.borrow_mut();
508        let device = iface.device_mut();
509        let is_tun = device.is_tun();
510
511        while let Some(data) = device.next_phy_tx() {
512            match {
513                if is_tun {
514                    EgressEvent::from_ip_packet(data)
515                } else {
516                    EgressEvent::from_eth_frame(data)
517                }
518            } {
519                Ok(event) => {
520                    sent += event.payload.len();
521
522                    if let Some((desc, size)) = event.desc.as_ref() {
523                        self.stack.on_sent(desc, *size);
524                    }
525
526                    if self.egress.tx.send(event).is_err() {
527                        log::trace!(
528                            "{}: egress channel closed, unable to send packets",
529                            *self.name
530                        );
531                        break;
532                    }
533                }
534                Err(err) => log::trace!("{}: egress packet error: {}", *self.name, err),
535            }
536
537            if sent >= self.config.max_send_batch {
538                finished = false;
539                continue;
540            }
541        }
542
543        finished
544    }
545}
546
547#[derive(Clone, Debug)]
548pub enum IngressEvent {
549    /// New connection to a bound endpoint
550    InboundConnection { desc: SocketDesc },
551    /// Disconnection from a bound endpoint
552    Disconnected { desc: SocketDesc },
553    /// Bound endpoint packet
554    Packet { desc: SocketDesc, payload: Vec<u8> },
555}
556
557#[derive(Clone, Debug)]
558pub struct EgressEvent {
559    pub remote: Box<[u8]>,
560    pub payload: Box<[u8]>,
561    pub desc: Option<(SocketDesc, usize)>,
562}
563
564impl EgressEvent {
565    pub fn from_eth_frame(data: Vec<u8>) -> Result<Self> {
566        let frame = EtherFrame::try_from(data)?;
567        let (desc, remote) = match &frame {
568            EtherFrame::Ip(_) => {
569                let data = frame.payload();
570                IpPacket::peek(data)?;
571
572                let packet = IpPacket::packet(data);
573                let remote = packet.dst_address().into();
574                let desc = Self::payload_desc(&packet);
575                (desc, remote)
576            }
577            EtherFrame::Arp(_) => {
578                let packet = ArpPacket::packet(frame.payload());
579                let remote = packet.get_field(ArpField::TPA).into();
580                (None, remote)
581            }
582        };
583
584        Ok(EgressEvent {
585            remote,
586            payload: frame.into(),
587            desc,
588        })
589    }
590
591    pub fn from_ip_packet(data: Vec<u8>) -> Result<Self> {
592        let (desc, remote) = {
593            IpPacket::peek(&data)?;
594            let packet = IpPacket::packet(&data);
595            let remote = packet.dst_address().into();
596            let desc = Self::payload_desc(&packet);
597
598            (desc, remote)
599        };
600
601        Ok(EgressEvent {
602            remote,
603            payload: data.into_boxed_slice(),
604            desc,
605        })
606    }
607
608    fn payload_desc(packet: &IpPacket) -> Option<(SocketDesc, usize)> {
609        let protocol = Protocol::try_from(packet.protocol()).ok()?;
610
611        let (local_port, remote_port, size) = match protocol {
612            Protocol::Tcp => {
613                TcpPacket::peek(packet.payload()).ok()?;
614                let tcp = TcpPacket::packet(packet.payload());
615                (tcp.src_port(), tcp.dst_port(), tcp.payload_size)
616            }
617            Protocol::Udp => {
618                UdpPacket::peek(packet.payload()).ok()?;
619                let udp = UdpPacket::packet(packet.payload());
620                (udp.src_port(), udp.dst_port(), udp.payload_size)
621            }
622            _ => return None,
623        };
624
625        let local_ip = ip_ntoh(packet.src_address())?;
626        let remote_ip = ip_ntoh(packet.dst_address())?;
627
628        let desc = SocketDesc {
629            protocol,
630            local: (local_ip, local_port).into(),
631            remote: (remote_ip, remote_port).into(),
632        };
633
634        Some((desc, size))
635    }
636}
637
638#[derive(Clone, Default)]
639struct StackSender {
640    inner: Rc<RefCell<StackSenderInner>>,
641    net: Rc<RefCell<Option<Network>>>,
642}
643
644impl StackSender {
645    #[inline]
646    pub fn send<'a>(
647        &self,
648        data: Payload,
649        conn: Connection,
650    ) -> impl Future<Output = Result<()>> + 'a {
651        let mut sender = {
652            match {
653                let inner = self.inner.borrow();
654                inner.map.get(&conn.handle).cloned()
655            } {
656                Some(sender) => sender,
657                None => self.spawn(conn.handle),
658            }
659        };
660        async move { sender.send((data, conn)).map_err(Error::from).await }
661    }
662
663    fn spawn(&self, handle: SocketHandle) -> mpsc::Sender<(Payload, Connection)> {
664        let net = self.net.borrow().clone().expect("Network not initialized");
665        let (tx, rx) = mpsc::channel(1);
666
667        spawn_local(async move {
668            rx.for_each(|(vec, conn)| {
669                let net = net.clone();
670                let stack = net.stack.clone();
671                async move {
672                    let _ = stack.send(vec, conn, move || net.poll()).await;
673                }
674            })
675            .await;
676        });
677
678        let mut inner = self.inner.borrow_mut();
679        inner.map.insert(handle, tx.clone());
680
681        tx
682    }
683
684    pub fn remove(&self, handle: &SocketHandle) {
685        let mut inner = self.inner.borrow_mut();
686        if let Some(mut tx) = inner.map.remove(handle) {
687            spawn_local(async move {
688                let _ = tx.close().await;
689            });
690        }
691    }
692}
693
694#[derive(Default)]
695struct StackSenderInner {
696    map: HashMap<SocketHandle, mpsc::Sender<(Payload, Connection)>>,
697}
698
699#[derive(Clone, Default)]
700struct StackPoller {
701    net: Rc<RefCell<Option<Network>>>,
702}
703
704impl StackPoller {
705    pub fn spawn(&self, interval: Duration) {
706        let poller = self.clone();
707        spawn_local(async move {
708            let mut interval = tokio::time::interval(interval);
709            interval.set_missed_tick_behavior(MissedTickBehavior::Skip);
710            loop {
711                interval.tick().await;
712                poller.net.borrow().as_ref().unwrap().poll();
713            }
714        });
715    }
716}
717
718#[derive(Clone)]
719pub struct Channel<T> {
720    pub tx: UnboundedSender<T>,
721    rx: Rc<RefCell<Option<UnboundedReceiver<T>>>>,
722}
723
724impl<T> Channel<T> {
725    pub fn receiver(&self) -> Option<UnboundedReceiver<T>> {
726        self.rx.borrow_mut().take()
727    }
728}
729
730impl<T> Default for Channel<T> {
731    fn default() -> Self {
732        let (tx, rx) = unbounded_channel();
733        Self {
734            tx,
735            rx: Rc::new(RefCell::new(Some(rx))),
736        }
737    }
738}
739
740#[cfg(test)]
741mod tests {
742    use std::fmt::Debug;
743    use std::rc::Rc;
744    use std::time::Duration;
745
746    use futures::channel::{mpsc, oneshot};
747    use futures::{Sink, SinkExt, Stream, StreamExt};
748    use sha3::Digest;
749    use smoltcp::iface::Route;
750    use smoltcp::phy::Medium;
751    use smoltcp::wire::{IpAddress, IpCidr, Ipv4Address};
752    use tokio::task::spawn_local;
753    use tokio_stream::wrappers::UnboundedReceiverStream;
754
755    use crate::interface::{add_iface_address, add_iface_route, ip_to_mac, tap_iface, tun_iface};
756    use crate::{Connection, EgressEvent, IngressEvent, Network, Protocol, Stack, StackConfig};
757
758    const EXCHANGE_TIMEOUT: Duration = Duration::from_secs(30);
759
760    fn new_network(medium: Medium, ip: IpAddress, config: StackConfig) -> Network {
761        let config = Rc::new(config);
762        let cidr = IpCidr::new(ip, 16);
763        let route = match ip {
764            IpAddress::Ipv4(ipv4) => Route::new_ipv4_gateway(ipv4),
765            IpAddress::Ipv6(ipv6) => Route::new_ipv6_gateway(ipv6),
766        };
767
768        let mut iface = match medium {
769            Medium::Ethernet => tap_iface(ip_to_mac(ip), config.max_transmission_unit),
770            Medium::Ip => tun_iface(config.max_transmission_unit),
771            _ => panic!("unsupported medium: {:?}", medium),
772        };
773
774        add_iface_address(&mut iface, cidr);
775        add_iface_route(&mut iface, cidr, route);
776        Network::new(
777            format!("[{:?}] {}", medium, ip),
778            config.clone(),
779            Stack::new(iface, config),
780        )
781    }
782
783    fn produce_data<S, E>(
784        mut tx: S,
785        total: usize,
786        chunk_size: usize,
787    ) -> oneshot::Receiver<anyhow::Result<(S, Vec<u8>)>>
788    where
789        S: Sink<Vec<u8>, Error = E> + Unpin + 'static,
790        E: Into<anyhow::Error>,
791    {
792        let (dtx, drx) = oneshot::channel();
793
794        spawn_local(async move {
795            let mut digest = sha3::Sha3_224::new();
796            let mut sent = 0;
797            let mut err = None;
798
799            while sent < total {
800                let vec: Vec<u8> = (0..chunk_size.min(total - sent))
801                    .map(|_| rand::random())
802                    .collect();
803
804                digest.input(&vec);
805                sent += vec.len();
806
807                if let Err(e) = tx.send(vec).await {
808                    err = Some(e);
809                    break;
810                }
811            }
812
813            println!("Produced {} B", sent);
814            match err {
815                Some(e) => dtx.send(Err(e.into())),
816                None => dtx.send(Ok((tx, digest.result().to_vec()))),
817            }
818        });
819
820        drx
821    }
822
823    fn consume_data(
824        mut rx: mpsc::Receiver<Vec<u8>>,
825        total: usize,
826    ) -> oneshot::Receiver<anyhow::Result<Vec<u8>>> {
827        let (dtx, drx) = oneshot::channel();
828
829        spawn_local(async move {
830            let mut digest = sha3::Sha3_224::new();
831            let mut read: usize = 0;
832
833            while let Some(vec) = rx.next().await {
834                let len = vec.len();
835
836                read += len;
837                digest.input(&vec);
838
839                if read >= total {
840                    break;
841                }
842            }
843
844            println!("Consumed {} B", read);
845            let _ = dtx.send(Ok(digest.result().to_vec()));
846        });
847
848        drx
849    }
850
851    fn net_inject<S>(rx: S, net: Network)
852    where
853        S: Stream<Item = Vec<u8>> + 'static,
854    {
855        spawn_local(async move {
856            rx.for_each(|vec| {
857                let net = net.clone();
858                async move {
859                    net.receive(vec);
860                    net.poll();
861                }
862            })
863            .await;
864        });
865    }
866
867    fn net_inject2<S>(rx: S, net1: Network, net2: Network)
868    where
869        S: Stream<Item = EgressEvent> + 'static,
870    {
871        let ip1 = net1
872            .stack
873            .address()
874            .unwrap()
875            .address()
876            .as_bytes()
877            .to_vec()
878            .into_boxed_slice();
879
880        spawn_local(async move {
881            rx.for_each(|event| {
882                let net = if event.remote == ip1 {
883                    net1.clone()
884                } else {
885                    net2.clone()
886                };
887                async move {
888                    net.receive(event.payload);
889                    net.poll();
890                }
891            })
892            .await;
893        });
894    }
895
896    fn net_send<S>(rx: S, net: Network, conn: Connection)
897    where
898        S: Stream<Item = Vec<u8>> + 'static,
899    {
900        spawn_local(async move {
901            let net = net.clone();
902            rx.for_each(|vec| async {
903                let _ = net
904                    .send(vec, conn)
905                    .await
906                    .map_err(|e| eprintln!("failed to send packet: {}", e));
907            })
908            .await;
909        });
910    }
911
912    fn net_receive<Si, St, E>(tx: Si, rx: St)
913    where
914        Si: Sink<Vec<u8>, Error = E> + Clone + Unpin + 'static,
915        St: Stream<Item = IngressEvent> + 'static,
916        E: Into<anyhow::Error> + Debug,
917    {
918        spawn_local(async move {
919            rx.for_each(move |event| {
920                let mut tx = tx.clone();
921                async move {
922                    match event {
923                        IngressEvent::Packet { payload, .. } => {
924                            if let Err(e) = tx.send(payload).await {
925                                eprintln!("net send error: {:?}", e);
926                            }
927                        }
928                        IngressEvent::Disconnected { desc } => {
929                            println!("disconnected: {:?}", desc);
930                        }
931                        IngressEvent::InboundConnection { desc } => {
932                            println!("inbound connection: {:?}", desc);
933                        }
934                    }
935                }
936            })
937            .await;
938        });
939    }
940
941    /// Generate, send and receive data across 2 network instances
942    async fn net_exchange(medium: Medium, total: usize, chunk_size: usize) -> anyhow::Result<()> {
943        const MTU: usize = 65535;
944
945        println!(">> exchanging {} B in {} B chunks", total, chunk_size);
946
947        let ip1 = Ipv4Address::new(10, 0, 0, 1);
948        let ip2 = Ipv4Address::new(10, 0, 0, 2);
949
950        let config = StackConfig {
951            max_transmission_unit: MTU,
952            ..Default::default()
953        };
954
955        let net1 = new_network(medium, ip1.into(), config.clone());
956        let net2 = new_network(medium, ip2.into(), config.clone());
957
958        net1.spawn_local();
959        net2.spawn_local();
960
961        net1.bind(Protocol::Tcp, (ip1, 1))?;
962        net2.bind(Protocol::Tcp, (ip2, 1))?;
963
964        // net 1
965        // inject egress packets from net 2 into net 1 rx buffer
966        net_inject(
967            UnboundedReceiverStream::new(net2.egress_receiver().unwrap())
968                .map(|e| e.payload.into_vec()),
969            net1.clone(),
970        );
971        // process net 1 events
972        let (tx, rx) = mpsc::channel(1);
973        net_receive(
974            tx,
975            UnboundedReceiverStream::new(net1.ingress_receiver().unwrap()),
976        );
977        // future for calculating checksum from data received by net 1
978        let consume1 = consume_data(rx, total);
979
980        // net 2
981        // inject egress packets from net 1 into net 2 rx buffer
982        net_inject(
983            UnboundedReceiverStream::new(net1.egress_receiver().unwrap())
984                .map(|e| e.payload.into_vec()),
985            net2.clone(),
986        );
987        // process net 2 events
988        let (tx, rx) = mpsc::channel(1);
989        net_receive(
990            tx,
991            UnboundedReceiverStream::new(net2.ingress_receiver().unwrap()),
992        );
993        // future for calculating checksum from data received by net 2
994        let consume2 = consume_data(rx, total);
995
996        let conn1 = net1.connect((ip2, 1), Duration::from_secs(3)).await?;
997        let conn2 = net2.connect((ip1, 1), Duration::from_secs(3)).await?;
998
999        net1.poll();
1000        net2.poll();
1001
1002        let (tx, rx) = mpsc::channel(1);
1003        let produce1 = produce_data(tx, total, chunk_size);
1004        net_send(rx, net1.clone(), conn1);
1005
1006        let (tx, rx) = mpsc::channel(1);
1007        let produce2 = produce_data(tx, total, chunk_size);
1008        net_send(rx, net2.clone(), conn2);
1009
1010        let (f1, f2, f3, f4) = futures::future::join4(produce1, produce2, consume1, consume2).await;
1011
1012        let (mut ptx1, produced1) = f1??;
1013        let (mut ptx2, produced2) = f2??;
1014        let consumed1 = f3??;
1015        let consumed2 = f4??;
1016
1017        let _ = ptx1.close().await;
1018        let _ = ptx2.close().await;
1019
1020        assert_eq!(hex::encode(produced1), hex::encode(consumed2));
1021        assert_eq!(hex::encode(produced2), hex::encode(consumed1));
1022
1023        Ok(())
1024    }
1025
1026    async fn re_bind(medium: Medium, total: usize, chunk_size: usize) -> anyhow::Result<()> {
1027        const MTU: usize = 65535;
1028
1029        println!(">> exchanging {} B in {} B chunks", total, chunk_size);
1030
1031        let ip1 = Ipv4Address::new(10, 0, 0, 1);
1032        let ip2 = Ipv4Address::new(10, 0, 0, 2);
1033        let ip3 = Ipv4Address::new(10, 0, 0, 3);
1034
1035        let config = StackConfig {
1036            max_transmission_unit: MTU,
1037            ..Default::default()
1038        };
1039
1040        let net1 = new_network(medium, ip1.into(), config.clone());
1041        let net2 = new_network(medium, ip2.into(), config.clone());
1042        let net3 = new_network(medium, ip3.into(), config.clone());
1043
1044        net1.spawn_local();
1045        net2.spawn_local();
1046        net3.spawn_local();
1047
1048        net1.bind(Protocol::Tcp, (ip1, 1))?;
1049        net2.bind(Protocol::Tcp, (ip2, 1))?;
1050        net3.bind(Protocol::Tcp, (ip3, 1))?;
1051
1052        // net 1
1053        // inject egress packets from net 2 into net 1 rx buffer
1054        net_inject(
1055            UnboundedReceiverStream::new(net2.egress_receiver().unwrap())
1056                .map(|e| e.payload.into_vec()),
1057            net1.clone(),
1058        );
1059        // inject egress packets from net 3 into net 1 rx buffer
1060        net_inject(
1061            UnboundedReceiverStream::new(net3.egress_receiver().unwrap())
1062                .map(|e| e.payload.into_vec()),
1063            net1.clone(),
1064        );
1065        // process net 1 events
1066        let (tx, rx) = mpsc::channel(1);
1067        net_receive(
1068            tx,
1069            UnboundedReceiverStream::new(net1.ingress_receiver().unwrap()),
1070        );
1071
1072        let _consume1 = spawn_local(rx.for_each(|e| async move { println!("consumer 1: {e:?}") }));
1073
1074        // net 2
1075        // inject egress packets from net 1 into net 2 or net 3 rx buffer
1076        net_inject2(
1077            UnboundedReceiverStream::new(net1.egress_receiver().unwrap()),
1078            net2.clone(),
1079            net3.clone(),
1080        );
1081
1082        // process net 2 events
1083        let (tx, rx) = mpsc::channel(1);
1084        net_receive(
1085            tx,
1086            UnboundedReceiverStream::new(net2.ingress_receiver().unwrap()),
1087        );
1088
1089        let _consume2 = spawn_local(rx.for_each(|e| async move { println!("consumer 2: {e:?}") }));
1090
1091        // process net 3 events
1092        let (tx, rx) = mpsc::channel(1);
1093        net_receive(
1094            tx,
1095            UnboundedReceiverStream::new(net3.ingress_receiver().unwrap()),
1096        );
1097
1098        let _consume3 = spawn_local(rx.for_each(|e| async move { println!("consumer 3: {e:?}") }));
1099
1100        let conn1 = net2.connect((ip1, 1), Duration::from_secs(3));
1101        let conn2 = net3.connect((ip1, 1), Duration::from_secs(3));
1102
1103        let (f1, f2) = futures::future::join(conn1, conn2).await;
1104
1105        f1.expect("Connection failed!");
1106        f2.expect("Connection failed!");
1107
1108        Ok(())
1109    }
1110
1111    /// Establish given number of connections between single client and server
1112    #[cfg(feature = "test-suite")]
1113    async fn establish_multiple_conn(
1114        medium: Medium,
1115        total: usize,
1116        chunk_size: usize,
1117        conn_num: u16,
1118    ) -> anyhow::Result<()> {
1119        use crate::error;
1120
1121        const MTU: usize = 65535;
1122
1123        println!(">> exchanging {} B in {} B chunks", total, chunk_size);
1124
1125        let ip1 = Ipv4Address::new(10, 0, 0, 1);
1126        let ip2 = Ipv4Address::new(10, 0, 0, 2);
1127
1128        let config = StackConfig {
1129            max_transmission_unit: MTU,
1130            ..Default::default()
1131        };
1132
1133        let net1 = new_network(medium, ip1.into(), config.clone());
1134        let net2 = new_network(medium, ip2.into(), config.clone());
1135
1136        net1.spawn_local();
1137        net2.spawn_local();
1138
1139        net1.bind(Protocol::Tcp, (ip1, 1))?;
1140        net2.bind(Protocol::Tcp, (ip2, 1))?;
1141
1142        // net 1
1143        // inject egress packets from net 2 into net 1 rx buffer
1144        net_inject(
1145            UnboundedReceiverStream::new(net2.egress_receiver().unwrap())
1146                .map(|e| e.payload.into_vec()),
1147            net1.clone(),
1148        );
1149
1150        // process net 1 events
1151        let (tx, rx) = mpsc::channel(1);
1152        net_receive(
1153            tx,
1154            UnboundedReceiverStream::new(net1.ingress_receiver().unwrap()),
1155        );
1156
1157        let _consume1 = spawn_local(rx.for_each(|e| async move { println!("consumer 1: {e:?}") }));
1158
1159        // net 2
1160        // inject egress packets from net 1 into net 2 rx buffer
1161        net_inject(
1162            UnboundedReceiverStream::new(net1.egress_receiver().unwrap())
1163                .map(|e| e.payload.into_vec()),
1164            net2.clone(),
1165        );
1166
1167        // process net 2 events
1168        let (tx, rx) = mpsc::channel(1);
1169        net_receive(
1170            tx,
1171            UnboundedReceiverStream::new(net2.ingress_receiver().unwrap()),
1172        );
1173
1174        let _consume2 = spawn_local(rx.for_each(|e| async move { println!("consumer 2: {e:?}") }));
1175
1176        for i in 1..=conn_num {
1177            let conn = net2.connect((ip1, 1), Duration::from_secs(3)).await;
1178            match conn {
1179                Ok(_) => println!("Connection({i}) successful"),
1180                Err(e) => {
1181                    if i != u16::MAX {
1182                        panic!("Connection failed! Error: {}", e);
1183                    };
1184
1185                    let expected = error::Error::Other("no ports available".into());
1186                    assert_eq!(expected, e)
1187                }
1188            }
1189        }
1190
1191        Ok(())
1192    }
1193
1194    async fn spawn_exchange(medium: Medium, total: usize, chunk_size: usize) -> anyhow::Result<()> {
1195        tokio::task::LocalSet::new()
1196            .run_until(tokio::time::timeout(
1197                EXCHANGE_TIMEOUT,
1198                net_exchange(medium, total, chunk_size),
1199            ))
1200            .await?
1201    }
1202
1203    async fn spawn_exchange_scenarios(medium: Medium) -> anyhow::Result<()> {
1204        spawn_exchange(medium, 1024, 1).await?;
1205        spawn_exchange(medium, 1024, 4).await?;
1206        spawn_exchange(medium, 1024, 7).await?;
1207        spawn_exchange(medium, 10240, 16).await?;
1208        spawn_exchange(medium, 1024000, 383).await?;
1209        spawn_exchange(medium, 1024000, 384).await?;
1210        spawn_exchange(medium, 1024000, 4096).await?;
1211        spawn_exchange(medium, 1024000, 40960).await?;
1212
1213        #[cfg(not(debug_assertions))]
1214        {
1215            spawn_exchange(medium, 10240000, 40960).await?;
1216            spawn_exchange(medium, 10240000, 131070).await?;
1217            spawn_exchange(medium, 10240000, 1024000).await?;
1218        }
1219
1220        Ok(())
1221    }
1222
1223    #[tokio::test]
1224    async fn tap_exchange() -> anyhow::Result<()> {
1225        spawn_exchange_scenarios(Medium::Ethernet).await
1226    }
1227
1228    #[tokio::test]
1229    async fn tun_exchange() -> anyhow::Result<()> {
1230        spawn_exchange_scenarios(Medium::Ip).await
1231    }
1232
1233    #[tokio::test]
1234    async fn socket_re_binding() -> anyhow::Result<()> {
1235        tokio::task::LocalSet::new()
1236            .run_until(tokio::time::timeout(
1237                EXCHANGE_TIMEOUT,
1238                re_bind(Medium::Ip, 0, 0),
1239            ))
1240            .await?
1241    }
1242
1243    // Test case where establishing a maximum number of connections (equal to 65 534 connections) does not fail.
1244    #[cfg(feature = "test-suite")]
1245    #[tokio::test]
1246    async fn multiple_conn() -> anyhow::Result<()> {
1247        tokio::task::LocalSet::new()
1248            .run_until(establish_multiple_conn(Medium::Ip, 0, 0, u16::MAX - 1))
1249            .await
1250    }
1251
1252    // Test case where establishing a new connection (above the number of 65 534 connections) results in the "no ports available" error.
1253    #[cfg(feature = "test-suite")]
1254    #[tokio::test]
1255    async fn overload_conn() -> anyhow::Result<()> {
1256        tokio::task::LocalSet::new()
1257            .run_until(establish_multiple_conn(Medium::Ip, 0, 0, u16::MAX))
1258            .await
1259    }
1260}