use bitflags::bitflags;
use std::io;
use std::ops::Drop;
use std::os::raw::c_int;
use std::os::windows::io::{AsRawHandle, AsRawSocket, RawHandle};
use std::ptr;
use std::time::Duration;
use wepoll_sys as sys;
#[repr(u32)]
#[derive(Copy, Clone, Eq, PartialEq, Debug, Hash)]
pub enum Operation {
Add = sys::EPOLL_CTL_ADD,
Delete = sys::EPOLL_CTL_DEL,
Modify = sys::EPOLL_CTL_MOD,
}
bitflags! {
pub struct EventFlag: u32 {
const ERR = sys::EPOLLERR;
const HUP = sys::EPOLLHUP;
const IN = sys::EPOLLIN;
const MSG = sys::EPOLLMSG;
const ONESHOT = sys::EPOLLONESHOT;
const OUT = sys::EPOLLOUT;
const PRI = sys::EPOLLPRI;
const RDBAND = sys::EPOLLRDBAND;
const RDHUP = sys::EPOLLRDHUP;
const RDNORM = sys::EPOLLRDNORM;
const WRBAND = sys::EPOLLWRBAND;
const WRNORM = sys::EPOLLWRNORM;
}
}
pub struct Event {
raw: sys::epoll_event,
}
impl Event {
pub fn new(flags: EventFlag, data: u64) -> Self {
Event {
raw: sys::epoll_event {
events: flags.bits(),
data: sys::epoll_data { u64: data },
},
}
}
pub fn flags(&self) -> EventFlag {
EventFlag::from_bits_truncate(self.raw.events)
}
pub fn data(&self) -> u64 {
unsafe { self.raw.data.u64 }
}
}
pub struct Events {
raw: Vec<sys::epoll_event>,
}
impl Events {
pub fn with_capacity(amount: usize) -> Self {
Events {
raw: Vec::with_capacity(amount),
}
}
pub fn len(&self) -> usize {
self.raw.len()
}
pub fn capacity(&self) -> usize {
self.raw.capacity()
}
pub fn iter(&self) -> Iter {
Iter {
events: &self,
index: 0,
}
}
pub fn clear(&mut self) {
unsafe { self.raw.set_len(0) };
}
}
unsafe impl Sync for Events {}
unsafe impl Send for Events {}
pub struct Iter<'a> {
events: &'a Events,
index: usize,
}
impl<'a> Iterator for Iter<'a> {
type Item = Event;
fn next(&mut self) -> Option<Event> {
if self.index == self.events.len() {
return None;
}
let ev = &self.events.raw[self.index];
let event =
Event::new(EventFlag::from_bits_truncate(ev.events), unsafe {
ev.data.u64
});
self.index += 1;
Some(event)
}
}
pub struct Epoll {
handle: sys::HANDLE,
}
impl Epoll {
pub fn new() -> io::Result<Epoll> {
let handle = unsafe { sys::epoll_create(1) };
if handle.is_null() {
return Err(io::Error::last_os_error());
}
Ok(Epoll { handle })
}
pub fn poll(
&self,
events: &mut Events,
timeout: Option<Duration>,
) -> io::Result<usize> {
let timeout_ms = if let Some(duration) = timeout {
duration.as_millis() as c_int
} else {
-1
};
let received = unsafe {
sys::epoll_wait(
self.handle,
events.raw.as_mut_ptr(),
events.capacity() as c_int,
timeout_ms,
)
};
if received == -1 {
return Err(io::Error::last_os_error());
}
unsafe { events.raw.set_len(received as usize) };
Ok(received as usize)
}
pub fn register<T: AsRawSocket>(
&self,
socket: &T,
flags: EventFlag,
data: u64,
) -> io::Result<()> {
self.register_or_reregister(socket, flags, data, Operation::Add)
}
pub fn reregister<T: AsRawSocket>(
&self,
socket: &T,
flags: EventFlag,
data: u64,
) -> io::Result<()> {
self.register_or_reregister(socket, flags, data, Operation::Modify)
}
pub fn deregister<T: AsRawSocket>(&self, socket: &T) -> io::Result<()> {
let result = unsafe {
sys::epoll_ctl(
self.handle,
Operation::Delete as c_int,
socket.as_raw_socket() as sys::SOCKET,
ptr::null_mut(),
)
};
if result == -1 {
return Err(io::Error::last_os_error());
}
Ok(())
}
fn register_or_reregister<T: AsRawSocket>(
&self,
socket: &T,
flags: EventFlag,
data: u64,
operation: Operation,
) -> io::Result<()> {
let mut event = Event::new(flags, data);
let result = unsafe {
sys::epoll_ctl(
self.handle,
operation as c_int,
socket.as_raw_socket() as sys::SOCKET,
&mut event.raw,
)
};
if result == -1 {
return Err(io::Error::last_os_error());
}
Ok(())
}
}
unsafe impl Sync for Epoll {}
unsafe impl Send for Epoll {}
impl Drop for Epoll {
fn drop(&mut self) {
if unsafe { sys::epoll_close(self.handle) } == -1 {
panic!("epoll_close() failed: {}", io::Error::last_os_error());
}
}
}
impl AsRawHandle for Epoll {
fn as_raw_handle(&self) -> RawHandle {
self.handle
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
use std::net::{TcpListener, TcpStream, UdpSocket};
#[test]
fn test_event_new() {
let event = Event::new(EventFlag::IN, 42);
assert_eq!(event.flags(), EventFlag::IN);
assert_eq!(event.data(), 42);
}
#[test]
fn test_events_with_capacity() {
let events = Events::with_capacity(2);
assert_eq!(events.raw.capacity(), 2);
}
#[test]
fn test_events_len() {
let events = Events::with_capacity(1);
assert_eq!(events.len(), 0);
}
#[test]
fn test_events_capacity() {
let events = Events::with_capacity(1);
assert_eq!(events.capacity(), 1);
}
#[test]
fn test_events_clear() {
let mut events = Events::with_capacity(1);
unsafe { events.raw.set_len(1) };
events.clear();
assert_eq!(events.len(), 0);
}
#[test]
fn test_poll_poll_without_timeout() {
let epoll = Epoll::new().unwrap();
let socket = UdpSocket::bind("0.0.0.0:0").unwrap();
let mut events = Events::with_capacity(1);
epoll
.register(&socket, EventFlag::OUT | EventFlag::ONESHOT, 42)
.unwrap();
assert_eq!(epoll.poll(&mut events, None).unwrap(), 1);
assert_eq!(events.len(), 1);
let event = events.iter().next().unwrap();
assert_eq!(event.data(), 42);
assert_eq!(event.flags(), EventFlag::OUT);
}
#[test]
fn test_poll_in_and_rdhup() {
let epoll = Epoll::new().unwrap();
let l = TcpListener::bind("127.0.0.1:0").unwrap();
let socket = TcpStream::connect(l.local_addr().unwrap()).unwrap();
let mut s1 = l.incoming().next().unwrap().unwrap();
s1.write_all(b"hello").unwrap();
let mut events = Events::with_capacity(1);
epoll
.register(&socket, EventFlag::IN | EventFlag::RDHUP, 42)
.unwrap();
assert_eq!(epoll.poll(&mut events, None).unwrap(), 1);
let event = events.iter().next().unwrap();
assert_eq!(event.flags(), EventFlag::IN);
s1.shutdown(std::net::Shutdown::Write).unwrap();
assert_eq!(epoll.poll(&mut events, None).unwrap(), 1);
let event = events.iter().next().unwrap();
assert_eq!(event.flags(), EventFlag::IN | EventFlag::RDHUP);
}
#[test]
fn test_poll_poll_with_timeout() {
let epoll = Epoll::new().unwrap();
let mut events = Events::with_capacity(1);
epoll
.poll(&mut events, Some(Duration::from_millis(5)))
.unwrap();
assert_eq!(events.len(), 0);
assert!(events.iter().next().is_none());
}
#[test]
fn test_poll_register_valid() {
let epoll = Epoll::new().unwrap();
let socket = UdpSocket::bind("0.0.0.0:0").unwrap();
assert!(epoll
.register(&socket, EventFlag::OUT | EventFlag::ONESHOT, 42)
.is_ok());
}
#[test]
fn test_poll_register_already_registered() {
let epoll = Epoll::new().unwrap();
let socket = UdpSocket::bind("0.0.0.0:0").unwrap();
assert!(epoll
.register(&socket, EventFlag::OUT | EventFlag::ONESHOT, 42)
.is_ok());
assert!(epoll
.register(&socket, EventFlag::OUT | EventFlag::ONESHOT, 45)
.is_err());
}
#[test]
fn test_poll_reregister_invalid() {
let epoll = Epoll::new().unwrap();
let socket = UdpSocket::bind("0.0.0.0:0").unwrap();
assert!(epoll
.reregister(&socket, EventFlag::OUT | EventFlag::ONESHOT, 42)
.is_err());
}
#[test]
fn test_poll_reregister_already_registered() {
let epoll = Epoll::new().unwrap();
let socket = UdpSocket::bind("0.0.0.0:0").unwrap();
assert!(epoll
.register(&socket, EventFlag::OUT | EventFlag::ONESHOT, 42)
.is_ok());
assert!(epoll
.reregister(&socket, EventFlag::OUT | EventFlag::ONESHOT, 45)
.is_ok());
}
#[test]
fn test_poll_deregister_invalid() {
let epoll = Epoll::new().unwrap();
let socket = UdpSocket::bind("0.0.0.0:0").unwrap();
assert!(epoll.deregister(&socket).is_err());
}
#[test]
fn test_poll_deregister_already_registered() {
let epoll = Epoll::new().unwrap();
let socket = UdpSocket::bind("0.0.0.0:0").unwrap();
assert!(epoll
.register(&socket, EventFlag::OUT | EventFlag::ONESHOT, 42)
.is_ok());
assert!(epoll.deregister(&socket).is_ok());
}
}