wapc_pool/
hostpool.rs

1type Result<T> = std::result::Result<T, wapc::errors::Error>;
2
3use std::sync::Arc;
4use std::time::Duration;
5
6use crossbeam::channel::{Receiver as SyncReceiver, SendTimeoutError, Sender as SyncSender};
7use rusty_pool::ThreadPool;
8use tokio::sync::oneshot::Sender as OneshotSender;
9use wapc::WapcHost;
10
11use crate::errors::Error;
12
13/// The [HostPool] initializes a number of workers for the passed [WapcHost] factory function.
14///
15#[must_use]
16pub struct HostPool {
17  /// The name of the [HostPool] (for debugging purposes).
18  pub name: String,
19  pool: Option<ThreadPool>,
20  factory: Arc<dyn Fn() -> WapcHost + Send + Sync + 'static>,
21  max_threads: usize,
22  max_wait: Duration,
23  max_idle: Duration,
24  tx: SyncSender<WorkerMessage>,
25  rx: SyncReceiver<WorkerMessage>,
26}
27
28impl std::fmt::Debug for HostPool {
29  fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30    f.debug_struct("HostPool")
31      .field("name", &self.name)
32      .field("tx", &self.tx)
33      .field("rx", &self.rx)
34      .finish()
35  }
36}
37
38type WorkerMessage = (
39  OneshotSender<std::result::Result<Vec<u8>, wapc::errors::Error>>,
40  String,
41  Vec<u8>,
42);
43
44impl HostPool {
45  /// Instantiate a new HostPool.
46  pub fn new<N, F>(
47    name: N,
48    factory: F,
49    min_threads: usize,
50    max_threads: usize,
51    max_wait: Duration,
52    max_idle: Duration,
53  ) -> Self
54  where
55    N: AsRef<str>,
56    F: Fn() -> WapcHost + Send + Sync + 'static,
57  {
58    debug!("Creating new wapc host pool with size {}", max_threads);
59    let arcfn = Arc::new(factory);
60    let pool = rusty_pool::Builder::new()
61      .name(name.as_ref().to_owned())
62      .core_size(min_threads)
63      .max_size(max_threads)
64      .keep_alive(Duration::from_millis(0))
65      .build();
66
67    let (tx, rx) = crossbeam::channel::bounded::<WorkerMessage>(1);
68
69    let pool = Self {
70      name: name.as_ref().to_owned(),
71      factory: arcfn,
72      pool: Some(pool),
73      max_threads,
74      max_wait,
75      max_idle,
76      tx,
77      rx,
78    };
79
80    for _ in 0..min_threads {
81      pool.spawn(None).unwrap();
82    }
83
84    pool
85  }
86
87  /// Get the current number of active workers.
88  #[must_use]
89  pub fn num_active_workers(&self) -> usize {
90    self.pool.as_ref().map_or(0, |pool| pool.get_current_worker_count())
91  }
92
93  fn spawn(&self, max_idle: Option<Duration>) -> Result<()> {
94    self.pool.as_ref().map_or_else(
95      || Err(Error::NoPool.into()),
96      |pool| {
97        let name = self.name.clone();
98        let i = pool.get_current_worker_count();
99        let factory = self.factory.clone();
100        let rx = self.rx.clone();
101        pool.execute(move || {
102          trace!("Host thread {}.{} started...", name, i);
103          let host = factory();
104          loop {
105            let message = max_idle.map_or_else(
106              || rx.recv().map_err(|e| e.to_string()),
107              |duration| rx.recv_timeout(duration).map_err(|e| e.to_string()),
108            );
109            if let Err(e) = message {
110              debug!("Host thread {}.{} closing: {}", name, i, e);
111              break;
112            }
113            let (tx, op, payload) = message.unwrap();
114            trace!(
115              "Host thread {}.{} received call for {} with {} byte payload",
116              name,
117              i,
118              op,
119              payload.len()
120            );
121            let result = host.call(&op, &payload);
122            if tx.send(result).is_err() {
123              error!("Host thread {}.{} failed when returning a value...", name, i);
124            }
125          }
126
127          trace!("Host thread {}.{} stopped.", name, i);
128        });
129        Ok(())
130      },
131    )
132  }
133
134  /// Call an operation on one of the workers.
135  pub async fn call<T: AsRef<str> + Sync + Send>(&self, op: T, payload: Vec<u8>) -> Result<Vec<u8>> {
136    let (tx, rx) = tokio::sync::oneshot::channel();
137    // Start the call with a timeout of max_wait.
138    let result = match self
139      .tx
140      .send_timeout((tx, op.as_ref().to_owned(), payload), self.max_wait)
141    {
142      Ok(_) => Ok(()),
143      Err(e) => {
144        // If we didn't get a response in time...
145        let args = match e {
146          SendTimeoutError::Timeout(args) => {
147            debug!("Timeout on pool '{}'", self.name);
148            args
149          }
150          SendTimeoutError::Disconnected(args) => {
151            warn!("Pool worker disconnected on pool '{}'", self.name);
152            args
153          }
154        };
155        // grow the pool...
156        if self.num_active_workers() < self.max_threads {
157          if let Err(e) = self.spawn(Some(self.max_idle)) {
158            error!("Error spawning worker for host pool '{}': {}", self.name, e);
159          };
160        }
161        // ...and wait.
162        self.tx.send(args)
163      }
164    };
165    if let Err(e) = result {
166      return Err(wapc::errors::Error::General(e.to_string()));
167    }
168    match rx.await {
169      Ok(res) => res,
170      Err(e) => Err(wapc::errors::Error::General(e.to_string())),
171    }
172  }
173
174  /// Shut down the host pool.
175  pub fn shutdown(&mut self) -> Result<()> {
176    let pool = self
177      .pool
178      .take()
179      .ok_or_else(|| wapc::errors::Error::from(crate::errors::Error::NoPool))?;
180
181    pool.shutdown_join();
182    Ok(())
183  }
184}
185
186#[must_use]
187/// Builder for a [HostPool]
188pub struct HostPoolBuilder {
189  name: Option<String>,
190  factory: Option<Box<dyn Fn() -> WapcHost + Send + Sync + 'static>>,
191  min_threads: usize,
192  max_threads: usize,
193  max_wait: Duration,
194  max_idle: Duration,
195}
196
197impl std::fmt::Debug for HostPoolBuilder {
198  fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
199    f.debug_struct("HostPoolBuilder")
200      .field("name", &self.name)
201      .field("factory", if self.factory.is_some() { &"Some(Fn)" } else { &"None" })
202      .field("min_threads", &self.min_threads)
203      .field("max_threads", &self.max_threads)
204      .field("max_wait", &self.max_wait)
205      .field("max_idle", &self.max_idle)
206      .finish()
207  }
208}
209
210impl Default for HostPoolBuilder {
211  fn default() -> Self {
212    Self {
213      name: None,
214      factory: None,
215      min_threads: 1,
216      max_threads: 2,
217      max_wait: Duration::from_millis(100),
218      max_idle: Duration::from_secs(5 * 60),
219    }
220  }
221}
222
223impl HostPoolBuilder {
224  /// Instantiate a nnew [HostPoolBuilder] with default settings.
225  ///
226  /// ```
227  /// # use wapc_pool::HostPoolBuilder;
228  /// let builder = HostPoolBuilder::new();
229  /// ```
230  ///
231  pub fn new() -> Self {
232    Self::default()
233  }
234
235  /// Set the name for the HostPool.
236  ///
237  /// ```
238  /// # use wapc_pool::HostPoolBuilder;
239  /// let builder = HostPoolBuilder::new().name("My Module");
240  /// ```
241  ///
242  pub fn name<T: AsRef<str>>(mut self, name: T) -> Self {
243    self.name = Some(name.as_ref().to_owned());
244    self
245  }
246
247  /// Set the [WapcHost] generator function to use when spawning new workers.
248  ///
249  /// ```
250  /// # use wapc_pool::HostPoolBuilder;
251  /// # use wapc::WapcHost;
252  /// # let bytes = std::fs::read("../../wasm/crates/wapc-guest-test/build/wapc_guest_test.wasm").unwrap();
253  /// let engine = wasmtime_provider::WasmtimeEngineProvider::new(&bytes, None).unwrap();
254  /// let pool = HostPoolBuilder::new()
255  ///   .factory(move || {
256  ///     let engine = engine.clone();
257  ///     WapcHost::new(Box::new(engine), None).unwrap()
258  ///   })
259  ///   .build();
260  /// ```
261  ///
262  pub fn factory<F>(mut self, factory: F) -> Self
263  where
264    F: Fn() -> WapcHost + Send + Sync + 'static,
265  {
266    self.factory = Some(Box::new(factory));
267    self
268  }
269
270  /// Set the minimum, base number of threads to spawn.
271  ///
272  /// ```
273  /// # use wapc_pool::HostPoolBuilder;
274  /// let builder = HostPoolBuilder::new().min_threads(1);
275  /// ```
276  ///
277  pub fn min_threads(mut self, min: usize) -> Self {
278    self.min_threads = min;
279    self
280  }
281
282  /// Set the upper limit on the number of threads to spawn.
283  ///
284  /// ```
285  /// # use wapc_pool::HostPoolBuilder;
286  /// let builder = HostPoolBuilder::new().max_threads(5);
287  /// ```
288  ///
289  pub fn max_threads(mut self, max: usize) -> Self {
290    self.max_threads = max;
291    self
292  }
293
294  /// Set the timeout for threads to self-close.
295  ///
296  /// ```
297  /// # use wapc_pool::HostPoolBuilder;
298  /// # use std::time::Duration;
299  /// let builder = HostPoolBuilder::new().max_idle(Duration::from_secs(60));
300  /// ```
301  ///
302  pub fn max_idle(mut self, timeout: Duration) -> Self {
303    self.max_idle = timeout;
304    self
305  }
306
307  /// Set the maximum amount of time to wait before spawning a new worker.
308  ///
309  /// ```
310  /// # use wapc_pool::HostPoolBuilder;
311  /// # use std::time::Duration;
312  /// let builder = HostPoolBuilder::new().max_wait(Duration::from_millis(100));
313  /// ```
314  ///
315  pub fn max_wait(mut self, duration: Duration) -> Self {
316    self.max_wait = duration;
317    self
318  }
319
320  /// Builds a [HostPool] with the current configuration. Warning: this will panic if a factory function is not supplied.
321  ///
322  /// ```
323  /// # use wapc_pool::HostPoolBuilder;
324  /// # use wapc::WapcHost;
325  /// # let bytes = std::fs::read("../../wasm/crates/wapc-guest-test/build/wapc_guest_test.wasm").unwrap();
326  /// let engine = wasmtime_provider::WasmtimeEngineProvider::new(&bytes, None).unwrap();
327  /// let pool = HostPoolBuilder::new()
328  ///   .factory(move || {
329  ///     let engine = engine.clone();
330  ///     WapcHost::new(Box::new(engine), None).unwrap()
331  ///   })
332  ///   .build();
333  /// ```
334  ///
335  pub fn build(mut self) -> HostPool {
336    #[allow(clippy::expect_used)]
337    let factory = self
338      .factory
339      .take()
340      .expect("A waPC host pool must have a factory function.");
341    HostPool::new(
342      self.name.unwrap_or_else(|| "waPC host pool".to_owned()),
343      factory,
344      self.min_threads,
345      self.max_threads,
346      self.max_wait,
347      self.max_idle,
348    )
349  }
350}
351
352#[cfg(test)]
353mod tests {
354
355  use std::time::{Duration, Instant};
356
357  use tokio::join;
358  use wapc::WebAssemblyEngineProvider;
359
360  use super::*;
361
362  #[test_log::test(tokio::test)]
363  async fn test_basic() -> Result<()> {
364    #[derive(Default)]
365    struct Test {
366      host: Option<Arc<wapc::ModuleState>>,
367    }
368    impl WebAssemblyEngineProvider for Test {
369      fn init(
370        &mut self,
371        host: Arc<wapc::ModuleState>,
372      ) -> std::result::Result<(), Box<dyn std::error::Error + Send + Sync>> {
373        self.host = Some(host);
374        Ok(())
375      }
376
377      fn call(
378        &mut self,
379        op_length: i32,
380        msg_length: i32,
381      ) -> std::result::Result<i32, Box<dyn std::error::Error + Send + Sync>> {
382        println!("op len:{}", op_length);
383        println!("msg len:{}", msg_length);
384        std::thread::sleep(Duration::from_millis(100));
385        let host = self.host.take().unwrap();
386        host.set_guest_response(b"{}".to_vec());
387        self.host.replace(host);
388        Ok(1)
389      }
390
391      fn replace(&mut self, _bytes: &[u8]) -> std::result::Result<(), Box<dyn std::error::Error + Send + Sync>> {
392        Ok(())
393      }
394    }
395    let pool = HostPoolBuilder::new()
396      .name("test")
397      .factory(move || WapcHost::new(Box::new(Test::default()), None).unwrap())
398      .min_threads(5)
399      .max_threads(5)
400      .build();
401
402    let now = Instant::now();
403    let result = pool.call("test", b"hello world".to_vec()).await.unwrap();
404    assert_eq!(result, b"{}");
405    let _res = join!(
406      pool.call("test", b"hello world".to_vec()),
407      pool.call("test", b"hello world".to_vec()),
408      pool.call("test", b"hello world".to_vec()),
409      pool.call("test", b"hello world".to_vec()),
410      pool.call("test", b"hello world".to_vec()),
411      pool.call("test", b"hello world".to_vec()),
412      pool.call("test", b"hello world".to_vec()),
413      pool.call("test", b"hello world".to_vec()),
414    );
415    let duration = now.elapsed();
416    println!("Took {}ms", duration.as_millis());
417    assert!(duration.as_millis() < 600);
418
419    Ok(())
420  }
421
422  #[test_log::test(tokio::test)]
423  async fn test_elasticity() -> Result<()> {
424    #[derive(Default)]
425    struct Test {
426      host: Option<Arc<wapc::ModuleState>>,
427    }
428    impl WebAssemblyEngineProvider for Test {
429      fn init(
430        &mut self,
431        host: Arc<wapc::ModuleState>,
432      ) -> std::result::Result<(), Box<dyn std::error::Error + Send + Sync>> {
433        self.host = Some(host);
434        Ok(())
435      }
436
437      fn call(&mut self, _: i32, _: i32) -> std::result::Result<i32, Box<dyn std::error::Error + Send + Sync>> {
438        std::thread::sleep(Duration::from_millis(100));
439        let host = self.host.take().unwrap();
440        host.set_guest_response(b"{}".to_vec());
441        self.host.replace(host);
442        Ok(1)
443      }
444
445      fn replace(&mut self, _bytes: &[u8]) -> std::result::Result<(), Box<dyn std::error::Error + Send + Sync>> {
446        Ok(())
447      }
448    }
449    let pool = HostPoolBuilder::new()
450      .name("test")
451      .factory(move || WapcHost::new(Box::new(Test::default()), None).unwrap())
452      .min_threads(1)
453      .max_threads(5)
454      .max_wait(Duration::from_millis(10))
455      .max_idle(Duration::from_secs(1))
456      .build();
457    assert_eq!(pool.num_active_workers(), 1);
458    let _ = futures::future::join_all(vec![
459      pool.call("test", b"hello world".to_vec()),
460      pool.call("test", b"hello world".to_vec()),
461      pool.call("test", b"hello world".to_vec()),
462    ])
463    .await;
464    assert_eq!(pool.num_active_workers(), 2);
465    let _ = futures::future::join_all(vec![
466      pool.call("test", b"hello world".to_vec()),
467      pool.call("test", b"hello world".to_vec()),
468      pool.call("test", b"hello world".to_vec()),
469      pool.call("test", b"hello world".to_vec()),
470      pool.call("test", b"hello world".to_vec()),
471      pool.call("test", b"hello world".to_vec()),
472      pool.call("test", b"hello world".to_vec()),
473      pool.call("test", b"hello world".to_vec()),
474      pool.call("test", b"hello world".to_vec()),
475    ])
476    .await;
477    assert_eq!(pool.num_active_workers(), 5);
478    std::thread::sleep(Duration::from_millis(1500));
479    assert_eq!(pool.num_active_workers(), 1);
480
481    Ok(())
482  }
483}