1#[cfg(all(target_os = "linux", feature = "splice"))]
4pub mod splice;
5
6use std::marker::PhantomData;
7use std::mem::MaybeUninit;
8use std::net::{Shutdown, SocketAddr};
9use std::os::fd::{AsFd, AsRawFd, BorrowedFd, RawFd};
10use std::pin::Pin;
11use std::sync::Arc;
12use std::task::{ready, Context, Poll};
13use std::time::Duration;
14use std::{fmt, io};
15
16use socket2::{SockRef, Socket};
17use tokio::io::unix::AsyncFd;
18use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
19use tokio::time::sleep;
20use uni_addr::{UniAddr, UniAddrInner};
21
22wrapper_lite::wrapper!(
23 #[wrapper_impl(AsRef)]
24 pub struct UniSocket<Ty = ()> {
26 inner: AsyncFd<Socket>,
27 ty: PhantomData<Ty>,
28 }
29);
30
31impl fmt::Debug for UniSocket {
32 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
33 f.debug_tuple("UniSocket").field(&self.inner).finish()
34 }
35}
36
37impl<Ty> AsFd for UniSocket<Ty> {
38 #[inline]
39 fn as_fd(&self) -> BorrowedFd<'_> {
40 self.inner.as_fd()
41 }
42}
43
44impl<Ty> AsRawFd for UniSocket<Ty> {
45 #[inline]
46 fn as_raw_fd(&self) -> RawFd {
47 self.inner.as_raw_fd()
48 }
49}
50
51impl UniSocket {
52 pub fn new(addr: &UniAddr) -> io::Result<Self> {
58 Self::new_priv(addr)
59 }
60}
61
62impl<Ty> UniSocket<Ty> {
63 #[inline]
64 const fn from_inner(inner: AsyncFd<Socket>) -> Self {
65 Self {
66 inner,
67 ty: PhantomData,
68 }
69 }
70
71 fn new_priv(addr: &UniAddr) -> io::Result<Self> {
72 let ty = socket2::Type::STREAM;
73
74 #[cfg(any(
75 target_os = "android",
76 target_os = "dragonfly",
77 target_os = "freebsd",
78 target_os = "fuchsia",
79 target_os = "illumos",
80 target_os = "linux",
81 target_os = "netbsd",
82 target_os = "openbsd"
83 ))]
84 let ty = ty.nonblocking();
85
86 let inner = match addr.as_inner() {
87 UniAddrInner::Inet(SocketAddr::V4(_)) => {
88 Socket::new(socket2::Domain::IPV4, ty, Some(socket2::Protocol::TCP))
89 }
90 UniAddrInner::Inet(SocketAddr::V6(_)) => {
91 Socket::new(socket2::Domain::IPV6, ty, Some(socket2::Protocol::TCP))
92 }
93 UniAddrInner::Unix(_) => Socket::new(socket2::Domain::UNIX, ty, None),
94 UniAddrInner::Host(_) => Err(io::Error::new(
95 io::ErrorKind::Other,
96 "The Host address type must be resolved before creating a socket",
97 )),
98 _ => Err(io::Error::new(
99 io::ErrorKind::Other,
100 "Unsupported address type",
101 )),
102 }?;
103
104 #[cfg(not(any(
105 target_os = "android",
106 target_os = "dragonfly",
107 target_os = "freebsd",
108 target_os = "fuchsia",
109 target_os = "illumos",
110 target_os = "linux",
111 target_os = "netbsd",
112 target_os = "openbsd"
113 )))]
114 inner.set_nonblocking(true)?;
115
116 #[cfg(not(windows))]
124 inner.set_reuse_address(true)?;
125
126 AsyncFd::new(inner).map(Self::from_inner)
127 }
128
129 pub fn bind(self, addr: &UniAddr) -> io::Result<Self> {
133 self.inner.get_ref().bind(&addr.try_into()?)?;
134
135 Ok(Self::from_inner(self.inner))
136 }
137
138 #[cfg(any(
139 target_os = "ios",
140 target_os = "visionos",
141 target_os = "macos",
142 target_os = "tvos",
143 target_os = "watchos",
144 target_os = "illumos",
145 target_os = "solaris",
146 target_os = "linux",
147 target_os = "android",
148 target_os = "fuchsia",
149 ))]
150 pub fn bind_device(self, addr: &UniAddr, device: Option<&str>) -> io::Result<Self> {
165 if let Some(device) = device {
166 #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
167 {
168 self.inner.get_ref().bind_device(Some(device.as_bytes()))?;
169 }
170
171 #[cfg(all(
172 not(any(target_os = "android", target_os = "fuchsia", target_os = "linux")),
173 any(
174 target_os = "ios",
175 target_os = "visionos",
176 target_os = "macos",
177 target_os = "tvos",
178 target_os = "watchos",
179 target_os = "illumos",
180 target_os = "solaris",
181 )
182 ))]
183 {
184 use std::num::NonZeroU32;
185
186 #[allow(unsafe_code)]
187 let if_index = unsafe { libc::if_nametoindex(device.as_ptr().cast()) };
188
189 let Some(if_index) = NonZeroU32::new(if_index) else {
190 return Err(io::Error::last_os_error());
191 };
192
193 match addr.as_inner() {
194 UniAddrInner::Inet(SocketAddr::V4(_)) => {
195 self.inner
196 .get_ref()
197 .bind_device_by_index_v4(Some(if_index))?;
198 }
199 UniAddrInner::Inet(SocketAddr::V6(_)) => {
200 self.inner
201 .get_ref()
202 .bind_device_by_index_v6(Some(if_index))?;
203 }
204 _ => {
205 return Err(io::Error::new(
206 io::ErrorKind::Other,
207 "`bind_device_by_index` only works for IPv4 and IPv6 addresses",
208 ))
209 }
210 }
211 }
212 }
213
214 self.bind(addr)
215 }
216
217 pub fn listen(self, backlog: u32) -> io::Result<UniListener> {
225 #[allow(clippy::cast_possible_wrap)]
226 self.inner.get_ref().listen(backlog as i32)?;
227
228 Ok(UniListener::from_inner(self.inner))
229 }
230
231 pub async fn connect(self, addr: &UniAddr) -> io::Result<UniStream> {
238 if let Err(e) = self.inner.get_ref().connect(&addr.try_into()?) {
239 if e.raw_os_error() != Some(libc::EINPROGRESS) {
240 return Err(e);
241 }
242 }
243
244 let this = UniStream::from_inner(self.inner);
245
246 loop {
248 let mut guard = this.inner.writable().await?;
249
250 match guard.try_io(|inner| inner.get_ref().take_error()) {
251 Ok(Ok(None)) => break,
252 Ok(Ok(Some(e)) | Err(e)) => return Err(e),
253 Err(_would_block) => {}
254 }
255 }
256
257 Ok(this)
258 }
259
260 pub fn local_addr(&self) -> io::Result<UniAddr> {
270 self.inner
271 .get_ref()
272 .local_addr()
273 .and_then(TryFrom::try_from)
274 }
275
276 pub fn as_socket_ref(&self) -> SockRef<'_> {
278 SockRef::from(&self.inner)
279 }
280}
281
282#[derive(Debug)]
283pub struct ListenerTy;
285
286pub type UniListener = UniSocket<ListenerTy>;
288
289impl fmt::Debug for UniListener {
290 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
291 f.debug_struct("UniListener")
292 .field("local_addr", &self.local_addr().ok())
293 .finish()
294 }
295}
296
297impl TryFrom<std::net::TcpListener> for UniListener {
298 type Error = io::Error;
299
300 fn try_from(listener: std::net::TcpListener) -> Result<Self, Self::Error> {
307 listener.set_nonblocking(true)?;
308
309 AsyncFd::new(listener.into()).map(Self::from_inner)
310 }
311}
312
313impl TryFrom<tokio::net::TcpListener> for UniListener {
314 type Error = io::Error;
315
316 fn try_from(listener: tokio::net::TcpListener) -> Result<Self, Self::Error> {
323 listener
324 .into_std()
325 .map(Into::into)
326 .and_then(AsyncFd::new)
327 .map(Self::from_inner)
328 }
329}
330
331impl UniListener {
332 pub async fn accept(&self) -> io::Result<(UniStream, UniAddr)> {
342 fn accept(socket: &Socket) -> io::Result<(UniStream, UniAddr)> {
343 #[cfg(any(
348 all(not(target_arch = "x86"), target_os = "android"),
349 target_os = "dragonfly",
350 target_os = "freebsd",
351 target_os = "fuchsia",
352 target_os = "hurd",
353 target_os = "illumos",
354 target_os = "linux",
355 target_os = "netbsd",
356 target_os = "openbsd",
357 target_os = "solaris",
358 target_os = "cygwin",
359 ))]
360 let (accepted, peer_addr) = socket.accept4(libc::SOCK_CLOEXEC | libc::SOCK_NONBLOCK)?;
361
362 #[cfg(any(
366 target_os = "aix",
367 target_os = "haiku",
368 target_os = "ios",
369 target_os = "macos",
370 target_os = "redox",
371 target_os = "tvos",
372 target_os = "visionos",
373 target_os = "watchos",
374 target_os = "espidf",
375 target_os = "vita",
376 target_os = "hermit",
377 target_os = "nto",
378 all(target_arch = "x86", target_os = "android"),
379 ))]
380 let (accepted, peer_addr) = socket.accept_raw().and_then(|(accepted, peer_addr)| {
381 #[cfg(not(any(target_os = "espidf", target_os = "vita")))]
382 accepted.set_cloexec(true)?;
383
384 #[cfg(any(
385 all(target_arch = "x86", target_os = "android"),
386 target_os = "aix",
387 target_os = "espidf",
388 target_os = "vita",
389 target_os = "hermit",
390 target_os = "nto",
391 ))]
392 accepted.set_nonblocking(true)?;
393
394 #[cfg(any(
396 target_os = "ios",
397 target_os = "visionos",
398 target_os = "macos",
399 target_os = "tvos",
400 target_os = "watchos",
401 ))]
402 socket.set_nosigpipe(true)?;
403
404 Ok((accepted, peer_addr))
405 })?;
406
407 #[cfg(windows)]
410 let (accepted, peer_addr) = socket.accept_raw()?;
411
412 Ok((
413 UniStream::from_inner(AsyncFd::new(accepted)?),
414 peer_addr.try_into()?,
415 ))
416 }
417
418 loop {
419 let accepted = self
420 .inner
421 .readable()
422 .await?
423 .try_io(|socket| accept(socket.get_ref()));
424
425 match accepted {
426 Ok(ret @ Ok(_)) => {
427 return ret;
428 }
429 Ok(Err(e))
430 if matches!(
431 e.kind(),
432 io::ErrorKind::ConnectionRefused
433 | io::ErrorKind::ConnectionAborted
434 | io::ErrorKind::ConnectionReset
435 ) =>
436 {
437 }
439 Ok(Err(e)) if matches!(e.raw_os_error(), Some(libc::EMFILE)) => {
440 sleep(Duration::from_secs(1)).await;
442 }
443 Ok(Err(e)) => {
444 return Err(e);
445 }
446 Err(_would_block) => {}
447 }
448 }
449 }
450
451 pub fn poll_accept(&self, cx: &mut Context<'_>) -> Poll<io::Result<(UniStream, UniAddr)>> {
464 loop {
465 let mut guard = ready!(self.inner.poll_read_ready(cx))?;
466
467 match guard.try_io(|socket| socket.get_ref().accept()) {
468 Ok(Ok((socket, addr))) => {
469 let addr = if let Some(addr) = addr.as_socket() {
470 UniAddr::from(addr)
471 } else if let Some(addr) = addr.as_unix() {
472 UniAddr::from(addr)
473 } else {
474 return Poll::Ready(Err(io::Error::new(
475 io::ErrorKind::Other,
476 "unsupported address type",
477 )));
478 };
479
480 return Poll::Ready(Ok((UniStream::from_inner(AsyncFd::new(socket)?), addr)));
481 }
482 Ok(Err(e))
483 if matches!(
484 e.kind(),
485 io::ErrorKind::ConnectionRefused
486 | io::ErrorKind::ConnectionAborted
487 | io::ErrorKind::ConnectionReset
488 ) =>
489 {
490 }
492 Ok(Err(e)) => {
493 return Poll::Ready(Err(e));
494 }
495 Err(_would_block) => {}
496 }
497 }
498 }
499}
500
501#[derive(Debug)]
502pub struct StreamTy;
504
505pub type UniStream = UniSocket<StreamTy>;
507
508impl fmt::Debug for UniStream {
509 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
510 f.debug_struct("UniStream")
511 .field("local_addr", &self.local_addr().ok())
512 .field("peer_addr", &self.peer_addr().ok())
513 .finish()
514 }
515}
516
517impl TryFrom<tokio::net::TcpStream> for UniStream {
518 type Error = io::Error;
519
520 fn try_from(stream: tokio::net::TcpStream) -> Result<Self, Self::Error> {
527 stream
528 .into_std()
529 .map(Into::into)
530 .and_then(AsyncFd::new)
531 .map(Self::from_inner)
532 }
533}
534
535impl TryFrom<std::net::TcpStream> for UniStream {
536 type Error = io::Error;
537
538 fn try_from(stream: std::net::TcpStream) -> Result<Self, Self::Error> {
545 stream.set_nonblocking(true)?;
546
547 AsyncFd::new(stream.into()).map(Self::from_inner)
548 }
549}
550
551impl TryFrom<tokio::net::UnixStream> for UniStream {
552 type Error = io::Error;
553
554 fn try_from(stream: tokio::net::UnixStream) -> Result<Self, Self::Error> {
561 stream
562 .into_std()
563 .map(Into::into)
564 .and_then(AsyncFd::new)
565 .map(Self::from_inner)
566 }
567}
568
569impl TryFrom<std::os::unix::net::UnixStream> for UniStream {
570 type Error = io::Error;
571
572 fn try_from(stream: std::os::unix::net::UnixStream) -> Result<Self, Self::Error> {
579 stream.set_nonblocking(true)?;
580
581 AsyncFd::new(stream.into()).map(Self::from_inner)
582 }
583}
584
585impl UniStream {
586 pub fn peer_addr(&self) -> io::Result<UniAddr> {
596 self.inner.get_ref().peer_addr().and_then(TryFrom::try_from)
597 }
598
599 pub async fn peek(&mut self, buf: &mut [MaybeUninit<u8>]) -> io::Result<usize> {
606 loop {
607 let mut guard = self.inner.readable().await?;
608
609 match guard.try_io(|inner| inner.get_ref().peek(buf)) {
610 Ok(result) => return result,
611 Err(_would_block) => {}
612 }
613 }
614 }
615
616 pub fn poll_peek(
629 self: Pin<&mut Self>,
630 cx: &mut Context<'_>,
631 buf: &mut ReadBuf<'_>,
632 ) -> Poll<io::Result<usize>> {
633 loop {
634 let mut guard = ready!(self.inner.poll_read_ready(cx))?;
635
636 #[allow(unsafe_code)]
637 let unfilled = unsafe { buf.unfilled_mut() };
638
639 match guard.try_io(|inner| inner.get_ref().peek(unfilled)) {
640 Ok(Ok(len)) => {
641 #[allow(unsafe_code)]
643 unsafe {
644 buf.assume_init(len);
645 };
646
647 buf.advance(len);
649
650 return Poll::Ready(Ok(len));
651 }
652 Ok(Err(e)) => return Poll::Ready(Err(e)),
653 Err(_would_block) => {}
654 }
655 }
656 }
657
658 #[inline]
659 pub async fn read(&mut self, buf: &mut [MaybeUninit<u8>]) -> io::Result<usize> {
669 self.read_priv(buf).await
670 }
671
672 async fn read_priv(&self, buf: &mut [MaybeUninit<u8>]) -> io::Result<usize> {
673 loop {
674 let mut guard = self.inner.readable().await?;
675
676 match guard.try_io(|inner| inner.get_ref().recv(buf)) {
677 Ok(result) => return result,
678 Err(_would_block) => {}
679 }
680 }
681 }
682
683 fn poll_read_priv(&self, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<io::Result<()>> {
684 loop {
685 let mut guard = match self.inner.poll_read_ready(cx) {
686 Poll::Ready(Ok(guard)) => guard,
687 Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
688 Poll::Pending => {
689 return Poll::Pending;
690 }
691 };
692
693 #[allow(unsafe_code)]
694 let unfilled = unsafe { buf.unfilled_mut() };
695
696 match guard.try_io(|inner| {
697 let ret = inner.get_ref().recv(unfilled);
698
699 ret
700 }) {
701 Ok(Ok(len)) => {
702 #[allow(unsafe_code)]
704 unsafe {
705 buf.assume_init(len);
706 };
707
708 buf.advance(len);
710
711 return Poll::Ready(Ok(()));
712 }
713 Ok(Err(e)) => return Poll::Ready(Err(e)),
714 Err(_would_block) => {}
715 }
716 }
717 }
718
719 #[inline]
720 pub async fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
729 self.write_priv(buf).await
730 }
731
732 async fn write_priv(&self, buf: &[u8]) -> io::Result<usize> {
733 loop {
734 let mut guard = self.inner.writable().await?;
735
736 match guard.try_io(|inner| inner.get_ref().send(buf)) {
737 Ok(result) => return result,
738 Err(_would_block) => {}
739 }
740 }
741 }
742
743 fn poll_write_priv(&self, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
744 loop {
745 let mut guard = ready!(self.inner.poll_write_ready(cx))?;
746
747 match guard.try_io(|inner| inner.get_ref().send(buf)) {
748 Ok(result) => return Poll::Ready(result),
749 Err(_would_block) => {}
750 }
751 }
752 }
753
754 pub fn shutdown(&mut self, shutdown: Shutdown) -> io::Result<()> {
759 match self.inner.get_ref().shutdown(shutdown) {
760 Ok(()) => Ok(()),
761 Err(e) if e.kind() == io::ErrorKind::NotConnected => Ok(()),
762 Err(e) => Err(e),
763 }
764 }
765
766 pub fn into_split(self) -> (OwnedReadHalf, OwnedWriteHalf) {
772 let this = Arc::new(self);
773
774 (
775 OwnedReadHalf::from_inner(this.clone()),
776 OwnedWriteHalf::from_inner(this),
777 )
778 }
779}
780
781impl AsyncRead for UniStream {
782 #[inline]
783 fn poll_read(
792 self: Pin<&mut Self>,
793 cx: &mut Context<'_>,
794 buf: &mut ReadBuf<'_>,
795 ) -> Poll<io::Result<()>> {
796 self.poll_read_priv(cx, buf)
797 }
798}
799
800impl AsyncWrite for UniStream {
801 #[inline]
802 fn poll_write(
810 self: Pin<&mut Self>,
811 cx: &mut Context<'_>,
812 buf: &[u8],
813 ) -> Poll<io::Result<usize>> {
814 self.poll_write_priv(cx, buf)
815 }
816
817 #[inline]
818 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
820 Poll::Ready(Ok(()))
821 }
822
823 fn poll_shutdown(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
825 Poll::Ready(self.shutdown(Shutdown::Write))
826 }
827}
828
829#[cfg(feature = "splice-legacy")]
830impl tokio_splice2::AsyncReadFd for UniStream {
831 fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
832 self.inner.poll_read_ready(cx).map_ok(|_| ())
833 }
834
835 fn try_io_read<R>(&self, f: impl FnOnce() -> io::Result<R>) -> io::Result<R> {
836 use tokio::io::Interest;
837
838 self.inner.try_io(Interest::READABLE, |_| f())
839 }
840}
841
842#[cfg(feature = "splice-legacy")]
843impl tokio_splice2::AsyncWriteFd for UniStream {
844 fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
845 self.inner.poll_write_ready(cx).map_ok(|_| ())
846 }
847
848 fn try_io_write<R>(&self, f: impl FnOnce() -> io::Result<R>) -> io::Result<R> {
849 use tokio::io::Interest;
850
851 self.inner.try_io(Interest::WRITABLE, |_| f())
852 }
853}
854
855#[cfg(feature = "splice-legacy")]
856impl tokio_splice2::IsNotFile for UniStream {}
857
858wrapper_lite::wrapper!(
859 #[wrapper_impl(AsRef<UniStream>)]
860 #[derive(Debug)]
861 pub struct OwnedReadHalf(Arc<UniStream>);
863);
864
865impl AsyncRead for OwnedReadHalf {
866 #[inline]
867 fn poll_read(
869 self: Pin<&mut Self>,
870 cx: &mut Context<'_>,
871 buf: &mut ReadBuf<'_>,
872 ) -> Poll<io::Result<()>> {
873 self.inner.poll_read_priv(cx, buf)
874 }
875}
876
877wrapper_lite::wrapper!(
878 #[wrapper_impl(AsRef<UniStream>)]
879 #[derive(Debug)]
880 pub struct OwnedWriteHalf(Arc<UniStream>);
882);
883
884impl AsyncWrite for OwnedWriteHalf {
885 #[inline]
886 fn poll_write(
888 self: Pin<&mut Self>,
889 cx: &mut Context<'_>,
890 buf: &[u8],
891 ) -> Poll<io::Result<usize>> {
892 self.inner.poll_write_priv(cx, buf)
893 }
894
895 #[inline]
896 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
898 Poll::Ready(Ok(()))
899 }
900
901 fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
903 Poll::Ready(self.inner.as_socket_ref().shutdown(Shutdown::Write))
904 }
905}
906
907impl Drop for OwnedWriteHalf {
908 fn drop(&mut self) {
909 let _ = self.inner.as_socket_ref().shutdown(Shutdown::Write);
910 }
911}