uni_socket/
unix.rs

1//! A unified stream type for both TCP and Unix domain sockets.
2
3#[cfg(all(target_os = "linux", feature = "splice"))]
4pub mod splice;
5
6use std::marker::PhantomData;
7use std::mem::MaybeUninit;
8use std::net::{Shutdown, SocketAddr};
9use std::os::fd::{AsFd, AsRawFd, BorrowedFd, RawFd};
10use std::pin::Pin;
11use std::sync::Arc;
12use std::task::{ready, Context, Poll};
13use std::time::Duration;
14use std::{fmt, io};
15
16use socket2::{SockRef, Socket};
17use tokio::io::unix::AsyncFd;
18use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
19use tokio::time::sleep;
20use uni_addr::{UniAddr, UniAddrInner};
21
22wrapper_lite::wrapper!(
23    #[wrapper_impl(AsRef)]
24    /// An async [`Socket`].
25    pub struct UniSocket<Ty = ()> {
26        inner: AsyncFd<Socket>,
27        ty: PhantomData<Ty>,
28    }
29);
30
31impl fmt::Debug for UniSocket {
32    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
33        f.debug_tuple("UniSocket").field(&self.inner).finish()
34    }
35}
36
37impl<Ty> AsFd for UniSocket<Ty> {
38    #[inline]
39    fn as_fd(&self) -> BorrowedFd<'_> {
40        self.inner.as_fd()
41    }
42}
43
44impl<Ty> AsRawFd for UniSocket<Ty> {
45    #[inline]
46    fn as_raw_fd(&self) -> RawFd {
47        self.inner.as_raw_fd()
48    }
49}
50
51impl UniSocket {
52    /// Creates a new [`UniSocket`], and applies the given initialization
53    /// function to the underlying socket.
54    ///
55    /// The given address determines the socket type, and the caller should bind
56    /// to / connect to the address later.
57    pub fn new(addr: &UniAddr) -> io::Result<Self> {
58        Self::new_priv(addr)
59    }
60}
61
62impl<Ty> UniSocket<Ty> {
63    #[inline]
64    const fn from_inner(inner: AsyncFd<Socket>) -> Self {
65        Self {
66            inner,
67            ty: PhantomData,
68        }
69    }
70
71    fn new_priv(addr: &UniAddr) -> io::Result<Self> {
72        let ty = socket2::Type::STREAM;
73
74        #[cfg(any(
75            target_os = "android",
76            target_os = "dragonfly",
77            target_os = "freebsd",
78            target_os = "fuchsia",
79            target_os = "illumos",
80            target_os = "linux",
81            target_os = "netbsd",
82            target_os = "openbsd"
83        ))]
84        let ty = ty.nonblocking();
85
86        let inner = match addr.as_inner() {
87            UniAddrInner::Inet(SocketAddr::V4(_)) => {
88                Socket::new(socket2::Domain::IPV4, ty, Some(socket2::Protocol::TCP))
89            }
90            UniAddrInner::Inet(SocketAddr::V6(_)) => {
91                Socket::new(socket2::Domain::IPV6, ty, Some(socket2::Protocol::TCP))
92            }
93            UniAddrInner::Unix(_) => Socket::new(socket2::Domain::UNIX, ty, None),
94            UniAddrInner::Host(_) => Err(io::Error::new(
95                io::ErrorKind::Other,
96                "The Host address type must be resolved before creating a socket",
97            )),
98            _ => Err(io::Error::new(
99                io::ErrorKind::Other,
100                "Unsupported address type",
101            )),
102        }?;
103
104        #[cfg(not(any(
105            target_os = "android",
106            target_os = "dragonfly",
107            target_os = "freebsd",
108            target_os = "fuchsia",
109            target_os = "illumos",
110            target_os = "linux",
111            target_os = "netbsd",
112            target_os = "openbsd"
113        )))]
114        inner.set_nonblocking(true)?;
115
116        // On platforms with Berkeley-derived sockets, this allows to quickly
117        // rebind a socket, without needing to wait for the OS to clean up the
118        // previous one.
119        //
120        // On Windows, this allows rebinding sockets which are actively in use,
121        // which allows “socket hijacking”, so we explicitly don't set it here.
122        // https://docs.microsoft.com/en-us/windows/win32/winsock/using-so-reuseaddr-and-so-exclusiveaddruse
123        #[cfg(not(windows))]
124        inner.set_reuse_address(true)?;
125
126        AsyncFd::new(inner).map(Self::from_inner)
127    }
128
129    /// Binds the socket to the specified address.
130    ///
131    /// Notes that the address must be the one used to create the socket.
132    pub fn bind(self, addr: &UniAddr) -> io::Result<Self> {
133        self.inner.get_ref().bind(&addr.try_into()?)?;
134
135        Ok(Self::from_inner(self.inner))
136    }
137
138    #[cfg(any(
139        target_os = "ios",
140        target_os = "visionos",
141        target_os = "macos",
142        target_os = "tvos",
143        target_os = "watchos",
144        target_os = "illumos",
145        target_os = "solaris",
146        target_os = "linux",
147        target_os = "android",
148        target_os = "fuchsia",
149    ))]
150    /// Sets the value for the `SO_BINDTODEVICE` option on this socket, then
151    /// [`bind`s](Self::bind) the socket to the specified address.
152    ///
153    /// If a socket is bound to an interface, only packets received from that
154    /// particular interface are processed by the socket. Note that this only
155    /// works for some socket types, particularly `AF_INET` sockets.
156    ///
157    /// For those platforms, like macOS, that do not support `SO_BINDTODEVICE`,
158    /// this function will fallback to `bind_device_by_index_v(4|6)`, while the
159    /// `if_index` obtained from the interface name with `if_nametoindex(3)`.
160    ///
161    /// When `device` is `None`, this is equivalent to calling
162    /// [`bind`](Self::bind), instead of clearing the option like
163    /// [`Socket::bind_device`].
164    pub fn bind_device(self, addr: &UniAddr, device: Option<&str>) -> io::Result<Self> {
165        if let Some(device) = device {
166            #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
167            {
168                self.inner.get_ref().bind_device(Some(device.as_bytes()))?;
169            }
170
171            #[cfg(all(
172                not(any(target_os = "android", target_os = "fuchsia", target_os = "linux")),
173                any(
174                    target_os = "ios",
175                    target_os = "visionos",
176                    target_os = "macos",
177                    target_os = "tvos",
178                    target_os = "watchos",
179                    target_os = "illumos",
180                    target_os = "solaris",
181                )
182            ))]
183            {
184                use std::num::NonZeroU32;
185
186                #[allow(unsafe_code)]
187                let if_index = unsafe { libc::if_nametoindex(device.as_ptr().cast()) };
188
189                let Some(if_index) = NonZeroU32::new(if_index) else {
190                    return Err(io::Error::last_os_error());
191                };
192
193                match addr.as_inner() {
194                    UniAddrInner::Inet(SocketAddr::V4(_)) => {
195                        self.inner
196                            .get_ref()
197                            .bind_device_by_index_v4(Some(if_index))?;
198                    }
199                    UniAddrInner::Inet(SocketAddr::V6(_)) => {
200                        self.inner
201                            .get_ref()
202                            .bind_device_by_index_v6(Some(if_index))?;
203                    }
204                    _ => {
205                        return Err(io::Error::new(
206                            io::ErrorKind::Other,
207                            "`bind_device_by_index` only works for IPv4 and IPv6 addresses",
208                        ))
209                    }
210                }
211            }
212        }
213
214        self.bind(addr)
215    }
216
217    /// Mark a socket as ready to accept incoming connection requests using
218    /// [`UniListener::accept`].
219    ///
220    /// This function directly corresponds to the `listen(2)` function on Unix.
221    ///
222    /// An error will be returned if `listen` or `connect` has already been
223    /// called on this builder.
224    pub fn listen(self, backlog: u32) -> io::Result<UniListener> {
225        #[allow(clippy::cast_possible_wrap)]
226        self.inner.get_ref().listen(backlog as i32)?;
227
228        Ok(UniListener::from_inner(self.inner))
229    }
230
231    /// Initiates and completes a connection on this socket to the specified
232    /// address.
233    ///
234    /// This function directly corresponds to the `connect(2)` function on Unix.
235    ///
236    /// An error will be returned if `connect` has already been called.
237    pub async fn connect(self, addr: &UniAddr) -> io::Result<UniStream> {
238        if let Err(e) = self.inner.get_ref().connect(&addr.try_into()?) {
239            if e.raw_os_error() != Some(libc::EINPROGRESS) {
240                return Err(e);
241            }
242        }
243
244        let this = UniStream::from_inner(self.inner);
245
246        // Poll connection completion
247        loop {
248            let mut guard = this.inner.writable().await?;
249
250            match guard.try_io(|inner| inner.get_ref().take_error()) {
251                Ok(Ok(None)) => break,
252                Ok(Ok(Some(e)) | Err(e)) => return Err(e),
253                Err(_would_block) => {}
254            }
255        }
256
257        Ok(this)
258    }
259
260    /// Returns the socket address of the local half of this socket.
261    ///
262    /// This function directly corresponds to the `getsockname(2)` function on
263    /// Windows and Unix.
264    ///
265    /// # Notes
266    ///
267    /// Depending on the OS this may return an error if the socket is not
268    /// [bound](Self::bind).
269    pub fn local_addr(&self) -> io::Result<UniAddr> {
270        self.inner
271            .get_ref()
272            .local_addr()
273            .and_then(TryFrom::try_from)
274    }
275
276    /// Returns a [`SockRef`] to the underlying socket for configuration.
277    pub fn as_socket_ref(&self) -> SockRef<'_> {
278        SockRef::from(&self.inner)
279    }
280}
281
282#[derive(Debug)]
283/// Marker type: this socket is a listener.
284pub struct ListenerTy;
285
286/// A [`UniSocket`] used as a listener.
287pub type UniListener = UniSocket<ListenerTy>;
288
289impl fmt::Debug for UniListener {
290    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
291        f.debug_struct("UniListener")
292            .field("local_addr", &self.local_addr().ok())
293            .finish()
294    }
295}
296
297impl TryFrom<std::net::TcpListener> for UniListener {
298    type Error = io::Error;
299
300    /// Converts a standard library TCP listener into a unified [`UniListener`].
301    ///
302    /// # Panics
303    ///
304    /// This function panics if there is no current Tokio reactor set, or if
305    /// the `rt` feature flag is not enabled.
306    fn try_from(listener: std::net::TcpListener) -> Result<Self, Self::Error> {
307        listener.set_nonblocking(true)?;
308
309        AsyncFd::new(listener.into()).map(Self::from_inner)
310    }
311}
312
313impl TryFrom<tokio::net::TcpListener> for UniListener {
314    type Error = io::Error;
315
316    /// Converts a Tokio library TCP listener into a unified [`UniListener`].
317    ///
318    /// # Panics
319    ///
320    /// This function panics if there is no current Tokio reactor set, or if
321    /// the `rt` feature flag is not enabled.
322    fn try_from(listener: tokio::net::TcpListener) -> Result<Self, Self::Error> {
323        listener
324            .into_std()
325            .map(Into::into)
326            .and_then(AsyncFd::new)
327            .map(Self::from_inner)
328    }
329}
330
331impl UniListener {
332    /// Accepts an incoming connection to this listener, and returns the
333    /// accepted stream and the peer address.
334    ///
335    /// This method will retry on non-deadly errors, including:
336    ///
337    /// - `ECONNREFUSED`.
338    /// - `ECONNABORTED`.
339    /// - `ECONNRESET`.
340    /// - `EMFILE`.
341    pub async fn accept(&self) -> io::Result<(UniStream, UniAddr)> {
342        fn accept(socket: &Socket) -> io::Result<(UniStream, UniAddr)> {
343            // On platforms that support it we can use `accept4(2)` to set `NONBLOCK`
344            // and `CLOEXEC` in the call to accept the connection.
345            // Android x86's seccomp profile forbids calls to `accept4(2)`
346            // See https://github.com/tokio-rs/mio/issues/1445 for details
347            #[cfg(any(
348                all(not(target_arch = "x86"), target_os = "android"),
349                target_os = "dragonfly",
350                target_os = "freebsd",
351                target_os = "fuchsia",
352                target_os = "hurd",
353                target_os = "illumos",
354                target_os = "linux",
355                target_os = "netbsd",
356                target_os = "openbsd",
357                target_os = "solaris",
358                target_os = "cygwin",
359            ))]
360            let (accepted, peer_addr) = socket.accept4(libc::SOCK_CLOEXEC | libc::SOCK_NONBLOCK)?;
361
362            // But not all platforms have the `accept4(2)` call. Luckily BSD (derived)
363            // OSs inherit the non-blocking flag from the listener, so we just have to
364            // set `CLOEXEC`.
365            #[cfg(any(
366                target_os = "aix",
367                target_os = "haiku",
368                target_os = "ios",
369                target_os = "macos",
370                target_os = "redox",
371                target_os = "tvos",
372                target_os = "visionos",
373                target_os = "watchos",
374                target_os = "espidf",
375                target_os = "vita",
376                target_os = "hermit",
377                target_os = "nto",
378                all(target_arch = "x86", target_os = "android"),
379            ))]
380            let (accepted, peer_addr) = socket.accept_raw().and_then(|(accepted, peer_addr)| {
381                #[cfg(not(any(target_os = "espidf", target_os = "vita")))]
382                accepted.set_cloexec(true)?;
383
384                #[cfg(any(
385                    all(target_arch = "x86", target_os = "android"),
386                    target_os = "aix",
387                    target_os = "espidf",
388                    target_os = "vita",
389                    target_os = "hermit",
390                    target_os = "nto",
391                ))]
392                accepted.set_nonblocking(true)?;
393
394                // On Apple platforms set `NOSIGPIPE`.
395                #[cfg(any(
396                    target_os = "ios",
397                    target_os = "visionos",
398                    target_os = "macos",
399                    target_os = "tvos",
400                    target_os = "watchos",
401                ))]
402                socket.set_nosigpipe(true)?;
403
404                Ok((accepted, peer_addr))
405            })?;
406
407            // The non-blocking state of `listener` is inherited. See
408            // https://docs.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-accept#remarks.
409            #[cfg(windows)]
410            let (accepted, peer_addr) = socket.accept_raw()?;
411
412            Ok((
413                UniStream::from_inner(AsyncFd::new(accepted)?),
414                peer_addr.try_into()?,
415            ))
416        }
417
418        loop {
419            let accepted = self
420                .inner
421                .readable()
422                .await?
423                .try_io(|socket| accept(socket.get_ref()));
424
425            match accepted {
426                Ok(ret @ Ok(_)) => {
427                    return ret;
428                }
429                Ok(Err(e))
430                    if matches!(
431                        e.kind(),
432                        io::ErrorKind::ConnectionRefused
433                            | io::ErrorKind::ConnectionAborted
434                            | io::ErrorKind::ConnectionReset
435                    ) =>
436                {
437                    // This is not a deadly error, too, just continue
438                }
439                Ok(Err(e)) if matches!(e.raw_os_error(), Some(libc::EMFILE)) => {
440                    // This is not a deadly error, but we may wait for a while.
441                    sleep(Duration::from_secs(1)).await;
442                }
443                Ok(Err(e)) => {
444                    return Err(e);
445                }
446                Err(_would_block) => {}
447            }
448        }
449    }
450
451    /// Accepts an incoming connection to this listener, and returns the
452    /// accepted stream and the peer address.
453    ///
454    /// Notes that on multiple calls to [`poll_accept`](Self::poll_accept), only
455    /// the waker from the [`Context`] passed to the most recent call is
456    /// scheduled to receive a wakeup. Unless you are implementing your own
457    /// future accepting connections, you probably want to use the asynchronous
458    /// [`accept`](Self::accept) method instead.
459    ///
460    /// Unlike [`accept`](Self::accept), this method does not handle `EMFILE`
461    /// (i.e., too many open files) errors and the caller may need to handle
462    /// it by itself.
463    pub fn poll_accept(&self, cx: &mut Context<'_>) -> Poll<io::Result<(UniStream, UniAddr)>> {
464        loop {
465            let mut guard = ready!(self.inner.poll_read_ready(cx))?;
466
467            match guard.try_io(|socket| socket.get_ref().accept()) {
468                Ok(Ok((socket, addr))) => {
469                    let addr = if let Some(addr) = addr.as_socket() {
470                        UniAddr::from(addr)
471                    } else if let Some(addr) = addr.as_unix() {
472                        UniAddr::from(addr)
473                    } else {
474                        return Poll::Ready(Err(io::Error::new(
475                            io::ErrorKind::Other,
476                            "unsupported address type",
477                        )));
478                    };
479
480                    return Poll::Ready(Ok((UniStream::from_inner(AsyncFd::new(socket)?), addr)));
481                }
482                Ok(Err(e))
483                    if matches!(
484                        e.kind(),
485                        io::ErrorKind::ConnectionRefused
486                            | io::ErrorKind::ConnectionAborted
487                            | io::ErrorKind::ConnectionReset
488                    ) =>
489                {
490                    // This is not a deadly error, too, just continue
491                }
492                Ok(Err(e)) => {
493                    return Poll::Ready(Err(e));
494                }
495                Err(_would_block) => {}
496            }
497        }
498    }
499}
500
501#[derive(Debug)]
502/// Marker type: this socket is a (TCP) stream.
503pub struct StreamTy;
504
505/// A [`UniSocket`] used as a stream.
506pub type UniStream = UniSocket<StreamTy>;
507
508impl fmt::Debug for UniStream {
509    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
510        f.debug_struct("UniStream")
511            .field("local_addr", &self.local_addr().ok())
512            .field("peer_addr", &self.peer_addr().ok())
513            .finish()
514    }
515}
516
517impl TryFrom<tokio::net::TcpStream> for UniStream {
518    type Error = io::Error;
519
520    /// Converts a Tokio TCP stream into a unified [`UniStream`].
521    ///
522    /// # Panics
523    ///
524    /// This function panics if there is no current Tokio reactor set, or if
525    /// the `rt` feature flag is not enabled.
526    fn try_from(stream: tokio::net::TcpStream) -> Result<Self, Self::Error> {
527        stream
528            .into_std()
529            .map(Into::into)
530            .and_then(AsyncFd::new)
531            .map(Self::from_inner)
532    }
533}
534
535impl TryFrom<std::net::TcpStream> for UniStream {
536    type Error = io::Error;
537
538    /// Converts a standard library TCP stream into a unified [`UniStream`].
539    ///
540    /// # Panics
541    ///
542    /// This function panics if there is no current Tokio reactor set, or if
543    /// the `rt` feature flag is not enabled.
544    fn try_from(stream: std::net::TcpStream) -> Result<Self, Self::Error> {
545        stream.set_nonblocking(true)?;
546
547        AsyncFd::new(stream.into()).map(Self::from_inner)
548    }
549}
550
551impl TryFrom<tokio::net::UnixStream> for UniStream {
552    type Error = io::Error;
553
554    /// Converts a Tokio Unix stream into a unified [`UniStream`].
555    ///
556    /// # Panics
557    ///
558    /// This function panics if there is no current Tokio reactor set, or if
559    /// the `rt` feature flag is not enabled.
560    fn try_from(stream: tokio::net::UnixStream) -> Result<Self, Self::Error> {
561        stream
562            .into_std()
563            .map(Into::into)
564            .and_then(AsyncFd::new)
565            .map(Self::from_inner)
566    }
567}
568
569impl TryFrom<std::os::unix::net::UnixStream> for UniStream {
570    type Error = io::Error;
571
572    /// Converts a standard library Unix stream into a unified [`UniStream`].
573    ///
574    /// # Panics
575    ///
576    /// This function panics if there is no current Tokio reactor set, or if
577    /// the `rt` feature flag is not enabled.
578    fn try_from(stream: std::os::unix::net::UnixStream) -> Result<Self, Self::Error> {
579        stream.set_nonblocking(true)?;
580
581        AsyncFd::new(stream.into()).map(Self::from_inner)
582    }
583}
584
585impl UniStream {
586    /// Returns the socket address of the remote peer of this socket.
587    ///
588    /// This function directly corresponds to the `getpeername(2)` function on
589    /// Unix.
590    ///
591    /// # Notes
592    ///
593    /// This returns an error if the socket is not
594    /// [`connect`ed](UniSocket::connect).
595    pub fn peer_addr(&self) -> io::Result<UniAddr> {
596        self.inner.get_ref().peer_addr().and_then(TryFrom::try_from)
597    }
598
599    /// Receives data on the socket from the remote adress to which it is
600    /// connected, without removing that data from the queue. On success,
601    /// returns the number of bytes peeked.
602    ///
603    /// Successive calls return the same data. This is accomplished by passing
604    /// `MSG_PEEK` as a flag to the underlying `recv` system call.
605    pub async fn peek(&mut self, buf: &mut [MaybeUninit<u8>]) -> io::Result<usize> {
606        loop {
607            let mut guard = self.inner.readable().await?;
608
609            match guard.try_io(|inner| inner.get_ref().peek(buf)) {
610                Ok(result) => return result,
611                Err(_would_block) => {}
612            }
613        }
614    }
615
616    /// Receives data on the socket from the remote adress to which it is
617    /// connected, without removing that data from the queue. On success,
618    /// returns the number of bytes peeked.
619    ///
620    /// Successive calls return the same data. This is accomplished by passing
621    /// `MSG_PEEK` as a flag to the underlying `recv` system call.
622    ///
623    /// Notes that on multiple calls to [`poll_peek`](Self::poll_peek), only
624    /// the waker from the [`Context`] passed to the most recent call is
625    /// scheduled to receive a wakeup. Unless you are implementing your own
626    /// future accepting connections, you probably want to use the asynchronous
627    /// [`accept`](UniListener::accept) method instead.
628    pub fn poll_peek(
629        self: Pin<&mut Self>,
630        cx: &mut Context<'_>,
631        buf: &mut ReadBuf<'_>,
632    ) -> Poll<io::Result<usize>> {
633        loop {
634            let mut guard = ready!(self.inner.poll_read_ready(cx))?;
635
636            #[allow(unsafe_code)]
637            let unfilled = unsafe { buf.unfilled_mut() };
638
639            match guard.try_io(|inner| inner.get_ref().peek(unfilled)) {
640                Ok(Ok(len)) => {
641                    // Advance initialized
642                    #[allow(unsafe_code)]
643                    unsafe {
644                        buf.assume_init(len);
645                    };
646
647                    // Advance filled
648                    buf.advance(len);
649
650                    return Poll::Ready(Ok(len));
651                }
652                Ok(Err(e)) => return Poll::Ready(Err(e)),
653                Err(_would_block) => {}
654            }
655        }
656    }
657
658    #[inline]
659    /// Receives data on the socket from the remote address to which it is
660    /// connected.
661    ///
662    /// # Cancel safety
663    ///
664    /// This method is cancel safe. Once a readiness event occurs, the method
665    /// will continue to return immediately until the readiness event is
666    /// consumed by an attempt to read or write that fails with `WouldBlock` or
667    /// `Poll::Pending`.
668    pub async fn read(&mut self, buf: &mut [MaybeUninit<u8>]) -> io::Result<usize> {
669        self.read_priv(buf).await
670    }
671
672    async fn read_priv(&self, buf: &mut [MaybeUninit<u8>]) -> io::Result<usize> {
673        loop {
674            let mut guard = self.inner.readable().await?;
675
676            match guard.try_io(|inner| inner.get_ref().recv(buf)) {
677                Ok(result) => return result,
678                Err(_would_block) => {}
679            }
680        }
681    }
682
683    fn poll_read_priv(&self, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<io::Result<()>> {
684        loop {
685            let mut guard = match self.inner.poll_read_ready(cx) {
686                Poll::Ready(Ok(guard)) => guard,
687                Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
688                Poll::Pending => {
689                    return Poll::Pending;
690                }
691            };
692
693            #[allow(unsafe_code)]
694            let unfilled = unsafe { buf.unfilled_mut() };
695
696            match guard.try_io(|inner| {
697                let ret = inner.get_ref().recv(unfilled);
698
699                ret
700            }) {
701                Ok(Ok(len)) => {
702                    // Advance initialized
703                    #[allow(unsafe_code)]
704                    unsafe {
705                        buf.assume_init(len);
706                    };
707
708                    // Advance filled
709                    buf.advance(len);
710
711                    return Poll::Ready(Ok(()));
712                }
713                Ok(Err(e)) => return Poll::Ready(Err(e)),
714                Err(_would_block) => {}
715            }
716        }
717    }
718
719    #[inline]
720    /// Sends data on the socket to a connected peer.
721    ///
722    /// # Cancel safety
723    ///
724    /// This method is cancel safe. Once a readiness event occurs, the method
725    /// will continue to return immediately until the readiness event is
726    /// consumed by an attempt to read or write that fails with `WouldBlock` or
727    /// `Poll::Pending`.
728    pub async fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
729        self.write_priv(buf).await
730    }
731
732    async fn write_priv(&self, buf: &[u8]) -> io::Result<usize> {
733        loop {
734            let mut guard = self.inner.writable().await?;
735
736            match guard.try_io(|inner| inner.get_ref().send(buf)) {
737                Ok(result) => return result,
738                Err(_would_block) => {}
739            }
740        }
741    }
742
743    fn poll_write_priv(&self, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
744        loop {
745            let mut guard = ready!(self.inner.poll_write_ready(cx))?;
746
747            match guard.try_io(|inner| inner.get_ref().send(buf)) {
748                Ok(result) => return Poll::Ready(result),
749                Err(_would_block) => {}
750            }
751        }
752    }
753
754    /// Shuts down the read, write, or both halves of this connection.
755    ///
756    /// This function will cause all pending and future I/O on the specified
757    /// portions to return immediately with an appropriate value.
758    pub fn shutdown(&mut self, shutdown: Shutdown) -> io::Result<()> {
759        match self.inner.get_ref().shutdown(shutdown) {
760            Ok(()) => Ok(()),
761            Err(e) if e.kind() == io::ErrorKind::NotConnected => Ok(()),
762            Err(e) => Err(e),
763        }
764    }
765
766    /// Splits a [`UniStream`] into a read half and a write half, which can be
767    /// used to read and write the stream concurrently.
768    ///
769    /// Note: dropping the write half will shutdown the write half of the
770    /// stream.
771    pub fn into_split(self) -> (OwnedReadHalf, OwnedWriteHalf) {
772        let this = Arc::new(self);
773
774        (
775            OwnedReadHalf::from_inner(this.clone()),
776            OwnedWriteHalf::from_inner(this),
777        )
778    }
779}
780
781impl AsyncRead for UniStream {
782    #[inline]
783    /// Receives data on the socket from the remote address to which it is
784    /// connected.
785    ///
786    /// Notes that on multiple calls to [`poll_read`](Self::poll_read), only
787    /// the waker from the [`Context`] passed to the most recent call is
788    /// scheduled to receive a wakeup. Unless you are implementing your own
789    /// future accepting connections, you probably want to use the asynchronous
790    /// [`read`](Self::read) method instead.
791    fn poll_read(
792        self: Pin<&mut Self>,
793        cx: &mut Context<'_>,
794        buf: &mut ReadBuf<'_>,
795    ) -> Poll<io::Result<()>> {
796        self.poll_read_priv(cx, buf)
797    }
798}
799
800impl AsyncWrite for UniStream {
801    #[inline]
802    /// Sends data on the socket to a connected peer.
803    ///
804    /// Notes that on multiple calls to [`poll_write`](Self::poll_write), only
805    /// the waker from the [`Context`] passed to the most recent call is
806    /// scheduled to receive a wakeup. Unless you are implementing your own
807    /// future accepting connections, you probably want to use the asynchronous
808    /// [`write`](Self::write) method instead.
809    fn poll_write(
810        self: Pin<&mut Self>,
811        cx: &mut Context<'_>,
812        buf: &[u8],
813    ) -> Poll<io::Result<usize>> {
814        self.poll_write_priv(cx, buf)
815    }
816
817    #[inline]
818    /// For TCP and Unix domain sockets, `flush` is a no-op.
819    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
820        Poll::Ready(Ok(()))
821    }
822
823    /// See [`shutdown`](Self::shutdown).
824    fn poll_shutdown(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
825        Poll::Ready(self.shutdown(Shutdown::Write))
826    }
827}
828
829#[cfg(feature = "splice-legacy")]
830impl tokio_splice2::AsyncReadFd for UniStream {
831    fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
832        self.inner.poll_read_ready(cx).map_ok(|_| ())
833    }
834
835    fn try_io_read<R>(&self, f: impl FnOnce() -> io::Result<R>) -> io::Result<R> {
836        use tokio::io::Interest;
837
838        self.inner.try_io(Interest::READABLE, |_| f())
839    }
840}
841
842#[cfg(feature = "splice-legacy")]
843impl tokio_splice2::AsyncWriteFd for UniStream {
844    fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
845        self.inner.poll_write_ready(cx).map_ok(|_| ())
846    }
847
848    fn try_io_write<R>(&self, f: impl FnOnce() -> io::Result<R>) -> io::Result<R> {
849        use tokio::io::Interest;
850
851        self.inner.try_io(Interest::WRITABLE, |_| f())
852    }
853}
854
855#[cfg(feature = "splice-legacy")]
856impl tokio_splice2::IsNotFile for UniStream {}
857
858wrapper_lite::wrapper!(
859    #[wrapper_impl(AsRef<UniStream>)]
860    #[derive(Debug)]
861    /// An owned read half of a [`UniStream`].
862    pub struct OwnedReadHalf(Arc<UniStream>);
863);
864
865impl AsyncRead for OwnedReadHalf {
866    #[inline]
867    /// See [`poll_read`](UniStream::poll_read).
868    fn poll_read(
869        self: Pin<&mut Self>,
870        cx: &mut Context<'_>,
871        buf: &mut ReadBuf<'_>,
872    ) -> Poll<io::Result<()>> {
873        self.inner.poll_read_priv(cx, buf)
874    }
875}
876
877wrapper_lite::wrapper!(
878    #[wrapper_impl(AsRef<UniStream>)]
879    #[derive(Debug)]
880    /// An owned write half of a [`UniStream`].
881    pub struct OwnedWriteHalf(Arc<UniStream>);
882);
883
884impl AsyncWrite for OwnedWriteHalf {
885    #[inline]
886    /// See [`poll_write`](UniStream::poll_write).
887    fn poll_write(
888        self: Pin<&mut Self>,
889        cx: &mut Context<'_>,
890        buf: &[u8],
891    ) -> Poll<io::Result<usize>> {
892        self.inner.poll_write_priv(cx, buf)
893    }
894
895    #[inline]
896    /// See [`poll_flush`](UniStream::poll_flush).
897    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
898        Poll::Ready(Ok(()))
899    }
900
901    /// See [`shutdown`](UniStream::shutdown).
902    fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
903        Poll::Ready(self.inner.as_socket_ref().shutdown(Shutdown::Write))
904    }
905}
906
907impl Drop for OwnedWriteHalf {
908    fn drop(&mut self) {
909        let _ = self.inner.as_socket_ref().shutdown(Shutdown::Write);
910    }
911}