xenet_socket/
unix.rs

1use std::{io, mem::MaybeUninit, net::SocketAddr, time::Duration};
2
3use socket2::{Domain, Protocol, Socket as SystemSocket, Type};
4use xenet_packet::ip::IpNextLevelProtocol;
5
6use super::{IpVersion, SocketOption, SocketType};
7
8pub(crate) fn check_socket_option(socket_option: SocketOption) -> Result<(), String> {
9    match socket_option.ip_version {
10        IpVersion::V4 => match socket_option.socket_type {
11            SocketType::Raw => match socket_option.protocol {
12                Some(IpNextLevelProtocol::Icmp) => Ok(()),
13                Some(IpNextLevelProtocol::Tcp) => Ok(()),
14                Some(IpNextLevelProtocol::Udp) => Ok(()),
15                _ => Err(String::from("Invalid protocol")),
16            },
17            SocketType::Datagram => match socket_option.protocol {
18                Some(IpNextLevelProtocol::Icmp) => Ok(()),
19                Some(IpNextLevelProtocol::Udp) => Ok(()),
20                _ => Err(String::from("Invalid protocol")),
21            },
22            SocketType::Stream => match socket_option.protocol {
23                Some(IpNextLevelProtocol::Tcp) => Ok(()),
24                _ => Err(String::from("Invalid protocol")),
25            },
26        },
27        IpVersion::V6 => match socket_option.socket_type {
28            SocketType::Raw => match socket_option.protocol {
29                Some(IpNextLevelProtocol::Icmpv6) => Ok(()),
30                Some(IpNextLevelProtocol::Tcp) => Ok(()),
31                Some(IpNextLevelProtocol::Udp) => Ok(()),
32                _ => Err(String::from("Invalid protocol")),
33            },
34            SocketType::Datagram => match socket_option.protocol {
35                Some(IpNextLevelProtocol::Icmpv6) => Ok(()),
36                Some(IpNextLevelProtocol::Udp) => Ok(()),
37                _ => Err(String::from("Invalid protocol")),
38            },
39            SocketType::Stream => match socket_option.protocol {
40                Some(IpNextLevelProtocol::Tcp) => Ok(()),
41                _ => Err(String::from("Invalid protocol")),
42            },
43        },
44    }
45}
46
47/// Receive all IPv4 or IPv6 packets passing through a network interface.
48pub struct ListenerSocket {
49    inner: SystemSocket,
50}
51
52impl ListenerSocket {
53    /// Constructs a new ListenerSocket.
54    pub fn new(
55        _socket_addr: SocketAddr,
56        ip_version: IpVersion,
57        protocol: Option<IpNextLevelProtocol>,
58        timeout: Option<Duration>,
59    ) -> io::Result<ListenerSocket> {
60        let socket = match ip_version {
61            IpVersion::V4 => match protocol {
62                Some(IpNextLevelProtocol::Icmp) => {
63                    SystemSocket::new(Domain::IPV4, Type::RAW, Some(Protocol::ICMPV4))?
64                }
65                Some(IpNextLevelProtocol::Tcp) => {
66                    SystemSocket::new(Domain::IPV4, Type::RAW, Some(Protocol::TCP))?
67                }
68                Some(IpNextLevelProtocol::Udp) => {
69                    SystemSocket::new(Domain::IPV4, Type::RAW, Some(Protocol::UDP))?
70                }
71                _ => SystemSocket::new(Domain::IPV4, Type::RAW, None)?,
72            },
73            IpVersion::V6 => match protocol {
74                Some(IpNextLevelProtocol::Icmpv6) => {
75                    SystemSocket::new(Domain::IPV6, Type::RAW, Some(Protocol::ICMPV6))?
76                }
77                Some(IpNextLevelProtocol::Tcp) => {
78                    SystemSocket::new(Domain::IPV6, Type::RAW, Some(Protocol::TCP))?
79                }
80                Some(IpNextLevelProtocol::Udp) => {
81                    SystemSocket::new(Domain::IPV6, Type::RAW, Some(Protocol::UDP))?
82                }
83                _ => SystemSocket::new(Domain::IPV6, Type::RAW, None)?,
84            },
85        };
86        if let Some(timeout) = timeout {
87            socket.set_read_timeout(Some(timeout))?;
88        }
89        //socket.bind(&socket_addr.into())?;
90        Ok(ListenerSocket { inner: socket })
91    }
92    /// Receive packet without source address.
93    pub fn receive_from(&self, buf: &mut Vec<u8>) -> io::Result<(usize, SocketAddr)> {
94        let recv_buf = unsafe { &mut *(buf.as_mut_slice() as *mut [u8] as *mut [MaybeUninit<u8>]) };
95        match self.inner.recv_from(recv_buf) {
96            Ok((packet_len, addr)) => match addr.as_socket() {
97                Some(socket_addr) => {
98                    return Ok((packet_len, socket_addr));
99                }
100                None => Err(io::Error::new(
101                    io::ErrorKind::Other,
102                    "Invalid socket address",
103                )),
104            },
105            Err(e) => Err(e),
106        }
107    }
108}