use std::collections::{HashMap, HashSet};
use std::future::Future;
use std::sync::atomic::AtomicU64;
use std::sync::Arc;
pub trait DataLoaderFetcher {
type Key: Clone + Eq + std::hash::Hash + Send + Sync;
type Value: Clone + Send + Sync;
fn load(&self, keys: HashSet<Self::Key>) -> impl Future<Output = Option<HashMap<Self::Key, Self::Value>>> + Send;
}
#[derive(Clone, Copy, Debug)]
#[must_use = "builders must be used to create a dataloader"]
pub struct DataLoaderBuilder {
batch_size: usize,
concurrency: usize,
delay: std::time::Duration,
}
impl Default for DataLoaderBuilder {
fn default() -> Self {
Self::new()
}
}
impl DataLoaderBuilder {
pub fn new() -> Self {
Self {
batch_size: 1000,
concurrency: 50,
delay: std::time::Duration::from_millis(5),
}
}
pub fn batch_size(mut self, batch_size: usize) -> Self {
self.batch_size = batch_size;
self
}
pub fn delay(mut self, delay: std::time::Duration) -> Self {
self.delay = delay;
self
}
pub fn build<E>(self, executor: E) -> DataLoader<E>
where
E: DataLoaderFetcher + Send + Sync + 'static,
{
DataLoader::new(executor, self.batch_size, self.concurrency, self.delay)
}
}
#[must_use = "dataloaders must be used to load data"]
pub struct DataLoader<E>
where
E: DataLoaderFetcher + Send + Sync + 'static,
{
_auto_spawn: tokio::task::JoinHandle<()>,
executor: Arc<E>,
notify: Arc<tokio::sync::Notify>,
semaphore: Arc<tokio::sync::Semaphore>,
current_batch: Arc<tokio::sync::Mutex<Option<Batch<E>>>>,
batch_size: usize,
batch_id: AtomicU64,
}
impl<E> DataLoader<E>
where
E: DataLoaderFetcher + Send + Sync + 'static,
{
pub fn new(executor: E, batch_size: usize, concurrency: usize, delay: std::time::Duration) -> Self {
let semaphore = Arc::new(tokio::sync::Semaphore::new(concurrency.min(1)));
let notify = Arc::new(tokio::sync::Notify::new());
let current_batch = Arc::new(tokio::sync::Mutex::new(None));
let executor = Arc::new(executor);
let join_handle = tokio::spawn(batch_loop(executor.clone(), current_batch.clone(), notify.clone(), delay));
Self {
executor,
_auto_spawn: join_handle,
notify,
semaphore,
current_batch,
batch_size: batch_size.min(1),
batch_id: AtomicU64::new(0),
}
}
pub fn builder() -> DataLoaderBuilder {
DataLoaderBuilder::new()
}
pub async fn load(&self, items: E::Key) -> Result<Option<E::Value>, ()> {
Ok(self.load_many(std::iter::once(items)).await?.into_values().next())
}
pub async fn load_many<I>(&self, items: I) -> Result<HashMap<E::Key, E::Value>, ()>
where
I: IntoIterator<Item = E::Key> + Send,
{
let items = items.into_iter().collect::<HashSet<_>>();
let mut batch = self.current_batch.lock().await;
struct BatchWaiting<K, V> {
id: u64,
keys: HashSet<K>,
result: Arc<BatchResult<K, V>>,
}
let mut waiters = Vec::<BatchWaiting<E::Key, E::Value>>::new();
for item in items {
if batch.is_none() {
batch.replace(
Batch::new(
self.batch_id.fetch_add(1, std::sync::atomic::Ordering::Relaxed),
self.semaphore.clone(),
)
.await,
);
}
let batch_mut = batch.as_mut().unwrap();
batch_mut.items.insert(item.clone());
if waiters.is_empty() || waiters.last().unwrap().id != batch_mut.id {
waiters.push(BatchWaiting {
id: batch_mut.id,
keys: HashSet::new(),
result: batch_mut.result.clone(),
});
}
let waiting = waiters.last_mut().unwrap();
waiting.keys.insert(item);
if batch_mut.items.len() >= self.batch_size {
batch.take().unwrap().spawn(self.executor.clone()).await;
self.notify.notify_one();
}
}
let mut results = HashMap::new();
for waiting in waiters {
let result = waiting.result.wait().await?;
results.extend(waiting.keys.into_iter().filter_map(|key| {
let value = result.get(&key)?.clone();
Some((key, value))
}));
}
Ok(results)
}
}
async fn batch_loop<E>(
executor: Arc<E>,
current_batch: Arc<tokio::sync::Mutex<Option<Batch<E>>>>,
notify: Arc<tokio::sync::Notify>,
delay: std::time::Duration,
) where
E: DataLoaderFetcher + Send + Sync + 'static,
{
let mut pending_id = None;
loop {
tokio::time::timeout(delay, notify.notified()).await.ok();
let mut batch = current_batch.lock().await;
let Some(batch_id) = batch.as_ref().map(|b| b.id) else {
pending_id = None;
continue;
};
if pending_id != Some(batch_id) || batch.as_ref().unwrap().items.is_empty() {
pending_id = Some(batch_id);
continue;
}
tokio::spawn(batch.take().unwrap().spawn(executor.clone()));
}
}
struct BatchResult<K, V> {
values: tokio::sync::OnceCell<Option<HashMap<K, V>>>,
token: tokio_util::sync::CancellationToken,
}
impl<K, V> BatchResult<K, V> {
fn new() -> Self {
Self {
values: tokio::sync::OnceCell::new(),
token: tokio_util::sync::CancellationToken::new(),
}
}
async fn wait(&self) -> Result<&HashMap<K, V>, ()> {
self.token.cancelled().await;
self.values.get().ok_or(())?.as_ref().ok_or(())
}
}
struct Batch<E>
where
E: DataLoaderFetcher + Send + Sync + 'static,
{
id: u64,
items: HashSet<E::Key>,
result: Arc<BatchResult<E::Key, E::Value>>,
_ticket: tokio::sync::OwnedSemaphorePermit,
}
impl<E> Batch<E>
where
E: DataLoaderFetcher + Send + Sync + 'static,
{
async fn new(id: u64, semaphore: Arc<tokio::sync::Semaphore>) -> Self {
Self {
id,
items: HashSet::new(),
result: Arc::new(BatchResult::new()),
_ticket: semaphore.acquire_owned().await.unwrap(),
}
}
async fn spawn(self, executor: Arc<E>) {
let _drop_guard = self.result.token.clone().drop_guard();
let result = executor.load(self.items).await;
match self.result.values.set(result) {
Ok(()) => {}
Err(_) => unreachable!(
"batch result already set, this is a bug please report it https://github.com/scufflecloud/scuffle/issues"
),
}
}
}