Skip to main content

uds_windows/stdnet/
socket.rs

1#![allow(non_camel_case_types)]
2
3use std::io;
4use std::mem;
5use std::net::Shutdown;
6use std::os::raw::c_int;
7use std::os::windows::io::{
8    AsRawSocket, AsSocket, BorrowedSocket, FromRawSocket, IntoRawSocket, OwnedSocket, RawSocket,
9};
10use std::ptr;
11use std::sync::Once;
12use std::time::Duration;
13
14use super::{cvt, last_error};
15
16use windows_sys::Win32::Foundation::SetHandleInformation;
17use windows_sys::Win32::Foundation::HANDLE;
18use windows_sys::Win32::Networking::WinSock::{
19    accept, closesocket, getsockopt as c_getsockopt, ioctlsocket, recv, send,
20    setsockopt as c_setsockopt, shutdown, WSADuplicateSocketW, WSASocketW, WSAStartup, AF_UNIX,
21    FIONBIO, INVALID_SOCKET, SD_BOTH, SD_RECEIVE, SD_SEND, SOCKADDR, SOCKET, SOCK_STREAM,
22    SOL_SOCKET, SO_ERROR, WSADATA, WSAPROTOCOL_INFOW, WSA_FLAG_OVERLAPPED,
23};
24use windows_sys::Win32::System::Threading::GetCurrentProcessId;
25
26pub const HANDLE_FLAG_INHERIT: u32 = 0x01;
27
28#[derive(Debug)]
29pub struct Socket(SOCKET);
30
31/// Checks whether the Windows socket interface has been started already, and
32/// if not, starts it.
33pub fn init() {
34    static START: Once = Once::new();
35
36    START.call_once(|| unsafe {
37        let mut data: WSADATA = mem::zeroed();
38        let ret = WSAStartup(
39            0x202, // version 2.2
40            &mut data,
41        );
42        assert_eq!(ret, 0);
43
44        // let _ = std::rt::at_exit(|| { WSACleanup(); });
45    });
46}
47
48#[doc(hidden)]
49pub trait IsZero {
50    fn is_zero(&self) -> bool;
51}
52
53macro_rules! impl_is_zero {
54    ($($t:ident)*) => ($(impl IsZero for $t {
55        fn is_zero(&self) -> bool {
56            *self == 0
57        }
58    })*)
59}
60
61impl_is_zero! { i8 i16 i32 i64 isize u8 u16 u32 u64 usize }
62
63fn cvt_z<I: IsZero>(i: I) -> io::Result<I> {
64    if i.is_zero() {
65        Err(io::Error::last_os_error())
66    } else {
67        Ok(i)
68    }
69}
70
71impl Socket {
72    pub fn new() -> io::Result<Socket> {
73        let socket = unsafe {
74            match WSASocketW(
75                AF_UNIX as i32,
76                SOCK_STREAM,
77                0,
78                ptr::null_mut(),
79                0,
80                WSA_FLAG_OVERLAPPED,
81            ) {
82                INVALID_SOCKET => Err(last_error()),
83                n => Ok(Socket(n)),
84            }
85        }?;
86        socket.set_no_inherit()?;
87        Ok(socket)
88    }
89
90    // socketpair() not supported on Windows
91    // pub fn new_pair(fam: c_int, ty: c_int) -> io::Result<(Socket, Socket)> { ... }
92
93    pub fn accept(&self, storage: *mut SOCKADDR, len: *mut c_int) -> io::Result<Socket> {
94        let socket = unsafe {
95            match accept(self.0, storage, len) {
96                INVALID_SOCKET => Err(last_error()),
97                n => Ok(Socket(n)),
98            }
99        }?;
100        socket.set_no_inherit()?;
101        Ok(socket)
102    }
103
104    pub fn duplicate(&self) -> io::Result<Socket> {
105        let socket = unsafe {
106            let mut info: WSAPROTOCOL_INFOW = mem::zeroed();
107            cvt(WSADuplicateSocketW(
108                self.0,
109                GetCurrentProcessId(),
110                &mut info,
111            ))?;
112            match WSASocketW(
113                info.iAddressFamily,
114                info.iSocketType,
115                info.iProtocol,
116                &info,
117                0,
118                WSA_FLAG_OVERLAPPED,
119            ) {
120                INVALID_SOCKET => Err(last_error()),
121                n => Ok(Socket(n)),
122            }
123        }?;
124        socket.set_no_inherit()?;
125        Ok(socket)
126    }
127
128    fn recv_with_flags(&self, buf: &mut [u8], flags: c_int) -> io::Result<usize> {
129        let ret = cvt(unsafe {
130            recv(
131                self.0,
132                buf.as_mut_ptr() as *mut _,
133                buf.len() as c_int,
134                flags,
135            )
136        })?;
137        Ok(ret as usize)
138    }
139
140    pub fn read(&self, buf: &mut [u8]) -> io::Result<usize> {
141        self.recv_with_flags(buf, 0)
142    }
143
144    pub fn write(&self, buf: &[u8]) -> io::Result<usize> {
145        let ret = cvt(unsafe { send(self.0, buf as *const _ as *const _, buf.len() as c_int, 0) })?;
146        Ok(ret as usize)
147    }
148
149    fn set_no_inherit(&self) -> io::Result<()> {
150        cvt_z(unsafe { SetHandleInformation(self.0 as HANDLE, HANDLE_FLAG_INHERIT, 0) }).map(|_| ())
151    }
152
153    pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> {
154        let mut nonblocking = nonblocking as u32;
155        let r = unsafe { ioctlsocket(self.0, FIONBIO as c_int, &mut nonblocking) };
156        if r == 0 {
157            Ok(())
158        } else {
159            Err(io::Error::last_os_error())
160        }
161    }
162
163    pub fn shutdown(&self, how: Shutdown) -> io::Result<()> {
164        let how = match how {
165            Shutdown::Write => SD_SEND,
166            Shutdown::Read => SD_RECEIVE,
167            Shutdown::Both => SD_BOTH,
168        };
169        cvt(unsafe { shutdown(self.0, how) })?;
170        Ok(())
171    }
172
173    pub fn take_error(&self) -> io::Result<Option<io::Error>> {
174        let raw: c_int = getsockopt(self, SOL_SOCKET, SO_ERROR)?;
175        if raw == 0 {
176            Ok(None)
177        } else {
178            Ok(Some(io::Error::from_raw_os_error(raw as i32)))
179        }
180    }
181
182    pub fn set_timeout(&self, dur: Option<Duration>, kind: c_int) -> io::Result<()> {
183        let timeout = match dur {
184            Some(dur) => {
185                let timeout = dur2timeout(dur);
186                if timeout == 0 {
187                    return Err(io::Error::new(
188                        io::ErrorKind::InvalidInput,
189                        "cannot set a 0 duration timeout",
190                    ));
191                }
192                timeout
193            }
194            None => 0,
195        };
196        setsockopt(self, SOL_SOCKET, kind, timeout)
197    }
198
199    pub fn timeout(&self, kind: c_int) -> io::Result<Option<Duration>> {
200        let raw: u32 = getsockopt(self, SOL_SOCKET, kind)?;
201        if raw == 0 {
202            Ok(None)
203        } else {
204            let secs = raw / 1000;
205            let nsec = (raw % 1000) * 1000000;
206            Ok(Some(Duration::new(secs as u64, nsec as u32)))
207        }
208    }
209}
210
211pub fn setsockopt<T>(sock: &Socket, opt: c_int, val: c_int, payload: T) -> io::Result<()> {
212    unsafe {
213        let payload = &payload as *const T as *const _;
214        cvt(c_setsockopt(
215            sock.as_raw_socket() as usize,
216            opt,
217            val,
218            payload,
219            mem::size_of::<T>() as i32,
220        ))?;
221        Ok(())
222    }
223}
224
225pub fn getsockopt<T: Copy>(sock: &Socket, opt: c_int, val: c_int) -> io::Result<T> {
226    unsafe {
227        let mut slot: T = mem::zeroed();
228        let mut len = mem::size_of::<T>() as i32;
229        cvt(c_getsockopt(
230            sock.as_raw_socket() as _,
231            opt,
232            val,
233            &mut slot as *mut _ as *mut _,
234            &mut len,
235        ))?;
236        assert_eq!(len as usize, mem::size_of::<T>());
237        Ok(slot)
238    }
239}
240
241fn dur2timeout(dur: Duration) -> u32 {
242    // Note that a duration is a (u64, u32) (seconds, nanoseconds) pair, and the
243    // timeouts in windows APIs are typically u32 milliseconds. To translate, we
244    // have two pieces to take care of:
245    //
246    // * Nanosecond precision is rounded up
247    // * Greater than u32::MAX milliseconds (50 days) is rounded up to INFINITE
248    //   (never time out).
249    const INFINITE: u32 = u32::MAX;
250    dur.as_secs()
251        .checked_mul(1000)
252        .and_then(|ms| ms.checked_add((dur.subsec_nanos() as u64) / 1_000_000))
253        .and_then(|ms| {
254            ms.checked_add(if dur.subsec_nanos() % 1_000_000 > 0 {
255                1
256            } else {
257                0
258            })
259        })
260        .map(|ms| {
261            if ms > u32::MAX as u64 {
262                INFINITE
263            } else {
264                ms as u32
265            }
266        })
267        .unwrap_or(INFINITE)
268}
269
270impl Drop for Socket {
271    fn drop(&mut self) {
272        let _ = unsafe { closesocket(self.0) };
273    }
274}
275
276impl AsRawSocket for Socket {
277    fn as_raw_socket(&self) -> RawSocket {
278        self.0 as RawSocket
279    }
280}
281
282impl FromRawSocket for Socket {
283    unsafe fn from_raw_socket(sock: RawSocket) -> Self {
284        Socket(sock as SOCKET)
285    }
286}
287
288impl IntoRawSocket for Socket {
289    fn into_raw_socket(self) -> RawSocket {
290        let ret = self.0 as RawSocket;
291        mem::forget(self);
292        ret
293    }
294}
295
296impl AsSocket for Socket {
297    fn as_socket(&self) -> BorrowedSocket<'_> {
298        // SAFETY: Although the lifetime is elided, it is indeed a borrow from self, and the returned value can not outlive the lifetime of the Socket.
299        unsafe { BorrowedSocket::borrow_raw(self.as_raw_socket()) }
300    }
301}
302
303impl From<Socket> for OwnedSocket {
304    fn from(sock: Socket) -> OwnedSocket {
305        // SAFETY: This is safe because it consumes the socket using the `OwnedSocket::from(Socket)`, or by using `Socket.into::<OwnedSocket>()`
306        unsafe { OwnedSocket::from_raw_socket(sock.into_raw_socket()) }
307    }
308}
309
310impl From<OwnedSocket> for Socket {
311    fn from(owned: OwnedSocket) -> Socket {
312        // SAFETY: This is safe because it consumes the socket using the `Socket::from(OwnedSocket)`, or by using `OwnedSocket.into::<Socket>()`
313        unsafe { Socket::from_raw_socket(owned.into_raw_socket()) }
314    }
315}