#![forbid(unsafe_code)]
use core::fmt::{Display, Formatter};
use core::sync::atomic::{AtomicUsize, Ordering};
use core::time::Duration;
use std::error::Error;
use std::sync::mpsc::{Receiver, RecvTimeoutError, SyncSender, TrySendError};
use std::sync::{Arc, Mutex};
struct AtomicCounter {
next_value: AtomicUsize,
}
impl AtomicCounter {
pub fn new() -> Self {
Self {
next_value: AtomicUsize::new(0),
}
}
pub fn next(&self) -> usize {
self.next_value.fetch_add(1, Ordering::AcqRel)
}
}
#[derive(Debug)]
pub struct QueueFull {}
impl Display for QueueFull {
fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> {
std::fmt::Debug::fmt(self, f)
}
}
impl Error for QueueFull {}
struct Inner {
name: &'static str,
next_name_num: AtomicCounter,
size: usize,
receiver: Mutex<Receiver<Box<dyn FnOnce() + Send>>>,
}
impl Inner {
pub fn num_live_threads(self: &Arc<Inner>) -> usize {
Arc::strong_count(self) - 1
}
fn work(self: &Arc<Inner>) {
loop {
let recv_result = self
.receiver
.lock()
.unwrap()
.recv_timeout(Duration::from_millis(100));
self.start_threads();
match recv_result {
Ok(f) => f(),
Err(RecvTimeoutError::Timeout) => {}
Err(RecvTimeoutError::Disconnected) => return,
};
self.start_threads();
}
}
fn start_thread(self: &Arc<Inner>) {
let self_clone = self.clone();
if self.num_live_threads() <= self.size {
std::thread::Builder::new()
.name(format!("{}{}", self.name, self.next_name_num.next()))
.spawn(move || self_clone.work())
.unwrap();
}
}
fn start_threads(self: &Arc<Inner>) {
while self.num_live_threads() < self.size {
self.start_thread();
}
}
fn start_thread_if_none(self: &Arc<Inner>) {
if self.num_live_threads() < 1 {
self.start_thread();
}
}
}
pub struct ThreadPool {
inner: Arc<Inner>,
sender: SyncSender<Box<dyn FnOnce() + Send>>,
}
impl ThreadPool {
pub fn new(name: &'static str, size: usize) -> Self {
if name.is_empty() {
panic!("ThreadPool::new called with empty name")
}
if size < 1 {
panic!("ThreadPool::new called with invalid size value: {:?}", size)
}
let (sender, receiver) = std::sync::mpsc::sync_channel(size * 200);
let pool = ThreadPool {
inner: Arc::new(Inner {
name,
next_name_num: AtomicCounter::new(),
size,
receiver: Mutex::new(receiver),
}),
sender,
};
pool.inner.start_threads();
pool
}
pub fn size(&self) -> usize {
self.inner.size
}
pub fn num_live_threads(&self) -> usize {
self.inner.num_live_threads()
}
pub fn schedule(&self, f: impl FnOnce() + Send + 'static) {
self.inner.start_thread_if_none();
self.sender.send(Box::new(f)).unwrap();
self.inner.start_threads();
}
pub fn try_schedule(&self, f: impl FnOnce() + Send + 'static) -> Result<(), QueueFull> {
self.inner.start_thread_if_none();
match self.sender.try_send(Box::new(f)) {
Ok(_) => {
self.inner.start_threads();
Ok(())
}
Err(TrySendError::Disconnected(_)) => unreachable!(),
Err(TrySendError::Full(_)) => Err(QueueFull {}),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use core::fmt::Debug;
use core::ops::Range;
use core::time::Duration;
use std::time::Instant;
fn assert_in_range<T: PartialOrd + Debug>(range: Range<T>, value: T) {
if range.is_empty() {
panic!("invalid range {:?}", range)
}
println!(
"measured concurrency value {:?}, expected range {:?}",
value, range,
);
if !range.contains(&value) {
panic!(
"measured concurrency value {:?} out of range {:?}",
value, range,
);
}
}
pub fn assert_elapsed(before: Instant, range_ms: Range<u64>) {
if range_ms.is_empty() {
panic!("invalid range {:?}", range_ms)
}
let elapsed = before.elapsed();
let duration_range =
Duration::from_millis(range_ms.start)..Duration::from_millis(range_ms.end);
if !duration_range.contains(&elapsed) {
panic!("{:?} elapsed, out of range {:?}", elapsed, duration_range);
}
}
fn measure_concurrency(pool: &ThreadPool, num_jobs: usize) -> f32 {
const WAIT_DURATION: Duration = Duration::from_millis(100);
let before = Instant::now();
let receiver = {
let (sender, receiver) = std::sync::mpsc::channel();
for _ in 0..num_jobs {
let sender_clone = sender.clone();
pool.schedule(move || {
std::thread::sleep(WAIT_DURATION);
sender_clone.send(()).unwrap();
});
}
receiver
};
for _ in 0..num_jobs {
receiver.recv_timeout(Duration::from_millis(500)).unwrap();
}
let elapsed = before.elapsed();
elapsed.as_secs_f32() / WAIT_DURATION.as_secs_f32()
}
fn sleep(ms: u64) {
std::thread::sleep(Duration::from_millis(ms));
}
#[test]
fn atomic_counter() {
let counter = Arc::new(AtomicCounter::new());
assert_eq!(0, counter.next());
assert_eq!(1, counter.next());
assert_eq!(2, counter.next());
}
#[test]
fn atomic_counter_many_readers() {
let receiver = {
let counter = Arc::new(AtomicCounter::new());
let (sender, receiver) = std::sync::mpsc::channel();
for _ in 0..10 {
let counter_clone = counter.clone();
let sender_clone = sender.clone();
std::thread::spawn(move || {
for _ in 0..10 {
sender_clone.send(counter_clone.next()).unwrap();
}
});
}
receiver
};
let mut values: Vec<usize> = receiver.iter().collect();
values.sort();
assert_eq!((0usize..100).collect::<Vec<usize>>(), values);
}
#[test]
fn queue_full_display() {
assert_eq!("QueueFull", format!("{}", QueueFull {}));
}
#[test]
fn empty_name() {
match std::panic::catch_unwind(|| ThreadPool::new("", 1)) {
Ok(_) => panic!("expected panic"),
Err(_) => {}
}
}
#[test]
fn zero_size() {
match std::panic::catch_unwind(|| ThreadPool::new("pool1", 0)) {
Ok(_) => panic!("expected panic"),
Err(_) => {}
}
}
#[test]
fn test_size() {
let pool = ThreadPool::new("pool1", 3);
assert_eq!(3, pool.size());
}
#[test]
fn test_num_live_threads() {
let pool = ThreadPool::new("pool1", 3);
sleep(100);
assert_eq!(3, pool.num_live_threads());
pool.schedule(move || {
sleep(100);
panic!("ignore this panic")
});
pool.schedule(move || {
sleep(100);
panic!("ignore this panic")
});
pool.schedule(move || {
sleep(100);
panic!("ignore this panic")
});
sleep(200);
assert_eq!(0, pool.num_live_threads());
pool.schedule(move || {});
assert_eq!(3, pool.num_live_threads());
}
#[test]
fn schedule_should_run_the_fn() {
let pool = ThreadPool::new("pool1", 1);
let before = Instant::now();
let (sender, receiver) = std::sync::mpsc::channel();
pool.schedule(move || {
sender.send(()).unwrap();
});
receiver.recv_timeout(Duration::from_millis(500)).unwrap();
assert_elapsed(before, 0..100);
}
#[test]
fn schedule_should_start_a_thread_if_none() {
let pool = ThreadPool::new("pool1", 3);
sleep(100);
pool.schedule(move || {
sleep(100);
panic!("ignore this panic")
});
pool.schedule(move || {
sleep(100);
panic!("ignore this panic")
});
pool.schedule(move || {
sleep(100);
panic!("ignore this panic")
});
sleep(200);
assert_eq!(0, pool.num_live_threads());
pool.schedule(|| {});
assert_eq!(3, pool.num_live_threads());
}
#[test]
fn try_schedule_should_run_the_fn() {
let pool = ThreadPool::new("pool1", 1);
let before = Instant::now();
let (sender, receiver) = std::sync::mpsc::channel();
pool.try_schedule(move || {
sender.send(()).unwrap();
})
.unwrap();
receiver.recv_timeout(Duration::from_millis(500)).unwrap();
assert_elapsed(before, 0..100);
}
#[test]
fn try_schedule_queue_full() {
let pool = ThreadPool::new("pool1", 1);
let before = Instant::now();
while Instant::now() - before < Duration::from_millis(500) {
match pool.try_schedule(move || panic!("ignore this panic")) {
Ok(_) => {}
Err(e) => {
println!("try_schedule got {:?}", e);
sleep(100);
return;
}
}
}
panic!("timeout");
}
#[test]
fn check_concurrency1() {
let pool = ThreadPool::new("pool1", 1);
assert_in_range(1.0..1.99, measure_concurrency(&pool, 1));
assert_in_range(2.0..2.99, measure_concurrency(&pool, 2));
}
#[test]
fn check_concurrency2() {
let pool = ThreadPool::new("pool1", 2);
assert_in_range(1.0..1.99, measure_concurrency(&pool, 1));
assert_in_range(1.0..1.99, measure_concurrency(&pool, 2));
assert_in_range(2.0..2.99, measure_concurrency(&pool, 3));
assert_in_range(2.0..2.99, measure_concurrency(&pool, 4));
}
#[test]
fn check_concurrency5() {
let pool = ThreadPool::new("pool1", 5);
assert_in_range(1.0..1.99, measure_concurrency(&pool, 5));
assert_in_range(2.0..2.99, measure_concurrency(&pool, 6));
}
#[test]
fn should_quickly_respawn_panicked_threads() {
let pool = ThreadPool::new("pool1", 2);
sleep(50);
pool.schedule(move || panic!("ignore this panic"));
sleep(100);
assert_eq!(1, pool.num_live_threads());
sleep(200);
assert_eq!(2, pool.num_live_threads());
assert_in_range(1.0..1.99, measure_concurrency(&pool, 2));
}
#[test]
fn should_respawn_after_recv() {
let pool = ThreadPool::new("pool1", 2);
sleep(50);
pool.schedule(move || panic!("ignore this panic"));
sleep(100);
assert_eq!(1, pool.num_live_threads());
pool.schedule(move || sleep(200));
sleep(100);
assert_eq!(2, pool.num_live_threads());
}
#[test]
fn should_respawn_after_executing_job() {
let pool = ThreadPool::new("pool1", 2);
pool.schedule(move || sleep(200));
pool.schedule(move || panic!("ignore this panic"));
sleep(100);
assert_eq!(1, pool.num_live_threads());
sleep(200);
assert_eq!(2, pool.num_live_threads());
}
}