uni_socket/
unix.rs

1//! A unified stream type for both TCP and Unix domain sockets.
2
3use std::marker::PhantomData;
4use std::mem::MaybeUninit;
5use std::net::{Shutdown, SocketAddr};
6use std::os::fd::{AsFd, AsRawFd, BorrowedFd, RawFd};
7use std::pin::Pin;
8use std::sync::Arc;
9use std::task::{ready, Context, Poll};
10use std::time::Duration;
11use std::{fmt, io};
12
13use socket2::{SockRef, Socket};
14use tokio::io::unix::AsyncFd;
15use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
16use tokio::time::sleep;
17use uni_addr::{UniAddr, UniAddrInner};
18
19/// An async [`Socket`].
20pub struct UniSocket<Ty = ()> {
21    inner: AsyncFd<Socket>,
22    ty: PhantomData<Ty>,
23}
24
25impl fmt::Debug for UniSocket {
26    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
27        f.debug_tuple("UniSocket").field(&self.inner).finish()
28    }
29}
30
31impl<Ty> AsFd for UniSocket<Ty> {
32    #[inline]
33    fn as_fd(&self) -> BorrowedFd<'_> {
34        self.inner.as_fd()
35    }
36}
37
38impl<Ty> AsRawFd for UniSocket<Ty> {
39    #[inline]
40    fn as_raw_fd(&self) -> RawFd {
41        self.inner.as_raw_fd()
42    }
43}
44
45impl UniSocket {
46    /// Creates a new [`UniSocket`], and applies the given initialization
47    /// function to the underlying socket.
48    ///
49    /// The given address determines the socket type, and the caller should bind
50    /// to / connect to the address later.
51    pub fn new(addr: &UniAddr) -> io::Result<Self> {
52        Self::new_priv(addr)
53    }
54}
55
56impl<Ty> UniSocket<Ty> {
57    #[inline]
58    const fn from_inner(inner: AsyncFd<Socket>) -> Self {
59        Self {
60            inner,
61            ty: PhantomData,
62        }
63    }
64
65    fn new_priv(addr: &UniAddr) -> io::Result<Self> {
66        let ty = socket2::Type::STREAM;
67
68        #[cfg(any(
69            target_os = "android",
70            target_os = "dragonfly",
71            target_os = "freebsd",
72            target_os = "fuchsia",
73            target_os = "illumos",
74            target_os = "linux",
75            target_os = "netbsd",
76            target_os = "openbsd"
77        ))]
78        let ty = ty.nonblocking();
79
80        let inner = match addr.as_inner() {
81            UniAddrInner::Inet(SocketAddr::V4(_)) => {
82                Socket::new(socket2::Domain::IPV4, ty, Some(socket2::Protocol::TCP))
83            }
84            UniAddrInner::Inet(SocketAddr::V6(_)) => {
85                Socket::new(socket2::Domain::IPV6, ty, Some(socket2::Protocol::TCP))
86            }
87            UniAddrInner::Unix(_) => Socket::new(socket2::Domain::UNIX, ty, None),
88            UniAddrInner::Host(_) => Err(io::Error::new(
89                io::ErrorKind::Other,
90                "The Host address type must be resolved before creating a socket",
91            )),
92            _ => Err(io::Error::new(
93                io::ErrorKind::Other,
94                "Unsupported address type",
95            )),
96        }?;
97
98        #[cfg(not(any(
99            target_os = "android",
100            target_os = "dragonfly",
101            target_os = "freebsd",
102            target_os = "fuchsia",
103            target_os = "illumos",
104            target_os = "linux",
105            target_os = "netbsd",
106            target_os = "openbsd"
107        )))]
108        inner.set_nonblocking(true)?;
109
110        // On platforms with Berkeley-derived sockets, this allows to quickly
111        // rebind a socket, without needing to wait for the OS to clean up the
112        // previous one.
113        //
114        // On Windows, this allows rebinding sockets which are actively in use,
115        // which allows “socket hijacking”, so we explicitly don't set it here.
116        // https://docs.microsoft.com/en-us/windows/win32/winsock/using-so-reuseaddr-and-so-exclusiveaddruse
117        #[cfg(not(windows))]
118        inner.set_reuse_address(true)?;
119
120        AsyncFd::new(inner).map(Self::from_inner)
121    }
122
123    /// Binds the socket to the specified address.
124    ///
125    /// Notes that the address must be the one used to create the socket.
126    pub fn bind(self, addr: &UniAddr) -> io::Result<Self> {
127        self.inner.get_ref().bind(&addr.try_into()?)?;
128
129        Ok(Self::from_inner(self.inner))
130    }
131
132    #[cfg(any(
133        target_os = "ios",
134        target_os = "visionos",
135        target_os = "macos",
136        target_os = "tvos",
137        target_os = "watchos",
138        target_os = "illumos",
139        target_os = "solaris",
140        target_os = "linux",
141        target_os = "android",
142        target_os = "fuchsia",
143    ))]
144    /// Sets the value for the `SO_BINDTODEVICE` option on this socket, then
145    /// [`bind`s](Self::bind) the socket to the specified address.
146    ///
147    /// If a socket is bound to an interface, only packets received from that
148    /// particular interface are processed by the socket. Note that this only
149    /// works for some socket types, particularly `AF_INET` sockets.
150    ///
151    /// For those platforms, like macOS, that do not support `SO_BINDTODEVICE`,
152    /// this function will fallback to `bind_device_by_index_v(4|6)`, while the
153    /// `if_index` obtained from the interface name with `if_nametoindex(3)`.
154    ///
155    /// When `device` is `None`, this is equivalent to calling
156    /// [`bind`](Self::bind), instead of clearing the option like
157    /// [`Socket::bind_device`].
158    pub fn bind_device(self, addr: &UniAddr, device: Option<&str>) -> io::Result<Self> {
159        if let Some(device) = device {
160            #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
161            {
162                self.inner.get_ref().bind_device(Some(device.as_bytes()))?;
163            }
164
165            #[cfg(all(
166                not(any(target_os = "android", target_os = "fuchsia", target_os = "linux")),
167                any(
168                    target_os = "ios",
169                    target_os = "visionos",
170                    target_os = "macos",
171                    target_os = "tvos",
172                    target_os = "watchos",
173                    target_os = "illumos",
174                    target_os = "solaris",
175                )
176            ))]
177            {
178                use std::num::NonZeroU32;
179
180                #[allow(unsafe_code)]
181                let if_index = unsafe { libc::if_nametoindex(device.as_ptr().cast()) };
182
183                let Some(if_index) = NonZeroU32::new(if_index) else {
184                    return Err(io::Error::last_os_error());
185                };
186
187                match addr.as_inner() {
188                    UniAddrInner::Inet(SocketAddr::V4(_)) => {
189                        self.inner
190                            .get_ref()
191                            .bind_device_by_index_v4(Some(if_index))?;
192                    }
193                    UniAddrInner::Inet(SocketAddr::V6(_)) => {
194                        self.inner
195                            .get_ref()
196                            .bind_device_by_index_v6(Some(if_index))?;
197                    }
198                    _ => {
199                        return Err(io::Error::new(
200                            io::ErrorKind::Other,
201                            "`bind_device_by_index` only works for IPv4 and IPv6 addresses",
202                        ))
203                    }
204                }
205            }
206        }
207
208        self.bind(addr)
209    }
210
211    /// Mark a socket as ready to accept incoming connection requests using
212    /// [`UniListener::accept`].
213    ///
214    /// This function directly corresponds to the `listen(2)` function on Unix.
215    ///
216    /// An error will be returned if `listen` or `connect` has already been
217    /// called on this builder.
218    pub fn listen(self, backlog: u32) -> io::Result<UniListener> {
219        #[allow(clippy::cast_possible_wrap)]
220        self.inner.get_ref().listen(backlog as i32)?;
221
222        Ok(UniListener::from_inner(self.inner))
223    }
224
225    /// Initiates and completes a connection on this socket to the specified
226    /// address.
227    ///
228    /// This function directly corresponds to the `connect(2)` function on Unix.
229    ///
230    /// An error will be returned if `connect` has already been called.
231    pub async fn connect(self, addr: &UniAddr) -> io::Result<UniStream> {
232        if let Err(e) = self.inner.get_ref().connect(&addr.try_into()?) {
233            if e.raw_os_error() != Some(libc::EINPROGRESS) {
234                return Err(e);
235            }
236        }
237
238        let this = UniStream::from_inner(self.inner);
239
240        // Poll connection completion
241        loop {
242            let mut guard = this.inner.writable().await?;
243
244            match guard.try_io(|inner| inner.get_ref().take_error()) {
245                Ok(Ok(None)) => break,
246                Ok(Ok(Some(e)) | Err(e)) => return Err(e),
247                Err(_would_block) => {}
248            }
249        }
250
251        Ok(this)
252    }
253
254    /// Returns the socket address of the local half of this socket.
255    ///
256    /// This function directly corresponds to the `getsockname(2)` function on
257    /// Windows and Unix.
258    ///
259    /// # Notes
260    ///
261    /// Depending on the OS this may return an error if the socket is not
262    /// [bound](Self::bind).
263    pub fn local_addr(&self) -> io::Result<UniAddr> {
264        self.inner
265            .get_ref()
266            .local_addr()
267            .and_then(TryFrom::try_from)
268    }
269
270    /// Returns a [`SockRef`] to the underlying socket for configuration.
271    pub fn as_socket_ref(&self) -> SockRef<'_> {
272        SockRef::from(&self.inner)
273    }
274}
275
276#[derive(Debug)]
277/// Marker type: this socket is a listener.
278pub struct ListenerTy;
279
280/// A [`UniSocket`] used as a listener.
281pub type UniListener = UniSocket<ListenerTy>;
282
283impl fmt::Debug for UniListener {
284    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
285        f.debug_struct("UniListener")
286            .field("local_addr", &self.local_addr().ok())
287            .finish()
288    }
289}
290
291impl TryFrom<std::net::TcpListener> for UniListener {
292    type Error = io::Error;
293
294    /// Converts a standard library TCP listener into a unified [`UniListener`].
295    ///
296    /// # Panics
297    ///
298    /// This function panics if there is no current Tokio reactor set, or if
299    /// the `rt` feature flag is not enabled.
300    fn try_from(listener: std::net::TcpListener) -> Result<Self, Self::Error> {
301        listener.set_nonblocking(true)?;
302
303        AsyncFd::new(listener.into()).map(Self::from_inner)
304    }
305}
306
307impl TryFrom<tokio::net::TcpListener> for UniListener {
308    type Error = io::Error;
309
310    /// Converts a Tokio library TCP listener into a unified [`UniListener`].
311    ///
312    /// # Panics
313    ///
314    /// This function panics if there is no current Tokio reactor set, or if
315    /// the `rt` feature flag is not enabled.
316    fn try_from(listener: tokio::net::TcpListener) -> Result<Self, Self::Error> {
317        listener
318            .into_std()
319            .map(Into::into)
320            .and_then(AsyncFd::new)
321            .map(Self::from_inner)
322    }
323}
324
325impl UniListener {
326    /// Accepts an incoming connection to this listener, and returns the
327    /// accepted stream and the peer address.
328    ///
329    /// This method will retry on non-deadly errors, including:
330    ///
331    /// - `ECONNREFUSED`.
332    /// - `ECONNABORTED`.
333    /// - `ECONNRESET`.
334    /// - `EMFILE`.
335    pub async fn accept(&self) -> io::Result<(UniStream, UniAddr)> {
336        fn accept(socket: &Socket) -> io::Result<(UniStream, UniAddr)> {
337            // On platforms that support it we can use `accept4(2)` to set `NONBLOCK`
338            // and `CLOEXEC` in the call to accept the connection.
339            // Android x86's seccomp profile forbids calls to `accept4(2)`
340            // See https://github.com/tokio-rs/mio/issues/1445 for details
341            #[cfg(any(
342                all(not(target_arch = "x86"), target_os = "android"),
343                target_os = "dragonfly",
344                target_os = "freebsd",
345                target_os = "fuchsia",
346                target_os = "hurd",
347                target_os = "illumos",
348                target_os = "linux",
349                target_os = "netbsd",
350                target_os = "openbsd",
351                target_os = "solaris",
352                target_os = "cygwin",
353            ))]
354            let (accepted, peer_addr) = socket.accept4(libc::SOCK_CLOEXEC | libc::SOCK_NONBLOCK)?;
355
356            // But not all platforms have the `accept4(2)` call. Luckily BSD (derived)
357            // OSs inherit the non-blocking flag from the listener, so we just have to
358            // set `CLOEXEC`.
359            #[cfg(any(
360                target_os = "aix",
361                target_os = "haiku",
362                target_os = "ios",
363                target_os = "macos",
364                target_os = "redox",
365                target_os = "tvos",
366                target_os = "visionos",
367                target_os = "watchos",
368                target_os = "espidf",
369                target_os = "vita",
370                target_os = "hermit",
371                target_os = "nto",
372                all(target_arch = "x86", target_os = "android"),
373            ))]
374            let (accepted, peer_addr) = socket.accept_raw().and_then(|(accepted, peer_addr)| {
375                #[cfg(not(any(target_os = "espidf", target_os = "vita")))]
376                accepted.set_cloexec(true)?;
377
378                #[cfg(any(
379                    all(target_arch = "x86", target_os = "android"),
380                    target_os = "aix",
381                    target_os = "espidf",
382                    target_os = "vita",
383                    target_os = "hermit",
384                    target_os = "nto",
385                ))]
386                accepted.set_nonblocking(true)?;
387
388                // On Apple platforms set `NOSIGPIPE`.
389                #[cfg(any(
390                    target_os = "ios",
391                    target_os = "visionos",
392                    target_os = "macos",
393                    target_os = "tvos",
394                    target_os = "watchos",
395                ))]
396                socket.set_nosigpipe(true)?;
397
398                Ok((accepted, peer_addr))
399            })?;
400
401            // The non-blocking state of `listener` is inherited. See
402            // https://docs.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-accept#remarks.
403            #[cfg(windows)]
404            let (accepted, peer_addr) = socket.accept_raw()?;
405
406            Ok((
407                UniStream::from_inner(AsyncFd::new(accepted)?),
408                peer_addr.try_into()?,
409            ))
410        }
411
412        loop {
413            let accepted = self
414                .inner
415                .readable()
416                .await?
417                .try_io(|socket| accept(socket.get_ref()));
418
419            match accepted {
420                Ok(ret @ Ok(_)) => {
421                    return ret;
422                }
423                Ok(Err(e))
424                    if matches!(
425                        e.kind(),
426                        io::ErrorKind::ConnectionRefused
427                            | io::ErrorKind::ConnectionAborted
428                            | io::ErrorKind::ConnectionReset
429                    ) =>
430                {
431                    // This is not a deadly error, too, just continue
432                }
433                Ok(Err(e)) if matches!(e.raw_os_error(), Some(libc::EMFILE)) => {
434                    // This is not a deadly error, but we may wait for a while.
435                    sleep(Duration::from_secs(1)).await;
436                }
437                Ok(Err(e)) => {
438                    return Err(e);
439                }
440                Err(_would_block) => {}
441            }
442        }
443    }
444
445    /// Accepts an incoming connection to this listener, and returns the
446    /// accepted stream and the peer address.
447    ///
448    /// Notes that on multiple calls to [`poll_accept`](Self::poll_accept), only
449    /// the waker from the [`Context`] passed to the most recent call is
450    /// scheduled to receive a wakeup. Unless you are implementing your own
451    /// future accepting connections, you probably want to use the asynchronous
452    /// [`accept`](Self::accept) method instead.
453    ///
454    /// Unlike [`accept`](Self::accept), this method does not handle `EMFILE`
455    /// (i.e., too many open files) errors and the caller may need to handle
456    /// it by itself.
457    pub fn poll_accept(&self, cx: &mut Context<'_>) -> Poll<io::Result<(UniStream, UniAddr)>> {
458        loop {
459            let mut guard = ready!(self.inner.poll_read_ready(cx))?;
460
461            match guard.try_io(|socket| socket.get_ref().accept()) {
462                Ok(Ok((socket, addr))) => {
463                    let addr = if let Some(addr) = addr.as_socket() {
464                        UniAddr::from(addr)
465                    } else if let Some(addr) = addr.as_unix() {
466                        UniAddr::from(addr)
467                    } else {
468                        return Poll::Ready(Err(io::Error::new(
469                            io::ErrorKind::Other,
470                            "unsupported address type",
471                        )));
472                    };
473
474                    return Poll::Ready(Ok((UniStream::from_inner(AsyncFd::new(socket)?), addr)));
475                }
476                Ok(Err(e))
477                    if matches!(
478                        e.kind(),
479                        io::ErrorKind::ConnectionRefused
480                            | io::ErrorKind::ConnectionAborted
481                            | io::ErrorKind::ConnectionReset
482                    ) =>
483                {
484                    // This is not a deadly error, too, just continue
485                }
486                Ok(Err(e)) => {
487                    return Poll::Ready(Err(e));
488                }
489                Err(_would_block) => {}
490            }
491        }
492    }
493}
494
495#[derive(Debug)]
496/// Marker type: this socket is a (TCP) stream.
497pub struct StreamTy;
498
499/// A [`UniSocket`] used as a stream.
500pub type UniStream = UniSocket<StreamTy>;
501
502impl fmt::Debug for UniStream {
503    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
504        f.debug_struct("UniStream")
505            .field("local_addr", &self.local_addr().ok())
506            .field("peer_addr", &self.peer_addr().ok())
507            .finish()
508    }
509}
510
511impl TryFrom<tokio::net::TcpStream> for UniStream {
512    type Error = io::Error;
513
514    /// Converts a Tokio TCP stream into a unified [`UniStream`].
515    ///
516    /// # Panics
517    ///
518    /// This function panics if there is no current Tokio reactor set, or if
519    /// the `rt` feature flag is not enabled.
520    fn try_from(stream: tokio::net::TcpStream) -> Result<Self, Self::Error> {
521        stream
522            .into_std()
523            .map(Into::into)
524            .and_then(AsyncFd::new)
525            .map(Self::from_inner)
526    }
527}
528
529impl TryFrom<std::net::TcpStream> for UniStream {
530    type Error = io::Error;
531
532    /// Converts a standard library TCP stream into a unified [`UniStream`].
533    ///
534    /// # Panics
535    ///
536    /// This function panics if there is no current Tokio reactor set, or if
537    /// the `rt` feature flag is not enabled.
538    fn try_from(stream: std::net::TcpStream) -> Result<Self, Self::Error> {
539        stream.set_nonblocking(true)?;
540
541        AsyncFd::new(stream.into()).map(Self::from_inner)
542    }
543}
544
545impl TryFrom<tokio::net::UnixStream> for UniStream {
546    type Error = io::Error;
547
548    /// Converts a Tokio Unix stream into a unified [`UniStream`].
549    ///
550    /// # Panics
551    ///
552    /// This function panics if there is no current Tokio reactor set, or if
553    /// the `rt` feature flag is not enabled.
554    fn try_from(stream: tokio::net::UnixStream) -> Result<Self, Self::Error> {
555        stream
556            .into_std()
557            .map(Into::into)
558            .and_then(AsyncFd::new)
559            .map(Self::from_inner)
560    }
561}
562
563impl TryFrom<std::os::unix::net::UnixStream> for UniStream {
564    type Error = io::Error;
565
566    /// Converts a standard library Unix stream into a unified [`UniStream`].
567    ///
568    /// # Panics
569    ///
570    /// This function panics if there is no current Tokio reactor set, or if
571    /// the `rt` feature flag is not enabled.
572    fn try_from(stream: std::os::unix::net::UnixStream) -> Result<Self, Self::Error> {
573        stream.set_nonblocking(true)?;
574
575        AsyncFd::new(stream.into()).map(Self::from_inner)
576    }
577}
578
579impl UniStream {
580    /// Returns the socket address of the remote peer of this socket.
581    ///
582    /// This function directly corresponds to the `getpeername(2)` function on
583    /// Unix.
584    ///
585    /// # Notes
586    ///
587    /// This returns an error if the socket is not
588    /// [`connect`ed](UniSocket::connect).
589    pub fn peer_addr(&self) -> io::Result<UniAddr> {
590        self.inner.get_ref().peer_addr().and_then(TryFrom::try_from)
591    }
592
593    /// Receives data on the socket from the remote adress to which it is
594    /// connected, without removing that data from the queue. On success,
595    /// returns the number of bytes peeked.
596    ///
597    /// Successive calls return the same data. This is accomplished by passing
598    /// `MSG_PEEK` as a flag to the underlying `recv` system call.
599    pub async fn peek(&self, buf: &mut [MaybeUninit<u8>]) -> io::Result<usize> {
600        loop {
601            let mut guard = self.inner.readable().await?;
602
603            match guard.try_io(|inner| inner.get_ref().peek(buf)) {
604                Ok(result) => return result,
605                Err(_would_block) => {}
606            }
607        }
608    }
609
610    /// Receives data on the socket from the remote adress to which it is
611    /// connected, without removing that data from the queue. On success,
612    /// returns the number of bytes peeked.
613    ///
614    /// Successive calls return the same data. This is accomplished by passing
615    /// `MSG_PEEK` as a flag to the underlying `recv` system call.
616    ///
617    /// Notes that on multiple calls to [`poll_peek`](Self::poll_peek), only
618    /// the waker from the [`Context`] passed to the most recent call is
619    /// scheduled to receive a wakeup. Unless you are implementing your own
620    /// future accepting connections, you probably want to use the asynchronous
621    /// [`accept`](UniListener::accept) method instead.
622    pub fn poll_peek(
623        self: Pin<&mut Self>,
624        cx: &mut Context<'_>,
625        buf: &mut ReadBuf<'_>,
626    ) -> Poll<io::Result<usize>> {
627        loop {
628            let mut guard = ready!(self.inner.poll_read_ready(cx))?;
629
630            #[allow(unsafe_code)]
631            let unfilled = unsafe { buf.unfilled_mut() };
632
633            match guard.try_io(|inner| inner.get_ref().peek(unfilled)) {
634                Ok(Ok(len)) => {
635                    // Advance initialized
636                    #[allow(unsafe_code)]
637                    unsafe {
638                        buf.assume_init(len);
639                    };
640
641                    // Advance filled
642                    buf.advance(len);
643
644                    return Poll::Ready(Ok(len));
645                }
646                Ok(Err(e)) => return Poll::Ready(Err(e)),
647                Err(_would_block) => {}
648            }
649        }
650    }
651
652    #[inline]
653    /// Receives data on the socket from the remote address to which it is
654    /// connected.
655    ///
656    /// # Cancel safety
657    ///
658    /// This method is cancel safe. Once a readiness event occurs, the method
659    /// will continue to return immediately until the readiness event is
660    /// consumed by an attempt to read or write that fails with `WouldBlock` or
661    /// `Poll::Pending`.
662    pub async fn read(&mut self, buf: &mut [MaybeUninit<u8>]) -> io::Result<usize> {
663        self.read_priv(buf).await
664    }
665
666    async fn read_priv(&self, buf: &mut [MaybeUninit<u8>]) -> io::Result<usize> {
667        loop {
668            let mut guard = self.inner.readable().await?;
669
670            match guard.try_io(|inner| inner.get_ref().recv(buf)) {
671                Ok(result) => return result,
672                Err(_would_block) => {}
673            }
674        }
675    }
676
677    fn poll_read_priv(&self, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<io::Result<()>> {
678        loop {
679            let mut guard = match self.inner.poll_read_ready(cx) {
680                Poll::Ready(Ok(guard)) => guard,
681                Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
682                Poll::Pending => {
683                    return Poll::Pending;
684                }
685            };
686
687            #[allow(unsafe_code)]
688            let unfilled = unsafe { buf.unfilled_mut() };
689
690            match guard.try_io(|inner| {
691                let ret = inner.get_ref().recv(unfilled);
692
693                ret
694            }) {
695                Ok(Ok(len)) => {
696                    // Advance initialized
697                    #[allow(unsafe_code)]
698                    unsafe {
699                        buf.assume_init(len);
700                    };
701
702                    // Advance filled
703                    buf.advance(len);
704
705                    return Poll::Ready(Ok(()));
706                }
707                Ok(Err(e)) => return Poll::Ready(Err(e)),
708                Err(_would_block) => {}
709            }
710        }
711    }
712
713    #[inline]
714    /// Sends data on the socket to a connected peer.
715    ///
716    /// # Cancel safety
717    ///
718    /// This method is cancel safe. Once a readiness event occurs, the method
719    /// will continue to return immediately until the readiness event is
720    /// consumed by an attempt to read or write that fails with `WouldBlock` or
721    /// `Poll::Pending`.
722    pub async fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
723        self.write_priv(buf).await
724    }
725
726    async fn write_priv(&self, buf: &[u8]) -> io::Result<usize> {
727        loop {
728            let mut guard = self.inner.writable().await?;
729
730            match guard.try_io(|inner: &AsyncFd<Socket>| inner.get_ref().send(buf)) {
731                Ok(result) => return result,
732                Err(_would_block) => {}
733            }
734        }
735    }
736
737    fn poll_write_priv(&self, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
738        loop {
739            let mut guard = ready!(self.inner.poll_write_ready(cx))?;
740
741            match guard.try_io(|inner| inner.get_ref().send(buf)) {
742                Ok(result) => return Poll::Ready(result),
743                Err(_would_block) => {}
744            }
745        }
746    }
747
748    /// Shuts down the read, write, or both halves of this connection.
749    ///
750    /// This function will cause all pending and future I/O on the specified
751    /// portions to return immediately with an appropriate value.
752    pub fn shutdown(&self, shutdown: Shutdown) -> io::Result<()> {
753        match self.inner.get_ref().shutdown(shutdown) {
754            Ok(()) => Ok(()),
755            Err(e) if e.kind() == io::ErrorKind::NotConnected => Ok(()),
756            Err(e) => Err(e),
757        }
758    }
759
760    /// Splits a [`UniStream`] into a read half and a write half, which can be
761    /// used to read and write the stream concurrently.
762    ///
763    /// Note: dropping the write half will shutdown the write half of the
764    /// stream.
765    pub fn into_split(self) -> (OwnedReadHalf, OwnedWriteHalf) {
766        let this = Arc::new(self);
767
768        (
769            OwnedReadHalf::from_inner(this.clone()),
770            OwnedWriteHalf::from_inner(this),
771        )
772    }
773}
774
775impl AsyncRead for UniStream {
776    #[inline]
777    /// Receives data on the socket from the remote address to which it is
778    /// connected.
779    ///
780    /// Notes that on multiple calls to [`poll_read`](Self::poll_read), only
781    /// the waker from the [`Context`] passed to the most recent call is
782    /// scheduled to receive a wakeup. Unless you are implementing your own
783    /// future accepting connections, you probably want to use the asynchronous
784    /// [`read`](Self::read) method instead.
785    fn poll_read(
786        self: Pin<&mut Self>,
787        cx: &mut Context<'_>,
788        buf: &mut ReadBuf<'_>,
789    ) -> Poll<io::Result<()>> {
790        self.poll_read_priv(cx, buf)
791    }
792}
793
794impl AsyncWrite for UniStream {
795    #[inline]
796    /// Sends data on the socket to a connected peer.
797    ///
798    /// Notes that on multiple calls to [`poll_write`](Self::poll_write), only
799    /// the waker from the [`Context`] passed to the most recent call is
800    /// scheduled to receive a wakeup. Unless you are implementing your own
801    /// future accepting connections, you probably want to use the asynchronous
802    /// [`write`](Self::write) method instead.
803    fn poll_write(
804        self: Pin<&mut Self>,
805        cx: &mut Context<'_>,
806        buf: &[u8],
807    ) -> Poll<io::Result<usize>> {
808        self.poll_write_priv(cx, buf)
809    }
810
811    #[inline]
812    /// For TCP and Unix domain sockets, `flush` is a no-op.
813    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
814        Poll::Ready(Ok(()))
815    }
816
817    /// See [`shutdown`](Self::shutdown).
818    fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
819        Poll::Ready(self.shutdown(Shutdown::Write))
820    }
821}
822
823#[cfg(feature = "splice")]
824impl tokio_splice2::AsyncReadFd for UniStream {
825    fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
826        self.inner.poll_read_ready(cx).map_ok(|_| ())
827    }
828
829    fn try_io_read<R>(&self, f: impl FnOnce() -> io::Result<R>) -> io::Result<R> {
830        use tokio::io::Interest;
831
832        self.inner.try_io(Interest::READABLE, |_| f())
833    }
834}
835
836#[cfg(feature = "splice")]
837impl tokio_splice2::AsyncWriteFd for UniStream {
838    fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
839        self.inner.poll_write_ready(cx).map_ok(|_| ())
840    }
841
842    fn try_io_write<R>(&self, f: impl FnOnce() -> io::Result<R>) -> io::Result<R> {
843        use tokio::io::Interest;
844
845        self.inner.try_io(Interest::WRITABLE, |_| f())
846    }
847}
848
849#[cfg(feature = "splice")]
850impl tokio_splice2::IsNotFile for UniStream {}
851
852wrapper_lite::wrapper!(
853    #[derive(Debug)]
854    /// An owned read half of a [`UniStream`].
855    pub struct OwnedReadHalf(Arc<UniStream>);
856);
857
858impl AsyncRead for OwnedReadHalf {
859    #[inline]
860    /// See [`poll_read`](UniStream::poll_read).
861    fn poll_read(
862        self: Pin<&mut Self>,
863        cx: &mut Context<'_>,
864        buf: &mut ReadBuf<'_>,
865    ) -> Poll<io::Result<()>> {
866        self.inner.poll_read_priv(cx, buf)
867    }
868}
869
870wrapper_lite::wrapper!(
871    #[derive(Debug)]
872    /// An owned write half of a [`UniStream`].
873    pub struct OwnedWriteHalf(Arc<UniStream>);
874);
875
876impl AsyncWrite for OwnedWriteHalf {
877    #[inline]
878    /// See [`poll_write`](UniStream::poll_write).
879    fn poll_write(
880        self: Pin<&mut Self>,
881        cx: &mut Context<'_>,
882        buf: &[u8],
883    ) -> Poll<io::Result<usize>> {
884        self.inner.poll_write_priv(cx, buf)
885    }
886
887    #[inline]
888    /// See [`poll_flush`](UniStream::poll_flush).
889    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
890        Poll::Ready(Ok(()))
891    }
892
893    /// See [`shutdown`](UniStream::shutdown).
894    fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
895        Poll::Ready(self.inner.shutdown(Shutdown::Write))
896    }
897}
898
899impl Drop for OwnedWriteHalf {
900    fn drop(&mut self) {
901        let _ = self.inner.shutdown(Shutdown::Write);
902    }
903}