web_thread_pool/
lib.rs

1// Copyright (c) Zefchain Labs, Inc.
2// SPDX-License-Identifier: Apache-2.0
3
4use std::{
5    pin::Pin,
6    sync::RwLock,
7    task::{Context, Poll},
8};
9
10use web_thread_select as web_thread;
11
12type Id = usize;
13
14pub use web_thread::Error;
15pub type Task<T> = Guard<web_thread::Task<T>>;
16pub type SendTask<T> = Guard<web_thread::SendTask<T>>;
17
18struct ResourceHandle {
19    id: Id,
20    sender: flume::Sender<Id>,
21}
22
23impl Drop for ResourceHandle {
24    fn drop(&mut self) {
25        let _ = self.sender.send(self.id);
26    }
27}
28
29/// A pool of shared resources, each of which can only be used once at a time.
30pub struct Pool {
31    threads: RwLock<Vec<web_thread::Thread>>,
32    capacity: usize,
33    sender: flume::Sender<Id>,
34    // we have to use an mpmc receiver here in order to be able to
35    // receive using a reference: otherwise we would have to hold the
36    // mutex guard over the await
37    receiver: flume::Receiver<Id>,
38}
39
40pin_project_lite::pin_project! {
41    /// A future that, while running, causes the thread to be considered
42    /// claimed.
43    pub struct Guard<F> {
44        #[pin]
45        future: F,
46        handle: ResourceHandle,
47    }
48}
49
50impl<F: Future> Future for Guard<F> {
51    type Output = F::Output;
52
53    fn poll(self: Pin<&mut Self>, context: &mut Context<'_>) -> Poll<Self::Output> {
54        self.project().future.poll(context)
55    }
56}
57
58impl Pool {
59    /// Create a new pool of `capacity` items, using `factory` to
60    /// generate new items.
61    pub fn new(capacity: usize) -> Self {
62        let (sender, receiver) = flume::unbounded();
63        Self {
64            threads: RwLock::new(Vec::with_capacity(capacity)),
65            capacity,
66            sender,
67            receiver,
68        }
69    }
70
71    async fn get(&self) -> Id {
72        let mut id = self.receiver.try_recv().ok();
73
74        if id.is_none() {
75            let mut threads = self.threads.write().unwrap();
76            let len = threads.len();
77            if len < self.capacity {
78                threads.push(web_thread::Thread::new());
79                id = Some(len);
80            }
81        }
82
83        if id.is_none() {
84            id = self.receiver.recv_async().await.ok();
85        }
86
87        id.expect("we hold a sender")
88    }
89
90    /// Run a job, creating a new thread if necessary or waiting for one to become available.
91    pub async fn run<Context: web_thread::Post, F: Future<Output: web_thread::Post> + 'static>(
92        &self,
93        context: Context,
94        code: impl FnOnce(Context) -> F + Send + 'static,
95    ) -> Task<F::Output> {
96        let id = self.get().await;
97        Guard {
98            future: self.threads.read().unwrap()[id].run(context, code),
99            handle: ResourceHandle {
100                sender: self.sender.clone(),
101                id,
102            },
103        }
104    }
105
106    /// Like [`Pool::run`], but the output can be sent through Rust
107    /// memory without `Post`ing.
108    pub async fn run_send<Context: web_thread::Post, F: Future<Output: Send> + 'static>(
109        &self,
110        context: Context,
111        code: impl FnOnce(Context) -> F + Send + 'static,
112    ) -> SendTask<F::Output> {
113        let id = self.get().await;
114        Guard {
115            future: self.threads.read().unwrap()[id].run_send(context, code),
116            handle: ResourceHandle {
117                sender: self.sender.clone(),
118                id,
119            },
120        }
121    }
122}