use std::{
io::{self, Error, ErrorKind},
mem,
os::unix::io::{AsRawFd, FromRawFd, RawFd},
os::unix::net::UnixStream,
process::Stdio,
ptr,
};
use crate::Result;
const RAW_FD_SIZE: usize = mem::size_of::<RawFd>();
pub trait UdsxUnixStream {
unsafe fn send_streams(&self, streams: &[&dyn AsRawFd]) -> Result<()>;
unsafe fn send_ioe(&self) -> Result<()>;
#[cfg(any(target_os = "linux", target_os = "l4re"))]
unsafe fn send_credentials(&self) -> Result<()>;
unsafe fn recv_streams<T>(&self, count: usize) -> Result<Vec<T>> where T: FromRawFd;
unsafe fn recv_ioe(&self) -> Result<(Stdio, Stdio, Stdio)>;
#[cfg(any(target_os = "linux", target_os = "l4re"))]
unsafe fn recv_credentials(&self) -> Result<crate::Credentials>;
}
impl UdsxUnixStream for UnixStream {
unsafe fn send_streams(&self, streams: &[&dyn AsRawFd]) -> Result<()> {
let streams: Vec<_> = streams.into_iter().map(|s| s.as_raw_fd()).collect();
let size_of_streams = RAW_FD_SIZE.checked_mul(streams.len())
.ok_or_else(|| Error::new(ErrorKind::InvalidData, format!("Stream is too large: {} * {} (bytes)", RAW_FD_SIZE, streams.len())))?;
let (_io_buf, _iovec, _msg_control, msg) = make_io_buf_and_msghdr(size_of_streams)?;
let cmsg = setup_first_cmsg(&msg, libc::SCM_RIGHTS, size_of_streams)?;
let fd_ptr = libc::CMSG_DATA(cmsg);
match fd_ptr.is_null() {
true => return Err(Error::new(ErrorKind::Other, format!("libc::CMSG_DATA() returned null"))),
false => if libc::memcpy(mem::transmute(fd_ptr), mem::transmute(streams.as_ptr()), size_of_streams).is_null() {
return Err(Error::new(ErrorKind::Other, format!("libc::memcpy() returned null")));
},
};
match libc::sendmsg(self.as_raw_fd(), &msg, 0) {
-1 => Err(Error::new(ErrorKind::Other, crate::format_errno("sendmsg", None))),
_ => Ok(()),
}
}
unsafe fn send_ioe(&self) -> Result<()> {
self.send_streams(&[&io::stdin() as &dyn AsRawFd, &io::stdout(), &io::stderr()])
}
#[cfg(any(target_os = "linux", target_os = "l4re"))]
unsafe fn send_credentials(&self) -> Result<()> {
let ucred = libc::ucred {
pid: libc::getpid(),
uid: libc::getuid(),
gid: libc::getgid(),
};
let data_size = mem::size_of_val(&ucred);
let (_io_buf, _iovec, _msg_control, msg) = make_io_buf_and_msghdr(data_size)?;
let cmsg = setup_first_cmsg(&msg, libc::SCM_CREDENTIALS, data_size)?;
let fd_ptr = libc::CMSG_DATA(cmsg);
match fd_ptr.is_null() {
true => return Err(Error::new(ErrorKind::Other, format!("libc::CMSG_DATA() returned null"))),
false => if libc::memcpy(mem::transmute(fd_ptr), mem::transmute(&ucred), data_size).is_null() {
return Err(Error::new(ErrorKind::Other, format!("libc::memcpy() returned null")));
},
};
match libc::sendmsg(self.as_raw_fd(), &msg, 0) {
-1 => Err(Error::new(ErrorKind::Other, crate::format_errno("sendmsg", None))),
_ => Ok(()),
}
}
unsafe fn recv_streams<T>(&self, count: usize) -> Result<Vec<T>> where T: FromRawFd {
let size_of_streams = RAW_FD_SIZE.checked_mul(count)
.ok_or_else(|| Error::new(ErrorKind::InvalidData, format!("Stream is too large: {} * {} (bytes)", RAW_FD_SIZE, count)))?;
let (_io_buf, _iovec, _msg_control, mut msg) = make_io_buf_and_msghdr(size_of_streams)?;
match libc::recvmsg(self.as_raw_fd(), &mut msg, 0) {
-1 => Err(Error::new(ErrorKind::Other, crate::format_errno("recvmsg", None))),
_ => {
let cmsg = libc::CMSG_FIRSTHDR(&msg);
if cmsg.is_null() {
return Err(Error::new(ErrorKind::Other, "libc::CMSG_FIRSTHDR() returned null"));
}
match ((*cmsg).cmsg_level, (*cmsg).cmsg_type) {
(libc::AF_UNIX, libc::SCM_RIGHTS) => {
let data = libc::CMSG_DATA(cmsg);
match data.is_null() {
true => Err(Error::new(ErrorKind::Other, format!("libc::CMSG_DATA() returned null"))),
false => {
let mut fds = vec![-1 as RawFd; count];
match libc::memcpy(mem::transmute(fds.as_mut_ptr()), mem::transmute(data), size_of_streams).is_null() {
false => {
let mut result = Vec::with_capacity(count);
for fd in fds {
verify_fd(fd)?;
result.push(T::from_raw_fd(fd));
}
Ok(result)
},
true => Err(Error::new(ErrorKind::Other, "libc::memcpy() returned null")),
}
},
}
},
(level, r#type) => Err(Error::new(ErrorKind::InvalidData, format!("Unknown message level {} and type {}", level, r#type))),
}
},
}
}
unsafe fn recv_ioe(&self) -> Result<(Stdio, Stdio, Stdio)> {
const COUNT: usize = 3;
let mut result = self.recv_streams::<Stdio>(COUNT)?;
match result.len() {
COUNT => Ok((result.remove(0), result.remove(0), result.remove(0))),
other => Err(Error::new(
ErrorKind::InvalidData, format!("recv_streams() returned a vector of {} item(s); expected: {}", other, COUNT),
)),
}
}
#[cfg(any(target_os = "linux", target_os = "l4re"))]
unsafe fn recv_credentials(&self) -> Result<crate::Credentials> {
enable_receiving_credentials(self)?;
let data_size = mem::size_of::<libc::ucred>();
let (_io_buf, _iovec, _msg_control, mut msg) = make_io_buf_and_msghdr(data_size)?;
match libc::recvmsg(self.as_raw_fd(), &mut msg, 0) {
-1 => Err(Error::new(ErrorKind::Other, crate::format_errno("recvmsg", None))),
_ => {
let cmsg = libc::CMSG_FIRSTHDR(&msg);
if cmsg.is_null() {
return Err(Error::new(ErrorKind::Other, "libc::CMSG_FIRSTHDR() returned null"));
}
match ((*cmsg).cmsg_level, (*cmsg).cmsg_type) {
(libc::AF_UNIX, libc::SCM_CREDENTIALS) => {
let data = libc::CMSG_DATA(cmsg);
match data.is_null() {
true => Err(Error::new(ErrorKind::Other, format!("libc::CMSG_DATA() returned null"))),
false => {
let mut result = libc::ucred {
pid: -1,
uid: u32::max_value(),
gid: u32::max_value(),
};
match libc::memcpy(mem::transmute(&mut result), mem::transmute(data), data_size).is_null() {
false => match result.pid >= 0 && result.uid < u32::max_value() && result.gid < u32::max_value() {
true => Ok(result.into()),
false => Err(Error::new(
ErrorKind::Other,
format!("Invalid credentials: pid:{}, uid:{}, gid:{}", result.pid, result.uid, result.gid),
)),
},
true => Err(Error::new(ErrorKind::Other, "libc::memcpy() returned null")),
}
},
}
},
(level, r#type) => Err(Error::new(ErrorKind::InvalidData, format!("Unknown message level {} and type {}", level, r#type))),
}
},
}
}
}
unsafe fn make_io_buf_and_msghdr(data_size: usize) -> Result<([u8; 1], libc::iovec, Vec<u8>, libc::msghdr)> {
let mut io_buf = [0_u8];
let mut iovec = libc::iovec {
iov_base: mem::transmute(io_buf.as_mut_ptr()),
iov_len: mem::size_of_val(&io_buf),
};
#[cfg(any(target_os = "linux", target_os = "l4re"))]
let msg_controllen = crate::u32_as_usize(libc::CMSG_SPACE(crate::usize_as_u32(data_size)?))?.max(mem::size_of::<libc::cmsghdr>());
#[cfg(not(any(target_os = "linux", target_os = "l4re")))]
let msg_controllen = libc::CMSG_SPACE(crate::usize_as_u32(data_size)?).max(crate::usize_as_u32(mem::size_of::<libc::cmsghdr>())?);
#[cfg(any(target_os = "linux", target_os = "l4re"))]
let mut msg_control = vec![0; msg_controllen];
#[cfg(not(any(target_os = "linux", target_os = "l4re")))]
let mut msg_control = vec![0; crate::u32_as_usize(msg_controllen)?];
let msg = {
libc::msghdr {
msg_name: ptr::null_mut(),
msg_namelen: 0,
msg_iov: &mut iovec,
msg_iovlen: 1,
msg_control: mem::transmute(msg_control.as_mut_ptr()),
msg_controllen,
msg_flags: 0,
}
};
Ok((io_buf, iovec, msg_control, msg))
}
unsafe fn setup_first_cmsg(msg: &libc::msghdr, r#type: i32, data_size: usize) -> Result<*mut libc::cmsghdr> {
let cmsg = libc::CMSG_FIRSTHDR(msg);
match cmsg.is_null() {
true => Err(Error::new(ErrorKind::Other, format!("libc::CMSG_FIRSTHDR() returned null"))),
false => {
(*cmsg).cmsg_level = libc::AF_UNIX;
(*cmsg).cmsg_type = r#type;
#[cfg(any(target_os = "linux", target_os = "l4re"))] {
(*cmsg).cmsg_len = crate::u32_as_usize(libc::CMSG_LEN(crate::usize_as_u32(data_size)?))?;
}
#[cfg(not(any(target_os = "linux", target_os = "l4re")))] {
(*cmsg).cmsg_len = libc::CMSG_LEN(crate::usize_as_u32(data_size)?);
}
Ok(cmsg)
},
}
}
unsafe fn verify_fd(fd: RawFd) -> Result<()> {
match libc::fcntl(fd, libc::F_GETFD) {
-1 => Err(Error::new(
ErrorKind::InvalidData, format!("Invalid file descriptor: {} -> {}", fd, crate::format_errno("fcntl", None)),
)),
_ => Ok(()),
}
}
#[cfg(any(target_os = "linux", target_os = "l4re"))]
unsafe fn enable_receiving_credentials(stream: &UnixStream) -> Result<()> {
const ENABLED: i32 = 1;
match libc::setsockopt(
stream.as_raw_fd(), libc::AF_UNIX, libc::SO_PASSCRED, mem::transmute(&ENABLED), crate::usize_as_u32(mem::size_of_val(&ENABLED))?,
) {
0 => Ok(()),
_ => Err(Error::new(ErrorKind::Other, crate::format_errno("setsockopt", None))),
}
}