use super::Errno;
use nix::libc;
use std::os::fd::RawFd;
use thiserror::Error;
use yash_syntax::syntax::Fd;
#[derive(Clone, Debug, Eq, Error, Hash, PartialEq)]
#[error("invalid file descriptor")]
pub struct InvalidFd;
impl From<InvalidFd> for Errno {
fn from(InvalidFd: InvalidFd) -> Errno {
Errno::EBADF
}
}
fn validate(fd: Fd) -> Result<RawFd, InvalidFd> {
if (0..(libc::FD_SETSIZE as _)).contains(&fd.0) {
Ok(fd.0)
} else {
Err(InvalidFd)
}
}
#[derive(Clone, Copy, Debug, Eq)]
pub struct FdSet {
pub(crate) inner: libc::fd_set,
upper_bound: Fd,
}
impl FdSet {
#[must_use]
pub fn new() -> Self {
let inner = unsafe {
let mut inner = std::mem::MaybeUninit::uninit();
libc::FD_ZERO(inner.as_mut_ptr());
inner.assume_init()
};
let upper_bound = Fd(0);
Self { inner, upper_bound }
}
pub fn insert(&mut self, fd: Fd) -> Result<(), InvalidFd> {
let fd = validate(fd)?;
unsafe { libc::FD_SET(fd, &mut self.inner) };
self.upper_bound = self.upper_bound.max(Fd(fd + 1));
Ok(())
}
pub fn remove(&mut self, fd: Fd) {
if let Ok(fd) = validate(fd) {
unsafe { libc::FD_CLR(fd, &mut self.inner) }
}
}
pub fn clear(&mut self) {
unsafe { libc::FD_ZERO(&mut self.inner) }
}
#[must_use]
pub fn contains(&self, fd: Fd) -> bool {
match validate(fd) {
Ok(fd) => unsafe { libc::FD_ISSET(fd, &self.inner) },
Err(_) => false,
}
}
#[must_use]
pub fn upper_bound(&self) -> Fd {
self.upper_bound
}
pub fn iter(&self) -> Iter {
Iter {
fd_set: self,
range: 0..self.upper_bound.0,
}
}
}
impl Default for FdSet {
fn default() -> Self {
Self::new()
}
}
impl PartialEq for FdSet {
fn eq(&self, other: &Self) -> bool {
self.inner == other.inner
}
}
impl std::hash::Hash for FdSet {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.inner.hash(state)
}
}
#[derive(Debug)]
pub struct Iter<'a> {
fd_set: &'a FdSet,
range: std::ops::Range<RawFd>,
}
impl Iterator for Iter<'_> {
type Item = Fd;
fn next(&mut self) -> Option<Fd> {
loop {
let fd = Fd(self.range.next()?);
if self.fd_set.contains(fd) {
return Some(fd);
}
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.range.size_hint()
}
}
impl DoubleEndedIterator for Iter<'_> {
fn next_back(&mut self) -> Option<Fd> {
loop {
let fd = Fd(self.range.next_back()?);
if self.fd_set.contains(fd) {
return Some(fd);
}
}
}
}
impl std::iter::FusedIterator for Iter<'_> {}
impl<'a> IntoIterator for &'a FdSet {
type Item = Fd;
type IntoIter = Iter<'a>;
fn into_iter(self) -> Iter<'a> {
self.iter()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn new_set_is_empty() {
let set = FdSet::new();
assert!(!set.contains(Fd::STDIN));
assert!(!set.contains(Fd::STDOUT));
assert!(!set.contains(Fd::STDERR));
assert_eq!(set.iter().next(), None);
}
#[test]
fn adding_fd_to_set() {
let mut set = FdSet::new();
set.insert(Fd::STDIN).unwrap();
assert!(set.contains(Fd::STDIN));
}
#[test]
fn adding_many_fds_to_set() {
let mut set = FdSet::new();
set.insert(Fd::STDOUT).unwrap();
set.insert(Fd::STDERR).unwrap();
set.insert(Fd(3)).unwrap();
assert!(!set.contains(Fd::STDIN));
assert!(set.contains(Fd::STDOUT));
assert!(set.contains(Fd::STDERR));
assert!(set.contains(Fd(3)));
}
#[test]
fn adding_invalid_fd_to_set() {
let mut set = FdSet::new();
set.insert(Fd(-1)).unwrap_err();
set.insert(Fd(libc::FD_SETSIZE as _)).unwrap_err();
}
#[test]
fn removing_fd_from_set() {
let mut set = FdSet::new();
set.insert(Fd::STDIN).unwrap();
set.remove(Fd::STDIN);
assert!(!set.contains(Fd::STDIN));
}
#[test]
fn clearing_set() {
let mut set = FdSet::new();
set.insert(Fd::STDOUT).unwrap();
set.insert(Fd::STDERR).unwrap();
set.clear();
assert!(!set.contains(Fd::STDIN));
assert!(!set.contains(Fd::STDOUT));
assert!(!set.contains(Fd::STDERR));
assert_eq!(set.iter().next(), None);
}
#[test]
fn adding_fd_updates_upper_bound() {
let mut set = FdSet::new();
assert_eq!(set.upper_bound(), Fd(0));
set.insert(Fd(1)).unwrap();
assert_eq!(set.upper_bound(), Fd(2));
set.insert(Fd(0)).unwrap();
assert_eq!(set.upper_bound(), Fd(2));
set.insert(Fd(2)).unwrap();
assert_eq!(set.upper_bound(), Fd(3));
set.remove(Fd(2));
assert!(set.upper_bound() >= Fd(2), "{:?}", set.upper_bound());
}
#[test]
fn equality_ignores_upper_bound() {
let mut set = FdSet::new();
assert_eq!(set, set);
set.insert(Fd(1)).unwrap();
set.insert(Fd(4)).unwrap();
assert_eq!(set, set);
let mut new_set = set;
new_set.insert(Fd(5)).unwrap();
assert_ne!(set, new_set);
new_set.remove(Fd(5));
assert_eq!(set, new_set);
}
#[test]
fn iterating_fds() {
let mut set = FdSet::new();
set.insert(Fd(1)).unwrap();
set.insert(Fd(6)).unwrap();
set.insert(Fd(3)).unwrap();
let fds = set.iter().collect::<Vec<_>>();
assert_eq!(fds, [Fd(1), Fd(3), Fd(6)]);
}
#[test]
fn reverse_iterating_fds() {
let fd_max = Fd((libc::FD_SETSIZE - 1) as _);
let mut set = FdSet::new();
set.insert(Fd(0)).unwrap();
set.insert(fd_max).unwrap();
let fds = set.iter().rev().collect::<Vec<_>>();
assert_eq!(fds, [fd_max, Fd(0)]);
}
}