1use socket2::SockAddr;
2use std::cmp::min;
3use std::io;
4use std::mem::{self, MaybeUninit};
5use std::net::{SocketAddr, UdpSocket};
6use std::ptr;
7use std::sync::Once;
8use std::time::Duration;
9
10#[allow(non_camel_case_types)]
11type c_int = i32;
12
13#[allow(non_camel_case_types)]
14type c_long = i32;
15
16type DWORD = u32;
17use windows_sys::Win32::Networking::WinSock::SIO_RCVALL;
18use windows_sys::Win32::System::Threading::INFINITE;
19
20#[allow(non_camel_case_types)]
21type u_long = u32;
22
23use windows_sys::Win32::Networking::WinSock::{self as sock, SOCKET, WSA_FLAG_NO_HANDLE_INHERIT};
24use windows_sys::Win32::Networking::WinSock::{
25 AF_INET, AF_INET6, IPPROTO_ICMP, IPPROTO_ICMPV6, IPPROTO_IP, IPPROTO_IPV6, IPPROTO_TCP,
26 IPPROTO_UDP,
27};
28
29pub(crate) const NO_INHERIT: c_int = 1 << (c_int::BITS - 1);
30pub(crate) const MAX_BUF_LEN: usize = <c_int>::max_value() as usize;
31
32use super::{IpVersion, SocketOption, SocketType};
33use xenet_packet::ip::IpNextLevelProtocol;
34
35pub fn check_socket_option(socket_option: SocketOption) -> Result<(), String> {
36 match socket_option.ip_version {
37 IpVersion::V4 => {
38 match socket_option.socket_type {
39 SocketType::Raw => {
40 match socket_option.protocol {
41 Some(IpNextLevelProtocol::Icmp) => Ok(()),
42 Some(IpNextLevelProtocol::Tcp) => Err(String::from("TCP is not supported on IPv4 raw socket on Windows(Due to Winsock2 limitation))")),
43 Some(IpNextLevelProtocol::Udp) => Ok(()),
44 _ => Err(String::from("Invalid protocol")),
45 }
46 }
47 SocketType::Datagram => {
48 match socket_option.protocol {
49 Some(IpNextLevelProtocol::Icmp) => Ok(()),
50 Some(IpNextLevelProtocol::Udp) => Ok(()),
51 _ => Err(String::from("Invalid protocol")),
52 }
53 }
54 SocketType::Stream => {
55 match socket_option.protocol {
56 Some(IpNextLevelProtocol::Tcp) => Ok(()),
57 _ => Err(String::from("Invalid protocol")),
58 }
59 }
60 }
61 }
62 IpVersion::V6 => {
63 match socket_option.socket_type {
64 SocketType::Raw => {
65 match socket_option.protocol {
66 Some(IpNextLevelProtocol::Icmpv6) => Ok(()),
67 Some(IpNextLevelProtocol::Tcp) => Err(String::from("TCP is not supported on IPv6 raw socket on Windows(Due to Winsock2 limitation))")),
68 Some(IpNextLevelProtocol::Udp) => Ok(()),
69 _ => Err(String::from("Invalid protocol")),
70 }
71 }
72 SocketType::Datagram => {
73 match socket_option.protocol {
74 Some(IpNextLevelProtocol::Icmpv6) => Ok(()),
75 Some(IpNextLevelProtocol::Udp) => Ok(()),
76 _ => Err(String::from("Invalid protocol")),
77 }
78 }
79 SocketType::Stream => {
80 match socket_option.protocol {
81 Some(IpNextLevelProtocol::Tcp) => Ok(()),
82 _ => Err(String::from("Invalid protocol")),
83 }
84 }
85 }
86 }
87 }
88}
89
90macro_rules! syscall {
91 ($fn: ident ( $($arg: expr),* $(,)* ), $err_test: path, $err_value: expr) => {{
92 #[allow(unused_unsafe)]
93 let res = unsafe { windows_sys::Win32::Networking::WinSock::$fn($($arg, )*) };
94 if $err_test(&res, &$err_value) {
95 Err(io::Error::last_os_error())
96 } else {
97 Ok(res)
98 }
99 }};
100}
101
102pub(crate) fn init_socket() {
103 static INIT: Once = Once::new();
104 INIT.call_once(|| {
105 let _ = UdpSocket::bind("127.0.0.1:34254");
106 });
107}
108
109pub(crate) fn ioctlsocket(socket: SOCKET, cmd: c_long, payload: &mut u_long) -> io::Result<()> {
110 syscall!(
111 ioctlsocket(socket, cmd, payload),
112 PartialEq::eq,
113 sock::SOCKET_ERROR
114 )
115 .map(|_| ())
116}
117
118pub(crate) fn create_socket(family: c_int, mut ty: c_int, protocol: c_int) -> io::Result<SOCKET> {
119 init_socket();
120 let flags = if ty & NO_INHERIT != 0 {
121 ty = ty & !NO_INHERIT;
122 WSA_FLAG_NO_HANDLE_INHERIT
123 } else {
124 0
125 };
126 syscall!(
127 WSASocketW(
128 family,
129 ty,
130 protocol,
131 ptr::null_mut(),
132 0,
133 sock::WSA_FLAG_OVERLAPPED | flags,
134 ),
135 PartialEq::eq,
136 sock::INVALID_SOCKET
137 )
138}
139
140pub(crate) fn bind(socket: SOCKET, addr: &SockAddr) -> io::Result<()> {
141 syscall!(bind(socket, addr.as_ptr(), addr.len()), PartialEq::ne, 0).map(|_| ())
142}
143
144#[allow(dead_code)]
145pub(crate) fn set_nonblocking(socket: SOCKET, nonblocking: bool) -> io::Result<()> {
146 let mut nonblocking = nonblocking as u_long;
147 ioctlsocket(socket, sock::FIONBIO, &mut nonblocking)
148}
149
150pub(crate) fn set_promiscuous(socket: SOCKET, promiscuous: bool) -> io::Result<()> {
151 let mut promiscuous = promiscuous as u_long;
152 ioctlsocket(socket, SIO_RCVALL as i32, &mut promiscuous)
153}
154
155pub(crate) unsafe fn setsockopt<T>(
156 socket: SOCKET,
157 level: c_int,
158 optname: i32,
159 optval: T,
160) -> io::Result<()> {
161 syscall!(
162 setsockopt(
163 socket,
164 level as i32,
165 optname,
166 (&optval as *const T).cast(),
167 mem::size_of::<T>() as c_int,
168 ),
169 PartialEq::eq,
170 sock::SOCKET_ERROR
171 )
172 .map(|_| ())
173}
174
175pub(crate) fn into_ms(duration: Option<Duration>) -> DWORD {
176 duration
177 .map(|duration| min(duration.as_millis(), INFINITE as u128) as DWORD)
178 .unwrap_or(0)
179}
180
181pub(crate) fn set_timeout_opt(
182 fd: SOCKET,
183 level: c_int,
184 optname: c_int,
185 duration: Option<Duration>,
186) -> io::Result<()> {
187 let duration = into_ms(duration);
188 unsafe { setsockopt(fd, level, optname, duration) }
189}
190
191pub(crate) fn recv_from(
192 socket: SOCKET,
193 buf: &mut [MaybeUninit<u8>],
194 flags: c_int,
195) -> io::Result<(usize, SockAddr)> {
196 unsafe {
197 SockAddr::try_init(|storage, addrlen| {
198 let res = syscall!(
199 recvfrom(
200 socket,
201 buf.as_mut_ptr().cast(),
202 min(buf.len(), MAX_BUF_LEN) as c_int,
203 flags,
204 storage.cast(),
205 addrlen,
206 ),
207 PartialEq::eq,
208 sock::SOCKET_ERROR
209 );
210 match res {
211 Ok(n) => Ok(n as usize),
212 Err(ref err) if err.raw_os_error() == Some(sock::WSAESHUTDOWN as i32) => Ok(0),
213 Err(err) => Err(err),
214 }
215 })
216 }
217}
218
219pub struct ListenerSocket {
221 inner: SOCKET,
222}
223
224impl ListenerSocket {
225 pub fn new(
226 socket_addr: SocketAddr,
227 ip_version: IpVersion,
228 protocol: Option<IpNextLevelProtocol>,
229 timeout: Option<Duration>,
230 ) -> io::Result<ListenerSocket> {
231 let socket = match ip_version {
232 IpVersion::V4 => match protocol {
233 Some(IpNextLevelProtocol::Icmp) => {
234 create_socket(AF_INET as i32, sock::SOCK_RAW, IPPROTO_ICMP)?
235 }
236 Some(IpNextLevelProtocol::Tcp) => {
237 create_socket(AF_INET as i32, sock::SOCK_RAW, IPPROTO_TCP)?
238 }
239 Some(IpNextLevelProtocol::Udp) => {
240 create_socket(AF_INET as i32, sock::SOCK_RAW, IPPROTO_UDP)?
241 }
242 _ => create_socket(AF_INET as i32, sock::SOCK_RAW, IPPROTO_IP)?,
243 },
244 IpVersion::V6 => match protocol {
245 Some(IpNextLevelProtocol::Icmpv6) => {
246 create_socket(AF_INET6 as i32, sock::SOCK_RAW, IPPROTO_ICMPV6)?
247 }
248 Some(IpNextLevelProtocol::Tcp) => {
249 create_socket(AF_INET6 as i32, sock::SOCK_RAW, IPPROTO_TCP)?
250 }
251 Some(IpNextLevelProtocol::Udp) => {
252 create_socket(AF_INET6 as i32, sock::SOCK_RAW, IPPROTO_UDP)?
253 }
254 _ => create_socket(AF_INET6 as i32, sock::SOCK_RAW, IPPROTO_IPV6)?,
255 },
256 };
257 let sock_addr = SockAddr::from(socket_addr);
258 bind(socket, &sock_addr)?;
259 set_promiscuous(socket, true)?;
260 set_timeout_opt(socket, sock::SOL_SOCKET, sock::SO_RCVTIMEO, timeout)?;
261 Ok(ListenerSocket { inner: socket })
262 }
263 pub fn bind(&self, addr: &SockAddr) -> io::Result<()> {
264 bind(self.inner, addr)
265 }
266 pub fn receive_from(&self, buf: &mut Vec<u8>) -> io::Result<(usize, SocketAddr)> {
267 let recv_buf = unsafe { &mut *(buf.as_mut_slice() as *mut [u8] as *mut [MaybeUninit<u8>]) };
268 match recv_from(self.inner, recv_buf, 0) {
269 Ok((n, addr)) => match addr.as_socket() {
270 Some(socket_addr) => {
271 return Ok((n, socket_addr));
272 }
273 None => Err(io::Error::new(
274 io::ErrorKind::Other,
275 "Invalid socket address",
276 )),
277 },
278 Err(e) => Err(e),
279 }
280 }
281}