xenet_socket/
windows.rs

1use socket2::SockAddr;
2use std::cmp::min;
3use std::io;
4use std::mem::{self, MaybeUninit};
5use std::net::{SocketAddr, UdpSocket};
6use std::ptr;
7use std::sync::Once;
8use std::time::Duration;
9
10#[allow(non_camel_case_types)]
11type c_int = i32;
12
13#[allow(non_camel_case_types)]
14type c_long = i32;
15
16type DWORD = u32;
17use windows_sys::Win32::Networking::WinSock::SIO_RCVALL;
18use windows_sys::Win32::System::Threading::INFINITE;
19
20#[allow(non_camel_case_types)]
21type u_long = u32;
22
23use windows_sys::Win32::Networking::WinSock::{self as sock, SOCKET, WSA_FLAG_NO_HANDLE_INHERIT};
24use windows_sys::Win32::Networking::WinSock::{
25    AF_INET, AF_INET6, IPPROTO_ICMP, IPPROTO_ICMPV6, IPPROTO_IP, IPPROTO_IPV6, IPPROTO_TCP,
26    IPPROTO_UDP,
27};
28
29pub(crate) const NO_INHERIT: c_int = 1 << (c_int::BITS - 1);
30pub(crate) const MAX_BUF_LEN: usize = <c_int>::max_value() as usize;
31
32use super::{IpVersion, SocketOption, SocketType};
33use xenet_packet::ip::IpNextLevelProtocol;
34
35pub fn check_socket_option(socket_option: SocketOption) -> Result<(), String> {
36    match socket_option.ip_version {
37        IpVersion::V4 => {
38            match socket_option.socket_type {
39                SocketType::Raw => {
40                    match socket_option.protocol {
41                        Some(IpNextLevelProtocol::Icmp) => Ok(()),
42                        Some(IpNextLevelProtocol::Tcp) => Err(String::from("TCP is not supported on IPv4 raw socket on Windows(Due to Winsock2 limitation))")),
43                        Some(IpNextLevelProtocol::Udp) => Ok(()),
44                        _ => Err(String::from("Invalid protocol")),
45                    }
46                }
47                SocketType::Datagram => {
48                    match socket_option.protocol {
49                        Some(IpNextLevelProtocol::Icmp) => Ok(()),
50                        Some(IpNextLevelProtocol::Udp) => Ok(()),
51                        _ => Err(String::from("Invalid protocol")),
52                    }
53                }
54                SocketType::Stream => {
55                    match socket_option.protocol {
56                        Some(IpNextLevelProtocol::Tcp) => Ok(()),
57                        _ => Err(String::from("Invalid protocol")),
58                    }
59                }
60            }
61        }
62        IpVersion::V6 => {
63            match socket_option.socket_type {
64                SocketType::Raw => {
65                    match socket_option.protocol {
66                        Some(IpNextLevelProtocol::Icmpv6) => Ok(()),
67                        Some(IpNextLevelProtocol::Tcp) => Err(String::from("TCP is not supported on IPv6 raw socket on Windows(Due to Winsock2 limitation))")),
68                        Some(IpNextLevelProtocol::Udp) => Ok(()),
69                        _ => Err(String::from("Invalid protocol")),
70                    }
71                }
72                SocketType::Datagram => {
73                    match socket_option.protocol {
74                        Some(IpNextLevelProtocol::Icmpv6) => Ok(()),
75                        Some(IpNextLevelProtocol::Udp) => Ok(()),
76                        _ => Err(String::from("Invalid protocol")),
77                    }
78                }
79                SocketType::Stream => {
80                    match socket_option.protocol {
81                        Some(IpNextLevelProtocol::Tcp) => Ok(()),
82                        _ => Err(String::from("Invalid protocol")),
83                    }
84                }
85            }
86        }
87    }
88}
89
90macro_rules! syscall {
91    ($fn: ident ( $($arg: expr),* $(,)* ), $err_test: path, $err_value: expr) => {{
92        #[allow(unused_unsafe)]
93        let res = unsafe { windows_sys::Win32::Networking::WinSock::$fn($($arg, )*) };
94        if $err_test(&res, &$err_value) {
95            Err(io::Error::last_os_error())
96        } else {
97            Ok(res)
98        }
99    }};
100}
101
102pub(crate) fn init_socket() {
103    static INIT: Once = Once::new();
104    INIT.call_once(|| {
105        let _ = UdpSocket::bind("127.0.0.1:34254");
106    });
107}
108
109pub(crate) fn ioctlsocket(socket: SOCKET, cmd: c_long, payload: &mut u_long) -> io::Result<()> {
110    syscall!(
111        ioctlsocket(socket, cmd, payload),
112        PartialEq::eq,
113        sock::SOCKET_ERROR
114    )
115    .map(|_| ())
116}
117
118pub(crate) fn create_socket(family: c_int, mut ty: c_int, protocol: c_int) -> io::Result<SOCKET> {
119    init_socket();
120    let flags = if ty & NO_INHERIT != 0 {
121        ty = ty & !NO_INHERIT;
122        WSA_FLAG_NO_HANDLE_INHERIT
123    } else {
124        0
125    };
126    syscall!(
127        WSASocketW(
128            family,
129            ty,
130            protocol,
131            ptr::null_mut(),
132            0,
133            sock::WSA_FLAG_OVERLAPPED | flags,
134        ),
135        PartialEq::eq,
136        sock::INVALID_SOCKET
137    )
138}
139
140pub(crate) fn bind(socket: SOCKET, addr: &SockAddr) -> io::Result<()> {
141    syscall!(bind(socket, addr.as_ptr(), addr.len()), PartialEq::ne, 0).map(|_| ())
142}
143
144#[allow(dead_code)]
145pub(crate) fn set_nonblocking(socket: SOCKET, nonblocking: bool) -> io::Result<()> {
146    let mut nonblocking = nonblocking as u_long;
147    ioctlsocket(socket, sock::FIONBIO, &mut nonblocking)
148}
149
150pub(crate) fn set_promiscuous(socket: SOCKET, promiscuous: bool) -> io::Result<()> {
151    let mut promiscuous = promiscuous as u_long;
152    ioctlsocket(socket, SIO_RCVALL as i32, &mut promiscuous)
153}
154
155pub(crate) unsafe fn setsockopt<T>(
156    socket: SOCKET,
157    level: c_int,
158    optname: i32,
159    optval: T,
160) -> io::Result<()> {
161    syscall!(
162        setsockopt(
163            socket,
164            level as i32,
165            optname,
166            (&optval as *const T).cast(),
167            mem::size_of::<T>() as c_int,
168        ),
169        PartialEq::eq,
170        sock::SOCKET_ERROR
171    )
172    .map(|_| ())
173}
174
175pub(crate) fn into_ms(duration: Option<Duration>) -> DWORD {
176    duration
177        .map(|duration| min(duration.as_millis(), INFINITE as u128) as DWORD)
178        .unwrap_or(0)
179}
180
181pub(crate) fn set_timeout_opt(
182    fd: SOCKET,
183    level: c_int,
184    optname: c_int,
185    duration: Option<Duration>,
186) -> io::Result<()> {
187    let duration = into_ms(duration);
188    unsafe { setsockopt(fd, level, optname, duration) }
189}
190
191pub(crate) fn recv_from(
192    socket: SOCKET,
193    buf: &mut [MaybeUninit<u8>],
194    flags: c_int,
195) -> io::Result<(usize, SockAddr)> {
196    unsafe {
197        SockAddr::try_init(|storage, addrlen| {
198            let res = syscall!(
199                recvfrom(
200                    socket,
201                    buf.as_mut_ptr().cast(),
202                    min(buf.len(), MAX_BUF_LEN) as c_int,
203                    flags,
204                    storage.cast(),
205                    addrlen,
206                ),
207                PartialEq::eq,
208                sock::SOCKET_ERROR
209            );
210            match res {
211                Ok(n) => Ok(n as usize),
212                Err(ref err) if err.raw_os_error() == Some(sock::WSAESHUTDOWN as i32) => Ok(0),
213                Err(err) => Err(err),
214            }
215        })
216    }
217}
218
219/// Receive all IPv4 or IPv6 packets passing through a network interface.
220pub struct ListenerSocket {
221    inner: SOCKET,
222}
223
224impl ListenerSocket {
225    pub fn new(
226        socket_addr: SocketAddr,
227        ip_version: IpVersion,
228        protocol: Option<IpNextLevelProtocol>,
229        timeout: Option<Duration>,
230    ) -> io::Result<ListenerSocket> {
231        let socket = match ip_version {
232            IpVersion::V4 => match protocol {
233                Some(IpNextLevelProtocol::Icmp) => {
234                    create_socket(AF_INET as i32, sock::SOCK_RAW, IPPROTO_ICMP)?
235                }
236                Some(IpNextLevelProtocol::Tcp) => {
237                    create_socket(AF_INET as i32, sock::SOCK_RAW, IPPROTO_TCP)?
238                }
239                Some(IpNextLevelProtocol::Udp) => {
240                    create_socket(AF_INET as i32, sock::SOCK_RAW, IPPROTO_UDP)?
241                }
242                _ => create_socket(AF_INET as i32, sock::SOCK_RAW, IPPROTO_IP)?,
243            },
244            IpVersion::V6 => match protocol {
245                Some(IpNextLevelProtocol::Icmpv6) => {
246                    create_socket(AF_INET6 as i32, sock::SOCK_RAW, IPPROTO_ICMPV6)?
247                }
248                Some(IpNextLevelProtocol::Tcp) => {
249                    create_socket(AF_INET6 as i32, sock::SOCK_RAW, IPPROTO_TCP)?
250                }
251                Some(IpNextLevelProtocol::Udp) => {
252                    create_socket(AF_INET6 as i32, sock::SOCK_RAW, IPPROTO_UDP)?
253                }
254                _ => create_socket(AF_INET6 as i32, sock::SOCK_RAW, IPPROTO_IPV6)?,
255            },
256        };
257        let sock_addr = SockAddr::from(socket_addr);
258        bind(socket, &sock_addr)?;
259        set_promiscuous(socket, true)?;
260        set_timeout_opt(socket, sock::SOL_SOCKET, sock::SO_RCVTIMEO, timeout)?;
261        Ok(ListenerSocket { inner: socket })
262    }
263    pub fn bind(&self, addr: &SockAddr) -> io::Result<()> {
264        bind(self.inner, addr)
265    }
266    pub fn receive_from(&self, buf: &mut Vec<u8>) -> io::Result<(usize, SocketAddr)> {
267        let recv_buf = unsafe { &mut *(buf.as_mut_slice() as *mut [u8] as *mut [MaybeUninit<u8>]) };
268        match recv_from(self.inner, recv_buf, 0) {
269            Ok((n, addr)) => match addr.as_socket() {
270                Some(socket_addr) => {
271                    return Ok((n, socket_addr));
272                }
273                None => Err(io::Error::new(
274                    io::ErrorKind::Other,
275                    "Invalid socket address",
276                )),
277            },
278            Err(e) => Err(e),
279        }
280    }
281}