1use std::future::poll_fn;
7use std::mem::MaybeUninit;
8use std::net::{Shutdown, SocketAddr};
9use std::pin::Pin;
10use std::task::{Context, Poll};
11use std::{fmt, io};
12
13use socket2::SockRef;
14use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
15use tokio::net::TcpSocket;
16use uni_addr::{UniAddr, UniAddrInner};
17
18pub struct UniSocket {
20 inner: tokio::net::TcpSocket,
21}
22
23impl fmt::Debug for UniSocket {
24 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
25 f.debug_tuple("UniSocket").field(&self.inner).finish()
26 }
27}
28
29impl UniSocket {
30 #[inline]
31 const fn from_inner(inner: TcpSocket) -> Self {
32 Self { inner }
33 }
34
35 pub fn new(addr: &UniAddr) -> io::Result<Self> {
41 match addr.as_inner() {
42 UniAddrInner::Inet(SocketAddr::V4(_)) => TcpSocket::new_v4().map(Self::from_inner),
43 UniAddrInner::Inet(SocketAddr::V6(_)) => TcpSocket::new_v6().map(Self::from_inner),
44 _ => Err(io::Error::new(
45 io::ErrorKind::Other,
46 "unsupported address type",
47 )),
48 }
49 }
50
51 pub fn bind(self, addr: &UniAddr) -> io::Result<Self> {
55 match addr.as_inner() {
56 UniAddrInner::Inet(addr) => self.inner.bind(*addr)?,
57 UniAddrInner::Host(_) => {
58 return Err(io::Error::new(
59 io::ErrorKind::Other,
60 "The Host address type must be resolved before creating a socket",
61 ))
62 }
63 _ => {
64 return Err(io::Error::new(
65 io::ErrorKind::Other,
66 "unsupported address type",
67 ))
68 }
69 }
70
71 Ok(self)
72 }
73
74 pub fn listen(self, backlog: u32) -> io::Result<UniListener> {
80 self.inner.listen(backlog).map(UniListener::from_inner)
81 }
82
83 pub async fn connect(self, addr: &UniAddr) -> io::Result<UniStream> {
89 match addr.as_inner() {
90 UniAddrInner::Inet(addr) => self.inner.connect(*addr).await.map(UniStream::from_inner),
91 _ => Err(io::Error::new(
92 io::ErrorKind::Other,
93 "unsupported address type",
94 )),
95 }
96 }
97
98 pub fn local_addr(&self) -> io::Result<UniAddr> {
108 self.inner.local_addr().map(UniAddr::from)
109 }
110
111 pub fn as_socket_ref(&self) -> SockRef<'_> {
113 SockRef::from(&self.inner)
114 }
115}
116
117wrapper_lite::wrapper!(
118 pub struct UniListener(tokio::net::TcpListener);
120);
121
122impl fmt::Debug for UniListener {
123 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
124 f.debug_struct("UniListener")
125 .field("local_addr", &self.local_addr().ok())
126 .finish()
127 }
128}
129
130impl TryFrom<std::net::TcpListener> for UniListener {
131 type Error = io::Error;
132
133 fn try_from(listener: std::net::TcpListener) -> Result<Self, Self::Error> {
135 listener.set_nonblocking(true)?;
136
137 Ok(Self::from_inner(listener.try_into()?))
138 }
139}
140
141impl TryFrom<tokio::net::TcpListener> for UniListener {
142 type Error = io::Error;
143
144 fn try_from(listener: tokio::net::TcpListener) -> Result<Self, Self::Error> {
151 Ok(Self::from_inner(listener))
152 }
153}
154
155impl UniListener {
156 pub async fn accept(&self) -> io::Result<(UniStream, UniAddr)> {
165 loop {
166 match self.inner.accept().await {
167 Ok((stream, addr)) => {
168 return Ok((UniStream::from_inner(stream), UniAddr::from(addr)))
169 }
170 Err(e)
171 if matches!(
172 e.kind(),
173 io::ErrorKind::ConnectionRefused
174 | io::ErrorKind::ConnectionAborted
175 | io::ErrorKind::ConnectionReset
176 ) => {}
177 Err(e) => return Err(e),
178 }
179 }
180 }
181
182 pub fn poll_accept(&self, cx: &mut Context<'_>) -> Poll<io::Result<(UniStream, UniAddr)>> {
191 loop {
192 match self.inner.poll_accept(cx) {
193 Poll::Ready(Ok((stream, addr))) => {
194 return Poll::Ready(Ok((UniStream::from_inner(stream), UniAddr::from(addr))))
195 }
196 Poll::Ready(Err(e))
197 if matches!(
198 e.kind(),
199 io::ErrorKind::ConnectionRefused
200 | io::ErrorKind::ConnectionAborted
201 | io::ErrorKind::ConnectionReset
202 ) => {}
203 Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
204 Poll::Pending => return Poll::Pending,
205 }
206 }
207 }
208
209 pub fn local_addr(&self) -> io::Result<UniAddr> {
219 self.inner.local_addr().map(UniAddr::from)
220 }
221
222 pub fn as_socket_ref(&self) -> SockRef<'_> {
224 SockRef::from(&self.inner)
225 }
226}
227
228wrapper_lite::wrapper!(
229 pub struct UniStream(tokio::net::TcpStream);
231);
232
233impl fmt::Debug for UniStream {
234 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
235 f.debug_struct("UniStream")
236 .field("local_addr", &self.local_addr().ok())
237 .field("peer_addr", &self.peer_addr().ok())
238 .finish()
239 }
240}
241
242impl TryFrom<tokio::net::TcpStream> for UniStream {
243 type Error = io::Error;
244
245 fn try_from(value: tokio::net::TcpStream) -> Result<Self, Self::Error> {
251 Ok(Self::from_inner(value))
252 }
253}
254
255impl TryFrom<std::net::TcpStream> for UniStream {
256 type Error = std::io::Error;
257
258 fn try_from(stream: std::net::TcpStream) -> Result<Self, Self::Error> {
269 stream.set_nonblocking(true)?;
270
271 Ok(Self::from_inner(stream.try_into()?))
272 }
273}
274
275impl UniStream {
276 pub fn local_addr(&self) -> io::Result<UniAddr> {
286 self.inner.local_addr().map(UniAddr::from)
287 }
288
289 pub fn peer_addr(&self) -> io::Result<UniAddr> {
299 self.inner.peer_addr().map(UniAddr::from)
300 }
301
302 pub async fn peek(&self, buf: &mut [MaybeUninit<u8>]) -> io::Result<usize> {
309 buf.fill(MaybeUninit::new(0));
310
311 #[allow(unsafe_code)]
312 let buf = unsafe { std::slice::from_raw_parts_mut(buf.as_mut_ptr().cast(), buf.len()) };
313
314 self.inner.peek(buf).await
315 }
316
317 pub fn poll_peek(
330 self: Pin<&mut Self>,
331 cx: &mut Context<'_>,
332 buf: &mut ReadBuf<'_>,
333 ) -> Poll<io::Result<usize>> {
334 self.get_mut().inner.poll_peek(cx, buf)
335 }
336
337 #[inline]
338 pub async fn read(&mut self, buf: &mut [MaybeUninit<u8>]) -> io::Result<usize> {
341 let mut this = Pin::new(&mut self.inner);
342
343 let buf = &mut ReadBuf::uninit(buf);
344
345 poll_fn(|cx| this.as_mut().poll_read(cx, buf)).await?;
346
347 Ok(buf.filled().len())
348 }
349
350 #[inline]
351 pub async fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
353 let mut this = Pin::new(&mut self.inner);
354
355 poll_fn(|cx| this.as_mut().poll_write(cx, buf)).await
356 }
357
358 pub fn shutdown(&self, shutdown: Shutdown) -> io::Result<()> {
363 match self.as_socket_ref().shutdown(shutdown) {
364 Ok(()) => Ok(()),
365 Err(e) if e.kind() == io::ErrorKind::NotConnected => Ok(()),
366 Err(e) => Err(e),
367 }
368 }
369
370 pub fn into_split(self) -> (OwnedReadHalf, OwnedWriteHalf) {
372 self.inner.into_split()
373 }
374
375 pub fn as_socket_ref(&self) -> SockRef<'_> {
377 SockRef::from(&self.inner)
378 }
379}
380
381impl AsyncRead for UniStream {
382 #[inline]
383 fn poll_read(
384 mut self: Pin<&mut Self>,
385 cx: &mut Context<'_>,
386 buf: &mut ReadBuf<'_>,
387 ) -> Poll<io::Result<()>> {
388 Pin::new(&mut self.inner).poll_read(cx, buf)
389 }
390}
391
392impl AsyncWrite for UniStream {
393 #[inline]
394 fn poll_write(
395 mut self: Pin<&mut Self>,
396 cx: &mut Context<'_>,
397 buf: &[u8],
398 ) -> Poll<io::Result<usize>> {
399 Pin::new(&mut self.inner).poll_write(cx, buf)
400 }
401
402 #[inline]
403 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
404 Pin::new(&mut self.inner).poll_flush(cx)
405 }
406
407 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
408 Pin::new(&mut self.inner).poll_shutdown(cx)
409 }
410}
411
412pub use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
414
415#[cfg(windows)]
416mod sys {
417 use std::os::windows::io::{AsRawSocket, AsSocket, BorrowedSocket, RawSocket};
418
419 use super::{UniListener, UniSocket, UniStream};
420
421 impl AsSocket for UniSocket {
422 fn as_socket(&self) -> BorrowedSocket<'_> {
423 self.inner.as_socket()
424 }
425 }
426
427 impl AsRawSocket for UniSocket {
428 fn as_raw_socket(&self) -> RawSocket {
429 self.inner.as_raw_socket()
430 }
431 }
432
433 impl AsSocket for UniListener {
434 fn as_socket(&self) -> BorrowedSocket<'_> {
435 self.inner.as_socket()
436 }
437 }
438
439 impl AsRawSocket for UniListener {
440 fn as_raw_socket(&self) -> RawSocket {
441 self.inner.as_raw_socket()
442 }
443 }
444
445 impl AsSocket for UniStream {
446 fn as_socket(&self) -> BorrowedSocket<'_> {
447 self.inner.as_socket()
448 }
449 }
450
451 impl AsRawSocket for UniStream {
452 fn as_raw_socket(&self) -> RawSocket {
453 self.inner.as_raw_socket()
454 }
455 }
456}