tokio_anyfd/
lib.rs

1use std::io::{Result, Error};
2use std::os::unix::io::AsRawFd;
3
4use std::pin::Pin;
5use std::task::{Context, Poll};
6
7use futures::ready;
8
9use tokio::io::{AsyncRead, AsyncWrite};
10use tokio::io::unix::AsyncFd;
11use tokio::io::ReadBuf;
12
13pub struct Anyfd<T: AsRawFd> {
14  afd: AsyncFd<T>,
15}
16
17/// Wrap any suitable file descriptor `fd` as
18/// [`AsyncRead`] and [`AsyncWrite`].
19///
20/// You need to make sure the file descriptor is
21/// non-blocking. Set it with [`set_nonblocking`] if not
22/// already.
23///
24/// [`AsyncRead`]: ../tokio/io/trait.AsyncRead.html
25/// [`AsyncWrite`]: ../tokio/io/trait.AsyncWrite.html
26/// [`set_nonblocking`]: fn.set_nonblocking.html
27pub fn anyfd<T: AsRawFd>(fd: T) -> Result<Anyfd<T>> {
28  Ok(Anyfd { afd: AsyncFd::new(fd)? })
29}
30
31/// Set `fd` as non-blocking (the [`O_NONBLOCK`] flag).
32///
33/// [`O_NONBLOCK`]: ../libc/constant.O_NONBLOCK.html
34pub fn set_nonblocking(fd: impl AsRawFd) -> Result<()> {
35  let fd = fd.as_raw_fd();
36  unsafe {
37    let mut flags = libc::fcntl(fd, libc::F_GETFL);
38    if flags < 0 {
39      return Err(Error::last_os_error());
40    }
41    flags |= libc::O_NONBLOCK;
42    let r = libc::fcntl(fd, libc::F_SETFL, flags);
43    if r < 0 {
44      return Err(Error::last_os_error());
45    }
46  }
47  Ok(())
48}
49
50impl<T: AsRawFd> AsyncRead for Anyfd<T> {
51  fn poll_read(
52    self: Pin<&mut Self>,
53    cx: &mut Context<'_>,
54    buf: &mut ReadBuf<'_>
55  ) -> Poll<Result<()>> {
56    let fd = self.afd.as_raw_fd();
57    loop {
58      let mut guard = ready!(self.afd.poll_read_ready(cx))?;
59
60      match guard.try_io(|_| {
61        let r = unsafe {
62          let unfilled = buf.unfilled_mut();
63          libc::read(fd, unfilled.as_ptr() as *mut _, unfilled.len())
64        };
65        if r < 0 {
66          let err = Error::last_os_error();
67          Err(err)
68        } else {
69          unsafe { buf.assume_init(r as usize) };
70          buf.advance(r as usize);
71          Ok(())
72        }
73      }) {
74        Ok(result) => return Poll::Ready(result),
75        Err(_would_block) => continue,
76      }
77    }
78  }
79}
80
81impl<T: AsRawFd> AsyncWrite for Anyfd<T> {
82  fn poll_write(
83    self: Pin<&mut Self>,
84    cx: &mut Context<'_>,
85    buf: &[u8]
86  ) -> Poll<Result<usize>> {
87    let fd = self.afd.as_raw_fd();
88    loop {
89      let mut guard = ready!(self.afd.poll_write_ready(cx))?;
90
91      match guard.try_io(|_| {
92        let r = unsafe {
93          libc::write(fd, buf.as_ptr() as *const _, buf.len())
94        };
95        if r < 0 {
96          let err = Error::last_os_error();
97          Err(err)
98        } else {
99          Ok(r as usize)
100        }
101      }) {
102        Ok(result) => return Poll::Ready(result),
103        Err(_would_block) => continue,
104      }
105    }
106  }
107
108  fn poll_flush(
109    self: Pin<&mut Self>,
110    _cx: &mut Context<'_>,
111  ) -> Poll<Result<()>> {
112    Poll::Ready(Ok(()))
113  }
114
115  fn poll_shutdown(
116    self: Pin<&mut Self>,
117    _cx: &mut Context<'_>,
118  ) -> Poll<Result<()>> {
119    let fd = self.afd.as_raw_fd();
120    let r = unsafe {
121      libc::shutdown(fd, libc::SHUT_WR)
122    };
123    if r == 0 {
124      Poll::Ready(Ok(()))
125    } else {
126      Poll::Ready(Err(Error::last_os_error()))
127    }
128  }
129}