1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
use crate::cache::{CacheLookup, CacheLookupState, CacheStore};
use crate::Fetcher;
use std::collections::HashSet;
use std::sync::Arc;

/// Used to batch and cache loads from some datastore. A `Batcher` can be used
/// with any type that implements [`Fetcher`]. `Batcher`s are asynchronous, and
/// designed to be passed and shared between threads or tasks. Cloning a
/// `Batcher` is shallow and can be used to use the same `Fetcher` across
/// multiple threads or tasks.
///
/// A `Batcher` is designed primarily around batching database lookups-- for
/// example, fetching a user from a user ID, where a signle query to retrieve
/// 50 users by ID is significantly faster than 50 separate queries to look up
/// the same set of users.
///
/// A `Batcher` is designed to be ephemeral. In the context of a web service,
/// this means callers should most likely create a new `Batcher` for each
/// request, and **not** a `Batcher` shared across multiple requests.
/// `Batcher`s have no concept of cache invalidation, so old values are stored
/// indefinitely (which means callers may get stale data or may exhaust memory
/// endlessly).
///
/// `Batcher`s introduce a small amount of latency for loads. Each time a
/// `Batcher` receives a key to fetch that hasn't been cached (or a set of
/// keys), it will first wait for more keys to build a batch. The load will only
/// trigger after a timeout is reached or once enough keys have been queued in
/// the batch. See [`BatcherBuilder`](struct.BatcherBuilder.html) for options
/// to tweak latency and batch sizes.
///
/// ## Load semantics
///
/// If the underlying [`Fetcher`] returns an error during the batch request,
/// then all pending [`load`](Batcher::load) and [`load_many`](Batcher::load_many)
/// requests will fail. Subsequent calls to [`load`](Batcher::load) or
/// [`load_many`](Batcher::load_many) with the same keys **will retry**.
///
/// If the underlying [`Fetcher`] succeeds but does not return a value for a
/// given key during a batch request, then the `Batcher` will mark that key as
/// "not found" and an error value of [`NotFound`](LoadError::NotFound) will be
/// returned to all pending [`load`](struct.Batcher.html#method.load) and
/// [`load_many`](struct.Batcher.html#method.load_many) requests. The
/// "not found" status will be preserved, so subsequent calls with the same key
/// will fail and **will not retry**.
pub struct Batcher<F>
where
    F: Fetcher,
{
    label: String,
    cache_store: CacheStore<F::Key, F::Value>,
    _fetch_task: Arc<tokio::task::JoinHandle<()>>,
    fetch_request_tx: tokio::sync::mpsc::Sender<FetchRequest<F::Key>>,
}

impl<F> Batcher<F>
where
    F: Fetcher + Send + Sync + 'static,
{
    /// Create a new `Batcher` that uses the given [`Fetcher`] to retrieve data.
    /// Returns a [`BatcherBuilder`], which can be used to customize the
    /// `Batcher`. Call [`.finish()`](BatcherBuilder::finish) to create the `Batcher`.
    ///
    /// # Examples
    ///
    /// Creating a `Batcher` with default options:
    ///
    /// ```
    /// # use async_trait::async_trait;
    /// # use ultra_batch::{Batcher, Fetcher, Cache};
    /// # struct UserFetcher;
    /// # impl UserFetcher {
    /// #     fn new(db_conn: ()) -> Self { UserFetcher }
    /// #  }
    /// # #[async_trait]
    /// # impl Fetcher for UserFetcher {
    /// #     type Key = ();
    /// #     type Value = ();
    /// #     type Error = anyhow::Error;
    /// #     async fn fetch(&self, keys: &[()], values: &mut Cache<'_, (), ()>) -> anyhow::Result<()> {
    /// #         unimplemented!();
    /// #     }
    /// # }
    /// # #[tokio::main] async fn main() -> anyhow::Result<()> {
    /// # let db_conn = ();
    /// let user_fetcher = UserFetcher::new(db_conn);
    /// let batcher = Batcher::build(user_fetcher).finish();
    /// # Ok(())
    /// # }
    /// ```
    ///
    /// Creating a `Batcher` with custom options:
    ///
    /// ```
    /// # use async_trait::async_trait;
    /// # use ultra_batch::{Batcher, Fetcher, Cache};
    /// # struct UserFetcher;
    /// # impl UserFetcher {
    /// #     fn new(db_conn: ()) -> Self { UserFetcher }
    /// #  }
    /// # #[async_trait]
    /// # impl Fetcher for UserFetcher {
    /// #     type Key = ();
    /// #     type Value = ();
    /// #     type Error = anyhow::Error;
    /// #     async fn fetch(&self, keys: &[()], values: &mut Cache<'_, (), ()>) -> anyhow::Result<()> {
    /// #         unimplemented!();
    /// #     }
    /// # }
    /// # #[tokio::main] async fn main() -> anyhow::Result<()> {
    /// # let db_conn = ();
    /// let user_fetcher = UserFetcher::new(db_conn);
    /// let batcher = Batcher::build(user_fetcher)
    ///     .eager_batch_size(Some(50))
    ///     .delay_duration(tokio::time::Duration::from_millis(5))
    ///     .finish();
    /// # Ok(()) }
    /// ```
    pub fn build(fetcher: F) -> BatcherBuilder<F> {
        BatcherBuilder {
            fetcher,
            delay_duration: tokio::time::Duration::from_millis(10),
            eager_batch_size: Some(100),
            label: "unlabeled-batcher".to_string(),
        }
    }

    /// Load the value with the associated key, either by calling the `Fetcher`
    /// or by loading the cached value. Returns an error if the value could
    /// not be loaded or if a value for the given key was not found.
    ///
    /// See the type-level docs for [`Batcher`](#load-semantics) for more
    /// detailed loading semantics.
    #[tracing::instrument(skip_all, fields(batcher = %self.label))]
    pub async fn load(&self, key: F::Key) -> Result<F::Value, LoadError> {
        let mut values = self.load_keys(&[key]).await?;
        Ok(values.remove(0))
    }

    /// Load all the values for the given keys, either by calling the `Fetcher`
    /// or by loading cached values. Values are returned in the same order as
    /// the input keys. Returns an error if _any_ load fails.
    ///
    /// See the type-level docs for [`Batcher`](#load-semantics) for more
    /// detailed loading semantics.
    #[tracing::instrument(skip_all, fields(batcher = %self.label, num_keys = keys.len()))]
    pub async fn load_many(&self, keys: &[F::Key]) -> Result<Vec<F::Value>, LoadError> {
        let values = self.load_keys(keys).await?;
        Ok(values)
    }

    async fn load_keys(&self, keys: &[F::Key]) -> Result<Vec<F::Value>, LoadError> {
        let mut cache_lookup = CacheLookup::new(keys.to_vec());

        match cache_lookup.lookup(&self.cache_store) {
            CacheLookupState::Done(result) => {
                tracing::debug!(batcher = %self.label, "all keys have already been looked up");
                return result;
            }
            CacheLookupState::Pending => {}
        }
        let pending_keys = cache_lookup.pending_keys();

        let fetch_request_tx = self.fetch_request_tx.clone();
        let (result_tx, result_rx) = tokio::sync::oneshot::channel();

        tracing::debug!(
            num_pending_keys = pending_keys.len(),
            batcher = %self.label,
            "sending a batch of keys to fetch",
        );
        let fetch_request = FetchRequest {
            keys: pending_keys,
            result_tx,
        };
        fetch_request_tx
            .send(fetch_request)
            .await
            .map_err(|_| LoadError::SendError)?;

        match result_rx.await {
            Ok(Ok(())) => {
                tracing::debug!(batcher = %self.label, "fetch response returned successfully");
            }
            Ok(Err(fetch_error)) => {
                tracing::info!("error returned while fetching keys: {}", fetch_error);
                return Err(LoadError::FetchError(fetch_error));
            }
            Err(recv_error) => {
                panic!(
                    "Batch result channel for batcher {batcher} hung up with error: {error}",
                    batcher = self.label,
                    error = recv_error,
                );
            }
        }

        match cache_lookup.lookup(&self.cache_store) {
            CacheLookupState::Done(result) => {
                tracing::debug!("all keys have now been looked up");
                result
            }
            CacheLookupState::Pending => {
                panic!(
                    "Batch result for batcher {batcher} is still pending after result channel was sent",
                    batcher = self.label,
                );
            }
        }
    }
}

impl<F> Clone for Batcher<F>
where
    F: Fetcher,
{
    fn clone(&self) -> Self {
        Batcher {
            cache_store: self.cache_store.clone(),
            _fetch_task: self._fetch_task.clone(),
            fetch_request_tx: self.fetch_request_tx.clone(),
            label: self.label.clone(),
        }
    }
}

/// Used to configure a new [`Batcher`]. A `BatcherBuilder` is returned from
/// [`Batcher::build`].
pub struct BatcherBuilder<F>
where
    F: Fetcher + Send + Sync + 'static,
{
    fetcher: F,
    delay_duration: tokio::time::Duration,
    eager_batch_size: Option<usize>,
    label: String,
}

impl<F> BatcherBuilder<F>
where
    F: Fetcher + Send + Sync + 'static,
{
    /// The maximum amount of time the [`Batcher`] will wait to queue up more
    /// keys before calling the [`Fetcher`].
    pub fn delay_duration(mut self, delay: tokio::time::Duration) -> Self {
        self.delay_duration = delay;
        self
    }

    /// The maximum number of keys to wait for before eagerly calling the
    /// [`Fetcher`]. A value of `Some(n)` will load the batch once `n` or more
    /// keys have been queued (or once the timeout set by
    /// [`delay_duration`](BatcherBuilder::delay_duration) is reached, whichever
    /// comes first). A value of `None` will never eagerly dispatch the queue,
    /// and the [`Batcher`] will always wait for the timeout set by
    /// [`delay_duration`](BatcherBuilder::delay_duration).
    ///
    /// Note that `eager_batch_size` **does not** set an upper limit on the
    /// batch! For example, if [`Batcher::load_many`] is called with more than
    /// `eager_batch_size` items, then the batch will be sent immediately with
    /// _all_ of the provided keys.
    pub fn eager_batch_size(mut self, eager_batch_size: Option<usize>) -> Self {
        self.eager_batch_size = eager_batch_size;
        self
    }

    /// Set a label for the [`Batcher`]. This is only used to improve diagnostic
    /// messages, such as logs.
    pub fn label(mut self, label: impl Into<String>) -> Self {
        self.label = label.into();
        self
    }

    /// Create and return a [`Batcher`] with the given options.
    pub fn finish(self) -> Batcher<F> {
        let cache_store = CacheStore::new();

        let (fetch_request_tx, mut fetch_request_rx) =
            tokio::sync::mpsc::channel::<FetchRequest<F::Key>>(1);
        let label = self.label.clone();

        let fetch_task = tokio::spawn({
            let cache_store = cache_store.clone();
            async move {
                'task: loop {
                    // Wait for some keys to come in
                    let mut pending_keys = HashSet::new();
                    let mut result_txs = vec![];

                    tracing::trace!(batcher = %self.label, "waiting for keys to fetch...");
                    match fetch_request_rx.recv().await {
                        Some(fetch_request) => {
                            tracing::trace!(batcher = %self.label, num_fetch_request_keys = fetch_request.keys.len(), "received initial fetch request");

                            for key in fetch_request.keys {
                                pending_keys.insert(key);
                            }
                            result_txs.push(fetch_request.result_tx);
                        }
                        None => {
                            // Fetch queue closed, so we're done
                            break 'task;
                        }
                    };

                    // Wait for more keys
                    'wait_for_more_keys: loop {
                        let should_run_batch_now = match self.eager_batch_size {
                            Some(eager_batch_size) => pending_keys.len() >= eager_batch_size,
                            None => false,
                        };
                        if should_run_batch_now {
                            // We have enough keys already, so don't wait for more
                            tracing::trace!(
                                batcher = %self.label,
                                num_pending_keys = pending_keys.len(),
                                eager_batch_size = ?self.eager_batch_size,
                                "batch filled up, ready to fetch keys now",
                            );

                            break 'wait_for_more_keys;
                        }

                        let delay = tokio::time::sleep(self.delay_duration);
                        tokio::pin!(delay);

                        tokio::select! {
                            fetch_request = fetch_request_rx.recv() => {
                                match fetch_request {
                                    Some(fetch_request) => {
                                        tracing::trace!(batcher = %self.label, num_fetch_request_keys = fetch_request.keys.len(), "retrieved additional fetch request");

                                        for key in fetch_request.keys {
                                            pending_keys.insert(key);
                                        }
                                        result_txs.push(fetch_request.result_tx);
                                    }
                                    None => {
                                        // Fetch queue closed, so we're done waiting for keys
                                        tracing::debug!(batcher = %self.label, num_pending_keys = pending_keys.len(), "fetch channel closed");
                                        break 'wait_for_more_keys;
                                    }
                                }

                            }
                            _ = &mut delay => {
                                // Reached delay, so we're done waiting for keys
                                tracing::trace!(
                                    batcher = %self.label,
                                    num_pending_keys = pending_keys.len(),
                                    "delay reached while waiting for more keys to fetch"
                                );
                                break 'wait_for_more_keys;
                            }
                        };
                    }

                    let result = {
                        let mut cache = cache_store.as_cache();

                        tracing::trace!(batcher = %self.label, num_pending_keys = pending_keys.len(), num_pending_channels = result_txs.len(), "fetching keys");
                        let pending_keys: Vec<_> = pending_keys.into_iter().collect();
                        let result = self
                            .fetcher
                            .fetch(&pending_keys, &mut cache)
                            .await
                            .map_err(|error| error.to_string());

                        if result.is_ok() {
                            cache.mark_keys_not_found(pending_keys);
                        }

                        result
                    };

                    for result_tx in result_txs {
                        // Ignore error if receiver was already closed
                        let _ = result_tx.send(result.clone());
                    }
                }
            }
        });

        Batcher {
            label,
            cache_store,
            _fetch_task: Arc::new(fetch_task),
            fetch_request_tx,
        }
    }
}

struct FetchRequest<K> {
    keys: Vec<K>,
    result_tx: tokio::sync::oneshot::Sender<Result<(), String>>,
}

/// Error indicating that loading one or more values from a [`Batcher`]
/// failed.
#[derive(Debug, thiserror::Error)]
pub enum LoadError {
    /// The [`Fetcher`] returned an error while loading the batch. The message
    /// contains the error message specified by [`Fetcher::Error`].
    #[error("error while fetching from batch: {}", _0)]
    FetchError(String),

    /// The request could not be sent to the [`Batcher`].
    #[error("error sending fetch request")]
    SendError,

    /// The [`Fetcher`] did not return a value for one or more keys in the batch.
    #[error("value not found")]
    NotFound,
}