1use std::{
2 collections::BTreeMap,
3 fs::File,
4 future::Future,
5 io::{BufRead, BufReader, BufWriter, Read},
6 path::{Path, PathBuf},
7 sync::{atomic::AtomicU64, Arc, RwLock},
8 time::SystemTime,
9};
10
11use anyhow::Context;
12use futures::{channel::oneshot, stream::FuturesUnordered, StreamExt};
13use serde::{de::DeserializeOwned, Deserialize, Serialize};
14
15#[derive(Default)]
16struct LoaderTaskSet {
17 #[cfg(feature = "tokio")]
18 inner: tokio::task::JoinSet<anyhow::Result<()>>,
19}
20
21impl LoaderTaskSet {
22 fn spawn_task<F: FnOnce() -> anyhow::Result<()> + Send + 'static>(&mut self, f: F) {
23 #[cfg(feature = "tokio")]
24 {
25 self.inner.spawn_blocking(f);
26 }
27 }
28
29 async fn wait_all(self) -> anyhow::Result<()> {
30 #[cfg(feature = "tokio")]
31 {
32 let res = self.inner.join_all().await;
33 for res in res {
34 res?;
35 }
36 Ok(())
37 }
38 }
39}
40
41pub trait DataFormat {
43 fn read_from<T: Data, R: Read + BufRead>(rdr: R) -> anyhow::Result<T>;
45
46 fn read_from_file<T: Data>(path: &Path) -> anyhow::Result<T> {
48 let file = File::open(path).context("open file")?;
49 let rdr = BufReader::new(file);
50 Self::read_from(rdr)
51 }
52
53 fn write_to<T: Serialize, W: std::io::Write>(data: &T, wtr: W) -> anyhow::Result<()>;
55 fn write_to_file<T: Serialize>(data: &T, path: &Path) -> anyhow::Result<()> {
57 let file = File::create(path).context("create file")?;
58 Self::write_to(data, file)
59 }
60}
61
62pub struct JsonDataFormat;
64impl DataFormat for JsonDataFormat {
65 fn read_from<T: Data, R: Read + BufRead>(rdr: R) -> anyhow::Result<T> {
66 serde_json::from_reader(rdr).context("parse json")
67 }
68
69 fn write_to<T: Serialize, W: std::io::Write>(data: &T, wtr: W) -> anyhow::Result<()> {
70 serde_json::to_writer(wtr, data).context("write json")
71 }
72}
73
74pub struct BinCodeDataFormat;
76impl DataFormat for BinCodeDataFormat {
77 fn read_from<T: Data, R: Read + BufRead>(rdr: R) -> anyhow::Result<T> {
78 bincode::deserialize_from(rdr).context("parse bincode")
79 }
80
81 fn write_to<T: Serialize, W: std::io::Write>(data: &T, wtr: W) -> anyhow::Result<()> {
82 bincode::serialize_into(wtr, data).context("write bincode")
83 }
84}
85
86pub trait DataMapper: Send + 'static {
89 type In: Data;
90 type Out: Send + 'static;
91
92 fn map(self, data: Self::In) -> anyhow::Result<Self::Out>;
93}
94
95pub fn data_mapper_fn<F, In, Out>(f: F) -> MapperFn<F, In, Out> {
97 MapperFn::new(f)
98}
99
100pub struct MapperFn<F, In, Out> {
102 f: F,
103 _t: std::marker::PhantomData<(In, Out)>,
104}
105
106impl<F, In, Out> Clone for MapperFn<F, In, Out>
107where
108 F: Clone,
109{
110 fn clone(&self) -> Self {
111 Self {
112 f: self.f.clone(),
113 _t: std::marker::PhantomData,
114 }
115 }
116}
117
118impl<F, In, Out> MapperFn<F, In, Out> {
119 pub fn new(f: F) -> Self {
121 Self {
122 f,
123 _t: std::marker::PhantomData,
124 }
125 }
126}
127
128impl<F, In: Data, Out: Send + 'static> DataMapper for MapperFn<F, In, Out>
129where
130 F: FnOnce(In) -> anyhow::Result<Out> + Send + 'static,
131{
132 type In = In;
133 type Out = Out;
134
135 fn map(self, data: Self::In) -> anyhow::Result<Self::Out> {
136 (self.f)(data)
137 }
138}
139
140pub trait Data: DeserializeOwned + Send + 'static {}
143impl<T: DeserializeOwned + Send + 'static> Data for T {}
144
145pub struct DataReceiver<T>(oneshot::Receiver<T>);
146
147impl<T> DataReceiver<T> {
148 pub fn get(mut self) -> T {
150 self.0
151 .try_recv()
152 .expect("Data recv closed")
153 .expect("Data recv no value")
154 }
155}
156
157impl<T> Future for DataReceiver<T> {
158 type Output = T;
159
160 fn poll(
161 mut self: std::pin::Pin<&mut Self>,
162 cx: &mut std::task::Context,
163 ) -> std::task::Poll<Self::Output> {
164 std::pin::Pin::new(&mut self.0).poll(cx).map(Result::unwrap)
165 }
166}
167
168pub trait DataFormatHandler {
170 fn load_from_file<T: Data>(p: &Path) -> anyhow::Result<T>;
172}
173
174pub struct AutoDataFormatHandler;
176impl DataFormatHandler for AutoDataFormatHandler {
177 fn load_from_file<T: Data>(p: &Path) -> anyhow::Result<T> {
178 let ext = p.extension().context("no extension")?;
179 match ext.to_string_lossy().to_lowercase().as_str() {
180 "json" => JsonDataFormat::read_from_file(p),
181 "bincode" => BinCodeDataFormat::read_from_file(p),
182 _ => anyhow::bail!("unknown extension: {:?}", ext),
183 }
184 }
185}
186
187#[derive(Debug, Deserialize, Serialize, Clone)]
189struct DataManifestEntry {
190 last_changed: SystemTime,
191 cached_name: String,
192}
193
194#[derive(Debug, Deserialize, Serialize)]
196struct DataManifest {
197 pub entries: BTreeMap<PathBuf, DataManifestEntry>,
198 pub counter: u64,
199}
200
201struct Cache {
203 manifest: RwLock<DataManifest>,
204 dir: PathBuf,
205 counter: AtomicU64,
206}
207
208impl Cache {
209 fn load(dir: &Path) -> anyhow::Result<Self> {
211 if !dir.exists() {
212 std::fs::create_dir(dir).context("create cache dir")?;
213 }
214
215 let manifest_path = dir.join("manifest.json");
216 let manifest = if manifest_path.exists() {
217 JsonDataFormat::read_from_file(&manifest_path)?
218 } else {
219 DataManifest {
220 entries: BTreeMap::new(),
221 counter: 0,
222 }
223 };
224
225 let counter = manifest.counter;
226
227 Ok(Self {
228 manifest: RwLock::new(manifest),
229 dir: dir.to_owned(),
230 counter: counter.into(),
231 })
232 }
233
234 fn save(&self) -> anyhow::Result<()> {
236 let mut manifest = self.manifest.write().expect("write");
237 manifest.counter = self.counter.load(std::sync::atomic::Ordering::Relaxed);
238
239 let manifest_path = self.dir.join("manifest.json");
240 JsonDataFormat::write_to_file::<DataManifest>(&manifest, &manifest_path)
241 }
242
243 fn update_entry<F>(&self, path: &Path, update_cached: F) -> anyhow::Result<()>
245 where
246 F: FnOnce(&mut BufWriter<File>) -> anyhow::Result<()>,
247 {
248 let path = path.canonicalize().expect("canonicalize path");
249 let filename = path.file_name().expect("file_name").to_string_lossy();
250 let num = self
251 .counter
252 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
253 let cache_file = format!("{num}_{filename}.cache");
254
255 let file = File::create(self.dir.join(&cache_file))?;
256 let mut file = BufWriter::new(file);
257 update_cached(&mut file)?;
258
259 if let Some(old) = self.get_entry(&path) {
261 let old_cache = self.dir.join(&old.cached_name);
262 std::fs::remove_file(old_cache).context("remove old cache file")?;
263 }
264
265 let last_changed = path
266 .metadata()
267 .context("metadata")?
268 .modified()
269 .expect("modified");
270 self.manifest.write().expect("write").entries.insert(
271 path.to_owned(),
272 DataManifestEntry {
273 last_changed,
274 cached_name: cache_file,
275 },
276 );
277
278 Ok(())
279 }
280
281 fn get_or_update<T: Data + Serialize>(
283 &self,
284 path: &Path,
285 load: impl FnOnce(&Path) -> anyhow::Result<T>,
286 ) -> anyhow::Result<T> {
287 if let Some(entry) = self.get_validated_entry(path)? {
288 let cache_file = self.dir.join(&entry.cached_name);
289 let rdr = BufReader::new(File::open(cache_file).context("open cache file")?);
290 return BinCodeDataFormat::read_from(rdr);
291 }
292
293 let data = load(path)?;
294 self.update_entry(path, |file| BinCodeDataFormat::write_to(&data, file))?;
295 Ok(data)
296 }
297
298 fn get_entry(&self, path: &Path) -> Option<DataManifestEntry> {
300 let norm = path.canonicalize().expect("canonicalize path");
301 self.manifest
302 .read()
303 .expect("read")
304 .entries
305 .get(&norm)
306 .cloned()
307 }
308
309 fn get_validated_entry(&self, path: &Path) -> anyhow::Result<Option<DataManifestEntry>> {
311 Ok(match self.get_entry(path) {
312 Some(entry) => {
313 let meta = path.metadata().context("metadata")?;
314 if meta.modified().expect("modified") > entry.last_changed {
315 None
316 } else {
317 Some(entry)
318 }
319 }
320 None => None,
321 })
322 }
323}
324
325pub struct DataLoader<F> {
327 pending: LoaderTaskSet,
328 dir: PathBuf,
329 cache: Arc<Cache>,
330 _f: std::marker::PhantomData<F>,
331}
332
333impl<F: DataFormatHandler> DataLoader<F> {
334 pub fn new(dir: &Path) -> anyhow::Result<Self> {
336 anyhow::ensure!(dir.is_dir(), "Dir is not a directory: {:?}", dir);
337 let cache = Cache::load(&dir.join(".cache"))?;
338
339 Ok(Self {
340 pending: LoaderTaskSet::default(),
341 cache: Arc::new(cache),
342 dir: dir.to_owned(),
343 _f: std::marker::PhantomData,
344 })
345 }
346
347 fn spawn<Out: Send + 'static>(
349 &mut self,
350 path: &Path,
351 f: impl FnOnce(&Path) -> anyhow::Result<Out> + Send + 'static,
352 ) -> DataReceiver<Out> {
353 let path = self.dir.join(path);
354
355 let (tx, rx) = oneshot::channel();
356 let path = path.to_owned();
357 self.pending.spawn_task(move || {
358 let out = f(&path).with_context(|| format!("load {path:?}"))?;
359 let _ = tx.send(out);
360 Ok(())
361 });
362
363 DataReceiver(rx)
364 }
365
366 pub fn load_file<T: Data>(&mut self, path: impl AsRef<Path>) -> DataReceiver<T> {
368 self.spawn(path.as_ref(), |path| F::load_from_file::<T>(path))
369 }
370
371 pub fn load_map<M: DataMapper>(
373 &mut self,
374 path: impl AsRef<Path>,
375 mapper: M,
376 ) -> DataReceiver<M::Out> {
377 self.spawn::<M::Out>(path.as_ref(), move |path| {
378 let data = F::load_from_file::<M::In>(path)?;
379 let out = mapper.map(data)?;
380 Ok(out)
381 })
382 }
383
384 pub fn load_map_cached<M: DataMapper>(
387 &mut self,
388 path: impl AsRef<Path>,
389 mapper: M,
390 ) -> DataReceiver<M::Out>
391 where
392 M::Out: Serialize + Data,
393 {
394 let cache = self.cache.clone();
395 self.spawn::<M::Out>(path.as_ref(), move |path| {
396 cache.get_or_update(path, |path| {
397 let data = F::load_from_file::<M::In>(path)?;
398 let out = mapper.map(data)?;
399 Ok(out)
400 })
401 })
402 }
403
404 pub fn load<Out: Send + 'static>(
406 &mut self,
407 path: &Path,
408 f: impl FnOnce(&Path) -> anyhow::Result<Out> + Send + 'static,
409 ) -> DataReceiver<Out> {
410 self.spawn(path, f)
411 }
412
413 fn spawn_all<Out: Send + 'static, P: AsRef<Path>>(
415 &mut self,
416 paths: impl Iterator<Item = P>,
417 f: impl Fn(&Path) -> anyhow::Result<Out> + Clone + Send + 'static,
418 ) -> DataReceiver<Vec<Out>> {
419 let mut tasks: FuturesUnordered<_> = paths.map(|path| self.spawn(path.as_ref(), f.clone())).collect();
420 let (tx, rx) = oneshot::channel();
421 tokio::spawn(async move {
422 let mut res = Vec::new();
423 while let Some(data) = tasks.next().await {
424 res.push(data);
425 }
426 let _ = tx.send(res);
427 });
428
429 DataReceiver(rx)
430 }
431
432 pub fn load_all_files<T: Data, P: AsRef<Path>>(
434 &mut self,
435 paths: impl Iterator<Item = P>,
436 ) -> DataReceiver<Vec<T>> {
437 self.spawn_all(paths, |path| F::load_from_file::<T>(path))
438 }
439
440
441 pub fn load_all_mapped<M: DataMapper + Clone, P: AsRef<Path>>(
443 &mut self,
444 paths: impl Iterator<Item = P>,
445 mapper: M,
446 ) -> DataReceiver<Vec<M::Out>> {
447 self.spawn_all(paths, move |path| {
448 let data = F::load_from_file::<M::In>(path)?;
449 let out = mapper.clone().map(data)?;
450 Ok(out)
451 })
452 }
453
454 pub fn load_all<Out: Send + 'static, P: AsRef<Path>>(
456 &mut self,
457 paths: impl Iterator<Item = P>,
458 f: impl Fn(&Path) -> anyhow::Result<Out> + Clone + Send + 'static,
459 ) -> DataReceiver<Vec<Out>> {
460 self.spawn_all(paths, f)
461 }
462
463 pub async fn wait_all(self) -> anyhow::Result<()> {
465 self.pending.wait_all().await?;
466 self.cache.save()?;
467 Ok(())
468 }
469}
470
471pub type AutoDataLoader = DataLoader<AutoDataFormatHandler>;
472
473#[cfg(test)]
474mod tests {
475 use std::{sync::atomic::AtomicBool, time::Duration};
476
477 use super::*;
478
479 #[tokio::test]
480 async fn data_loader() {
481 let path = Path::new("test_1");
482 let _ = std::fs::create_dir(&path);
483
484 let a_file = path.join("a.json");
485 let is_mapped = Arc::new(AtomicBool::new(false));
486
487 const HELLO_WORLD: &str = "Hello, World...";
488 const HELLO_UNIVERSE: &str = "Hello, Universe...";
489
490 let seq = [
491 (HELLO_WORLD, true),
492 (HELLO_WORLD, false),
493 (HELLO_UNIVERSE, true),
494 (HELLO_UNIVERSE, false),
495 (HELLO_WORLD, true),
496 (HELLO_UNIVERSE, true),
497 ];
498
499 for (inp, map) in seq {
500 let json = format!("\"{inp}\"");
502 let content = std::fs::read_to_string(&a_file).unwrap();
503 if content != json {
504 std::fs::write(&a_file, json).unwrap();
505 }
506
507 is_mapped.store(false, std::sync::atomic::Ordering::Relaxed);
508 let mut loader = AutoDataLoader::new(&path).unwrap();
509 let txt = loader.load_file::<String>("a.json");
510 let mapped = loader.load_map("a.json", data_mapper_fn(|s: String| Ok(s.len())));
511 let is_mapped_ = is_mapped.clone();
512 let mapped_cached = loader.load_map_cached(
513 "a.json",
514 data_mapper_fn(move |s: String| {
515 is_mapped_.store(true, std::sync::atomic::Ordering::Relaxed);
516 Ok(s.len())
517 }),
518 );
519 loader.wait_all().await.unwrap();
520
521 assert_eq!(map, is_mapped.load(std::sync::atomic::Ordering::Relaxed));
522 assert_eq!(txt.get(), inp);
523 assert_eq!(mapped.get(), inp.len());
524 assert_eq!(mapped_cached.get(), inp.len());
525 std::thread::sleep(Duration::from_millis(1));
526 }
527
528
529 let mut loader = AutoDataLoader::new(&path).unwrap();
530 let all = loader.load_all_files::<String, _>(std::iter::repeat_n(Path::new("a.json"), 10));
531 loader.wait_all().await.unwrap();
532
533 let all = all.get();
534 assert_eq!(all.len(), 10);
535 for s in all {
536 assert_eq!(s, HELLO_UNIVERSE);
537 }
538 }
539
540}