1use std::marker::PhantomData;
4use std::mem::MaybeUninit;
5use std::net::{Shutdown, SocketAddr};
6use std::os::fd::{AsFd, AsRawFd, BorrowedFd, RawFd};
7use std::pin::Pin;
8use std::sync::Arc;
9use std::task::{ready, Context, Poll};
10use std::time::Duration;
11use std::{fmt, io};
12
13use socket2::{SockRef, Socket};
14use tokio::io::unix::AsyncFd;
15use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
16use tokio::time::sleep;
17use uni_addr::{UniAddr, UniAddrInner};
18
19pub struct UniSocket<Ty = ()> {
21 inner: AsyncFd<Socket>,
22 ty: PhantomData<Ty>,
23}
24
25impl fmt::Debug for UniSocket {
26 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
27 f.debug_tuple("UniSocket").field(&self.inner).finish()
28 }
29}
30
31impl<Ty> AsFd for UniSocket<Ty> {
32 #[inline]
33 fn as_fd(&self) -> BorrowedFd<'_> {
34 self.inner.as_fd()
35 }
36}
37
38impl<Ty> AsRawFd for UniSocket<Ty> {
39 #[inline]
40 fn as_raw_fd(&self) -> RawFd {
41 self.inner.as_raw_fd()
42 }
43}
44
45impl UniSocket {
46 pub fn new(addr: &UniAddr) -> io::Result<Self> {
52 Self::new_priv(addr)
53 }
54}
55
56impl<Ty> UniSocket<Ty> {
57 #[inline]
58 const fn from_inner(inner: AsyncFd<Socket>) -> Self {
59 Self {
60 inner,
61 ty: PhantomData,
62 }
63 }
64
65 fn new_priv(addr: &UniAddr) -> io::Result<Self> {
66 let ty = socket2::Type::STREAM;
67
68 #[cfg(any(
69 target_os = "android",
70 target_os = "dragonfly",
71 target_os = "freebsd",
72 target_os = "fuchsia",
73 target_os = "illumos",
74 target_os = "linux",
75 target_os = "netbsd",
76 target_os = "openbsd"
77 ))]
78 let ty = ty.nonblocking();
79
80 let inner = match addr.as_inner() {
81 UniAddrInner::Inet(SocketAddr::V4(_)) => {
82 Socket::new(socket2::Domain::IPV4, ty, Some(socket2::Protocol::TCP))
83 }
84 UniAddrInner::Inet(SocketAddr::V6(_)) => {
85 Socket::new(socket2::Domain::IPV6, ty, Some(socket2::Protocol::TCP))
86 }
87 UniAddrInner::Unix(_) => Socket::new(socket2::Domain::UNIX, ty, None),
88 UniAddrInner::Host(_) => Err(io::Error::new(
89 io::ErrorKind::Other,
90 "The Host address type must be resolved before creating a socket",
91 )),
92 _ => Err(io::Error::new(
93 io::ErrorKind::Other,
94 "Unsupported address type",
95 )),
96 }?;
97
98 #[cfg(not(any(
99 target_os = "android",
100 target_os = "dragonfly",
101 target_os = "freebsd",
102 target_os = "fuchsia",
103 target_os = "illumos",
104 target_os = "linux",
105 target_os = "netbsd",
106 target_os = "openbsd"
107 )))]
108 inner.set_nonblocking(true)?;
109
110 #[cfg(not(windows))]
118 inner.set_reuse_address(true)?;
119
120 AsyncFd::new(inner).map(Self::from_inner)
121 }
122
123 pub fn bind(self, addr: &UniAddr) -> io::Result<Self> {
127 self.inner.get_ref().bind(&addr.try_into()?)?;
128
129 Ok(Self::from_inner(self.inner))
130 }
131
132 #[cfg(any(
133 target_os = "ios",
134 target_os = "visionos",
135 target_os = "macos",
136 target_os = "tvos",
137 target_os = "watchos",
138 target_os = "illumos",
139 target_os = "solaris",
140 target_os = "linux",
141 target_os = "android",
142 target_os = "fuchsia",
143 ))]
144 pub fn bind_device(self, addr: &UniAddr, device: Option<&str>) -> io::Result<Self> {
159 if let Some(device) = device {
160 #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
161 {
162 self.inner.get_ref().bind_device(Some(device.as_bytes()))?;
163 }
164
165 #[cfg(all(
166 not(any(target_os = "android", target_os = "fuchsia", target_os = "linux")),
167 any(
168 target_os = "ios",
169 target_os = "visionos",
170 target_os = "macos",
171 target_os = "tvos",
172 target_os = "watchos",
173 target_os = "illumos",
174 target_os = "solaris",
175 )
176 ))]
177 {
178 use std::num::NonZeroU32;
179
180 #[allow(unsafe_code)]
181 let if_index = unsafe { libc::if_nametoindex(device.as_ptr().cast()) };
182
183 let Some(if_index) = NonZeroU32::new(if_index) else {
184 return Err(io::Error::last_os_error());
185 };
186
187 match addr.as_inner() {
188 UniAddrInner::Inet(SocketAddr::V4(_)) => {
189 self.inner
190 .get_ref()
191 .bind_device_by_index_v4(Some(if_index))?;
192 }
193 UniAddrInner::Inet(SocketAddr::V6(_)) => {
194 self.inner
195 .get_ref()
196 .bind_device_by_index_v6(Some(if_index))?;
197 }
198 _ => {
199 return Err(io::Error::new(
200 io::ErrorKind::Other,
201 "`bind_device_by_index` only works for IPv4 and IPv6 addresses",
202 ))
203 }
204 }
205 }
206 }
207
208 self.bind(addr)
209 }
210
211 pub fn listen(self, backlog: u32) -> io::Result<UniListener> {
219 #[allow(clippy::cast_possible_wrap)]
220 self.inner.get_ref().listen(backlog as i32)?;
221
222 Ok(UniListener::from_inner(self.inner))
223 }
224
225 pub async fn connect(self, addr: &UniAddr) -> io::Result<UniStream> {
232 if let Err(e) = self.inner.get_ref().connect(&addr.try_into()?) {
233 if e.raw_os_error() != Some(libc::EINPROGRESS) {
234 return Err(e);
235 }
236 }
237
238 let this = UniStream::from_inner(self.inner);
239
240 loop {
242 let mut guard = this.inner.writable().await?;
243
244 match guard.try_io(|inner| inner.get_ref().take_error()) {
245 Ok(Ok(None)) => break,
246 Ok(Ok(Some(e)) | Err(e)) => return Err(e),
247 Err(_would_block) => {}
248 }
249 }
250
251 Ok(this)
252 }
253
254 pub fn local_addr(&self) -> io::Result<UniAddr> {
264 self.inner
265 .get_ref()
266 .local_addr()
267 .and_then(TryFrom::try_from)
268 }
269
270 pub fn as_socket_ref(&self) -> SockRef<'_> {
272 SockRef::from(&self.inner)
273 }
274}
275
276#[derive(Debug)]
277pub struct ListenerTy;
279
280pub type UniListener = UniSocket<ListenerTy>;
282
283impl fmt::Debug for UniListener {
284 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
285 f.debug_struct("UniListener")
286 .field("local_addr", &self.local_addr().ok())
287 .finish()
288 }
289}
290
291impl TryFrom<std::net::TcpListener> for UniListener {
292 type Error = io::Error;
293
294 fn try_from(listener: std::net::TcpListener) -> Result<Self, Self::Error> {
301 listener.set_nonblocking(true)?;
302
303 AsyncFd::new(listener.into()).map(Self::from_inner)
304 }
305}
306
307impl TryFrom<tokio::net::TcpListener> for UniListener {
308 type Error = io::Error;
309
310 fn try_from(listener: tokio::net::TcpListener) -> Result<Self, Self::Error> {
317 listener
318 .into_std()
319 .map(Into::into)
320 .and_then(AsyncFd::new)
321 .map(Self::from_inner)
322 }
323}
324
325impl UniListener {
326 pub async fn accept(&self) -> io::Result<(UniStream, UniAddr)> {
336 fn accept(socket: &Socket) -> io::Result<(UniStream, UniAddr)> {
337 #[cfg(any(
342 all(not(target_arch = "x86"), target_os = "android"),
343 target_os = "dragonfly",
344 target_os = "freebsd",
345 target_os = "fuchsia",
346 target_os = "hurd",
347 target_os = "illumos",
348 target_os = "linux",
349 target_os = "netbsd",
350 target_os = "openbsd",
351 target_os = "solaris",
352 target_os = "cygwin",
353 ))]
354 let (accepted, peer_addr) = socket.accept4(libc::SOCK_CLOEXEC | libc::SOCK_NONBLOCK)?;
355
356 #[cfg(any(
360 target_os = "aix",
361 target_os = "haiku",
362 target_os = "ios",
363 target_os = "macos",
364 target_os = "redox",
365 target_os = "tvos",
366 target_os = "visionos",
367 target_os = "watchos",
368 target_os = "espidf",
369 target_os = "vita",
370 target_os = "hermit",
371 target_os = "nto",
372 all(target_arch = "x86", target_os = "android"),
373 ))]
374 let (accepted, peer_addr) = socket.accept_raw().and_then(|(accepted, peer_addr)| {
375 #[cfg(not(any(target_os = "espidf", target_os = "vita")))]
376 accepted.set_cloexec(true)?;
377
378 #[cfg(any(
379 all(target_arch = "x86", target_os = "android"),
380 target_os = "aix",
381 target_os = "espidf",
382 target_os = "vita",
383 target_os = "hermit",
384 target_os = "nto",
385 ))]
386 accepted.set_nonblocking(true)?;
387
388 #[cfg(any(
390 target_os = "ios",
391 target_os = "visionos",
392 target_os = "macos",
393 target_os = "tvos",
394 target_os = "watchos",
395 ))]
396 socket.set_nosigpipe(true)?;
397
398 Ok((accepted, peer_addr))
399 })?;
400
401 #[cfg(windows)]
404 let (accepted, peer_addr) = socket.accept_raw()?;
405
406 Ok((
407 UniStream::from_inner(AsyncFd::new(accepted)?),
408 peer_addr.try_into()?,
409 ))
410 }
411
412 loop {
413 let accepted = self
414 .inner
415 .readable()
416 .await?
417 .try_io(|socket| accept(socket.get_ref()));
418
419 match accepted {
420 Ok(ret @ Ok(_)) => {
421 return ret;
422 }
423 Ok(Err(e))
424 if matches!(
425 e.kind(),
426 io::ErrorKind::ConnectionRefused
427 | io::ErrorKind::ConnectionAborted
428 | io::ErrorKind::ConnectionReset
429 ) =>
430 {
431 }
433 Ok(Err(e)) if matches!(e.raw_os_error(), Some(libc::EMFILE)) => {
434 sleep(Duration::from_secs(1)).await;
436 }
437 Ok(Err(e)) => {
438 return Err(e);
439 }
440 Err(_would_block) => {}
441 }
442 }
443 }
444
445 pub fn poll_accept(&self, cx: &mut Context<'_>) -> Poll<io::Result<(UniStream, UniAddr)>> {
458 loop {
459 let mut guard = ready!(self.inner.poll_read_ready(cx))?;
460
461 match guard.try_io(|socket| socket.get_ref().accept()) {
462 Ok(Ok((socket, addr))) => {
463 let addr = if let Some(addr) = addr.as_socket() {
464 UniAddr::from(addr)
465 } else if let Some(addr) = addr.as_unix() {
466 UniAddr::from(addr)
467 } else {
468 return Poll::Ready(Err(io::Error::new(
469 io::ErrorKind::Other,
470 "unsupported address type",
471 )));
472 };
473
474 return Poll::Ready(Ok((UniStream::from_inner(AsyncFd::new(socket)?), addr)));
475 }
476 Ok(Err(e))
477 if matches!(
478 e.kind(),
479 io::ErrorKind::ConnectionRefused
480 | io::ErrorKind::ConnectionAborted
481 | io::ErrorKind::ConnectionReset
482 ) =>
483 {
484 }
486 Ok(Err(e)) => {
487 return Poll::Ready(Err(e));
488 }
489 Err(_would_block) => {}
490 }
491 }
492 }
493}
494
495#[derive(Debug)]
496pub struct StreamTy;
498
499pub type UniStream = UniSocket<StreamTy>;
501
502impl fmt::Debug for UniStream {
503 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
504 f.debug_struct("UniStream")
505 .field("local_addr", &self.local_addr().ok())
506 .field("peer_addr", &self.peer_addr().ok())
507 .finish()
508 }
509}
510
511impl TryFrom<tokio::net::TcpStream> for UniStream {
512 type Error = io::Error;
513
514 fn try_from(stream: tokio::net::TcpStream) -> Result<Self, Self::Error> {
521 stream
522 .into_std()
523 .map(Into::into)
524 .and_then(AsyncFd::new)
525 .map(Self::from_inner)
526 }
527}
528
529impl TryFrom<std::net::TcpStream> for UniStream {
530 type Error = io::Error;
531
532 fn try_from(stream: std::net::TcpStream) -> Result<Self, Self::Error> {
539 stream.set_nonblocking(true)?;
540
541 AsyncFd::new(stream.into()).map(Self::from_inner)
542 }
543}
544
545impl TryFrom<tokio::net::UnixStream> for UniStream {
546 type Error = io::Error;
547
548 fn try_from(stream: tokio::net::UnixStream) -> Result<Self, Self::Error> {
555 stream
556 .into_std()
557 .map(Into::into)
558 .and_then(AsyncFd::new)
559 .map(Self::from_inner)
560 }
561}
562
563impl TryFrom<std::os::unix::net::UnixStream> for UniStream {
564 type Error = io::Error;
565
566 fn try_from(stream: std::os::unix::net::UnixStream) -> Result<Self, Self::Error> {
573 stream.set_nonblocking(true)?;
574
575 AsyncFd::new(stream.into()).map(Self::from_inner)
576 }
577}
578
579impl UniStream {
580 pub fn peer_addr(&self) -> io::Result<UniAddr> {
590 self.inner.get_ref().peer_addr().and_then(TryFrom::try_from)
591 }
592
593 pub async fn peek(&self, buf: &mut [MaybeUninit<u8>]) -> io::Result<usize> {
600 loop {
601 let mut guard = self.inner.readable().await?;
602
603 match guard.try_io(|inner| inner.get_ref().peek(buf)) {
604 Ok(result) => return result,
605 Err(_would_block) => {}
606 }
607 }
608 }
609
610 pub fn poll_peek(
623 self: Pin<&mut Self>,
624 cx: &mut Context<'_>,
625 buf: &mut ReadBuf<'_>,
626 ) -> Poll<io::Result<usize>> {
627 loop {
628 let mut guard = ready!(self.inner.poll_read_ready(cx))?;
629
630 #[allow(unsafe_code)]
631 let unfilled = unsafe { buf.unfilled_mut() };
632
633 match guard.try_io(|inner| inner.get_ref().peek(unfilled)) {
634 Ok(Ok(len)) => {
635 #[allow(unsafe_code)]
637 unsafe {
638 buf.assume_init(len);
639 };
640
641 buf.advance(len);
643
644 return Poll::Ready(Ok(len));
645 }
646 Ok(Err(e)) => return Poll::Ready(Err(e)),
647 Err(_would_block) => {}
648 }
649 }
650 }
651
652 #[inline]
653 pub async fn read(&mut self, buf: &mut [MaybeUninit<u8>]) -> io::Result<usize> {
663 self.read_priv(buf).await
664 }
665
666 async fn read_priv(&self, buf: &mut [MaybeUninit<u8>]) -> io::Result<usize> {
667 loop {
668 let mut guard = self.inner.readable().await?;
669
670 match guard.try_io(|inner| inner.get_ref().recv(buf)) {
671 Ok(result) => return result,
672 Err(_would_block) => {}
673 }
674 }
675 }
676
677 fn poll_read_priv(&self, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<io::Result<()>> {
678 loop {
679 let mut guard = match self.inner.poll_read_ready(cx) {
680 Poll::Ready(Ok(guard)) => guard,
681 Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
682 Poll::Pending => {
683 return Poll::Pending;
684 }
685 };
686
687 #[allow(unsafe_code)]
688 let unfilled = unsafe { buf.unfilled_mut() };
689
690 match guard.try_io(|inner| {
691 let ret = inner.get_ref().recv(unfilled);
692
693 ret
694 }) {
695 Ok(Ok(len)) => {
696 #[allow(unsafe_code)]
698 unsafe {
699 buf.assume_init(len);
700 };
701
702 buf.advance(len);
704
705 return Poll::Ready(Ok(()));
706 }
707 Ok(Err(e)) => return Poll::Ready(Err(e)),
708 Err(_would_block) => {}
709 }
710 }
711 }
712
713 #[inline]
714 pub async fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
723 self.write_priv(buf).await
724 }
725
726 async fn write_priv(&self, buf: &[u8]) -> io::Result<usize> {
727 loop {
728 let mut guard = self.inner.writable().await?;
729
730 match guard.try_io(|inner: &AsyncFd<Socket>| inner.get_ref().send(buf)) {
731 Ok(result) => return result,
732 Err(_would_block) => {}
733 }
734 }
735 }
736
737 fn poll_write_priv(&self, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
738 loop {
739 let mut guard = ready!(self.inner.poll_write_ready(cx))?;
740
741 match guard.try_io(|inner| inner.get_ref().send(buf)) {
742 Ok(result) => return Poll::Ready(result),
743 Err(_would_block) => {}
744 }
745 }
746 }
747
748 pub fn shutdown(&self, shutdown: Shutdown) -> io::Result<()> {
753 match self.inner.get_ref().shutdown(shutdown) {
754 Ok(()) => Ok(()),
755 Err(e) if e.kind() == io::ErrorKind::NotConnected => Ok(()),
756 Err(e) => Err(e),
757 }
758 }
759
760 pub fn into_split(self) -> (OwnedReadHalf, OwnedWriteHalf) {
766 let this = Arc::new(self);
767
768 (
769 OwnedReadHalf::from_inner(this.clone()),
770 OwnedWriteHalf::from_inner(this),
771 )
772 }
773}
774
775impl AsyncRead for UniStream {
776 #[inline]
777 fn poll_read(
786 self: Pin<&mut Self>,
787 cx: &mut Context<'_>,
788 buf: &mut ReadBuf<'_>,
789 ) -> Poll<io::Result<()>> {
790 self.poll_read_priv(cx, buf)
791 }
792}
793
794impl AsyncWrite for UniStream {
795 #[inline]
796 fn poll_write(
804 self: Pin<&mut Self>,
805 cx: &mut Context<'_>,
806 buf: &[u8],
807 ) -> Poll<io::Result<usize>> {
808 self.poll_write_priv(cx, buf)
809 }
810
811 #[inline]
812 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
814 Poll::Ready(Ok(()))
815 }
816
817 fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
819 Poll::Ready(self.shutdown(Shutdown::Write))
820 }
821}
822
823#[cfg(feature = "splice")]
824impl tokio_splice2::AsyncReadFd for UniStream {
825 fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
826 self.inner.poll_read_ready(cx).map_ok(|_| ())
827 }
828
829 fn try_io_read<R>(&self, f: impl FnOnce() -> io::Result<R>) -> io::Result<R> {
830 use tokio::io::Interest;
831
832 self.inner.try_io(Interest::READABLE, |_| f())
833 }
834}
835
836#[cfg(feature = "splice")]
837impl tokio_splice2::AsyncWriteFd for UniStream {
838 fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
839 self.inner.poll_write_ready(cx).map_ok(|_| ())
840 }
841
842 fn try_io_write<R>(&self, f: impl FnOnce() -> io::Result<R>) -> io::Result<R> {
843 use tokio::io::Interest;
844
845 self.inner.try_io(Interest::WRITABLE, |_| f())
846 }
847}
848
849#[cfg(feature = "splice")]
850impl tokio_splice2::IsNotFile for UniStream {}
851
852wrapper_lite::wrapper!(
853 #[derive(Debug)]
854 pub struct OwnedReadHalf(Arc<UniStream>);
856);
857
858impl AsyncRead for OwnedReadHalf {
859 #[inline]
860 fn poll_read(
862 self: Pin<&mut Self>,
863 cx: &mut Context<'_>,
864 buf: &mut ReadBuf<'_>,
865 ) -> Poll<io::Result<()>> {
866 self.inner.poll_read_priv(cx, buf)
867 }
868}
869
870wrapper_lite::wrapper!(
871 #[derive(Debug)]
872 pub struct OwnedWriteHalf(Arc<UniStream>);
874);
875
876impl AsyncWrite for OwnedWriteHalf {
877 #[inline]
878 fn poll_write(
880 self: Pin<&mut Self>,
881 cx: &mut Context<'_>,
882 buf: &[u8],
883 ) -> Poll<io::Result<usize>> {
884 self.inner.poll_write_priv(cx, buf)
885 }
886
887 #[inline]
888 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
890 Poll::Ready(Ok(()))
891 }
892
893 fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
895 Poll::Ready(self.inner.shutdown(Shutdown::Write))
896 }
897}
898
899impl Drop for OwnedWriteHalf {
900 fn drop(&mut self) {
901 let _ = self.inner.shutdown(Shutdown::Write);
902 }
903}