1use super::{
9 DataFrame, DataQuery, OwnedDataRow, SharedAsyncProvider, Timeframe,
10 async_provider::AsyncDataError,
11};
12use crate::snapshot::{
13 CacheKeySnapshot, CachedDataSnapshot, DEFAULT_CHUNK_LEN, DataCacheSnapshot, LiveBufferSnapshot,
14 SnapshotStore, deserialize_dataframe, load_chunked_vec, serialize_dataframe, store_chunked_vec,
15};
16use anyhow::Result as AnyResult;
17use std::collections::HashMap;
18use std::sync::{Arc, Mutex, RwLock};
19use tokio::task::JoinHandle;
20
21#[derive(Debug, Clone, Hash, Eq, PartialEq)]
23pub struct CacheKey {
24 pub id: String,
26 pub timeframe: Timeframe,
28}
29
30impl CacheKey {
31 pub fn new(id: String, timeframe: Timeframe) -> Self {
33 Self { id, timeframe }
34 }
35}
36
37#[derive(Debug, Clone)]
39pub struct CachedData {
40 pub historical: DataFrame,
42 pub current_index: usize,
44}
45
46impl CachedData {
47 pub fn new(historical: DataFrame) -> Self {
49 Self {
50 historical,
51 current_index: 0,
52 }
53 }
54
55 pub fn row_count(&self) -> usize {
57 self.historical.row_count()
58 }
59}
60
61#[derive(Clone)]
75pub struct DataCache {
76 provider: SharedAsyncProvider,
78
79 historical: Arc<RwLock<HashMap<CacheKey, CachedData>>>,
82
83 live_buffer: Arc<RwLock<HashMap<CacheKey, Vec<OwnedDataRow>>>>,
86
87 subscriptions: Arc<Mutex<HashMap<CacheKey, JoinHandle<()>>>>,
91
92 runtime: tokio::runtime::Handle,
94}
95
96impl DataCache {
97 pub fn new(provider: SharedAsyncProvider, runtime: tokio::runtime::Handle) -> Self {
104 Self {
105 provider,
106 historical: Arc::new(RwLock::new(HashMap::new())),
107 live_buffer: Arc::new(RwLock::new(HashMap::new())),
108 subscriptions: Arc::new(Mutex::new(HashMap::new())),
109 runtime,
110 }
111 }
112
113 #[cfg(test)]
117 pub(crate) fn from_test_data(data: HashMap<CacheKey, DataFrame>) -> Self {
118 let historical: HashMap<CacheKey, CachedData> = data
119 .into_iter()
120 .map(|(k, df)| (k, CachedData::new(df)))
121 .collect();
122 let rt = tokio::runtime::Builder::new_current_thread()
123 .enable_all()
124 .build()
125 .expect("test tokio runtime");
126 Self {
127 provider: Arc::new(super::async_provider::NullAsyncProvider),
128 historical: Arc::new(RwLock::new(historical)),
129 live_buffer: Arc::new(RwLock::new(HashMap::new())),
130 subscriptions: Arc::new(Mutex::new(HashMap::new())),
131 runtime: rt.handle().clone(),
132 }
133 }
134
135 pub async fn prefetch(&self, queries: Vec<DataQuery>) -> Result<(), AsyncDataError> {
158 use futures::future::join_all;
159
160 let futures: Vec<_> = queries
162 .iter()
163 .map(|q| {
164 let provider = self.provider.clone();
165 let query = q.clone();
166 async move {
167 let df = provider.load(&query).await?;
168 Ok::<_, AsyncDataError>((query, df))
169 }
170 })
171 .collect();
172
173 let results = join_all(futures).await;
174
175 let mut historical = self.historical.write().unwrap();
177 for result in results {
178 let (query, df) = result?;
179 let key = CacheKey::new(query.id.clone(), query.timeframe);
180 historical.insert(key, CachedData::new(df));
181 }
182
183 Ok(())
184 }
185
186 pub fn get_row(&self, id: &str, timeframe: &Timeframe, index: usize) -> Option<OwnedDataRow> {
201 let key = CacheKey::new(id.to_string(), *timeframe);
202
203 let historical = self.historical.read().unwrap();
204 historical.get(&key).and_then(|cached| {
205 let hist_len = cached.row_count();
206
207 if index < hist_len {
209 if let Some(row) = cached.historical.get_row(index) {
210 return OwnedDataRow::from_data_row(&row);
211 }
212 }
213
214 if let Ok(live) = self.live_buffer.read() {
216 let live_index = index.saturating_sub(hist_len);
217 if let Some(live_rows) = live.get(&key) {
218 return live_rows.get(live_index).cloned();
219 }
220 }
221
222 None
223 })
224 }
225
226 pub fn get_row_range(
239 &self,
240 id: &str,
241 timeframe: &Timeframe,
242 start: usize,
243 end: usize,
244 ) -> Vec<OwnedDataRow> {
245 let key = CacheKey::new(id.to_string(), *timeframe);
246 let mut rows = Vec::new();
247
248 let historical = self.historical.read().unwrap();
249 if let Some(cached) = historical.get(&key) {
250 let hist_len = cached.row_count();
251
252 for i in start..end.min(hist_len) {
254 if let Some(row) = cached.historical.get_row(i) {
255 if let Some(owned) = OwnedDataRow::from_data_row(&row) {
256 rows.push(owned);
257 }
258 }
259 }
260
261 if end > hist_len {
263 if let Ok(live) = self.live_buffer.read() {
264 if let Some(live_rows) = live.get(&key) {
265 let live_start = start.saturating_sub(hist_len);
266 let live_end = end - hist_len;
267 for row in live_rows
268 .iter()
269 .skip(live_start)
270 .take(live_end.saturating_sub(live_start))
271 {
272 rows.push(row.clone());
273 }
274 }
275 }
276 }
277 }
278
279 rows
280 }
281
282 pub fn subscribe_live(&self, id: &str, timeframe: &Timeframe) -> Result<(), AsyncDataError> {
297 let key = CacheKey::new(id.to_string(), *timeframe);
298
299 {
301 let subscriptions = self.subscriptions.lock().unwrap();
302 if subscriptions.contains_key(&key) {
303 return Ok(());
304 }
305 }
306
307 let mut rx = self.provider.subscribe(id, timeframe)?;
308 let live_buffer = self.live_buffer.clone();
309 let key_clone = key.clone();
310
311 let handle = self.runtime.spawn(async move {
313 while let Some(df) = rx.recv().await {
314 if let Ok(mut buffer) = live_buffer.write() {
316 let rows = buffer.entry(key_clone.clone()).or_insert_with(Vec::new);
317
318 for i in 0..df.row_count() {
319 if let Some(row) = df.get_row(i) {
320 if let Some(owned) = OwnedDataRow::from_data_row(&row) {
321 rows.push(owned);
322 }
323 }
324 }
325 }
326 }
327 });
328
329 let mut subscriptions = self.subscriptions.lock().unwrap();
330 subscriptions.insert(key, handle);
331 Ok(())
332 }
333
334 pub fn unsubscribe_live(&self, symbol: &str, timeframe: &Timeframe) {
343 let key = CacheKey::new(symbol.to_string(), *timeframe);
344
345 let mut subscriptions = self.subscriptions.lock().unwrap();
346 if let Some(handle) = subscriptions.remove(&key) {
347 handle.abort();
348 }
349
350 let _ = self.provider.unsubscribe(symbol, timeframe);
352
353 if let Ok(mut buffer) = self.live_buffer.write() {
355 buffer.remove(&key);
356 }
357 }
358
359 pub fn row_count(&self, id: &str, timeframe: &Timeframe) -> usize {
370 let key = CacheKey::new(id.to_string(), *timeframe);
371
372 let historical = self.historical.read().unwrap();
373 let hist_count = historical.get(&key).map(|c| c.row_count()).unwrap_or(0);
374
375 let live_count = self
376 .live_buffer
377 .read()
378 .ok()
379 .and_then(|b| b.get(&key).map(|v| v.len()))
380 .unwrap_or(0);
381
382 hist_count + live_count
383 }
384
385 pub fn has_cached(&self, symbol: &str, timeframe: &Timeframe) -> bool {
396 let key = CacheKey::new(symbol.to_string(), *timeframe);
397 let historical = self.historical.read().unwrap();
398 historical.contains_key(&key)
399 }
400
401 pub fn cached_keys(&self) -> Vec<(String, Timeframe)> {
407 let historical = self.historical.read().unwrap();
408 historical
409 .keys()
410 .map(|k| (k.id.clone(), k.timeframe))
411 .collect()
412 }
413
414 pub fn clear(&self) {
418 let mut subscriptions = self.subscriptions.lock().unwrap();
420 for (_, handle) in subscriptions.drain() {
421 handle.abort();
422 }
423 drop(subscriptions);
424
425 let mut historical = self.historical.write().unwrap();
427 historical.clear();
428 drop(historical);
429
430 if let Ok(mut buffer) = self.live_buffer.write() {
432 buffer.clear();
433 }
434 }
435
436 pub fn provider(&self) -> SharedAsyncProvider {
440 self.provider.clone()
441 }
442
443 pub fn snapshot(&self, store: &SnapshotStore) -> AnyResult<DataCacheSnapshot> {
445 let historical_guard = self.historical.read().unwrap();
446 let mut historical = Vec::with_capacity(historical_guard.len());
447 for (key, cached) in historical_guard.iter() {
448 let key_snapshot = CacheKeySnapshot {
449 id: key.id.clone(),
450 timeframe: key.timeframe,
451 };
452 historical.push(CachedDataSnapshot {
453 key: key_snapshot,
454 historical: serialize_dataframe(&cached.historical, store)?,
455 current_index: cached.current_index,
456 });
457 }
458 historical.sort_by(|a, b| {
459 a.key
460 .id
461 .cmp(&b.key.id)
462 .then(a.key.timeframe.cmp(&b.key.timeframe))
463 });
464
465 let live_guard = self.live_buffer.read().unwrap();
466 let mut live_buffer = Vec::with_capacity(live_guard.len());
467 for (key, rows) in live_guard.iter() {
468 let key_snapshot = CacheKeySnapshot {
469 id: key.id.clone(),
470 timeframe: key.timeframe,
471 };
472 let rows_blob = store_chunked_vec(rows, DEFAULT_CHUNK_LEN, store)?;
473 live_buffer.push(LiveBufferSnapshot {
474 key: key_snapshot,
475 rows: rows_blob,
476 });
477 }
478 live_buffer.sort_by(|a, b| {
479 a.key
480 .id
481 .cmp(&b.key.id)
482 .then(a.key.timeframe.cmp(&b.key.timeframe))
483 });
484
485 Ok(DataCacheSnapshot {
486 historical,
487 live_buffer,
488 })
489 }
490
491 pub fn restore_from_snapshot(
493 &self,
494 snapshot: DataCacheSnapshot,
495 store: &SnapshotStore,
496 ) -> AnyResult<()> {
497 self.clear();
498
499 let mut historical_guard = self.historical.write().unwrap();
500 for entry in snapshot.historical.into_iter() {
501 let key = CacheKey::new(entry.key.id, entry.key.timeframe);
502 let df = deserialize_dataframe(entry.historical, store)?;
503 historical_guard.insert(
504 key,
505 CachedData {
506 historical: df,
507 current_index: entry.current_index,
508 },
509 );
510 }
511 drop(historical_guard);
512
513 let mut live_guard = self.live_buffer.write().unwrap();
514 for entry in snapshot.live_buffer.into_iter() {
515 let key = CacheKey::new(entry.key.id, entry.key.timeframe);
516 let rows: Vec<OwnedDataRow> = load_chunked_vec(&entry.rows, store)?;
517 live_guard.insert(key, rows);
518 }
519 Ok(())
520 }
521}
522
523impl Drop for DataCache {
524 fn drop(&mut self) {
525 let mut subscriptions = self.subscriptions.lock().unwrap();
527 for (_, handle) in subscriptions.drain() {
528 handle.abort();
529 }
530 }
531}
532
533#[cfg(test)]
534mod tests {
535 use super::*;
536 use crate::data::{DataQuery, NullAsyncProvider};
537 use crate::snapshot::SnapshotStore;
538 use std::sync::Arc;
539 use std::sync::atomic::{AtomicUsize, Ordering};
540 use std::time::{SystemTime, UNIX_EPOCH};
541
542 #[test]
546 fn test_cache_key() {
547 let key1 = CacheKey::new("AAPL".to_string(), Timeframe::d1());
548 let key2 = CacheKey::new("AAPL".to_string(), Timeframe::d1());
549 let key3 = CacheKey::new("MSFT".to_string(), Timeframe::d1());
550
551 assert_eq!(key1, key2);
552 assert_ne!(key1, key3);
553 }
554
555 #[test]
556 fn test_cached_data() {
557 let df = DataFrame::new("TEST", Timeframe::d1());
558 let cached = CachedData::new(df);
559
560 assert_eq!(cached.row_count(), 0);
561 assert_eq!(cached.current_index, 0);
562 }
563
564 #[derive(Clone)]
565 struct TestAsyncProvider {
566 frames: Arc<HashMap<CacheKey, DataFrame>>,
567 load_calls: Arc<AtomicUsize>,
568 }
569
570 impl crate::data::AsyncDataProvider for TestAsyncProvider {
571 fn load<'a>(
572 &'a self,
573 query: &'a DataQuery,
574 ) -> std::pin::Pin<
575 Box<
576 dyn std::future::Future<Output = Result<DataFrame, crate::data::AsyncDataError>>
577 + Send
578 + 'a,
579 >,
580 > {
581 let key = CacheKey::new(query.id.clone(), query.timeframe);
582 let frames = self.frames.clone();
583 let calls = self.load_calls.clone();
584 Box::pin(async move {
585 calls.fetch_add(1, Ordering::SeqCst);
586 frames
587 .get(&key)
588 .cloned()
589 .ok_or_else(|| crate::data::AsyncDataError::SymbolNotFound(query.id.clone()))
590 })
591 }
592
593 fn has_data(&self, symbol: &str, timeframe: &Timeframe) -> bool {
594 let key = CacheKey::new(symbol.to_string(), *timeframe);
595 self.frames.contains_key(&key)
596 }
597
598 fn symbols(&self) -> Vec<String> {
599 self.frames.keys().map(|k| k.id.clone()).collect()
600 }
601 }
602
603 fn temp_snapshot_root(name: &str) -> std::path::PathBuf {
604 let ts = SystemTime::now()
605 .duration_since(UNIX_EPOCH)
606 .unwrap()
607 .as_millis();
608 std::env::temp_dir().join(format!("shape_snapshot_test_{}_{}", name, ts))
609 }
610
611 fn make_df(id: &str, timeframe: Timeframe) -> DataFrame {
612 let mut df = DataFrame::new(id, timeframe);
613 df.timestamps = vec![1, 2, 3];
614 df.add_column("a", vec![10.0, 11.0, 12.0]);
615 df.add_column("b", vec![20.0, 21.0, 22.0]);
616 df
617 }
618
619 #[tokio::test]
620 async fn test_data_cache_snapshot_roundtrip_no_refetch() {
621 let tf = Timeframe::d1();
622 let df = make_df("TEST", tf);
623 let mut frames = HashMap::new();
624 frames.insert(CacheKey::new("TEST".to_string(), tf), df);
625 let load_calls = Arc::new(AtomicUsize::new(0));
626 let provider = Arc::new(TestAsyncProvider {
627 frames: Arc::new(frames),
628 load_calls: load_calls.clone(),
629 });
630
631 let cache = DataCache::new(provider, tokio::runtime::Handle::current());
632 cache
633 .prefetch(vec![DataQuery::new("TEST", tf)])
634 .await
635 .unwrap();
636
637 let key = CacheKey::new("TEST".to_string(), tf);
639 if let Some(entry) = cache.historical.write().unwrap().get_mut(&key) {
640 entry.current_index = 2;
641 }
642 cache.live_buffer.write().unwrap().insert(
643 key.clone(),
644 vec![OwnedDataRow::new_generic(
645 4,
646 HashMap::from([("a".to_string(), 13.0)]),
647 )],
648 );
649
650 let store = SnapshotStore::new(temp_snapshot_root("data_cache")).unwrap();
651 let snapshot = cache.snapshot(&store).unwrap();
652
653 let fail_provider = Arc::new(NullAsyncProvider::default());
655 let restored = DataCache::new(fail_provider, tokio::runtime::Handle::current());
656 restored.restore_from_snapshot(snapshot, &store).unwrap();
657
658 let row = restored
659 .get_row("TEST", &tf, 0)
660 .expect("row should be cached");
661 assert_eq!(row.timestamp, 1);
662 assert_eq!(row.fields.get("a"), Some(&10.0));
663
664 let live_rows = restored
665 .live_buffer
666 .read()
667 .unwrap()
668 .get(&key)
669 .cloned()
670 .unwrap_or_default();
671 assert_eq!(live_rows.len(), 1);
672 assert_eq!(live_rows[0].timestamp, 4);
673
674 let restored_index = restored
675 .historical
676 .read()
677 .unwrap()
678 .get(&key)
679 .map(|c| c.current_index)
680 .unwrap_or(0);
681 assert_eq!(restored_index, 2);
682
683 assert_eq!(load_calls.load(Ordering::SeqCst), 1);
685 }
686}