uni_socket/
windows.rs

1//! A unified stream type for both TCP and Unix domain sockets.
2//!
3//! This module is to keep the API consistency across different platforms.
4//! On Windows, only TCP sockets are supported.
5
6use std::future::poll_fn;
7use std::mem::MaybeUninit;
8use std::net::{Shutdown, SocketAddr};
9use std::pin::Pin;
10use std::task::{Context, Poll};
11use std::{fmt, io};
12
13use socket2::SockRef;
14use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
15use tokio::net::TcpSocket;
16use uni_addr::{UniAddr, UniAddrInner};
17
18/// A simple wrapper of [`tokio::net::TcpSocket`].
19pub struct UniSocket {
20    inner: tokio::net::TcpSocket,
21}
22
23impl fmt::Debug for UniSocket {
24    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
25        f.debug_tuple("UniSocket").field(&self.inner).finish()
26    }
27}
28
29impl UniSocket {
30    #[inline]
31    const fn from_inner(inner: TcpSocket) -> Self {
32        Self { inner }
33    }
34
35    /// Creates a new [`UniSocket`], and applies the given initialization
36    /// function to the underlying socket.
37    ///
38    /// The given address determines the socket type, and the caller should bind
39    /// to / connect to the address later.
40    pub fn new(addr: &UniAddr) -> io::Result<Self> {
41        match addr.as_inner() {
42            UniAddrInner::Inet(SocketAddr::V4(_)) => TcpSocket::new_v4().map(Self::from_inner),
43            UniAddrInner::Inet(SocketAddr::V6(_)) => TcpSocket::new_v6().map(Self::from_inner),
44            _ => Err(io::Error::new(
45                io::ErrorKind::Other,
46                "unsupported address type",
47            )),
48        }
49    }
50
51    /// Binds the socket to the specified address.
52    ///
53    /// Notes that the address must be the one used to create the socket.
54    pub fn bind(self, addr: &UniAddr) -> io::Result<Self> {
55        match addr.as_inner() {
56            UniAddrInner::Inet(addr) => self.inner.bind(*addr)?,
57            UniAddrInner::Host(_) => {
58                return Err(io::Error::new(
59                    io::ErrorKind::Other,
60                    "The Host address type must be resolved before creating a socket",
61                ))
62            }
63            _ => {
64                return Err(io::Error::new(
65                    io::ErrorKind::Other,
66                    "unsupported address type",
67                ))
68            }
69        }
70
71        Ok(self)
72    }
73
74    /// Mark a socket as ready to accept incoming connection requests using
75    /// [`UniListener::accept`].
76    ///
77    /// This function directly corresponds to the `listen(2)` function on
78    /// Windows.
79    pub fn listen(self, backlog: u32) -> io::Result<UniListener> {
80        self.inner.listen(backlog).map(UniListener::from_inner)
81    }
82
83    /// Initiates and completes a connection on this socket to the specified
84    /// address.
85    ///
86    /// This function directly corresponds to the `connect(2)` function on
87    /// Windows.
88    pub async fn connect(self, addr: &UniAddr) -> io::Result<UniStream> {
89        match addr.as_inner() {
90            UniAddrInner::Inet(addr) => self.inner.connect(*addr).await.map(UniStream::from_inner),
91            _ => Err(io::Error::new(
92                io::ErrorKind::Other,
93                "unsupported address type",
94            )),
95        }
96    }
97
98    /// Returns the socket address of the local half of this socket.
99    ///
100    /// This function directly corresponds to the `getsockname(2)` function on
101    /// Windows.
102    ///
103    /// # Notes
104    ///
105    /// Depending on the OS this may return an error if the socket is not
106    /// [bound](Self::bind).
107    pub fn local_addr(&self) -> io::Result<UniAddr> {
108        self.inner.local_addr().map(UniAddr::from)
109    }
110
111    /// Returns a [`SockRef`] to the underlying socket for configuration.
112    pub fn as_socket_ref(&self) -> SockRef<'_> {
113        SockRef::from(&self.inner)
114    }
115}
116
117wrapper_lite::wrapper!(
118    /// A simple wrapper of [`tokio::net::TcpListener`].
119    pub struct UniListener(tokio::net::TcpListener);
120);
121
122impl fmt::Debug for UniListener {
123    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
124        f.debug_struct("UniListener")
125            .field("local_addr", &self.local_addr().ok())
126            .finish()
127    }
128}
129
130impl TryFrom<std::net::TcpListener> for UniListener {
131    type Error = io::Error;
132
133    /// Converts a standard library TCP listener into a unified [`UniListener`].
134    fn try_from(listener: std::net::TcpListener) -> Result<Self, Self::Error> {
135        listener.set_nonblocking(true)?;
136
137        Ok(Self::from_inner(listener.try_into()?))
138    }
139}
140
141impl TryFrom<tokio::net::TcpListener> for UniListener {
142    type Error = io::Error;
143
144    /// Converts a Tokio library TCP listener into a unified [`UniListener`].
145    ///
146    /// # Errors
147    ///
148    /// Actually, this is infallible and always returns `Ok`, for APIs
149    /// consistency.
150    fn try_from(listener: tokio::net::TcpListener) -> Result<Self, Self::Error> {
151        Ok(Self::from_inner(listener))
152    }
153}
154
155impl UniListener {
156    /// Accepts an incoming connection to this listener, and returns the
157    /// accepted stream and the peer address.
158    ///
159    /// This method will retry on non-deadly errors, including:
160    ///
161    /// - `ECONNREFUSED`.
162    /// - `ECONNABORTED`.
163    /// - `ECONNRESET`.
164    pub async fn accept(&self) -> io::Result<(UniStream, UniAddr)> {
165        loop {
166            match self.inner.accept().await {
167                Ok((stream, addr)) => {
168                    return Ok((UniStream::from_inner(stream), UniAddr::from(addr)))
169                }
170                Err(e)
171                    if matches!(
172                        e.kind(),
173                        io::ErrorKind::ConnectionRefused
174                            | io::ErrorKind::ConnectionAborted
175                            | io::ErrorKind::ConnectionReset
176                    ) => {}
177                Err(e) => return Err(e),
178            }
179        }
180    }
181
182    /// Accepts an incoming connection to this listener, and returns the
183    /// accepted stream and the peer address.
184    ///
185    /// Notes that on multiple calls to [`poll_accept`](Self::poll_accept), only
186    /// the waker from the [`Context`] passed to the most recent call is
187    /// scheduled to receive a wakeup. Unless you are implementing your own
188    /// future accepting connections, you probably want to use the asynchronous
189    /// [`accept`](Self::accept) method instead.
190    pub fn poll_accept(&self, cx: &mut Context<'_>) -> Poll<io::Result<(UniStream, UniAddr)>> {
191        loop {
192            match self.inner.poll_accept(cx) {
193                Poll::Ready(Ok((stream, addr))) => {
194                    return Poll::Ready(Ok((UniStream::from_inner(stream), UniAddr::from(addr))))
195                }
196                Poll::Ready(Err(e))
197                    if matches!(
198                        e.kind(),
199                        io::ErrorKind::ConnectionRefused
200                            | io::ErrorKind::ConnectionAborted
201                            | io::ErrorKind::ConnectionReset
202                    ) => {}
203                Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
204                Poll::Pending => return Poll::Pending,
205            }
206        }
207    }
208
209    /// Returns the socket address of the local half of this socket.
210    ///
211    /// This function directly corresponds to the `getsockname(2)` function on
212    /// Windows.
213    ///
214    /// # Notes
215    ///
216    /// Depending on the OS this may return an error if the socket is not
217    /// [bound](UniSocket::bind).
218    pub fn local_addr(&self) -> io::Result<UniAddr> {
219        self.inner.local_addr().map(UniAddr::from)
220    }
221
222    /// Returns a [`SockRef`] to the underlying socket for configuration.
223    pub fn as_socket_ref(&self) -> SockRef<'_> {
224        SockRef::from(&self.inner)
225    }
226}
227
228wrapper_lite::wrapper!(
229    /// A simple wrapper of [`tokio::net::TcpStream`].
230    pub struct UniStream(tokio::net::TcpStream);
231);
232
233impl fmt::Debug for UniStream {
234    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
235        f.debug_struct("UniStream")
236            .field("local_addr", &self.local_addr().ok())
237            .field("peer_addr", &self.peer_addr().ok())
238            .finish()
239    }
240}
241
242impl TryFrom<tokio::net::TcpStream> for UniStream {
243    type Error = io::Error;
244
245    /// Converts a Tokio TCP stream into a [`UniStream`].
246    ///
247    /// # Errors
248    ///
249    /// This is infallible and always returns `Ok`, for APIs consistency.
250    fn try_from(value: tokio::net::TcpStream) -> Result<Self, Self::Error> {
251        Ok(Self::from_inner(value))
252    }
253}
254
255impl TryFrom<std::net::TcpStream> for UniStream {
256    type Error = std::io::Error;
257
258    /// Converts a standard library TCP stream into a [`UniStream`].
259    ///
260    /// # Panics
261    ///
262    /// This function panics if it is not called from within a runtime with
263    /// IO enabled.
264    ///
265    /// The runtime is usually set implicitly when this function is called
266    /// from a future driven by a tokio runtime, otherwise runtime can be
267    /// set explicitly with `Runtime::enter` function.
268    fn try_from(stream: std::net::TcpStream) -> Result<Self, Self::Error> {
269        stream.set_nonblocking(true)?;
270
271        Ok(Self::from_inner(stream.try_into()?))
272    }
273}
274
275impl UniStream {
276    /// Returns the socket address of the local half of this socket.
277    ///
278    /// This function directly corresponds to the `getsockname(2)` function on
279    /// Windows.
280    ///
281    /// # Notes
282    ///
283    /// Depending on the OS this may return an error if the socket is not
284    /// [bound](UniSocket::bind).
285    pub fn local_addr(&self) -> io::Result<UniAddr> {
286        self.inner.local_addr().map(UniAddr::from)
287    }
288
289    /// Returns the socket address of the remote peer of this socket.
290    ///
291    /// This function directly corresponds to the `getpeername(2)` function on
292    /// Windows and Unix.
293    ///
294    /// # Notes
295    ///
296    /// This returns an error if the socket is not
297    /// [`connect`ed](UniSocket::connect).
298    pub fn peer_addr(&self) -> io::Result<UniAddr> {
299        self.inner.peer_addr().map(UniAddr::from)
300    }
301
302    /// Receives data on the socket from the remote adress to which it is
303    /// connected, without removing that data from the queue. On success,
304    /// returns the number of bytes peeked.
305    ///
306    /// Successive calls return the same data. This is accomplished by passing
307    /// `MSG_PEEK` as a flag to the underlying `recv` system call.
308    pub async fn peek(&self, buf: &mut [MaybeUninit<u8>]) -> io::Result<usize> {
309        buf.fill(MaybeUninit::new(0));
310
311        #[allow(unsafe_code)]
312        let buf = unsafe { std::slice::from_raw_parts_mut(buf.as_mut_ptr().cast(), buf.len()) };
313
314        self.inner.peek(buf).await
315    }
316
317    /// Receives data on the socket from the remote adress to which it is
318    /// connected, without removing that data from the queue. On success,
319    /// returns the number of bytes peeked.
320    ///
321    /// Successive calls return the same data. This is accomplished by passing
322    /// `MSG_PEEK` as a flag to the underlying `recv` system call.
323    ///
324    /// Notes that on multiple calls to [`poll_peek`](Self::poll_peek), only
325    /// the waker from the [`Context`] passed to the most recent call is
326    /// scheduled to receive a wakeup. Unless you are implementing your own
327    /// future accepting connections, you probably want to use the asynchronous
328    /// [`accept`](UniListener::accept) method instead.
329    pub fn poll_peek(
330        self: Pin<&mut Self>,
331        cx: &mut Context<'_>,
332        buf: &mut ReadBuf<'_>,
333    ) -> Poll<io::Result<usize>> {
334        self.get_mut().inner.poll_peek(cx, buf)
335    }
336
337    #[inline]
338    /// Receives data on the socket from the remote address to which it is
339    /// connected.
340    pub async fn read(&mut self, buf: &mut [MaybeUninit<u8>]) -> io::Result<usize> {
341        let mut this = Pin::new(&mut self.inner);
342
343        let buf = &mut ReadBuf::uninit(buf);
344
345        poll_fn(|cx| this.as_mut().poll_read(cx, buf)).await?;
346
347        Ok(buf.filled().len())
348    }
349
350    #[inline]
351    /// Sends data on the socket to a connected peer.
352    pub async fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
353        let mut this = Pin::new(&mut self.inner);
354
355        poll_fn(|cx| this.as_mut().poll_write(cx, buf)).await
356    }
357
358    /// Shuts down the read, write, or both halves of this connection.
359    ///
360    /// This function will cause all pending and future I/O on the specified
361    /// portions to return immediately with an appropriate value.
362    pub fn shutdown(&self, shutdown: Shutdown) -> io::Result<()> {
363        match self.as_socket_ref().shutdown(shutdown) {
364            Ok(()) => Ok(()),
365            Err(e) if e.kind() == io::ErrorKind::NotConnected => Ok(()),
366            Err(e) => Err(e),
367        }
368    }
369
370    /// See [`tokio::net::TcpStream::into_split`].
371    pub fn into_split(self) -> (OwnedReadHalf, OwnedWriteHalf) {
372        self.inner.into_split()
373    }
374
375    /// Returns a [`SockRef`] to the underlying socket for configuration.
376    pub fn as_socket_ref(&self) -> SockRef<'_> {
377        SockRef::from(&self.inner)
378    }
379}
380
381impl AsyncRead for UniStream {
382    #[inline]
383    fn poll_read(
384        mut self: Pin<&mut Self>,
385        cx: &mut Context<'_>,
386        buf: &mut ReadBuf<'_>,
387    ) -> Poll<io::Result<()>> {
388        Pin::new(&mut self.inner).poll_read(cx, buf)
389    }
390}
391
392impl AsyncWrite for UniStream {
393    #[inline]
394    fn poll_write(
395        mut self: Pin<&mut Self>,
396        cx: &mut Context<'_>,
397        buf: &[u8],
398    ) -> Poll<io::Result<usize>> {
399        Pin::new(&mut self.inner).poll_write(cx, buf)
400    }
401
402    #[inline]
403    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
404        Pin::new(&mut self.inner).poll_flush(cx)
405    }
406
407    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
408        Pin::new(&mut self.inner).poll_shutdown(cx)
409    }
410}
411
412// Re-export split halves
413pub use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
414
415#[cfg(windows)]
416mod sys {
417    use std::os::windows::io::{AsRawSocket, AsSocket, BorrowedSocket, RawSocket};
418
419    use super::{UniListener, UniSocket, UniStream};
420
421    impl AsSocket for UniSocket {
422        fn as_socket(&self) -> BorrowedSocket<'_> {
423            self.inner.as_socket()
424        }
425    }
426
427    impl AsRawSocket for UniSocket {
428        fn as_raw_socket(&self) -> RawSocket {
429            self.inner.as_raw_socket()
430        }
431    }
432
433    impl AsSocket for UniListener {
434        fn as_socket(&self) -> BorrowedSocket<'_> {
435            self.inner.as_socket()
436        }
437    }
438
439    impl AsRawSocket for UniListener {
440        fn as_raw_socket(&self) -> RawSocket {
441            self.inner.as_raw_socket()
442        }
443    }
444
445    impl AsSocket for UniStream {
446        fn as_socket(&self) -> BorrowedSocket<'_> {
447            self.inner.as_socket()
448        }
449    }
450
451    impl AsRawSocket for UniStream {
452        fn as_raw_socket(&self) -> RawSocket {
453            self.inner.as_raw_socket()
454        }
455    }
456}