Skip to main content

shape_runtime/data/
cache.rs

1//! Data cache for prefetched historical and live streaming data
2//!
3//! This module provides a caching layer that:
4//! - Prefetches historical data asynchronously before execution
5//! - Maintains a live data buffer updated by background tasks
6//! - Provides synchronous access methods for the execution hot path
7
8use super::{
9    DataFrame, DataQuery, OwnedDataRow, SharedAsyncProvider, Timeframe,
10    async_provider::AsyncDataError,
11};
12use crate::snapshot::{
13    CacheKeySnapshot, CachedDataSnapshot, DEFAULT_CHUNK_LEN, DataCacheSnapshot, LiveBufferSnapshot,
14    SnapshotStore, deserialize_dataframe, load_chunked_vec, serialize_dataframe, store_chunked_vec,
15};
16use anyhow::Result as AnyResult;
17use std::collections::HashMap;
18use std::sync::{Arc, Mutex, RwLock};
19use tokio::task::JoinHandle;
20
21/// Cache key for data lookups
22#[derive(Debug, Clone, Hash, Eq, PartialEq)]
23pub struct CacheKey {
24    /// Identifier (symbol, sensor ID, etc)
25    pub id: String,
26    /// Timeframe
27    pub timeframe: Timeframe,
28}
29
30impl CacheKey {
31    /// Create a new cache key
32    pub fn new(id: String, timeframe: Timeframe) -> Self {
33        Self { id, timeframe }
34    }
35}
36
37/// Cached data for a symbol/timeframe combination
38#[derive(Debug, Clone)]
39pub struct CachedData {
40    /// Historical data (immutable after prefetch)
41    pub historical: DataFrame,
42    /// Current row index for iteration (used by ExecutionContext)
43    pub current_index: usize,
44}
45
46impl CachedData {
47    /// Create new cached data
48    pub fn new(historical: DataFrame) -> Self {
49        Self {
50            historical,
51            current_index: 0,
52        }
53    }
54
55    /// Get total number of historical rows
56    pub fn row_count(&self) -> usize {
57        self.historical.row_count()
58    }
59}
60
61/// Data cache with prefetched historical and live data buffers
62///
63/// The cache provides:
64/// - Async prefetch of historical data (called before execution)
65/// - Sync access to cached data (during execution)
66/// - Live data streaming via background tasks
67///
68/// # Thread Safety
69///
70/// The live buffer uses `Arc<RwLock<...>>` to allow concurrent reads during
71/// execution while background tasks write new bars.
72///
73/// Subscriptions use `Arc<Mutex<...>>` to allow shared ownership across clones.
74#[derive(Clone)]
75pub struct DataCache {
76    /// Async data provider
77    provider: SharedAsyncProvider,
78
79    /// Prefetched historical data (populated by prefetch())
80    /// Wrapped in Arc for cheap cloning
81    historical: Arc<RwLock<HashMap<CacheKey, CachedData>>>,
82
83    /// Live data buffer (updated by background tasks)
84    /// RwLock allows many readers during execution, one writer in bg task
85    live_buffer: Arc<RwLock<HashMap<CacheKey, Vec<OwnedDataRow>>>>,
86
87    /// Active subscription handles
88    /// Tracks background tasks so we can cancel them
89    /// Mutex allows mutation through shared reference
90    subscriptions: Arc<Mutex<HashMap<CacheKey, JoinHandle<()>>>>,
91
92    /// Tokio runtime handle for spawning tasks
93    runtime: tokio::runtime::Handle,
94}
95
96impl DataCache {
97    /// Create a new data cache
98    ///
99    /// # Arguments
100    ///
101    /// * `provider` - Async data provider for loading data
102    /// * `runtime` - Tokio runtime handle for spawning background tasks
103    pub fn new(provider: SharedAsyncProvider, runtime: tokio::runtime::Handle) -> Self {
104        Self {
105            provider,
106            historical: Arc::new(RwLock::new(HashMap::new())),
107            live_buffer: Arc::new(RwLock::new(HashMap::new())),
108            subscriptions: Arc::new(Mutex::new(HashMap::new())),
109            runtime,
110        }
111    }
112
113    /// Create a DataCache pre-loaded with historical data (for tests).
114    ///
115    /// Uses a NullAsyncProvider and a temporary tokio runtime.
116    #[cfg(test)]
117    pub(crate) fn from_test_data(data: HashMap<CacheKey, DataFrame>) -> Self {
118        let historical: HashMap<CacheKey, CachedData> = data
119            .into_iter()
120            .map(|(k, df)| (k, CachedData::new(df)))
121            .collect();
122        let rt = tokio::runtime::Builder::new_current_thread()
123            .enable_all()
124            .build()
125            .expect("test tokio runtime");
126        Self {
127            provider: Arc::new(super::async_provider::NullAsyncProvider),
128            historical: Arc::new(RwLock::new(historical)),
129            live_buffer: Arc::new(RwLock::new(HashMap::new())),
130            subscriptions: Arc::new(Mutex::new(HashMap::new())),
131            runtime: rt.handle().clone(),
132        }
133    }
134
135    /// Prefetch historical data for given queries (async)
136    ///
137    /// This loads all queries concurrently and populates the cache.
138    /// Should be called before execution starts.
139    ///
140    /// # Arguments
141    ///
142    /// * `queries` - List of data queries to prefetch
143    ///
144    /// # Returns
145    ///
146    /// Ok if all queries loaded successfully, error otherwise.
147    ///
148    /// # Example
149    ///
150    /// ```ignore
151    /// let queries = vec![
152    ///     DataQuery::new("AAPL", Timeframe::d1()).limit(1000),
153    ///     DataQuery::new("MSFT", Timeframe::d1()).limit(1000),
154    /// ];
155    /// cache.prefetch(queries).await?;
156    /// ```
157    pub async fn prefetch(&self, queries: Vec<DataQuery>) -> Result<(), AsyncDataError> {
158        use futures::future::join_all;
159
160        // Load all queries concurrently
161        let futures: Vec<_> = queries
162            .iter()
163            .map(|q| {
164                let provider = self.provider.clone();
165                let query = q.clone();
166                async move {
167                    let df = provider.load(&query).await?;
168                    Ok::<_, AsyncDataError>((query, df))
169                }
170            })
171            .collect();
172
173        let results = join_all(futures).await;
174
175        // Process results and populate cache
176        let mut historical = self.historical.write().unwrap();
177        for result in results {
178            let (query, df) = result?;
179            let key = CacheKey::new(query.id.clone(), query.timeframe);
180            historical.insert(key, CachedData::new(df));
181        }
182
183        Ok(())
184    }
185
186    /// Get row at index (sync - reads from cache)
187    ///
188    /// This is the hot path - called frequently during execution.
189    /// Reads are lock-free for historical data, read-locked for live data.
190    ///
191    /// # Arguments
192    ///
193    /// * `symbol` - Symbol to query
194    /// * `timeframe` - Timeframe
195    /// * `index` - Absolute row index
196    ///
197    /// # Returns
198    ///
199    /// The row if available, None otherwise.
200    pub fn get_row(&self, id: &str, timeframe: &Timeframe, index: usize) -> Option<OwnedDataRow> {
201        let key = CacheKey::new(id.to_string(), *timeframe);
202
203        let historical = self.historical.read().unwrap();
204        historical.get(&key).and_then(|cached| {
205            let hist_len = cached.row_count();
206
207            // First try historical data (no lock needed)
208            if index < hist_len {
209                if let Some(row) = cached.historical.get_row(index) {
210                    return OwnedDataRow::from_data_row(&row);
211                }
212            }
213
214            // Then check live buffer for newer data
215            if let Ok(live) = self.live_buffer.read() {
216                let live_index = index.saturating_sub(hist_len);
217                if let Some(live_rows) = live.get(&key) {
218                    return live_rows.get(live_index).cloned();
219                }
220            }
221
222            None
223        })
224    }
225
226    /// Get row range (sync - reads from cache)
227    ///
228    /// # Arguments
229    ///
230    /// * `symbol` - Symbol to query
231    /// * `timeframe` - Timeframe
232    /// * `start` - Start index (inclusive)
233    /// * `end` - End index (exclusive)
234    ///
235    /// # Returns
236    ///
237    /// Vector of rows in the range. May be shorter than requested if data unavailable.
238    pub fn get_row_range(
239        &self,
240        id: &str,
241        timeframe: &Timeframe,
242        start: usize,
243        end: usize,
244    ) -> Vec<OwnedDataRow> {
245        let key = CacheKey::new(id.to_string(), *timeframe);
246        let mut rows = Vec::new();
247
248        let historical = self.historical.read().unwrap();
249        if let Some(cached) = historical.get(&key) {
250            let hist_len = cached.row_count();
251
252            // Get rows from historical data
253            for i in start..end.min(hist_len) {
254                if let Some(row) = cached.historical.get_row(i) {
255                    if let Some(owned) = OwnedDataRow::from_data_row(&row) {
256                        rows.push(owned);
257                    }
258                }
259            }
260
261            // Get rows from live buffer if needed
262            if end > hist_len {
263                if let Ok(live) = self.live_buffer.read() {
264                    if let Some(live_rows) = live.get(&key) {
265                        let live_start = start.saturating_sub(hist_len);
266                        let live_end = end - hist_len;
267                        for row in live_rows
268                            .iter()
269                            .skip(live_start)
270                            .take(live_end.saturating_sub(live_start))
271                        {
272                            rows.push(row.clone());
273                        }
274                    }
275                }
276            }
277        }
278
279        rows
280    }
281
282    /// Start live data subscription (spawns background task)
283    ///
284    /// This subscribes to live bar updates and spawns a background task
285    /// that appends new bars to the live buffer as they arrive.
286    ///
287    /// # Arguments
288    ///
289    /// * `symbol` - Symbol to subscribe to
290    /// * `timeframe` - Timeframe for bars
291    ///
292    /// # Returns
293    ///
294    /// Ok if subscription started, error otherwise.
295    /// Returns Ok without action if already subscribed.
296    pub fn subscribe_live(&self, id: &str, timeframe: &Timeframe) -> Result<(), AsyncDataError> {
297        let key = CacheKey::new(id.to_string(), *timeframe);
298
299        // Don't subscribe twice
300        {
301            let subscriptions = self.subscriptions.lock().unwrap();
302            if subscriptions.contains_key(&key) {
303                return Ok(());
304            }
305        }
306
307        let mut rx = self.provider.subscribe(id, timeframe)?;
308        let live_buffer = self.live_buffer.clone();
309        let key_clone = key.clone();
310
311        // Spawn background task to receive bars
312        let handle = self.runtime.spawn(async move {
313            while let Some(df) = rx.recv().await {
314                // Convert DataFrame rows to OwnedDataRow and append to buffer
315                if let Ok(mut buffer) = live_buffer.write() {
316                    let rows = buffer.entry(key_clone.clone()).or_insert_with(Vec::new);
317
318                    for i in 0..df.row_count() {
319                        if let Some(row) = df.get_row(i) {
320                            if let Some(owned) = OwnedDataRow::from_data_row(&row) {
321                                rows.push(owned);
322                            }
323                        }
324                    }
325                }
326            }
327        });
328
329        let mut subscriptions = self.subscriptions.lock().unwrap();
330        subscriptions.insert(key, handle);
331        Ok(())
332    }
333
334    /// Stop live data subscription
335    ///
336    /// Cancels the background task and unsubscribes from the provider.
337    ///
338    /// # Arguments
339    ///
340    /// * `symbol` - Symbol to unsubscribe from
341    /// * `timeframe` - Timeframe
342    pub fn unsubscribe_live(&self, symbol: &str, timeframe: &Timeframe) {
343        let key = CacheKey::new(symbol.to_string(), *timeframe);
344
345        let mut subscriptions = self.subscriptions.lock().unwrap();
346        if let Some(handle) = subscriptions.remove(&key) {
347            handle.abort();
348        }
349
350        // Also tell the provider (best effort)
351        let _ = self.provider.unsubscribe(symbol, timeframe);
352
353        // Clear live buffer for this key
354        if let Ok(mut buffer) = self.live_buffer.write() {
355            buffer.remove(&key);
356        }
357    }
358
359    /// Get total row count (historical + live)
360    ///
361    /// # Arguments
362    ///
363    /// * `symbol` - Symbol to query
364    /// * `timeframe` - Timeframe
365    ///
366    /// # Returns
367    ///
368    /// Total number of rows available (historical + live).
369    pub fn row_count(&self, id: &str, timeframe: &Timeframe) -> usize {
370        let key = CacheKey::new(id.to_string(), *timeframe);
371
372        let historical = self.historical.read().unwrap();
373        let hist_count = historical.get(&key).map(|c| c.row_count()).unwrap_or(0);
374
375        let live_count = self
376            .live_buffer
377            .read()
378            .ok()
379            .and_then(|b| b.get(&key).map(|v| v.len()))
380            .unwrap_or(0);
381
382        hist_count + live_count
383    }
384
385    /// Check if data is cached
386    ///
387    /// # Arguments
388    ///
389    /// * `symbol` - Symbol to check
390    /// * `timeframe` - Timeframe to check
391    ///
392    /// # Returns
393    ///
394    /// true if historical data is cached for this key.
395    pub fn has_cached(&self, symbol: &str, timeframe: &Timeframe) -> bool {
396        let key = CacheKey::new(symbol.to_string(), *timeframe);
397        let historical = self.historical.read().unwrap();
398        historical.contains_key(&key)
399    }
400
401    /// Get list of cached symbols
402    ///
403    /// # Returns
404    ///
405    /// Vector of (symbol, timeframe) pairs that are cached.
406    pub fn cached_keys(&self) -> Vec<(String, Timeframe)> {
407        let historical = self.historical.read().unwrap();
408        historical
409            .keys()
410            .map(|k| (k.id.clone(), k.timeframe))
411            .collect()
412    }
413
414    /// Clear all cached data
415    ///
416    /// Stops all subscriptions and clears all cached data.
417    pub fn clear(&self) {
418        // Abort all background tasks
419        let mut subscriptions = self.subscriptions.lock().unwrap();
420        for (_, handle) in subscriptions.drain() {
421            handle.abort();
422        }
423        drop(subscriptions);
424
425        // Clear historical cache
426        let mut historical = self.historical.write().unwrap();
427        historical.clear();
428        drop(historical);
429
430        // Clear live buffer
431        if let Ok(mut buffer) = self.live_buffer.write() {
432            buffer.clear();
433        }
434    }
435
436    /// Get the async provider
437    ///
438    /// Returns a clone of the SharedAsyncProvider for use in other components.
439    pub fn provider(&self) -> SharedAsyncProvider {
440        self.provider.clone()
441    }
442
443    /// Create a snapshot of the data cache (historical + live buffers).
444    pub fn snapshot(&self, store: &SnapshotStore) -> AnyResult<DataCacheSnapshot> {
445        let historical_guard = self.historical.read().unwrap();
446        let mut historical = Vec::with_capacity(historical_guard.len());
447        for (key, cached) in historical_guard.iter() {
448            let key_snapshot = CacheKeySnapshot {
449                id: key.id.clone(),
450                timeframe: key.timeframe,
451            };
452            historical.push(CachedDataSnapshot {
453                key: key_snapshot,
454                historical: serialize_dataframe(&cached.historical, store)?,
455                current_index: cached.current_index,
456            });
457        }
458        historical.sort_by(|a, b| {
459            a.key
460                .id
461                .cmp(&b.key.id)
462                .then(a.key.timeframe.cmp(&b.key.timeframe))
463        });
464
465        let live_guard = self.live_buffer.read().unwrap();
466        let mut live_buffer = Vec::with_capacity(live_guard.len());
467        for (key, rows) in live_guard.iter() {
468            let key_snapshot = CacheKeySnapshot {
469                id: key.id.clone(),
470                timeframe: key.timeframe,
471            };
472            let rows_blob = store_chunked_vec(rows, DEFAULT_CHUNK_LEN, store)?;
473            live_buffer.push(LiveBufferSnapshot {
474                key: key_snapshot,
475                rows: rows_blob,
476            });
477        }
478        live_buffer.sort_by(|a, b| {
479            a.key
480                .id
481                .cmp(&b.key.id)
482                .then(a.key.timeframe.cmp(&b.key.timeframe))
483        });
484
485        Ok(DataCacheSnapshot {
486            historical,
487            live_buffer,
488        })
489    }
490
491    /// Restore data cache contents from a snapshot.
492    pub fn restore_from_snapshot(
493        &self,
494        snapshot: DataCacheSnapshot,
495        store: &SnapshotStore,
496    ) -> AnyResult<()> {
497        self.clear();
498
499        let mut historical_guard = self.historical.write().unwrap();
500        for entry in snapshot.historical.into_iter() {
501            let key = CacheKey::new(entry.key.id, entry.key.timeframe);
502            let df = deserialize_dataframe(entry.historical, store)?;
503            historical_guard.insert(
504                key,
505                CachedData {
506                    historical: df,
507                    current_index: entry.current_index,
508                },
509            );
510        }
511        drop(historical_guard);
512
513        let mut live_guard = self.live_buffer.write().unwrap();
514        for entry in snapshot.live_buffer.into_iter() {
515            let key = CacheKey::new(entry.key.id, entry.key.timeframe);
516            let rows: Vec<OwnedDataRow> = load_chunked_vec(&entry.rows, store)?;
517            live_guard.insert(key, rows);
518        }
519        Ok(())
520    }
521}
522
523impl Drop for DataCache {
524    fn drop(&mut self) {
525        // Clean shutdown: abort all background tasks
526        let mut subscriptions = self.subscriptions.lock().unwrap();
527        for (_, handle) in subscriptions.drain() {
528            handle.abort();
529        }
530    }
531}
532
533#[cfg(test)]
534mod tests {
535    use super::*;
536    use crate::data::{DataQuery, NullAsyncProvider};
537    use crate::snapshot::SnapshotStore;
538    use std::sync::Arc;
539    use std::sync::atomic::{AtomicUsize, Ordering};
540    use std::time::{SystemTime, UNIX_EPOCH};
541
542    // Note: Full tests require a mock provider and tokio runtime
543    // These tests verify basic structure
544
545    #[test]
546    fn test_cache_key() {
547        let key1 = CacheKey::new("AAPL".to_string(), Timeframe::d1());
548        let key2 = CacheKey::new("AAPL".to_string(), Timeframe::d1());
549        let key3 = CacheKey::new("MSFT".to_string(), Timeframe::d1());
550
551        assert_eq!(key1, key2);
552        assert_ne!(key1, key3);
553    }
554
555    #[test]
556    fn test_cached_data() {
557        let df = DataFrame::new("TEST", Timeframe::d1());
558        let cached = CachedData::new(df);
559
560        assert_eq!(cached.row_count(), 0);
561        assert_eq!(cached.current_index, 0);
562    }
563
564    #[derive(Clone)]
565    struct TestAsyncProvider {
566        frames: Arc<HashMap<CacheKey, DataFrame>>,
567        load_calls: Arc<AtomicUsize>,
568    }
569
570    impl crate::data::AsyncDataProvider for TestAsyncProvider {
571        fn load<'a>(
572            &'a self,
573            query: &'a DataQuery,
574        ) -> std::pin::Pin<
575            Box<
576                dyn std::future::Future<Output = Result<DataFrame, crate::data::AsyncDataError>>
577                    + Send
578                    + 'a,
579            >,
580        > {
581            let key = CacheKey::new(query.id.clone(), query.timeframe);
582            let frames = self.frames.clone();
583            let calls = self.load_calls.clone();
584            Box::pin(async move {
585                calls.fetch_add(1, Ordering::SeqCst);
586                frames
587                    .get(&key)
588                    .cloned()
589                    .ok_or_else(|| crate::data::AsyncDataError::SymbolNotFound(query.id.clone()))
590            })
591        }
592
593        fn has_data(&self, symbol: &str, timeframe: &Timeframe) -> bool {
594            let key = CacheKey::new(symbol.to_string(), *timeframe);
595            self.frames.contains_key(&key)
596        }
597
598        fn symbols(&self) -> Vec<String> {
599            self.frames.keys().map(|k| k.id.clone()).collect()
600        }
601    }
602
603    fn temp_snapshot_root(name: &str) -> std::path::PathBuf {
604        let ts = SystemTime::now()
605            .duration_since(UNIX_EPOCH)
606            .unwrap()
607            .as_millis();
608        std::env::temp_dir().join(format!("shape_snapshot_test_{}_{}", name, ts))
609    }
610
611    fn make_df(id: &str, timeframe: Timeframe) -> DataFrame {
612        let mut df = DataFrame::new(id, timeframe);
613        df.timestamps = vec![1, 2, 3];
614        df.add_column("a", vec![10.0, 11.0, 12.0]);
615        df.add_column("b", vec![20.0, 21.0, 22.0]);
616        df
617    }
618
619    #[tokio::test]
620    async fn test_data_cache_snapshot_roundtrip_no_refetch() {
621        let tf = Timeframe::d1();
622        let df = make_df("TEST", tf);
623        let mut frames = HashMap::new();
624        frames.insert(CacheKey::new("TEST".to_string(), tf), df);
625        let load_calls = Arc::new(AtomicUsize::new(0));
626        let provider = Arc::new(TestAsyncProvider {
627            frames: Arc::new(frames),
628            load_calls: load_calls.clone(),
629        });
630
631        let cache = DataCache::new(provider, tokio::runtime::Handle::current());
632        cache
633            .prefetch(vec![DataQuery::new("TEST", tf)])
634            .await
635            .unwrap();
636
637        // Inject live buffer rows and tweak current index for snapshot fidelity
638        let key = CacheKey::new("TEST".to_string(), tf);
639        if let Some(entry) = cache.historical.write().unwrap().get_mut(&key) {
640            entry.current_index = 2;
641        }
642        cache.live_buffer.write().unwrap().insert(
643            key.clone(),
644            vec![OwnedDataRow::new_generic(
645                4,
646                HashMap::from([("a".to_string(), 13.0)]),
647            )],
648        );
649
650        let store = SnapshotStore::new(temp_snapshot_root("data_cache")).unwrap();
651        let snapshot = cache.snapshot(&store).unwrap();
652
653        // Restore into a cache with a provider that always fails (proves no refetch)
654        let fail_provider = Arc::new(NullAsyncProvider::default());
655        let restored = DataCache::new(fail_provider, tokio::runtime::Handle::current());
656        restored.restore_from_snapshot(snapshot, &store).unwrap();
657
658        let row = restored
659            .get_row("TEST", &tf, 0)
660            .expect("row should be cached");
661        assert_eq!(row.timestamp, 1);
662        assert_eq!(row.fields.get("a"), Some(&10.0));
663
664        let live_rows = restored
665            .live_buffer
666            .read()
667            .unwrap()
668            .get(&key)
669            .cloned()
670            .unwrap_or_default();
671        assert_eq!(live_rows.len(), 1);
672        assert_eq!(live_rows[0].timestamp, 4);
673
674        let restored_index = restored
675            .historical
676            .read()
677            .unwrap()
678            .get(&key)
679            .map(|c| c.current_index)
680            .unwrap_or(0);
681        assert_eq!(restored_index, 2);
682
683        // Ensure we only loaded once during prefetch
684        assert_eq!(load_calls.load(Ordering::SeqCst), 1);
685    }
686}