ultra_batch/
batch_fetcher.rs

1use crate::cache::{CacheLookup, CacheLookupState, CacheStore};
2use crate::Fetcher;
3use std::borrow::Cow;
4use std::collections::HashSet;
5use std::sync::Arc;
6
7/// Batches and caches loads from some datastore. A `BatchFetcher` can be
8/// used with any type that implements [`Fetcher`]. `BatchFetcher`s are
9/// asynchronous and designed to be passed and shared between threads or tasks.
10/// Cloning a `BatchFetcher` is shallow and will use the same [`Fetcher`].
11///
12/// `BatchFetcher` is designed primarily around batching database lookups--
13/// for example, fetching a user from a user ID, where a signle query to
14/// retrieve 50 users by ID is significantly faster than 50 separate queries to
15/// look up the same set of users.
16///
17/// A `BatchFetcher` is designed to be ephemeral. In the context of a web
18/// service, this means callers should most likely create a new `BatchFetcher`
19/// for each request, and **not** a `BatchFetcher` shared across multiple
20/// requests. `BatchFetcher`s have no concept of cache invalidation, so old
21/// values are stored indefinitely (which means callers may get stale data or
22/// may exhaust memory endlessly).
23///
24/// `BatchFetcher`s introduce a small amount of latency for loads. Each time a
25/// `BatchFetcher` receives a key to fetch that hasn't been cached (or a set of
26/// keys), it will first wait for more keys to build a batch. The load will only
27/// trigger after a timeout is reached or once enough keys have been queued in
28/// the batch. See [`BatchFetcherBuilder`] for options to tweak latency and
29/// batch sizes.
30///
31/// See also [`BatchExecutor`](crate::BatchExecutor) for a more general type
32/// designed primarly for mutations, but can also be used for fetching with
33/// more control over how batches are fetched.
34///
35/// ## Load semantics
36///
37/// If the underlying [`Fetcher`] returns an error during the batch request,
38/// then all pending [`load`](BatchFetcher::load) and [`load_many`](BatchFetcher::load_many)
39/// requests will fail. Subsequent calls to [`load`](BatchFetcher::load) or
40/// [`load_many`](BatchFetcher::load_many) with the same keys **will retry**.
41///
42/// If the underlying [`Fetcher`] succeeds but does not return a value for a
43/// given key during a batch request, then the `BatchFetcher` will mark that key
44/// as "not found" and an error value of [`NotFound`](LoadError::NotFound) will
45/// be returned to all pending [`load`](BatchFetcher::load) and
46/// [`load_many`](BatchFetcher::load_many) requests. The "not found" status will
47/// be preserved, so subsequent calls with the same key will fail and **will
48/// not retry**.
49pub struct BatchFetcher<F>
50where
51    F: Fetcher,
52{
53    label: Cow<'static, str>,
54    cache_store: CacheStore<F::Key, F::Value>,
55    _fetch_task: Arc<tokio::task::JoinHandle<()>>,
56    fetch_request_tx: tokio::sync::mpsc::Sender<FetchRequest<F::Key>>,
57}
58
59impl<F> BatchFetcher<F>
60where
61    F: Fetcher + Send + Sync + 'static,
62{
63    /// Create a new `BatchFetcher` that uses the given [`Fetcher`] to retrieve
64    /// data. Returns a [`BatchFetcherBuilder`], which can be used to customize
65    /// the `BatchFetcher`. Call [`.finish()`](BatchFetcherBuilder::finish) to
66    /// create the `BatchFetcher`.
67    ///
68    /// # Examples
69    ///
70    /// Creating a `BatchFetcher` with default options:
71    ///
72    /// ```
73    /// # use ultra_batch::{BatchFetcher, Fetcher, Cache};
74    /// # struct UserFetcher;
75    /// # impl UserFetcher {
76    /// #     fn new(db_conn: ()) -> Self { UserFetcher }
77    /// #  }
78    /// # impl Fetcher for UserFetcher {
79    /// #     type Key = ();
80    /// #     type Value = ();
81    /// #     type Error = anyhow::Error;
82    /// #     async fn fetch(&self, keys: &[()], values: &mut Cache<'_, (), ()>) -> anyhow::Result<()> {
83    /// #         unimplemented!();
84    /// #     }
85    /// # }
86    /// # #[tokio::main] async fn main() -> anyhow::Result<()> {
87    /// # let db_conn = ();
88    /// let user_fetcher = UserFetcher::new(db_conn);
89    /// let batch_fetcher = BatchFetcher::build(user_fetcher).finish();
90    /// # Ok(())
91    /// # }
92    /// ```
93    ///
94    /// Creating a `BatchFetcher` with custom options:
95    ///
96    /// ```
97    /// # use ultra_batch::{BatchFetcher, Fetcher, Cache};
98    /// # struct UserFetcher;
99    /// # impl UserFetcher {
100    /// #     fn new(db_conn: ()) -> Self { UserFetcher }
101    /// #  }
102    /// # impl Fetcher for UserFetcher {
103    /// #     type Key = ();
104    /// #     type Value = ();
105    /// #     type Error = anyhow::Error;
106    /// #     async fn fetch(&self, keys: &[()], values: &mut Cache<'_, (), ()>) -> anyhow::Result<()> {
107    /// #         unimplemented!();
108    /// #     }
109    /// # }
110    /// # #[tokio::main] async fn main() -> anyhow::Result<()> {
111    /// # let db_conn = ();
112    /// let user_fetcher = UserFetcher::new(db_conn);
113    /// let batch_fetcher = BatchFetcher::build(user_fetcher)
114    ///     .eager_batch_size(Some(50))
115    ///     .delay_duration(tokio::time::Duration::from_millis(5))
116    ///     .finish();
117    /// # Ok(()) }
118    /// ```
119    pub fn build(fetcher: F) -> BatchFetcherBuilder<F> {
120        BatchFetcherBuilder {
121            fetcher,
122            delay_duration: tokio::time::Duration::from_millis(10),
123            eager_batch_size: Some(100),
124            label: "unlabeled-batch-fetcher".into(),
125        }
126    }
127
128    /// Load the value with the associated key, either by calling the `Fetcher`
129    /// or by loading the cached value. Returns an error if the value could
130    /// not be loaded or if a value for the given key was not found.
131    ///
132    /// See the type-level docs for [`BatchFetcher`](#load-semantics) for more
133    /// detailed loading semantics.
134    #[tracing::instrument(skip_all, fields(batch_fetcher = %self.label))]
135    pub async fn load(&self, key: F::Key) -> Result<F::Value, LoadError> {
136        let mut values = self.load_keys(&[key]).await?;
137        Ok(values.remove(0))
138    }
139
140    /// Load all the values for the given keys, either by calling the `Fetcher`
141    /// or by loading cached values. Values are returned in the same order as
142    /// the input keys. Returns an error if _any_ load fails.
143    ///
144    /// See the type-level docs for [`BatchFetcher`](#load-semantics) for more
145    /// detailed loading semantics.
146    #[tracing::instrument(skip_all, fields(batch_fetcher = %self.label, num_keys = keys.len()))]
147    pub async fn load_many(&self, keys: &[F::Key]) -> Result<Vec<F::Value>, LoadError> {
148        let values = self.load_keys(keys).await?;
149        Ok(values)
150    }
151
152    async fn load_keys(&self, keys: &[F::Key]) -> Result<Vec<F::Value>, LoadError> {
153        let mut cache_lookup = CacheLookup::new(keys.to_vec());
154
155        match cache_lookup.lookup(&self.cache_store) {
156            CacheLookupState::Done(result) => {
157                tracing::debug!(batch_fetcher = %self.label, "all keys have already been looked up");
158                return result;
159            }
160            CacheLookupState::Pending => {}
161        }
162        let pending_keys = cache_lookup.pending_keys();
163
164        let fetch_request_tx = self.fetch_request_tx.clone();
165        let (result_tx, result_rx) = tokio::sync::oneshot::channel();
166
167        tracing::debug!(
168            num_pending_keys = pending_keys.len(),
169            batch_fetcher = %self.label,
170            "sending a batch of keys to fetch",
171        );
172        let fetch_request = FetchRequest {
173            keys: pending_keys,
174            result_tx,
175        };
176        fetch_request_tx
177            .send(fetch_request)
178            .await
179            .map_err(|_| LoadError::SendError)?;
180
181        match result_rx.await {
182            Ok(Ok(())) => {
183                tracing::debug!(batch_fetcher = %self.label, "fetch response returned successfully");
184            }
185            Ok(Err(fetch_error)) => {
186                tracing::info!("error returned while fetching keys: {fetch_error}");
187                return Err(LoadError::FetchError(fetch_error));
188            }
189            Err(recv_error) => {
190                panic!(
191                    "Batch result channel for batch fetcher {} hung up with error: {recv_error}",
192                    self.label,
193                );
194            }
195        }
196
197        match cache_lookup.lookup(&self.cache_store) {
198            CacheLookupState::Done(result) => {
199                tracing::debug!("all keys have now been looked up");
200                result
201            }
202            CacheLookupState::Pending => {
203                panic!(
204                    "Batch result for batch fetcher {} is still pending after result channel was sent",
205                    self.label,
206                );
207            }
208        }
209    }
210}
211
212impl<F> Clone for BatchFetcher<F>
213where
214    F: Fetcher,
215{
216    fn clone(&self) -> Self {
217        BatchFetcher {
218            cache_store: self.cache_store.clone(),
219            _fetch_task: self._fetch_task.clone(),
220            fetch_request_tx: self.fetch_request_tx.clone(),
221            label: self.label.clone(),
222        }
223    }
224}
225
226/// Used to configure a new [`BatchFetcher`]. A `BatchFetcherBuilder` is
227/// returned from [`BatchFetcher::build`].
228pub struct BatchFetcherBuilder<F>
229where
230    F: Fetcher + Send + Sync + 'static,
231{
232    fetcher: F,
233    delay_duration: tokio::time::Duration,
234    eager_batch_size: Option<usize>,
235    label: Cow<'static, str>,
236}
237
238impl<F> BatchFetcherBuilder<F>
239where
240    F: Fetcher + Send + Sync + 'static,
241{
242    /// The maximum amount of time the [`BatchFetcher`] will wait to queue up
243    /// more keys before calling the [`Fetcher`].
244    pub fn delay_duration(mut self, delay: tokio::time::Duration) -> Self {
245        self.delay_duration = delay;
246        self
247    }
248
249    /// The maximum number of keys to wait for before eagerly calling the
250    /// [`Fetcher`]. A value of `Some(n)` will load the batch once `n` or more
251    /// keys have been queued (or once the timeout set by
252    /// [`delay_duration`](BatchFetcherBuilder::delay_duration) is reached,
253    /// whichever comes first). A value of `None` will never eagerly dispatch
254    /// the queue, and the [`BatchFetcher`] will always wait for the timeout set
255    /// by [`delay_duration`](BatchFetcherBuilder::delay_duration).
256    ///
257    /// Note that `eager_batch_size` **does not** set an upper limit on the
258    /// batch! For example, if [`BatchFetcher::load_many`] is called with more
259    /// than `eager_batch_size` items, then the batch will be sent immediately
260    /// with _all_ of the provided keys.
261    pub fn eager_batch_size(mut self, eager_batch_size: Option<usize>) -> Self {
262        self.eager_batch_size = eager_batch_size;
263        self
264    }
265
266    /// Set a label for the [`BatchFetcher`]. This is only used to improve
267    /// diagnostic messages, such as log messages.
268    pub fn label(mut self, label: impl Into<Cow<'static, str>>) -> Self {
269        self.label = label.into();
270        self
271    }
272
273    /// Create and return a [`BatchFetcher`] with the given options.
274    pub fn finish(self) -> BatchFetcher<F> {
275        let cache_store = CacheStore::new();
276
277        let (fetch_request_tx, mut fetch_request_rx) =
278            tokio::sync::mpsc::channel::<FetchRequest<F::Key>>(1);
279        let label = self.label.clone();
280
281        let fetch_task = tokio::spawn({
282            let cache_store = cache_store.clone();
283            async move {
284                'task: loop {
285                    // Wait for some keys to come in
286                    let mut pending_keys = HashSet::new();
287                    let mut result_txs = vec![];
288
289                    tracing::trace!(batch_fetcher = %self.label, "waiting for keys to fetch...");
290                    match fetch_request_rx.recv().await {
291                        Some(fetch_request) => {
292                            tracing::trace!(batch_fetcher = %self.label, num_fetch_request_keys = fetch_request.keys.len(), "received initial fetch request");
293
294                            for key in fetch_request.keys {
295                                pending_keys.insert(key);
296                            }
297                            result_txs.push(fetch_request.result_tx);
298                        }
299                        None => {
300                            // Fetch queue closed, so we're done
301                            break 'task;
302                        }
303                    };
304
305                    // Wait for more keys
306                    'wait_for_more_keys: loop {
307                        let should_run_batch_now = match self.eager_batch_size {
308                            Some(eager_batch_size) => pending_keys.len() >= eager_batch_size,
309                            None => false,
310                        };
311                        if should_run_batch_now {
312                            // We have enough keys already, so don't wait for more
313                            tracing::trace!(
314                                batch_fetcher = %self.label,
315                                num_pending_keys = pending_keys.len(),
316                                eager_batch_size = ?self.eager_batch_size,
317                                "batch filled up, ready to fetch keys now",
318                            );
319
320                            break 'wait_for_more_keys;
321                        }
322
323                        let delay = tokio::time::sleep(self.delay_duration);
324                        tokio::pin!(delay);
325
326                        tokio::select! {
327                            fetch_request = fetch_request_rx.recv() => {
328                                match fetch_request {
329                                    Some(fetch_request) => {
330                                        tracing::trace!(batch_fetcher = %self.label, num_fetch_request_keys = fetch_request.keys.len(), "retrieved additional fetch request");
331
332                                        for key in fetch_request.keys {
333                                            pending_keys.insert(key);
334                                        }
335                                        result_txs.push(fetch_request.result_tx);
336                                    }
337                                    None => {
338                                        // Fetch queue closed, so we're done waiting for keys
339                                        tracing::debug!(batch_fetcher = %self.label, num_pending_keys = pending_keys.len(), "fetch channel closed");
340                                        break 'wait_for_more_keys;
341                                    }
342                                }
343
344                            }
345                            _ = &mut delay => {
346                                // Reached delay, so we're done waiting for keys
347                                tracing::trace!(
348                                    batch_fetcher = %self.label,
349                                    num_pending_keys = pending_keys.len(),
350                                    "delay reached while waiting for more keys to fetch"
351                                );
352                                break 'wait_for_more_keys;
353                            }
354                        };
355                    }
356
357                    let result = {
358                        let mut cache = cache_store.as_cache();
359
360                        tracing::trace!(batch_fetcher = %self.label, num_pending_keys = pending_keys.len(), num_pending_channels = result_txs.len(), "fetching keys");
361                        let pending_keys: Vec<_> = pending_keys.into_iter().collect();
362                        let result = self
363                            .fetcher
364                            .fetch(&pending_keys, &mut cache)
365                            .await
366                            .map_err(|error| error.to_string());
367
368                        if result.is_ok() {
369                            cache.mark_keys_not_found(pending_keys);
370                        }
371
372                        result
373                    };
374
375                    for result_tx in result_txs {
376                        // Ignore error if receiver was already closed
377                        let _ = result_tx.send(result.clone());
378                    }
379                }
380            }
381        });
382
383        BatchFetcher {
384            label,
385            cache_store,
386            _fetch_task: Arc::new(fetch_task),
387            fetch_request_tx,
388        }
389    }
390}
391
392struct FetchRequest<K> {
393    keys: Vec<K>,
394    result_tx: tokio::sync::oneshot::Sender<Result<(), String>>,
395}
396
397/// Error indicating that loading one or more values from a [`BatchFetcher`]
398/// failed.
399#[derive(Debug, thiserror::Error)]
400pub enum LoadError {
401    /// The [`Fetcher`] returned an error while loading the batch. The message
402    /// contains the error message specified by [`Fetcher::Error`].
403    #[error("error while fetching from batch: {}", _0)]
404    FetchError(String),
405
406    /// The request could not be sent to the [`BatchFetcher`].
407    #[error("error sending fetch request")]
408    SendError,
409
410    /// The [`Fetcher`] did not return a value for one or more keys in the batch.
411    #[error("value not found")]
412    NotFound,
413}