socket2_plus/sys/
windows.rs

1// Copyright 2015 The Rust Project Developers.
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8
9use std::cmp::min;
10use std::io::{self, IoSlice};
11use std::marker::PhantomData;
12use std::mem::{self, size_of, MaybeUninit};
13use std::net::{self, Ipv4Addr, Ipv6Addr, Shutdown};
14use std::os::windows::io::{
15    AsRawSocket, AsSocket, BorrowedSocket, FromRawSocket, IntoRawSocket, OwnedSocket, RawSocket,
16};
17use std::path::Path;
18use std::sync::Once;
19use std::time::{Duration, Instant};
20use std::{process, ptr, slice};
21
22use windows_sys::Win32::Foundation::{SetHandleInformation, HANDLE, HANDLE_FLAG_INHERIT};
23#[cfg(feature = "all")]
24use windows_sys::Win32::Networking::WinSock::SO_PROTOCOL_INFOW;
25use windows_sys::Win32::Networking::WinSock::{
26    self, tcp_keepalive, FIONBIO, IN6_ADDR, IN6_ADDR_0, INVALID_SOCKET, IN_ADDR, IN_ADDR_0,
27    LPFN_WSARECVMSG, LPWSAOVERLAPPED_COMPLETION_ROUTINE, POLLERR, POLLHUP, POLLRDNORM, POLLWRNORM,
28    SD_BOTH, SD_RECEIVE, SD_SEND, SIO_GET_EXTENSION_FUNCTION_POINTER, SIO_KEEPALIVE_VALS,
29    SOCKET_ERROR, WSAEMSGSIZE, WSAESHUTDOWN, WSAID_WSARECVMSG, WSAPOLLFD, WSAPROTOCOL_INFOW,
30    WSA_FLAG_NO_HANDLE_INHERIT, WSA_FLAG_OVERLAPPED,
31};
32use windows_sys::Win32::System::Threading::INFINITE;
33use windows_sys::Win32::System::IO::OVERLAPPED;
34
35use crate::{MsgHdr, RecvFlags, SockAddr, TcpKeepalive, Type};
36
37#[allow(non_camel_case_types)]
38pub(crate) type c_int = std::os::raw::c_int;
39
40/// Fake MSG_TRUNC flag for the [`RecvFlags`] struct.
41///
42/// The flag is enabled when a `WSARecv[From]` call returns `WSAEMSGSIZE`. The
43/// value of the flag is defined by us.
44pub(crate) const MSG_TRUNC: c_int = 0x01;
45
46// Used in `Domain`.
47pub(crate) const AF_INET: c_int = windows_sys::Win32::Networking::WinSock::AF_INET as c_int;
48pub(crate) const AF_INET6: c_int = windows_sys::Win32::Networking::WinSock::AF_INET6 as c_int;
49pub(crate) const AF_UNIX: c_int = windows_sys::Win32::Networking::WinSock::AF_UNIX as c_int;
50pub(crate) const AF_UNSPEC: c_int = windows_sys::Win32::Networking::WinSock::AF_UNSPEC as c_int;
51// Used in `Type`.
52pub(crate) const SOCK_STREAM: c_int = windows_sys::Win32::Networking::WinSock::SOCK_STREAM as c_int;
53pub(crate) const SOCK_DGRAM: c_int = windows_sys::Win32::Networking::WinSock::SOCK_DGRAM as c_int;
54pub(crate) const SOCK_RAW: c_int = windows_sys::Win32::Networking::WinSock::SOCK_RAW as c_int;
55const SOCK_RDM: c_int = windows_sys::Win32::Networking::WinSock::SOCK_RDM as c_int;
56pub(crate) const SOCK_SEQPACKET: c_int =
57    windows_sys::Win32::Networking::WinSock::SOCK_SEQPACKET as c_int;
58// Used in `Protocol`.
59pub(crate) use windows_sys::Win32::Networking::WinSock::{
60    CMSGHDR as cmsghdr, IN6_PKTINFO as In6PktInfo, IN_PKTINFO as InPktInfo, IPPROTO_ICMP,
61    IPPROTO_ICMPV6, IPPROTO_TCP, IPPROTO_UDP, IPV6_PKTINFO, IP_PKTINFO, WSABUF,
62};
63// Used in `SockAddr`.
64pub(crate) use windows_sys::Win32::Networking::WinSock::{
65    SOCKADDR as sockaddr, SOCKADDR_IN as sockaddr_in, SOCKADDR_IN6 as sockaddr_in6,
66    SOCKADDR_STORAGE as sockaddr_storage,
67};
68#[allow(non_camel_case_types)]
69pub(crate) type sa_family_t = windows_sys::Win32::Networking::WinSock::ADDRESS_FAMILY;
70#[allow(non_camel_case_types)]
71pub(crate) type socklen_t = windows_sys::Win32::Networking::WinSock::socklen_t;
72// Used in `Socket`.
73#[cfg(feature = "all")]
74pub(crate) use windows_sys::Win32::Networking::WinSock::IP_HDRINCL;
75pub(crate) use windows_sys::Win32::Networking::WinSock::{
76    IPPROTO_IPV6, IPV6_ADD_MEMBERSHIP, IPV6_DROP_MEMBERSHIP, IPV6_MREQ as Ipv6Mreq,
77    IPV6_MULTICAST_HOPS, IPV6_MULTICAST_IF, IPV6_MULTICAST_LOOP, IPV6_RECVTCLASS,
78    IPV6_UNICAST_HOPS, IPV6_V6ONLY, IP_ADD_MEMBERSHIP, IP_ADD_SOURCE_MEMBERSHIP,
79    IP_DROP_MEMBERSHIP, IP_DROP_SOURCE_MEMBERSHIP, IP_MREQ as IpMreq,
80    IP_MREQ_SOURCE as IpMreqSource, IP_MULTICAST_IF, IP_MULTICAST_LOOP, IP_MULTICAST_TTL,
81    IP_RECVTOS, IP_TOS, IP_TTL, LINGER as linger, MSG_OOB, MSG_PEEK, SO_BROADCAST, SO_ERROR,
82    SO_KEEPALIVE, SO_LINGER, SO_OOBINLINE, SO_RCVBUF, SO_RCVTIMEO, SO_REUSEADDR, SO_SNDBUF,
83    SO_SNDTIMEO, SO_TYPE, TCP_NODELAY,
84};
85pub(crate) const IPPROTO_IP: c_int = windows_sys::Win32::Networking::WinSock::IPPROTO_IP as c_int;
86pub(crate) const SOL_SOCKET: c_int = windows_sys::Win32::Networking::WinSock::SOL_SOCKET as c_int;
87
88/// Type used in set/getsockopt to retrieve the `TCP_NODELAY` option.
89///
90/// NOTE: <https://docs.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-getsockopt>
91/// documents that options such as `TCP_NODELAY` and `SO_KEEPALIVE` expect a
92/// `BOOL` (alias for `c_int`, 4 bytes), however in practice this turns out to
93/// be false (or misleading) as a `BOOLEAN` (`c_uchar`, 1 byte) is returned by
94/// `getsockopt`.
95pub(crate) type Bool = windows_sys::Win32::Foundation::BOOLEAN;
96
97/// Maximum size of a buffer passed to system call like `recv` and `send`.
98const MAX_BUF_LEN: usize = c_int::MAX as usize;
99
100/// Helper macro to execute a system call that returns an `io::Result`.
101macro_rules! syscall {
102    ($fn: ident ( $($arg: expr),* $(,)* ), $err_test: path, $err_value: expr) => {{
103        #[allow(unused_unsafe)]
104        let res = unsafe { windows_sys::Win32::Networking::WinSock::$fn($($arg, )*) };
105        if $err_test(&res, &$err_value) {
106            Err(io::Error::last_os_error())
107        } else {
108            Ok(res)
109        }
110    }};
111}
112
113impl_debug!(
114    crate::Domain,
115    self::AF_INET,
116    self::AF_INET6,
117    self::AF_UNIX,
118    self::AF_UNSPEC,
119);
120
121/// Windows only API.
122impl Type {
123    /// Our custom flag to set `WSA_FLAG_NO_HANDLE_INHERIT` on socket creation.
124    /// Trying to mimic `Type::cloexec` on windows.
125    const NO_INHERIT: c_int = 1 << ((size_of::<c_int>() * 8) - 1); // Last bit.
126
127    /// Set `WSA_FLAG_NO_HANDLE_INHERIT` on the socket.
128    #[cfg(feature = "all")]
129    #[cfg_attr(docsrs, doc(cfg(all(windows, feature = "all"))))]
130    pub const fn no_inherit(self) -> Type {
131        self._no_inherit()
132    }
133
134    pub(crate) const fn _no_inherit(self) -> Type {
135        Type(self.0 | Type::NO_INHERIT)
136    }
137}
138
139impl_debug!(
140    crate::Type,
141    self::SOCK_STREAM,
142    self::SOCK_DGRAM,
143    self::SOCK_RAW,
144    self::SOCK_RDM,
145    self::SOCK_SEQPACKET,
146);
147
148impl_debug!(
149    crate::Protocol,
150    WinSock::IPPROTO_ICMP,
151    WinSock::IPPROTO_ICMPV6,
152    WinSock::IPPROTO_TCP,
153    WinSock::IPPROTO_UDP,
154);
155
156impl std::fmt::Debug for RecvFlags {
157    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
158        f.debug_struct("RecvFlags")
159            .field("is_truncated", &self.is_truncated())
160            .finish()
161    }
162}
163
164#[repr(transparent)]
165pub struct MaybeUninitSlice<'a> {
166    vec: WSABUF,
167    _lifetime: PhantomData<&'a mut [MaybeUninit<u8>]>,
168}
169
170unsafe impl<'a> Send for MaybeUninitSlice<'a> {}
171
172unsafe impl<'a> Sync for MaybeUninitSlice<'a> {}
173
174impl<'a> MaybeUninitSlice<'a> {
175    pub fn new(buf: &'a mut [MaybeUninit<u8>]) -> MaybeUninitSlice<'a> {
176        assert!(buf.len() <= u32::MAX as usize);
177        MaybeUninitSlice {
178            vec: WSABUF {
179                len: buf.len() as u32,
180                buf: buf.as_mut_ptr().cast(),
181            },
182            _lifetime: PhantomData,
183        }
184    }
185
186    pub fn as_slice(&self) -> &[MaybeUninit<u8>] {
187        unsafe { slice::from_raw_parts(self.vec.buf.cast(), self.vec.len as usize) }
188    }
189
190    pub fn as_mut_slice(&mut self) -> &mut [MaybeUninit<u8>] {
191        unsafe { slice::from_raw_parts_mut(self.vec.buf.cast(), self.vec.len as usize) }
192    }
193}
194
195// Used in `MsgHdr`.
196pub(crate) use windows_sys::Win32::Networking::WinSock::WSAMSG as msghdr;
197
198use crate::{CMsgHdrOps, MsgHdrInit};
199
200impl CMsgHdrOps for cmsghdr {
201    fn cmsg_data(&self) -> *mut u8 {
202        (self as *const _ as usize + cmsgdata_align(mem::size_of::<Self>())) as *mut u8
203    }
204}
205
206pub(crate) const fn _cmsg_space(length: usize) -> usize {
207    cmsgdata_align(mem::size_of::<cmsghdr>() + cmsghdr_align(length))
208}
209
210// Helpers functions for `WinSock::WSAMSG` and `WinSock::CMSGHDR` are based on C macros from
211// https://github.com/microsoft/win32metadata/blob/main/generation/WinSDK/RecompiledIdlHeaders/shared/ws2def.h#L741
212const fn cmsghdr_align(length: usize) -> usize {
213    (length + mem::align_of::<cmsghdr>() - 1) & !(mem::align_of::<cmsghdr>() - 1)
214}
215
216const fn cmsgdata_align(length: usize) -> usize {
217    (length + mem::align_of::<usize>() - 1) & !(mem::align_of::<usize>() - 1)
218}
219
220use crate::MsgHdrOps;
221
222impl MsgHdrOps for msghdr {
223    fn cmsg_first_hdr(&self) -> *mut cmsghdr {
224        if self.Control.len as usize >= mem::size_of::<cmsghdr>() {
225            self.Control.buf as *mut cmsghdr
226        } else {
227            ptr::null_mut::<cmsghdr>()
228        }
229    }
230
231    fn cmsg_next_hdr(&self, cmsg: &cmsghdr) -> *mut cmsghdr {
232        let next = (cmsg as *const _ as usize + cmsghdr_align(cmsg.cmsg_len)) as *mut cmsghdr;
233
234        // check if the end of the next cmsg overshoots the buf.
235        let max = self.Control.buf as usize + self.Control.len as usize;
236        if unsafe { next.offset(1) } as usize > max {
237            ptr::null_mut()
238        } else {
239            next
240        }
241    }
242}
243
244pub(crate) fn set_msghdr_name(msg: &mut msghdr, name: &SockAddr) {
245    msg.name = name.as_ptr() as *mut _;
246    msg.namelen = name.len();
247}
248
249pub(crate) fn set_msghdr_iov(msg: &mut msghdr, ptr: *mut WSABUF, len: usize) {
250    msg.lpBuffers = ptr;
251    msg.dwBufferCount = min(len, u32::MAX as usize) as u32;
252}
253
254pub(crate) fn set_msghdr_control(msg: &mut msghdr, ptr: *mut u8, len: usize) {
255    msg.Control.buf = ptr;
256    msg.Control.len = len as u32;
257}
258
259pub(crate) fn set_msghdr_flags(msg: &mut msghdr, flags: c_int) {
260    msg.dwFlags = flags as u32;
261}
262
263pub(crate) fn msghdr_flags(msg: &msghdr) -> RecvFlags {
264    RecvFlags(msg.dwFlags as c_int)
265}
266
267pub(crate) fn msghdr_control_len(msg: &msghdr) -> usize {
268    msg.Control.len as _
269}
270
271fn init() {
272    static INIT: Once = Once::new();
273
274    INIT.call_once(|| {
275        // Initialize winsock through the standard library by just creating a
276        // dummy socket. Whether this is successful or not we drop the result as
277        // libstd will be sure to have initialized winsock.
278        let _ = net::UdpSocket::bind("127.0.0.1:34254");
279    });
280}
281
282pub(crate) type Socket = windows_sys::Win32::Networking::WinSock::SOCKET;
283
284pub(crate) unsafe fn socket_from_raw(socket: Socket) -> crate::socket::Inner {
285    crate::socket::Inner::from_raw_socket(socket as RawSocket)
286}
287
288pub(crate) fn socket_as_raw(socket: &crate::socket::Inner) -> Socket {
289    socket.as_raw_socket() as Socket
290}
291
292pub(crate) fn socket_into_raw(socket: crate::socket::Inner) -> Socket {
293    socket.into_raw_socket() as Socket
294}
295
296pub(crate) fn socket(family: c_int, mut ty: c_int, protocol: c_int) -> io::Result<Socket> {
297    init();
298
299    // Check if we set our custom flag.
300    let flags = if ty & Type::NO_INHERIT != 0 {
301        ty = ty & !Type::NO_INHERIT;
302        WSA_FLAG_NO_HANDLE_INHERIT
303    } else {
304        0
305    };
306
307    syscall!(
308        WSASocketW(
309            family,
310            ty,
311            protocol,
312            ptr::null_mut(),
313            0,
314            WSA_FLAG_OVERLAPPED | flags,
315        ),
316        PartialEq::eq,
317        INVALID_SOCKET
318    )
319}
320
321pub(crate) fn bind(socket: Socket, addr: &SockAddr) -> io::Result<()> {
322    syscall!(bind(socket, addr.as_ptr(), addr.len()), PartialEq::ne, 0).map(|_| ())
323}
324
325pub(crate) fn connect(socket: Socket, addr: &SockAddr) -> io::Result<()> {
326    syscall!(connect(socket, addr.as_ptr(), addr.len()), PartialEq::ne, 0).map(|_| ())
327}
328
329pub(crate) fn poll_connect(socket: &crate::Socket, timeout: Duration) -> io::Result<()> {
330    let start = Instant::now();
331
332    let mut fd_array = WSAPOLLFD {
333        fd: socket.as_raw(),
334        events: (POLLRDNORM | POLLWRNORM) as i16,
335        revents: 0,
336    };
337
338    loop {
339        let elapsed = start.elapsed();
340        if elapsed >= timeout {
341            return Err(io::ErrorKind::TimedOut.into());
342        }
343
344        let timeout = (timeout - elapsed).as_millis();
345        let timeout = clamp(timeout, 1, c_int::MAX as u128) as c_int;
346
347        match syscall!(
348            WSAPoll(&mut fd_array, 1, timeout),
349            PartialEq::eq,
350            SOCKET_ERROR
351        ) {
352            Ok(0) => return Err(io::ErrorKind::TimedOut.into()),
353            Ok(_) => {
354                // Error or hang up indicates an error (or failure to connect).
355                if (fd_array.revents & POLLERR as i16) != 0
356                    || (fd_array.revents & POLLHUP as i16) != 0
357                {
358                    match socket.take_error() {
359                        Ok(Some(err)) => return Err(err),
360                        Ok(None) => {
361                            return Err(io::Error::new(
362                                io::ErrorKind::Other,
363                                "no error set after POLLHUP",
364                            ))
365                        }
366                        Err(err) => return Err(err),
367                    }
368                }
369                return Ok(());
370            }
371            // Got interrupted, try again.
372            Err(ref err) if err.kind() == io::ErrorKind::Interrupted => continue,
373            Err(err) => return Err(err),
374        }
375    }
376}
377
378// TODO: use clamp from std lib, stable since 1.50.
379fn clamp<T>(value: T, min: T, max: T) -> T
380where
381    T: Ord,
382{
383    if value <= min {
384        min
385    } else if value >= max {
386        max
387    } else {
388        value
389    }
390}
391
392pub(crate) fn listen(socket: Socket, backlog: c_int) -> io::Result<()> {
393    syscall!(listen(socket, backlog), PartialEq::ne, 0).map(|_| ())
394}
395
396pub(crate) fn accept(socket: Socket) -> io::Result<(Socket, SockAddr)> {
397    // Safety: `accept` initialises the `SockAddr` for us.
398    unsafe {
399        SockAddr::try_init(|storage, len| {
400            syscall!(
401                accept(socket, storage.cast(), len),
402                PartialEq::eq,
403                INVALID_SOCKET
404            )
405        })
406    }
407}
408
409pub(crate) fn getsockname(socket: Socket) -> io::Result<SockAddr> {
410    // Safety: `getsockname` initialises the `SockAddr` for us.
411    unsafe {
412        SockAddr::try_init(|storage, len| {
413            syscall!(
414                getsockname(socket, storage.cast(), len),
415                PartialEq::eq,
416                SOCKET_ERROR
417            )
418        })
419    }
420    .map(|(_, addr)| addr)
421}
422
423pub(crate) fn getpeername(socket: Socket) -> io::Result<SockAddr> {
424    // Safety: `getpeername` initialises the `SockAddr` for us.
425    unsafe {
426        SockAddr::try_init(|storage, len| {
427            syscall!(
428                getpeername(socket, storage.cast(), len),
429                PartialEq::eq,
430                SOCKET_ERROR
431            )
432        })
433    }
434    .map(|(_, addr)| addr)
435}
436
437pub(crate) fn try_clone(socket: Socket) -> io::Result<Socket> {
438    let mut info: MaybeUninit<WSAPROTOCOL_INFOW> = MaybeUninit::uninit();
439    syscall!(
440        // NOTE: `process.id` is the same as `GetCurrentProcessId`.
441        WSADuplicateSocketW(socket, process::id(), info.as_mut_ptr()),
442        PartialEq::eq,
443        SOCKET_ERROR
444    )?;
445    // Safety: `WSADuplicateSocketW` intialised `info` for us.
446    let mut info = unsafe { info.assume_init() };
447
448    syscall!(
449        WSASocketW(
450            info.iAddressFamily,
451            info.iSocketType,
452            info.iProtocol,
453            &mut info,
454            0,
455            WSA_FLAG_OVERLAPPED | WSA_FLAG_NO_HANDLE_INHERIT,
456        ),
457        PartialEq::eq,
458        INVALID_SOCKET
459    )
460}
461
462pub(crate) fn set_nonblocking(socket: Socket, nonblocking: bool) -> io::Result<()> {
463    let mut nonblocking = if nonblocking { 1 } else { 0 };
464    ioctlsocket(socket, FIONBIO, &mut nonblocking)
465}
466
467pub(crate) fn shutdown(socket: Socket, how: Shutdown) -> io::Result<()> {
468    let how = match how {
469        Shutdown::Write => SD_SEND,
470        Shutdown::Read => SD_RECEIVE,
471        Shutdown::Both => SD_BOTH,
472    } as i32;
473    syscall!(shutdown(socket, how), PartialEq::eq, SOCKET_ERROR).map(|_| ())
474}
475
476pub(crate) fn recv(socket: Socket, buf: &mut [MaybeUninit<u8>], flags: c_int) -> io::Result<usize> {
477    let res = syscall!(
478        recv(
479            socket,
480            buf.as_mut_ptr().cast(),
481            min(buf.len(), MAX_BUF_LEN) as c_int,
482            flags,
483        ),
484        PartialEq::eq,
485        SOCKET_ERROR
486    );
487    match res {
488        Ok(n) => Ok(n as usize),
489        Err(ref err) if err.raw_os_error() == Some(WSAESHUTDOWN as i32) => Ok(0),
490        Err(err) => Err(err),
491    }
492}
493
494pub(crate) fn recv_vectored(
495    socket: Socket,
496    bufs: &mut [crate::MaybeUninitSlice<'_>],
497    flags: c_int,
498) -> io::Result<(usize, RecvFlags)> {
499    let mut nread = 0;
500    let mut flags = flags as u32;
501    let res = syscall!(
502        WSARecv(
503            socket,
504            bufs.as_mut_ptr().cast(),
505            min(bufs.len(), u32::MAX as usize) as u32,
506            &mut nread,
507            &mut flags,
508            ptr::null_mut(),
509            None,
510        ),
511        PartialEq::eq,
512        SOCKET_ERROR
513    );
514    match res {
515        Ok(_) => Ok((nread as usize, RecvFlags(0))),
516        Err(ref err) if err.raw_os_error() == Some(WSAESHUTDOWN as i32) => Ok((0, RecvFlags(0))),
517        Err(ref err) if err.raw_os_error() == Some(WSAEMSGSIZE as i32) => {
518            Ok((nread as usize, RecvFlags(MSG_TRUNC)))
519        }
520        Err(err) => Err(err),
521    }
522}
523
524pub(crate) fn recv_from(
525    socket: Socket,
526    buf: &mut [MaybeUninit<u8>],
527    flags: c_int,
528) -> io::Result<(usize, SockAddr)> {
529    // Safety: `recvfrom` initialises the `SockAddr` for us.
530    unsafe {
531        SockAddr::try_init(|storage, addrlen| {
532            let res = syscall!(
533                recvfrom(
534                    socket,
535                    buf.as_mut_ptr().cast(),
536                    min(buf.len(), MAX_BUF_LEN) as c_int,
537                    flags,
538                    storage.cast(),
539                    addrlen,
540                ),
541                PartialEq::eq,
542                SOCKET_ERROR
543            );
544            match res {
545                Ok(n) => Ok(n as usize),
546                Err(ref err) if err.raw_os_error() == Some(WSAESHUTDOWN as i32) => Ok(0),
547                Err(err) => Err(err),
548            }
549        })
550    }
551}
552
553pub(crate) fn peek_sender(socket: Socket) -> io::Result<SockAddr> {
554    // Safety: `recvfrom` initialises the `SockAddr` for us.
555    let ((), sender) = unsafe {
556        SockAddr::try_init(|storage, addrlen| {
557            let res = syscall!(
558                recvfrom(
559                    socket,
560                    // Windows *appears* not to care if you pass a null pointer.
561                    ptr::null_mut(),
562                    0,
563                    MSG_PEEK,
564                    storage.cast(),
565                    addrlen,
566                ),
567                PartialEq::eq,
568                SOCKET_ERROR
569            );
570            match res {
571                Ok(_n) => Ok(()),
572                Err(e) => match e.raw_os_error() {
573                    Some(code) if code == (WSAESHUTDOWN as i32) || code == (WSAEMSGSIZE as i32) => {
574                        Ok(())
575                    }
576                    _ => Err(e),
577                },
578            }
579        })
580    }?;
581
582    Ok(sender)
583}
584
585pub(crate) fn recv_from_vectored(
586    socket: Socket,
587    bufs: &mut [crate::MaybeUninitSlice<'_>],
588    flags: c_int,
589) -> io::Result<(usize, RecvFlags, SockAddr)> {
590    // Safety: `recvfrom` initialises the `SockAddr` for us.
591    unsafe {
592        SockAddr::try_init(|storage, addrlen| {
593            let mut nread = 0;
594            let mut flags = flags as u32;
595            let res = syscall!(
596                WSARecvFrom(
597                    socket,
598                    bufs.as_mut_ptr().cast(),
599                    min(bufs.len(), u32::MAX as usize) as u32,
600                    &mut nread,
601                    &mut flags,
602                    storage.cast(),
603                    addrlen,
604                    ptr::null_mut(),
605                    None,
606                ),
607                PartialEq::eq,
608                SOCKET_ERROR
609            );
610            match res {
611                Ok(_) => Ok((nread as usize, RecvFlags(0))),
612                Err(ref err) if err.raw_os_error() == Some(WSAESHUTDOWN as i32) => {
613                    Ok((nread as usize, RecvFlags(0)))
614                }
615                Err(ref err) if err.raw_os_error() == Some(WSAEMSGSIZE as i32) => {
616                    Ok((nread as usize, RecvFlags(MSG_TRUNC)))
617                }
618                Err(err) => Err(err),
619            }
620        })
621    }
622    .map(|((n, recv_flags), addr)| (n, recv_flags, addr))
623}
624
625pub(crate) fn send(socket: Socket, buf: &[u8], flags: c_int) -> io::Result<usize> {
626    syscall!(
627        send(
628            socket,
629            buf.as_ptr().cast(),
630            min(buf.len(), MAX_BUF_LEN) as c_int,
631            flags,
632        ),
633        PartialEq::eq,
634        SOCKET_ERROR
635    )
636    .map(|n| n as usize)
637}
638
639pub(crate) fn send_vectored(
640    socket: Socket,
641    bufs: &[IoSlice<'_>],
642    flags: c_int,
643) -> io::Result<usize> {
644    let mut nsent = 0;
645    syscall!(
646        WSASend(
647            socket,
648            // FIXME: From the `WSASend` docs [1]:
649            // > For a Winsock application, once the WSASend function is called,
650            // > the system owns these buffers and the application may not
651            // > access them.
652            //
653            // So what we're doing is actually UB as `bufs` needs to be `&mut
654            // [IoSlice<'_>]`.
655            //
656            // Tracking issue: https://github.com/rust-lang/socket2-rs/issues/129.
657            //
658            // NOTE: `send_to_vectored` has the same problem.
659            //
660            // [1] https://docs.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-wsasend
661            bufs.as_ptr() as *mut _,
662            min(bufs.len(), u32::MAX as usize) as u32,
663            &mut nsent,
664            flags as u32,
665            std::ptr::null_mut(),
666            None,
667        ),
668        PartialEq::eq,
669        SOCKET_ERROR
670    )
671    .map(|_| nsent as usize)
672}
673
674pub(crate) fn send_to(
675    socket: Socket,
676    buf: &[u8],
677    addr: &SockAddr,
678    flags: c_int,
679) -> io::Result<usize> {
680    syscall!(
681        sendto(
682            socket,
683            buf.as_ptr().cast(),
684            min(buf.len(), MAX_BUF_LEN) as c_int,
685            flags,
686            addr.as_ptr(),
687            addr.len(),
688        ),
689        PartialEq::eq,
690        SOCKET_ERROR
691    )
692    .map(|n| n as usize)
693}
694
695pub(crate) fn send_to_vectored(
696    socket: Socket,
697    bufs: &[IoSlice<'_>],
698    addr: &SockAddr,
699    flags: c_int,
700) -> io::Result<usize> {
701    let mut nsent = 0;
702    syscall!(
703        WSASendTo(
704            socket,
705            // FIXME: Same problem as in `send_vectored`.
706            bufs.as_ptr() as *mut _,
707            bufs.len().min(u32::MAX as usize) as u32,
708            &mut nsent,
709            flags as u32,
710            addr.as_ptr(),
711            addr.len(),
712            ptr::null_mut(),
713            None,
714        ),
715        PartialEq::eq,
716        SOCKET_ERROR
717    )
718    .map(|_| nsent as usize)
719}
720
721pub(crate) fn sendmsg(socket: Socket, msg: &MsgHdr<'_, '_, '_>, flags: c_int) -> io::Result<usize> {
722    let mut nsent = 0;
723    syscall!(
724        WSASendMsg(
725            socket,
726            &msg.inner,
727            flags as u32,
728            &mut nsent,
729            ptr::null_mut(),
730            None,
731        ),
732        PartialEq::eq,
733        SOCKET_ERROR
734    )
735    .map(|_| nsent as usize)
736}
737
738pub(crate) type WSARecvMsgExtension = unsafe extern "system" fn(
739    s: Socket,
740    lpMsg: *mut msghdr,
741    lpdwNumberOfBytesRecvd: *mut u32,
742    lpOverlapped: *mut OVERLAPPED,
743    lpCompletionRoutine: LPWSAOVERLAPPED_COMPLETION_ROUTINE,
744) -> i32;
745
746/// Find the WSARECVMSG function pointer
747//
748// This implementation is copied from:
749// https://github.com/pixsper/socket-pktinfo/blob/3845f44eef707eaa3d34f9d4bc4ebcb6dc9c5959/src/win.rs#L44
750pub(crate) fn locate_wsarecvmsg(socket: Socket) -> io::Result<WSARecvMsgExtension> {
751    let mut fn_pointer: usize = 0;
752    let mut byte_len: u32 = 0;
753
754    let r = unsafe {
755        WinSock::WSAIoctl(
756            socket as _,
757            SIO_GET_EXTENSION_FUNCTION_POINTER,
758            &WSAID_WSARECVMSG as *const _ as *mut _,
759            mem::size_of_val(&WSAID_WSARECVMSG) as u32,
760            &mut fn_pointer as *const _ as *mut _,
761            mem::size_of_val(&fn_pointer) as u32,
762            &mut byte_len,
763            ptr::null_mut(),
764            None,
765        )
766    };
767
768    if r != 0 {
769        return Err(io::Error::last_os_error());
770    }
771
772    if mem::size_of::<LPFN_WSARECVMSG>() != byte_len as _ {
773        return Err(io::Error::new(
774            io::ErrorKind::Other,
775            "Locating fn pointer to WSARecvMsg returned different expected bytes",
776        ));
777    }
778    let cast_to_fn: LPFN_WSARECVMSG = unsafe { mem::transmute(fn_pointer) };
779
780    match cast_to_fn {
781        None => Err(io::Error::new(
782            io::ErrorKind::Other,
783            "WSARecvMsg extension not found",
784        )),
785        Some(extension) => Ok(extension),
786    }
787}
788
789/// `recvmsg` with fully initialized buffers.
790pub(crate) fn recvmsg_init(
791    wsarecvmsg: WSARecvMsgExtension,
792    fd: Socket,
793    msg: &mut MsgHdrInit<'_, '_, '_>,
794    _flags: c_int,
795) -> io::Result<usize> {
796    let mut read_bytes = 0;
797    let error_code = unsafe {
798        (wsarecvmsg)(
799            fd as _,
800            &mut msg.inner,
801            &mut read_bytes,
802            std::ptr::null_mut(),
803            None,
804        )
805    };
806
807    if error_code != 0 {
808        return Err(io::Error::last_os_error());
809    }
810
811    if let Some(src) = msg.src.as_mut() {
812        // SAFETY: `msg.inner.namelen` has been update properly in the success case.
813        unsafe {
814            src.set_length(msg.inner.namelen as socklen_t);
815        }
816    }
817
818    Ok(read_bytes as usize)
819}
820
821/// Wrapper around `getsockopt` to deal with platform specific timeouts.
822pub(crate) fn timeout_opt(fd: Socket, lvl: c_int, name: i32) -> io::Result<Option<Duration>> {
823    unsafe { getsockopt(fd, lvl, name).map(from_ms) }
824}
825
826fn from_ms(duration: u32) -> Option<Duration> {
827    if duration == 0 {
828        None
829    } else {
830        let secs = duration / 1000;
831        let nsec = (duration % 1000) * 1000000;
832        Some(Duration::new(secs as u64, nsec as u32))
833    }
834}
835
836/// Wrapper around `setsockopt` to deal with platform specific timeouts.
837pub(crate) fn set_timeout_opt(
838    socket: Socket,
839    level: c_int,
840    optname: i32,
841    duration: Option<Duration>,
842) -> io::Result<()> {
843    let duration = into_ms(duration);
844    unsafe { setsockopt(socket, level, optname, duration) }
845}
846
847fn into_ms(duration: Option<Duration>) -> u32 {
848    // Note that a duration is a (u64, u32) (seconds, nanoseconds) pair, and the
849    // timeouts in windows APIs are typically u32 milliseconds. To translate, we
850    // have two pieces to take care of:
851    //
852    // * Nanosecond precision is rounded up
853    // * Greater than u32::MAX milliseconds (50 days) is rounded up to
854    //   INFINITE (never time out).
855    duration.map_or(0, |duration| {
856        min(duration.as_millis(), INFINITE as u128) as u32
857    })
858}
859
860pub(crate) fn set_tcp_keepalive(socket: Socket, keepalive: &TcpKeepalive) -> io::Result<()> {
861    let mut keepalive = tcp_keepalive {
862        onoff: 1,
863        keepalivetime: into_ms(keepalive.time),
864        keepaliveinterval: into_ms(keepalive.interval),
865    };
866    let mut out = 0;
867    syscall!(
868        WSAIoctl(
869            socket,
870            SIO_KEEPALIVE_VALS,
871            &mut keepalive as *mut _ as *mut _,
872            size_of::<tcp_keepalive>() as _,
873            ptr::null_mut(),
874            0,
875            &mut out,
876            ptr::null_mut(),
877            None,
878        ),
879        PartialEq::eq,
880        SOCKET_ERROR
881    )
882    .map(|_| ())
883}
884
885/// Caller must ensure `T` is the correct type for `level` and `optname`.
886// NOTE: `optname` is actually `i32`, but all constants are `u32`.
887pub(crate) unsafe fn getsockopt<T>(socket: Socket, level: c_int, optname: i32) -> io::Result<T> {
888    let mut optval: MaybeUninit<T> = MaybeUninit::uninit();
889    let mut optlen = mem::size_of::<T>() as c_int;
890    syscall!(
891        getsockopt(
892            socket,
893            level as i32,
894            optname,
895            optval.as_mut_ptr().cast(),
896            &mut optlen,
897        ),
898        PartialEq::eq,
899        SOCKET_ERROR
900    )
901    .map(|_| {
902        debug_assert_eq!(optlen as usize, mem::size_of::<T>());
903        // Safety: `getsockopt` initialised `optval` for us.
904        optval.assume_init()
905    })
906}
907
908/// Caller must ensure `T` is the correct type for `level` and `optname`.
909// NOTE: `optname` is actually `i32`, but all constants are `u32`.
910pub(crate) unsafe fn setsockopt<T>(
911    socket: Socket,
912    level: c_int,
913    optname: i32,
914    optval: T,
915) -> io::Result<()> {
916    syscall!(
917        setsockopt(
918            socket,
919            level as i32,
920            optname,
921            (&optval as *const T).cast(),
922            mem::size_of::<T>() as c_int,
923        ),
924        PartialEq::eq,
925        SOCKET_ERROR
926    )
927    .map(|_| ())
928}
929
930fn ioctlsocket(socket: Socket, cmd: i32, payload: &mut u32) -> io::Result<()> {
931    syscall!(
932        ioctlsocket(socket, cmd, payload),
933        PartialEq::eq,
934        SOCKET_ERROR
935    )
936    .map(|_| ())
937}
938
939pub(crate) fn to_in_addr(addr: &Ipv4Addr) -> IN_ADDR {
940    IN_ADDR {
941        S_un: IN_ADDR_0 {
942            // `S_un` is stored as BE on all machines, and the array is in BE
943            // order. So the native endian conversion method is used so that
944            // it's never swapped.
945            S_addr: u32::from_ne_bytes(addr.octets()),
946        },
947    }
948}
949
950pub(crate) fn from_in_addr(in_addr: IN_ADDR) -> Ipv4Addr {
951    Ipv4Addr::from(unsafe { in_addr.S_un.S_addr }.to_ne_bytes())
952}
953
954pub(crate) fn to_in6_addr(addr: &Ipv6Addr) -> IN6_ADDR {
955    IN6_ADDR {
956        u: IN6_ADDR_0 {
957            Byte: addr.octets(),
958        },
959    }
960}
961
962pub(crate) fn from_in6_addr(addr: IN6_ADDR) -> Ipv6Addr {
963    Ipv6Addr::from(unsafe { addr.u.Byte })
964}
965
966pub(crate) fn to_mreqn(
967    multiaddr: &Ipv4Addr,
968    interface: &crate::socket::InterfaceIndexOrAddress,
969) -> IpMreq {
970    IpMreq {
971        imr_multiaddr: to_in_addr(multiaddr),
972        // Per https://docs.microsoft.com/en-us/windows/win32/api/ws2ipdef/ns-ws2ipdef-ip_mreq#members:
973        //
974        // imr_interface
975        //
976        // The local IPv4 address of the interface or the interface index on
977        // which the multicast group should be joined or dropped. This value is
978        // in network byte order. If this member specifies an IPv4 address of
979        // 0.0.0.0, the default IPv4 multicast interface is used.
980        //
981        // To use an interface index of 1 would be the same as an IP address of
982        // 0.0.0.1.
983        imr_interface: match interface {
984            crate::socket::InterfaceIndexOrAddress::Index(interface) => {
985                to_in_addr(&(*interface).into())
986            }
987            crate::socket::InterfaceIndexOrAddress::Address(interface) => to_in_addr(interface),
988        },
989    }
990}
991
992#[allow(unsafe_op_in_unsafe_fn)]
993pub(crate) fn unix_sockaddr(path: &Path) -> io::Result<SockAddr> {
994    // SAFETY: a `sockaddr_storage` of all zeros is valid.
995    let mut storage = unsafe { mem::zeroed::<sockaddr_storage>() };
996    let len = {
997        let storage: &mut windows_sys::Win32::Networking::WinSock::SOCKADDR_UN =
998            unsafe { &mut *(&mut storage as *mut sockaddr_storage).cast() };
999
1000        // Windows expects a UTF-8 path here even though Windows paths are
1001        // usually UCS-2 encoded. If Rust exposed OsStr's Wtf8 encoded
1002        // buffer, this could be used directly, relying on Windows to
1003        // validate the path, but Rust hides this implementation detail.
1004        //
1005        // See <https://github.com/rust-lang/rust/pull/95290>.
1006        let bytes = path
1007            .to_str()
1008            .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "path must be valid UTF-8"))?
1009            .as_bytes();
1010
1011        // Windows appears to allow non-null-terminated paths, but this is
1012        // not documented, so do not rely on it yet.
1013        //
1014        // See <https://github.com/rust-lang/socket2/issues/331>.
1015        if bytes.len() >= storage.sun_path.len() {
1016            return Err(io::Error::new(
1017                io::ErrorKind::InvalidInput,
1018                "path must be shorter than SUN_LEN",
1019            ));
1020        }
1021
1022        storage.sun_family = crate::sys::AF_UNIX as sa_family_t;
1023        // `storage` was initialized to zero above, so the path is
1024        // already null terminated.
1025        storage.sun_path[..bytes.len()].copy_from_slice(bytes);
1026
1027        let base = storage as *const _ as usize;
1028        let path = &storage.sun_path as *const _ as usize;
1029        let sun_path_offset = path - base;
1030        sun_path_offset + bytes.len() + 1
1031    };
1032    Ok(unsafe { SockAddr::new(storage, len as socklen_t) })
1033}
1034
1035/// Windows only API.
1036impl crate::Socket {
1037    /// Sets `HANDLE_FLAG_INHERIT` using `SetHandleInformation`.
1038    #[cfg(feature = "all")]
1039    #[cfg_attr(docsrs, doc(cfg(all(windows, feature = "all"))))]
1040    pub fn set_no_inherit(&self, no_inherit: bool) -> io::Result<()> {
1041        self._set_no_inherit(no_inherit)
1042    }
1043
1044    pub(crate) fn _set_no_inherit(&self, no_inherit: bool) -> io::Result<()> {
1045        // NOTE: can't use `syscall!` because it expects the function in the
1046        // `windows_sys::Win32::Networking::WinSock::` path.
1047        let res = unsafe {
1048            SetHandleInformation(
1049                self.as_raw() as HANDLE,
1050                HANDLE_FLAG_INHERIT,
1051                !no_inherit as _,
1052            )
1053        };
1054        if res == 0 {
1055            // Zero means error.
1056            Err(io::Error::last_os_error())
1057        } else {
1058            Ok(())
1059        }
1060    }
1061
1062    /// Returns the [`Protocol`] of this socket by checking the `SO_PROTOCOL_INFOW`
1063    /// option on this socket.
1064    ///
1065    /// [`Protocol`]: crate::Protocol
1066    #[cfg(feature = "all")]
1067    pub fn protocol(&self) -> io::Result<Option<crate::Protocol>> {
1068        let info = unsafe {
1069            getsockopt::<WSAPROTOCOL_INFOW>(self.as_raw(), SOL_SOCKET, SO_PROTOCOL_INFOW)?
1070        };
1071        match info.iProtocol {
1072            0 => Ok(None),
1073            p => Ok(Some(crate::Protocol::from(p))),
1074        }
1075    }
1076}
1077
1078#[cfg_attr(docsrs, doc(cfg(windows)))]
1079impl AsSocket for crate::Socket {
1080    fn as_socket(&self) -> BorrowedSocket<'_> {
1081        // SAFETY: lifetime is bound by self.
1082        unsafe { BorrowedSocket::borrow_raw(self.as_raw() as RawSocket) }
1083    }
1084}
1085
1086#[cfg_attr(docsrs, doc(cfg(windows)))]
1087impl AsRawSocket for crate::Socket {
1088    fn as_raw_socket(&self) -> RawSocket {
1089        self.as_raw() as RawSocket
1090    }
1091}
1092
1093#[cfg_attr(docsrs, doc(cfg(windows)))]
1094impl From<crate::Socket> for OwnedSocket {
1095    fn from(sock: crate::Socket) -> OwnedSocket {
1096        // SAFETY: sock.into_raw() always returns a valid fd.
1097        unsafe { OwnedSocket::from_raw_socket(sock.into_raw() as RawSocket) }
1098    }
1099}
1100
1101#[cfg_attr(docsrs, doc(cfg(windows)))]
1102impl IntoRawSocket for crate::Socket {
1103    fn into_raw_socket(self) -> RawSocket {
1104        self.into_raw() as RawSocket
1105    }
1106}
1107
1108#[cfg_attr(docsrs, doc(cfg(windows)))]
1109impl From<OwnedSocket> for crate::Socket {
1110    fn from(fd: OwnedSocket) -> crate::Socket {
1111        // SAFETY: `OwnedFd` ensures the fd is valid.
1112        unsafe { crate::Socket::from_raw_socket(fd.into_raw_socket()) }
1113    }
1114}
1115
1116#[cfg_attr(docsrs, doc(cfg(windows)))]
1117impl FromRawSocket for crate::Socket {
1118    unsafe fn from_raw_socket(socket: RawSocket) -> crate::Socket {
1119        crate::Socket::from_raw(socket as Socket)
1120    }
1121}
1122
1123#[test]
1124fn in_addr_convertion() {
1125    let ip = Ipv4Addr::new(127, 0, 0, 1);
1126    let raw = to_in_addr(&ip);
1127    assert_eq!(unsafe { raw.S_un.S_addr }, 127 << 0 | 1 << 24);
1128    assert_eq!(from_in_addr(raw), ip);
1129
1130    let ip = Ipv4Addr::new(127, 34, 4, 12);
1131    let raw = to_in_addr(&ip);
1132    assert_eq!(
1133        unsafe { raw.S_un.S_addr },
1134        127 << 0 | 34 << 8 | 4 << 16 | 12 << 24
1135    );
1136    assert_eq!(from_in_addr(raw), ip);
1137}
1138
1139#[test]
1140fn in6_addr_convertion() {
1141    let ip = Ipv6Addr::new(0x2000, 1, 2, 3, 4, 5, 6, 7);
1142    let raw = to_in6_addr(&ip);
1143    let want = [
1144        0x2000u16.to_be(),
1145        1u16.to_be(),
1146        2u16.to_be(),
1147        3u16.to_be(),
1148        4u16.to_be(),
1149        5u16.to_be(),
1150        6u16.to_be(),
1151        7u16.to_be(),
1152    ];
1153    assert_eq!(unsafe { raw.u.Word }, want);
1154    assert_eq!(from_in6_addr(raw), ip);
1155}