1use std::{io, mem::MaybeUninit, net::SocketAddr, time::Duration};
2
3use socket2::{Domain, Protocol, Socket as SystemSocket, Type};
4use xenet_packet::ip::IpNextLevelProtocol;
5
6use super::{IpVersion, SocketOption, SocketType};
7
8pub(crate) fn check_socket_option(socket_option: SocketOption) -> Result<(), String> {
9 match socket_option.ip_version {
10 IpVersion::V4 => match socket_option.socket_type {
11 SocketType::Raw => match socket_option.protocol {
12 Some(IpNextLevelProtocol::Icmp) => Ok(()),
13 Some(IpNextLevelProtocol::Tcp) => Ok(()),
14 Some(IpNextLevelProtocol::Udp) => Ok(()),
15 _ => Err(String::from("Invalid protocol")),
16 },
17 SocketType::Datagram => match socket_option.protocol {
18 Some(IpNextLevelProtocol::Icmp) => Ok(()),
19 Some(IpNextLevelProtocol::Udp) => Ok(()),
20 _ => Err(String::from("Invalid protocol")),
21 },
22 SocketType::Stream => match socket_option.protocol {
23 Some(IpNextLevelProtocol::Tcp) => Ok(()),
24 _ => Err(String::from("Invalid protocol")),
25 },
26 },
27 IpVersion::V6 => match socket_option.socket_type {
28 SocketType::Raw => match socket_option.protocol {
29 Some(IpNextLevelProtocol::Icmpv6) => Ok(()),
30 Some(IpNextLevelProtocol::Tcp) => Ok(()),
31 Some(IpNextLevelProtocol::Udp) => Ok(()),
32 _ => Err(String::from("Invalid protocol")),
33 },
34 SocketType::Datagram => match socket_option.protocol {
35 Some(IpNextLevelProtocol::Icmpv6) => Ok(()),
36 Some(IpNextLevelProtocol::Udp) => Ok(()),
37 _ => Err(String::from("Invalid protocol")),
38 },
39 SocketType::Stream => match socket_option.protocol {
40 Some(IpNextLevelProtocol::Tcp) => Ok(()),
41 _ => Err(String::from("Invalid protocol")),
42 },
43 },
44 }
45}
46
47pub struct ListenerSocket {
49 inner: SystemSocket,
50}
51
52impl ListenerSocket {
53 pub fn new(
55 _socket_addr: SocketAddr,
56 ip_version: IpVersion,
57 protocol: Option<IpNextLevelProtocol>,
58 timeout: Option<Duration>,
59 ) -> io::Result<ListenerSocket> {
60 let socket = match ip_version {
61 IpVersion::V4 => match protocol {
62 Some(IpNextLevelProtocol::Icmp) => {
63 SystemSocket::new(Domain::IPV4, Type::RAW, Some(Protocol::ICMPV4))?
64 }
65 Some(IpNextLevelProtocol::Tcp) => {
66 SystemSocket::new(Domain::IPV4, Type::RAW, Some(Protocol::TCP))?
67 }
68 Some(IpNextLevelProtocol::Udp) => {
69 SystemSocket::new(Domain::IPV4, Type::RAW, Some(Protocol::UDP))?
70 }
71 _ => SystemSocket::new(Domain::IPV4, Type::RAW, None)?,
72 },
73 IpVersion::V6 => match protocol {
74 Some(IpNextLevelProtocol::Icmpv6) => {
75 SystemSocket::new(Domain::IPV6, Type::RAW, Some(Protocol::ICMPV6))?
76 }
77 Some(IpNextLevelProtocol::Tcp) => {
78 SystemSocket::new(Domain::IPV6, Type::RAW, Some(Protocol::TCP))?
79 }
80 Some(IpNextLevelProtocol::Udp) => {
81 SystemSocket::new(Domain::IPV6, Type::RAW, Some(Protocol::UDP))?
82 }
83 _ => SystemSocket::new(Domain::IPV6, Type::RAW, None)?,
84 },
85 };
86 if let Some(timeout) = timeout {
87 socket.set_read_timeout(Some(timeout))?;
88 }
89 Ok(ListenerSocket { inner: socket })
91 }
92 pub fn receive_from(&self, buf: &mut Vec<u8>) -> io::Result<(usize, SocketAddr)> {
94 let recv_buf = unsafe { &mut *(buf.as_mut_slice() as *mut [u8] as *mut [MaybeUninit<u8>]) };
95 match self.inner.recv_from(recv_buf) {
96 Ok((packet_len, addr)) => match addr.as_socket() {
97 Some(socket_addr) => {
98 return Ok((packet_len, socket_addr));
99 }
100 None => Err(io::Error::new(
101 io::ErrorKind::Other,
102 "Invalid socket address",
103 )),
104 },
105 Err(e) => Err(e),
106 }
107 }
108}