#![cfg_attr(docsrs, feature(doc_cfg))]
use std::{
io::{IoSlice, IoSliceMut, Read, Write},
os::fd::{AsRawFd, BorrowedFd, FromRawFd, OwnedFd, RawFd},
};
use nix::sys::socket::ControlMessageOwned;
#[cfg_attr(feature = "async-io", pin_project::pin_project)]
pub struct WithFd<T> {
#[cfg_attr(feature = "async-io", pin)]
inner: T,
fds: Vec<OwnedFd>,
cmsg: Vec<u8>,
}
pub trait WithFdExt: Sized {
fn with_fd(self) -> WithFd<Self>;
}
pub const SCM_MAX_FD: usize = 253;
impl Read for WithFd<std::os::unix::net::UnixStream> {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
self.read_with_fd(buf)
}
}
impl Write for WithFd<std::os::unix::net::UnixStream> {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
self.inner.write(buf)
}
fn flush(&mut self) -> std::io::Result<()> {
self.inner.flush()
}
fn write_all(&mut self, buf: &[u8]) -> std::io::Result<()> {
self.inner.write_all(buf)
}
fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> std::io::Result<usize> {
self.inner.write_vectored(bufs)
}
fn write_fmt(&mut self, fmt: std::fmt::Arguments<'_>) -> std::io::Result<()> {
self.inner.write_fmt(fmt)
}
}
impl<T: AsRawFd> WithFd<T> {
fn write_with_fd_impl(fd: RawFd, buf: &[u8], fds: &[BorrowedFd<'_>]) -> std::io::Result<usize> {
let fds = unsafe { std::slice::from_raw_parts(fds.as_ptr().cast::<RawFd>(), fds.len()) };
let cmsg = nix::sys::socket::ControlMessage::ScmRights(fds);
let sendmsg = nix::sys::socket::sendmsg::<()>(
fd,
&[IoSlice::new(buf)],
&[cmsg],
nix::sys::socket::MsgFlags::empty(),
None,
)?;
Ok(sendmsg)
}
fn raw_read_with_fd(
fd: RawFd,
cmsg: &mut Vec<u8>,
out_fds: &mut Vec<OwnedFd>,
buf: &mut [u8],
) -> std::io::Result<usize> {
let mut buf = [IoSliceMut::new(buf)];
let recvmsg = nix::sys::socket::recvmsg::<()>(
fd,
&mut buf,
Some(cmsg),
nix::sys::socket::MsgFlags::empty(),
)?;
for cmsg in recvmsg.cmsgs()? {
if let ControlMessageOwned::ScmRights(fds) = cmsg {
out_fds.extend(fds.iter().map(|&fd| unsafe { OwnedFd::from_raw_fd(fd) }));
}
}
Ok(recvmsg.bytes)
}
fn read_with_fd(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
let fd = self.inner.as_raw_fd();
Self::raw_read_with_fd(fd, &mut self.cmsg, &mut self.fds, buf)
}
pub fn take_fds(&mut self) -> impl Iterator<Item = OwnedFd> + '_ {
struct Iter<'a>(&'a mut Vec<OwnedFd>);
impl Iterator for Iter<'_> {
type Item = OwnedFd;
fn next(&mut self) -> Option<Self::Item> {
self.0.pop()
}
}
Iter(&mut self.fds)
}
}
impl WithFd<std::os::unix::net::UnixStream> {
pub fn write_with_fd(&mut self, buf: &[u8], fds: &[BorrowedFd<'_>]) -> std::io::Result<usize> {
let fd = self.inner.as_raw_fd();
Self::write_with_fd_impl(fd, buf, fds)
}
}
impl WithFdExt for std::os::unix::net::UnixStream {
fn with_fd(self) -> WithFd<Self> {
self.into()
}
}
impl From<std::os::unix::net::UnixStream> for WithFd<std::os::unix::net::UnixStream> {
fn from(inner: std::os::unix::net::UnixStream) -> Self {
Self {
inner,
fds: Vec::new(),
cmsg: nix::cmsg_space!([RawFd; SCM_MAX_FD]),
}
}
}
#[cfg(test)]
mod test {
use std::{
fs::File,
io::{Read, Seek, Write},
os::fd::AsFd,
};
use cstr::cstr;
#[cfg(target_os = "linux")]
use nix::sys::memfd::MemFdCreateFlag;
#[cfg(target_os = "linux")]
#[test]
fn test_send_fd() {
let (a, b) = std::os::unix::net::UnixStream::pair().unwrap();
let mut a = super::WithFd::from(a);
let mut b = super::WithFd::from(b);
let memfd =
nix::sys::memfd::memfd_create(cstr!("test"), MemFdCreateFlag::MFD_CLOEXEC).unwrap();
let mut memfd: File = memfd.into();
a.write_with_fd(b"hello", &[memfd.as_fd()]).unwrap();
let mut buf = [0u8; 5];
b.read_exact(&mut buf).unwrap();
assert_eq!(&buf[..], b"hello");
let fds = b.take_fds().collect::<Vec<_>>();
assert_eq!(fds.len(), 1);
let mut memfd2: File = fds.into_iter().next().unwrap().into();
memfd.write_all(b"Hello").unwrap();
drop(memfd);
memfd2.rewind().unwrap();
memfd2.read_exact(&mut buf).unwrap();
assert_eq!(&buf[..], b"Hello");
}
#[cfg(feature = "async-io")]
#[tokio::test]
async fn test_send_fd_async_async_io() {
use futures_util::io::{AsyncReadExt, AsyncWriteExt};
let (a, b) = async_io::Async::<std::os::unix::net::UnixStream>::pair().unwrap();
let a = super::WithFd::from(a);
let mut b = super::WithFd::from(b);
let memfd =
nix::sys::memfd::memfd_create(cstr!("test"), MemFdCreateFlag::MFD_CLOEXEC).unwrap();
let mut memfd: File = memfd.into();
tokio::spawn(async move {
memfd.write_all(b"Hello").unwrap();
a.write_with_fd(b"hello", &[memfd.as_fd()]).await.unwrap();
(&a).write_all(b"world").await.unwrap();
drop(memfd);
});
let mut buf = [0u8; 5];
b.read_exact(&mut buf).await.unwrap();
assert_eq!(&buf[..], b"hello");
let fds = b.take_fds().collect::<Vec<_>>();
assert_eq!(fds.len(), 1);
b.read_exact(&mut buf).await.unwrap();
assert_eq!(&buf[..], b"world");
let mut memfd2: File = fds.into_iter().next().unwrap().into();
memfd2.rewind().unwrap();
memfd2.read_exact(&mut buf).unwrap();
assert_eq!(&buf[..], b"Hello");
}
#[cfg(feature = "tokio")]
#[tokio::test]
async fn test_send_fd_async_tokio() {
use tokio::io::AsyncReadExt;
let (a, b) = tokio::net::UnixStream::pair().unwrap();
let mut a = super::WithFd::from(a);
let mut b = super::WithFd::from(b);
let memfd =
nix::sys::memfd::memfd_create(cstr!("test"), MemFdCreateFlag::MFD_CLOEXEC).unwrap();
let memfd = unsafe { OwnedFd::from_raw_fd(memfd) };
let mut memfd: File = memfd.into();
a.write_with_fd(b"hello", &[memfd.as_fd()]).await.unwrap();
let mut buf = [0u8; 5];
b.read_exact(&mut buf).await.unwrap();
assert_eq!(&buf[..], b"hello");
let read_handle = tokio::spawn(async move {
b.read_exact(&mut buf).await.unwrap();
(b, buf)
});
tokio::task::yield_now().await;
a.write_with_fd(b"world", &[]).await.unwrap();
let (mut b, mut buf) = read_handle.await.unwrap();
assert_eq!(&buf[..], b"world");
let fds = b.take_fds().collect::<Vec<_>>();
assert_eq!(fds.len(), 1);
let mut memfd2: File = fds.into_iter().next().unwrap().into();
memfd.write_all(b"Hello").unwrap();
drop(memfd);
memfd2.rewind().unwrap();
memfd2.read_exact(&mut buf).unwrap();
assert_eq!(&buf[..], b"Hello");
}
}
#[cfg(any(feature = "tokio", docsrs))]
#[cfg_attr(docsrs, doc(cfg(feature = "tokio")))]
#[doc(hidden)]
pub mod tokio {
use std::{
os::fd::{AsRawFd, BorrowedFd, RawFd},
pin::Pin,
task::ready,
};
use tokio::io::{AsyncRead, AsyncWrite, Interest};
use crate::WithFd;
impl AsyncRead for WithFd<tokio::net::UnixStream> {
fn poll_read(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> std::task::Poll<std::io::Result<()>> {
let unfilled = buf.initialize_unfilled();
let Self { inner, cmsg, fds } = self.get_mut();
let fd = inner.as_raw_fd();
loop {
ready!(inner.poll_read_ready(cx))?;
match inner.try_io(Interest::READABLE, || {
Self::raw_read_with_fd(fd, cmsg, fds, unfilled)
}) {
Ok(bytes) => {
buf.advance(bytes);
return std::task::Poll::Ready(Ok(()))
},
Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => continue,
e => return std::task::Poll::Ready(e.map(|_| ())),
}
}
}
}
impl AsyncWrite for WithFd<tokio::net::UnixStream> {
fn poll_write(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<Result<usize, std::io::Error>> {
Pin::new(&mut self.inner).poll_write(cx, buf)
}
fn poll_flush(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
Pin::new(&mut self.inner).poll_flush(cx)
}
fn poll_shutdown(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
Pin::new(&mut self.inner).poll_shutdown(cx)
}
fn poll_write_vectored(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
bufs: &[std::io::IoSlice<'_>],
) -> std::task::Poll<Result<usize, std::io::Error>> {
Pin::new(&mut self.inner).poll_write_vectored(cx, bufs)
}
fn is_write_vectored(&self) -> bool {
self.inner.is_write_vectored()
}
}
impl WithFd<tokio::net::UnixStream> {
pub async fn write_with_fd(
&mut self,
buf: &[u8],
fds: &[BorrowedFd<'_>],
) -> std::io::Result<usize> {
let fd = self.inner.as_raw_fd();
loop {
self.inner.writable().await?;
match self.inner.try_io(Interest::WRITABLE, || {
Self::write_with_fd_impl(fd, buf, fds)
}) {
Ok(bytes) => break Ok(bytes),
Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => continue,
e => break Ok(e?),
}
}
}
}
impl From<tokio::net::UnixStream> for WithFd<tokio::net::UnixStream> {
fn from(inner: tokio::net::UnixStream) -> Self {
Self {
inner,
fds: Vec::new(),
cmsg: nix::cmsg_space!([RawFd; super::SCM_MAX_FD]),
}
}
}
impl super::WithFdExt for tokio::net::UnixStream {
fn with_fd(self) -> super::WithFd<Self> {
self.into()
}
}
}
#[cfg(any(feature = "async-io", docsrs))]
#[cfg_attr(docsrs, doc(cfg(feature = "async-io")))]
#[doc(hidden)]
pub mod async_io {
use std::{os::fd::AsRawFd, pin::Pin, task::ready};
use async_io::Async;
use futures_io::{AsyncRead, AsyncWrite};
use crate::WithFd;
impl AsyncRead for WithFd<Async<std::os::unix::net::UnixStream>> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut [u8],
) -> std::task::Poll<futures_io::Result<usize>> {
let this = self.project();
let fd = this.inner.as_raw_fd();
loop {
match Self::raw_read_with_fd(fd, this.cmsg, this.fds, buf) {
Ok(bytes) => return std::task::Poll::Ready(Ok(bytes)),
Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => (),
e => return std::task::Poll::Ready(e),
}
ready!(this.inner.poll_readable(cx))?;
}
}
}
impl<T> AsyncWrite for &WithFd<Async<T>>
where
for<'a> &'a Async<T>: AsyncWrite,
{
fn poll_close(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<futures_io::Result<()>> {
Pin::new(&mut &self.inner).poll_close(cx)
}
fn poll_flush(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<futures_io::Result<()>> {
Pin::new(&mut &self.inner).poll_flush(cx)
}
fn poll_write(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<futures_io::Result<usize>> {
Pin::new(&mut &self.inner).poll_write(cx, buf)
}
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
bufs: &[futures_io::IoSlice<'_>],
) -> std::task::Poll<futures_io::Result<usize>> {
Pin::new(&mut &self.inner).poll_write_vectored(cx, bufs)
}
}
impl<T> AsyncWrite for WithFd<Async<T>>
where
Async<T>: AsyncWrite,
{
fn poll_close(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<futures_io::Result<()>> {
self.project().inner.poll_close(cx)
}
fn poll_flush(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<futures_io::Result<()>> {
self.project().inner.poll_flush(cx)
}
fn poll_write(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<futures_io::Result<usize>> {
self.project().inner.poll_write(cx, buf)
}
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
bufs: &[futures_io::IoSlice<'_>],
) -> std::task::Poll<futures_io::Result<usize>> {
self.project().inner.poll_write_vectored(cx, bufs)
}
}
impl WithFd<Async<std::os::unix::net::UnixStream>> {
pub async fn write_with_fd(
&self,
buf: &[u8],
fds: &[std::os::fd::BorrowedFd<'_>],
) -> std::io::Result<usize> {
let fd = self.inner.as_raw_fd();
loop {
self.inner.writable().await?;
match Self::write_with_fd_impl(fd, buf, fds) {
Ok(bytes) => break Ok(bytes),
Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => continue,
e => break Ok(e?),
}
}
}
}
impl From<Async<std::os::unix::net::UnixStream>> for WithFd<Async<std::os::unix::net::UnixStream>> {
fn from(inner: Async<std::os::unix::net::UnixStream>) -> Self {
Self {
inner,
fds: Vec::new(),
cmsg: nix::cmsg_space!([std::os::unix::io::RawFd; super::SCM_MAX_FD]),
}
}
}
impl super::WithFdExt for Async<std::os::unix::net::UnixStream> {
fn with_fd(self) -> super::WithFd<Self> {
self.into()
}
}
}