ultra_batch/batch_fetcher.rs
1use crate::cache::{CacheLookup, CacheLookupState, CacheStore};
2use crate::Fetcher;
3use std::borrow::Cow;
4use std::collections::HashSet;
5use std::sync::Arc;
6
7/// Batches and caches loads from some datastore. A `BatchFetcher` can be
8/// used with any type that implements [`Fetcher`]. `BatchFetcher`s are
9/// asynchronous and designed to be passed and shared between threads or tasks.
10/// Cloning a `BatchFetcher` is shallow and will use the same [`Fetcher`].
11///
12/// `BatchFetcher` is designed primarily around batching database lookups--
13/// for example, fetching a user from a user ID, where a signle query to
14/// retrieve 50 users by ID is significantly faster than 50 separate queries to
15/// look up the same set of users.
16///
17/// A `BatchFetcher` is designed to be ephemeral. In the context of a web
18/// service, this means callers should most likely create a new `BatchFetcher`
19/// for each request, and **not** a `BatchFetcher` shared across multiple
20/// requests. `BatchFetcher`s have no concept of cache invalidation, so old
21/// values are stored indefinitely (which means callers may get stale data or
22/// may exhaust memory endlessly).
23///
24/// `BatchFetcher`s introduce a small amount of latency for loads. Each time a
25/// `BatchFetcher` receives a key to fetch that hasn't been cached (or a set of
26/// keys), it will first wait for more keys to build a batch. The load will only
27/// trigger after a timeout is reached or once enough keys have been queued in
28/// the batch. See [`BatchFetcherBuilder`] for options to tweak latency and
29/// batch sizes.
30///
31/// See also [`BatchExecutor`](crate::BatchExecutor) for a more general type
32/// designed primarly for mutations, but can also be used for fetching with
33/// more control over how batches are fetched.
34///
35/// ## Load semantics
36///
37/// If the underlying [`Fetcher`] returns an error during the batch request,
38/// then all pending [`load`](BatchFetcher::load) and [`load_many`](BatchFetcher::load_many)
39/// requests will fail. Subsequent calls to [`load`](BatchFetcher::load) or
40/// [`load_many`](BatchFetcher::load_many) with the same keys **will retry**.
41///
42/// If the underlying [`Fetcher`] succeeds but does not return a value for a
43/// given key during a batch request, then the `BatchFetcher` will mark that key
44/// as "not found" and an error value of [`NotFound`](LoadError::NotFound) will
45/// be returned to all pending [`load`](BatchFetcher::load) and
46/// [`load_many`](BatchFetcher::load_many) requests. The "not found" status will
47/// be preserved, so subsequent calls with the same key will fail and **will
48/// not retry**.
49pub struct BatchFetcher<F>
50where
51 F: Fetcher,
52{
53 label: Cow<'static, str>,
54 cache_store: CacheStore<F::Key, F::Value>,
55 _fetch_task: Arc<tokio::task::JoinHandle<()>>,
56 fetch_request_tx: tokio::sync::mpsc::Sender<FetchRequest<F::Key>>,
57}
58
59impl<F> BatchFetcher<F>
60where
61 F: Fetcher + Send + Sync + 'static,
62{
63 /// Create a new `BatchFetcher` that uses the given [`Fetcher`] to retrieve
64 /// data. Returns a [`BatchFetcherBuilder`], which can be used to customize
65 /// the `BatchFetcher`. Call [`.finish()`](BatchFetcherBuilder::finish) to
66 /// create the `BatchFetcher`.
67 ///
68 /// # Examples
69 ///
70 /// Creating a `BatchFetcher` with default options:
71 ///
72 /// ```
73 /// # use ultra_batch::{BatchFetcher, Fetcher, Cache};
74 /// # struct UserFetcher;
75 /// # impl UserFetcher {
76 /// # fn new(db_conn: ()) -> Self { UserFetcher }
77 /// # }
78 /// # impl Fetcher for UserFetcher {
79 /// # type Key = ();
80 /// # type Value = ();
81 /// # type Error = anyhow::Error;
82 /// # async fn fetch(&self, keys: &[()], values: &mut Cache<'_, (), ()>) -> anyhow::Result<()> {
83 /// # unimplemented!();
84 /// # }
85 /// # }
86 /// # #[tokio::main] async fn main() -> anyhow::Result<()> {
87 /// # let db_conn = ();
88 /// let user_fetcher = UserFetcher::new(db_conn);
89 /// let batch_fetcher = BatchFetcher::build(user_fetcher).finish();
90 /// # Ok(())
91 /// # }
92 /// ```
93 ///
94 /// Creating a `BatchFetcher` with custom options:
95 ///
96 /// ```
97 /// # use ultra_batch::{BatchFetcher, Fetcher, Cache};
98 /// # struct UserFetcher;
99 /// # impl UserFetcher {
100 /// # fn new(db_conn: ()) -> Self { UserFetcher }
101 /// # }
102 /// # impl Fetcher for UserFetcher {
103 /// # type Key = ();
104 /// # type Value = ();
105 /// # type Error = anyhow::Error;
106 /// # async fn fetch(&self, keys: &[()], values: &mut Cache<'_, (), ()>) -> anyhow::Result<()> {
107 /// # unimplemented!();
108 /// # }
109 /// # }
110 /// # #[tokio::main] async fn main() -> anyhow::Result<()> {
111 /// # let db_conn = ();
112 /// let user_fetcher = UserFetcher::new(db_conn);
113 /// let batch_fetcher = BatchFetcher::build(user_fetcher)
114 /// .eager_batch_size(Some(50))
115 /// .delay_duration(tokio::time::Duration::from_millis(5))
116 /// .finish();
117 /// # Ok(()) }
118 /// ```
119 pub fn build(fetcher: F) -> BatchFetcherBuilder<F> {
120 BatchFetcherBuilder {
121 fetcher,
122 delay_duration: tokio::time::Duration::from_millis(10),
123 eager_batch_size: Some(100),
124 label: "unlabeled-batch-fetcher".into(),
125 }
126 }
127
128 /// Load the value with the associated key, either by calling the `Fetcher`
129 /// or by loading the cached value. Returns an error if the value could
130 /// not be loaded or if a value for the given key was not found.
131 ///
132 /// See the type-level docs for [`BatchFetcher`](#load-semantics) for more
133 /// detailed loading semantics.
134 #[tracing::instrument(skip_all, fields(batch_fetcher = %self.label))]
135 pub async fn load(&self, key: F::Key) -> Result<F::Value, LoadError> {
136 let mut values = self.load_keys(&[key]).await?;
137 Ok(values.remove(0))
138 }
139
140 /// Load all the values for the given keys, either by calling the `Fetcher`
141 /// or by loading cached values. Values are returned in the same order as
142 /// the input keys. Returns an error if _any_ load fails.
143 ///
144 /// See the type-level docs for [`BatchFetcher`](#load-semantics) for more
145 /// detailed loading semantics.
146 #[tracing::instrument(skip_all, fields(batch_fetcher = %self.label, num_keys = keys.len()))]
147 pub async fn load_many(&self, keys: &[F::Key]) -> Result<Vec<F::Value>, LoadError> {
148 let values = self.load_keys(keys).await?;
149 Ok(values)
150 }
151
152 async fn load_keys(&self, keys: &[F::Key]) -> Result<Vec<F::Value>, LoadError> {
153 let mut cache_lookup = CacheLookup::new(keys.to_vec());
154
155 match cache_lookup.lookup(&self.cache_store) {
156 CacheLookupState::Done(result) => {
157 tracing::debug!(batch_fetcher = %self.label, "all keys have already been looked up");
158 return result;
159 }
160 CacheLookupState::Pending => {}
161 }
162 let pending_keys = cache_lookup.pending_keys();
163
164 let fetch_request_tx = self.fetch_request_tx.clone();
165 let (result_tx, result_rx) = tokio::sync::oneshot::channel();
166
167 tracing::debug!(
168 num_pending_keys = pending_keys.len(),
169 batch_fetcher = %self.label,
170 "sending a batch of keys to fetch",
171 );
172 let fetch_request = FetchRequest {
173 keys: pending_keys,
174 result_tx,
175 };
176 fetch_request_tx
177 .send(fetch_request)
178 .await
179 .map_err(|_| LoadError::SendError)?;
180
181 match result_rx.await {
182 Ok(Ok(())) => {
183 tracing::debug!(batch_fetcher = %self.label, "fetch response returned successfully");
184 }
185 Ok(Err(fetch_error)) => {
186 tracing::info!("error returned while fetching keys: {fetch_error}");
187 return Err(LoadError::FetchError(fetch_error));
188 }
189 Err(recv_error) => {
190 panic!(
191 "Batch result channel for batch fetcher {} hung up with error: {recv_error}",
192 self.label,
193 );
194 }
195 }
196
197 match cache_lookup.lookup(&self.cache_store) {
198 CacheLookupState::Done(result) => {
199 tracing::debug!("all keys have now been looked up");
200 result
201 }
202 CacheLookupState::Pending => {
203 panic!(
204 "Batch result for batch fetcher {} is still pending after result channel was sent",
205 self.label,
206 );
207 }
208 }
209 }
210}
211
212impl<F> Clone for BatchFetcher<F>
213where
214 F: Fetcher,
215{
216 fn clone(&self) -> Self {
217 BatchFetcher {
218 cache_store: self.cache_store.clone(),
219 _fetch_task: self._fetch_task.clone(),
220 fetch_request_tx: self.fetch_request_tx.clone(),
221 label: self.label.clone(),
222 }
223 }
224}
225
226/// Used to configure a new [`BatchFetcher`]. A `BatchFetcherBuilder` is
227/// returned from [`BatchFetcher::build`].
228pub struct BatchFetcherBuilder<F>
229where
230 F: Fetcher + Send + Sync + 'static,
231{
232 fetcher: F,
233 delay_duration: tokio::time::Duration,
234 eager_batch_size: Option<usize>,
235 label: Cow<'static, str>,
236}
237
238impl<F> BatchFetcherBuilder<F>
239where
240 F: Fetcher + Send + Sync + 'static,
241{
242 /// The maximum amount of time the [`BatchFetcher`] will wait to queue up
243 /// more keys before calling the [`Fetcher`].
244 pub fn delay_duration(mut self, delay: tokio::time::Duration) -> Self {
245 self.delay_duration = delay;
246 self
247 }
248
249 /// The maximum number of keys to wait for before eagerly calling the
250 /// [`Fetcher`]. A value of `Some(n)` will load the batch once `n` or more
251 /// keys have been queued (or once the timeout set by
252 /// [`delay_duration`](BatchFetcherBuilder::delay_duration) is reached,
253 /// whichever comes first). A value of `None` will never eagerly dispatch
254 /// the queue, and the [`BatchFetcher`] will always wait for the timeout set
255 /// by [`delay_duration`](BatchFetcherBuilder::delay_duration).
256 ///
257 /// Note that `eager_batch_size` **does not** set an upper limit on the
258 /// batch! For example, if [`BatchFetcher::load_many`] is called with more
259 /// than `eager_batch_size` items, then the batch will be sent immediately
260 /// with _all_ of the provided keys.
261 pub fn eager_batch_size(mut self, eager_batch_size: Option<usize>) -> Self {
262 self.eager_batch_size = eager_batch_size;
263 self
264 }
265
266 /// Set a label for the [`BatchFetcher`]. This is only used to improve
267 /// diagnostic messages, such as log messages.
268 pub fn label(mut self, label: impl Into<Cow<'static, str>>) -> Self {
269 self.label = label.into();
270 self
271 }
272
273 /// Create and return a [`BatchFetcher`] with the given options.
274 pub fn finish(self) -> BatchFetcher<F> {
275 let cache_store = CacheStore::new();
276
277 let (fetch_request_tx, mut fetch_request_rx) =
278 tokio::sync::mpsc::channel::<FetchRequest<F::Key>>(1);
279 let label = self.label.clone();
280
281 let fetch_task = tokio::spawn({
282 let cache_store = cache_store.clone();
283 async move {
284 'task: loop {
285 // Wait for some keys to come in
286 let mut pending_keys = HashSet::new();
287 let mut result_txs = vec![];
288
289 tracing::trace!(batch_fetcher = %self.label, "waiting for keys to fetch...");
290 match fetch_request_rx.recv().await {
291 Some(fetch_request) => {
292 tracing::trace!(batch_fetcher = %self.label, num_fetch_request_keys = fetch_request.keys.len(), "received initial fetch request");
293
294 for key in fetch_request.keys {
295 pending_keys.insert(key);
296 }
297 result_txs.push(fetch_request.result_tx);
298 }
299 None => {
300 // Fetch queue closed, so we're done
301 break 'task;
302 }
303 };
304
305 // Wait for more keys
306 'wait_for_more_keys: loop {
307 let should_run_batch_now = match self.eager_batch_size {
308 Some(eager_batch_size) => pending_keys.len() >= eager_batch_size,
309 None => false,
310 };
311 if should_run_batch_now {
312 // We have enough keys already, so don't wait for more
313 tracing::trace!(
314 batch_fetcher = %self.label,
315 num_pending_keys = pending_keys.len(),
316 eager_batch_size = ?self.eager_batch_size,
317 "batch filled up, ready to fetch keys now",
318 );
319
320 break 'wait_for_more_keys;
321 }
322
323 let delay = tokio::time::sleep(self.delay_duration);
324 tokio::pin!(delay);
325
326 tokio::select! {
327 fetch_request = fetch_request_rx.recv() => {
328 match fetch_request {
329 Some(fetch_request) => {
330 tracing::trace!(batch_fetcher = %self.label, num_fetch_request_keys = fetch_request.keys.len(), "retrieved additional fetch request");
331
332 for key in fetch_request.keys {
333 pending_keys.insert(key);
334 }
335 result_txs.push(fetch_request.result_tx);
336 }
337 None => {
338 // Fetch queue closed, so we're done waiting for keys
339 tracing::debug!(batch_fetcher = %self.label, num_pending_keys = pending_keys.len(), "fetch channel closed");
340 break 'wait_for_more_keys;
341 }
342 }
343
344 }
345 _ = &mut delay => {
346 // Reached delay, so we're done waiting for keys
347 tracing::trace!(
348 batch_fetcher = %self.label,
349 num_pending_keys = pending_keys.len(),
350 "delay reached while waiting for more keys to fetch"
351 );
352 break 'wait_for_more_keys;
353 }
354 };
355 }
356
357 let result = {
358 let mut cache = cache_store.as_cache();
359
360 tracing::trace!(batch_fetcher = %self.label, num_pending_keys = pending_keys.len(), num_pending_channels = result_txs.len(), "fetching keys");
361 let pending_keys: Vec<_> = pending_keys.into_iter().collect();
362 let result = self
363 .fetcher
364 .fetch(&pending_keys, &mut cache)
365 .await
366 .map_err(|error| error.to_string());
367
368 if result.is_ok() {
369 cache.mark_keys_not_found(pending_keys);
370 }
371
372 result
373 };
374
375 for result_tx in result_txs {
376 // Ignore error if receiver was already closed
377 let _ = result_tx.send(result.clone());
378 }
379 }
380 }
381 });
382
383 BatchFetcher {
384 label,
385 cache_store,
386 _fetch_task: Arc::new(fetch_task),
387 fetch_request_tx,
388 }
389 }
390}
391
392struct FetchRequest<K> {
393 keys: Vec<K>,
394 result_tx: tokio::sync::oneshot::Sender<Result<(), String>>,
395}
396
397/// Error indicating that loading one or more values from a [`BatchFetcher`]
398/// failed.
399#[derive(Debug, thiserror::Error)]
400pub enum LoadError {
401 /// The [`Fetcher`] returned an error while loading the batch. The message
402 /// contains the error message specified by [`Fetcher::Error`].
403 #[error("error while fetching from batch: {}", _0)]
404 FetchError(String),
405
406 /// The request could not be sent to the [`BatchFetcher`].
407 #[error("error sending fetch request")]
408 SendError,
409
410 /// The [`Fetcher`] did not return a value for one or more keys in the batch.
411 #[error("value not found")]
412 NotFound,
413}