Skip to main content

shape_runtime/data/
cache.rs

1//! Data cache for prefetched historical and live streaming data
2//!
3//! This module provides a caching layer that:
4//! - Prefetches historical data asynchronously before execution
5//! - Maintains a live data buffer updated by background tasks
6//! - Provides synchronous access methods for the execution hot path
7
8use super::{
9    DataFrame, DataQuery, OwnedDataRow, SharedAsyncProvider, Timeframe,
10    async_provider::AsyncDataError,
11};
12use crate::snapshot::{
13    CacheKeySnapshot, CachedDataSnapshot, DEFAULT_CHUNK_LEN, DataCacheSnapshot, LiveBufferSnapshot,
14    SnapshotStore, load_chunked_vec, store_chunked_vec,
15};
16use anyhow::Result as AnyResult;
17use std::collections::HashMap;
18use std::sync::{Arc, Mutex, RwLock};
19use tokio::task::JoinHandle;
20
21/// Cache key for data lookups
22#[derive(Debug, Clone, Hash, Eq, PartialEq)]
23pub struct CacheKey {
24    /// Identifier (symbol, sensor ID, etc)
25    pub id: String,
26    /// Timeframe
27    pub timeframe: Timeframe,
28}
29
30impl CacheKey {
31    /// Create a new cache key
32    pub fn new(id: String, timeframe: Timeframe) -> Self {
33        Self { id, timeframe }
34    }
35}
36
37/// Cached data for a symbol/timeframe combination
38#[derive(Debug, Clone)]
39pub struct CachedData {
40    /// Historical data (immutable after prefetch)
41    pub historical: DataFrame,
42    /// Current row index for iteration (used by ExecutionContext)
43    pub current_index: usize,
44}
45
46impl CachedData {
47    /// Create new cached data
48    pub fn new(historical: DataFrame) -> Self {
49        Self {
50            historical,
51            current_index: 0,
52        }
53    }
54
55    /// Get total number of historical rows
56    pub fn row_count(&self) -> usize {
57        self.historical.row_count()
58    }
59}
60
61/// Data cache with prefetched historical and live data buffers
62///
63/// The cache provides:
64/// - Async prefetch of historical data (called before execution)
65/// - Sync access to cached data (during execution)
66/// - Live data streaming via background tasks
67///
68/// # Thread Safety
69///
70/// The live buffer uses `Arc<RwLock<...>>` to allow concurrent reads during
71/// execution while background tasks write new bars.
72///
73/// Subscriptions use `Arc<Mutex<...>>` to allow shared ownership across clones.
74#[derive(Clone)]
75pub struct DataCache {
76    /// Async data provider
77    provider: SharedAsyncProvider,
78
79    /// Prefetched historical data (populated by prefetch())
80    /// Wrapped in Arc for cheap cloning
81    historical: Arc<RwLock<HashMap<CacheKey, CachedData>>>,
82
83    /// Live data buffer (updated by background tasks)
84    /// RwLock allows many readers during execution, one writer in bg task
85    live_buffer: Arc<RwLock<HashMap<CacheKey, Vec<OwnedDataRow>>>>,
86
87    /// Active subscription handles
88    /// Tracks background tasks so we can cancel them
89    /// Mutex allows mutation through shared reference
90    subscriptions: Arc<Mutex<HashMap<CacheKey, JoinHandle<()>>>>,
91
92    /// Tokio runtime handle for spawning tasks
93    runtime: tokio::runtime::Handle,
94}
95
96impl DataCache {
97    /// Create a new data cache
98    ///
99    /// # Arguments
100    ///
101    /// * `provider` - Async data provider for loading data
102    /// * `runtime` - Tokio runtime handle for spawning background tasks
103    pub fn new(provider: SharedAsyncProvider, runtime: tokio::runtime::Handle) -> Self {
104        Self {
105            provider,
106            historical: Arc::new(RwLock::new(HashMap::new())),
107            live_buffer: Arc::new(RwLock::new(HashMap::new())),
108            subscriptions: Arc::new(Mutex::new(HashMap::new())),
109            runtime,
110        }
111    }
112
113    /// Create a DataCache pre-loaded with historical data (for tests).
114    ///
115    /// Uses a NullAsyncProvider and a temporary tokio runtime.
116    #[cfg(test)]
117    pub(crate) fn from_test_data(data: HashMap<CacheKey, DataFrame>) -> Self {
118        let historical: HashMap<CacheKey, CachedData> = data
119            .into_iter()
120            .map(|(k, df)| (k, CachedData::new(df)))
121            .collect();
122        let rt = tokio::runtime::Builder::new_current_thread()
123            .enable_all()
124            .build()
125            .expect("test tokio runtime");
126        Self {
127            provider: Arc::new(super::async_provider::NullAsyncProvider),
128            historical: Arc::new(RwLock::new(historical)),
129            live_buffer: Arc::new(RwLock::new(HashMap::new())),
130            subscriptions: Arc::new(Mutex::new(HashMap::new())),
131            runtime: rt.handle().clone(),
132        }
133    }
134
135    /// Prefetch historical data for given queries (async)
136    ///
137    /// This loads all queries concurrently and populates the cache.
138    /// Should be called before execution starts.
139    ///
140    /// # Arguments
141    ///
142    /// * `queries` - List of data queries to prefetch
143    ///
144    /// # Returns
145    ///
146    /// Ok if all queries loaded successfully, error otherwise.
147    ///
148    /// # Example
149    ///
150    /// ```ignore
151    /// let queries = vec![
152    ///     DataQuery::new("AAPL", Timeframe::d1()).limit(1000),
153    ///     DataQuery::new("MSFT", Timeframe::d1()).limit(1000),
154    /// ];
155    /// cache.prefetch(queries).await?;
156    /// ```
157    pub async fn prefetch(&self, queries: Vec<DataQuery>) -> Result<(), AsyncDataError> {
158        use futures::future::join_all;
159
160        // Load all queries concurrently
161        let futures: Vec<_> = queries
162            .iter()
163            .map(|q| {
164                let provider = self.provider.clone();
165                let query = q.clone();
166                async move {
167                    let df = provider.load(&query).await?;
168                    Ok::<_, AsyncDataError>((query, df))
169                }
170            })
171            .collect();
172
173        let results = join_all(futures).await;
174
175        // Process results and populate cache
176        let mut historical = self.historical.write().unwrap();
177        for result in results {
178            let (query, df) = result?;
179            let key = CacheKey::new(query.id.clone(), query.timeframe);
180            historical.insert(key, CachedData::new(df));
181        }
182
183        Ok(())
184    }
185
186    /// Get row at index (sync - reads from cache)
187    ///
188    /// This is the hot path - called frequently during execution.
189    /// Reads are lock-free for historical data, read-locked for live data.
190    ///
191    /// # Arguments
192    ///
193    /// * `symbol` - Symbol to query
194    /// * `timeframe` - Timeframe
195    /// * `index` - Absolute row index
196    ///
197    /// # Returns
198    ///
199    /// The row if available, None otherwise.
200    pub fn get_row(&self, id: &str, timeframe: &Timeframe, index: usize) -> Option<OwnedDataRow> {
201        let key = CacheKey::new(id.to_string(), *timeframe);
202
203        let historical = self.historical.read().unwrap();
204        historical.get(&key).and_then(|cached| {
205            let hist_len = cached.row_count();
206
207            // First try historical data (no lock needed)
208            if index < hist_len {
209                if let Some(row) = cached.historical.get_row(index) {
210                    return OwnedDataRow::from_data_row(&row);
211                }
212            }
213
214            // Then check live buffer for newer data
215            if let Ok(live) = self.live_buffer.read() {
216                let live_index = index.saturating_sub(hist_len);
217                if let Some(live_rows) = live.get(&key) {
218                    return live_rows.get(live_index).cloned();
219                }
220            }
221
222            None
223        })
224    }
225
226    /// Get row range (sync - reads from cache)
227    ///
228    /// # Arguments
229    ///
230    /// * `symbol` - Symbol to query
231    /// * `timeframe` - Timeframe
232    /// * `start` - Start index (inclusive)
233    /// * `end` - End index (exclusive)
234    ///
235    /// # Returns
236    ///
237    /// Vector of rows in the range. May be shorter than requested if data unavailable.
238    pub fn get_row_range(
239        &self,
240        id: &str,
241        timeframe: &Timeframe,
242        start: usize,
243        end: usize,
244    ) -> Vec<OwnedDataRow> {
245        let key = CacheKey::new(id.to_string(), *timeframe);
246        let mut rows = Vec::new();
247
248        let historical = self.historical.read().unwrap();
249        if let Some(cached) = historical.get(&key) {
250            let hist_len = cached.row_count();
251
252            // Get rows from historical data
253            for i in start..end.min(hist_len) {
254                if let Some(row) = cached.historical.get_row(i) {
255                    if let Some(owned) = OwnedDataRow::from_data_row(&row) {
256                        rows.push(owned);
257                    }
258                }
259            }
260
261            // Get rows from live buffer if needed
262            if end > hist_len {
263                if let Ok(live) = self.live_buffer.read() {
264                    if let Some(live_rows) = live.get(&key) {
265                        let live_start = start.saturating_sub(hist_len);
266                        let live_end = end - hist_len;
267                        for row in live_rows
268                            .iter()
269                            .skip(live_start)
270                            .take(live_end.saturating_sub(live_start))
271                        {
272                            rows.push(row.clone());
273                        }
274                    }
275                }
276            }
277        }
278
279        rows
280    }
281
282    /// Start live data subscription (spawns background task)
283    ///
284    /// This subscribes to live bar updates and spawns a background task
285    /// that appends new bars to the live buffer as they arrive.
286    ///
287    /// # Arguments
288    ///
289    /// * `symbol` - Symbol to subscribe to
290    /// * `timeframe` - Timeframe for bars
291    ///
292    /// # Returns
293    ///
294    /// Ok if subscription started, error otherwise.
295    /// Returns Ok without action if already subscribed.
296    pub fn subscribe_live(&self, id: &str, timeframe: &Timeframe) -> Result<(), AsyncDataError> {
297        let key = CacheKey::new(id.to_string(), *timeframe);
298
299        // Don't subscribe twice
300        {
301            let subscriptions = self.subscriptions.lock().unwrap();
302            if subscriptions.contains_key(&key) {
303                return Ok(());
304            }
305        }
306
307        let mut rx = self.provider.subscribe(id, timeframe)?;
308        let live_buffer = self.live_buffer.clone();
309        let key_clone = key.clone();
310
311        // Spawn background task to receive bars
312        let handle = self.runtime.spawn(async move {
313            while let Some(df) = rx.recv().await {
314                // Convert DataFrame rows to OwnedDataRow and append to buffer
315                if let Ok(mut buffer) = live_buffer.write() {
316                    let rows = buffer.entry(key_clone.clone()).or_insert_with(Vec::new);
317
318                    for i in 0..df.row_count() {
319                        if let Some(row) = df.get_row(i) {
320                            if let Some(owned) = OwnedDataRow::from_data_row(&row) {
321                                rows.push(owned);
322                            }
323                        }
324                    }
325                }
326            }
327        });
328
329        let mut subscriptions = self.subscriptions.lock().unwrap();
330        subscriptions.insert(key, handle);
331        Ok(())
332    }
333
334    /// Stop live data subscription
335    ///
336    /// Cancels the background task and unsubscribes from the provider.
337    ///
338    /// # Arguments
339    ///
340    /// * `symbol` - Symbol to unsubscribe from
341    /// * `timeframe` - Timeframe
342    pub fn unsubscribe_live(&self, symbol: &str, timeframe: &Timeframe) {
343        let key = CacheKey::new(symbol.to_string(), *timeframe);
344
345        let mut subscriptions = self.subscriptions.lock().unwrap();
346        if let Some(handle) = subscriptions.remove(&key) {
347            handle.abort();
348        }
349
350        // Also tell the provider (best effort)
351        let _ = self.provider.unsubscribe(symbol, timeframe);
352
353        // Clear live buffer for this key
354        if let Ok(mut buffer) = self.live_buffer.write() {
355            buffer.remove(&key);
356        }
357    }
358
359    /// Get total row count (historical + live)
360    ///
361    /// # Arguments
362    ///
363    /// * `symbol` - Symbol to query
364    /// * `timeframe` - Timeframe
365    ///
366    /// # Returns
367    ///
368    /// Total number of rows available (historical + live).
369    pub fn row_count(&self, id: &str, timeframe: &Timeframe) -> usize {
370        let key = CacheKey::new(id.to_string(), *timeframe);
371
372        let historical = self.historical.read().unwrap();
373        let hist_count = historical.get(&key).map(|c| c.row_count()).unwrap_or(0);
374
375        let live_count = self
376            .live_buffer
377            .read()
378            .ok()
379            .and_then(|b| b.get(&key).map(|v| v.len()))
380            .unwrap_or(0);
381
382        hist_count + live_count
383    }
384
385    /// Check if data is cached
386    ///
387    /// # Arguments
388    ///
389    /// * `symbol` - Symbol to check
390    /// * `timeframe` - Timeframe to check
391    ///
392    /// # Returns
393    ///
394    /// true if historical data is cached for this key.
395    pub fn has_cached(&self, symbol: &str, timeframe: &Timeframe) -> bool {
396        let key = CacheKey::new(symbol.to_string(), *timeframe);
397        let historical = self.historical.read().unwrap();
398        historical.contains_key(&key)
399    }
400
401    /// Get list of cached symbols
402    ///
403    /// # Returns
404    ///
405    /// Vector of (symbol, timeframe) pairs that are cached.
406    pub fn cached_keys(&self) -> Vec<(String, Timeframe)> {
407        let historical = self.historical.read().unwrap();
408        historical
409            .keys()
410            .map(|k| (k.id.clone(), k.timeframe))
411            .collect()
412    }
413
414    /// Clear all cached data
415    ///
416    /// Stops all subscriptions and clears all cached data.
417    pub fn clear(&self) {
418        // Abort all background tasks
419        let mut subscriptions = self.subscriptions.lock().unwrap();
420        for (_, handle) in subscriptions.drain() {
421            handle.abort();
422        }
423        drop(subscriptions);
424
425        // Clear historical cache
426        let mut historical = self.historical.write().unwrap();
427        historical.clear();
428        drop(historical);
429
430        // Clear live buffer
431        if let Ok(mut buffer) = self.live_buffer.write() {
432            buffer.clear();
433        }
434    }
435
436    /// Get the async provider
437    ///
438    /// Returns a clone of the SharedAsyncProvider for use in other components.
439    pub fn provider(&self) -> SharedAsyncProvider {
440        self.provider.clone()
441    }
442
443    /// Create a snapshot of the data cache (historical + live buffers).
444    ///
445    /// **W17-snapshot-resume surface — see ADR-006 §2.7.4 + §2.7.5.1.**
446    /// The DataFrame (de)serializers were deleted alongside the broader
447    /// nanboxed-slot snapshot helpers. The kind-threaded replacement
448    /// lands in the Phase 2c snapshot rebuild session; until then, this
449    /// method returns a structured `anyhow!` error rather than panicking
450    /// via `todo!()` (the strict improvement over a `todo!()`-driven
451    /// process abort).
452    pub fn snapshot(&self, store: &SnapshotStore) -> AnyResult<DataCacheSnapshot> {
453        let _ = (
454            store,
455            &self.historical,
456            &self.live_buffer,
457            DEFAULT_CHUNK_LEN,
458        );
459        let _: Option<CacheKeySnapshot> = None;
460        let _: Option<CachedDataSnapshot> = None;
461        let _: Option<LiveBufferSnapshot> = None;
462        let _ = store_chunked_vec::<u8>;
463        anyhow::bail!(
464            "DataCache::snapshot: W17-snapshot-resume surface — \
465             DataFrame / cached-row (de)serializers were deleted alongside \
466             the kind-threaded `slot_to_serializable` rebuild. The kinded \
467             replacement uses `store_chunked_vec` over the parallel \
468             (bits, NativeKind) per-row track. Tracked as \
469             W17-snapshot-resume per docs/cluster-audits/phase-2d-playbook.md §3. \
470             ADR-006 §2.7.4 (snapshot serialization deferral) + §2.7.5.1 \
471             (post-proof wire-format shape for new HeapKinds).",
472        );
473    }
474
475    /// Restore data cache contents from a snapshot.
476    ///
477    /// See [`Self::snapshot`] — W17-snapshot-resume surface.
478    pub fn restore_from_snapshot(
479        &self,
480        _snapshot: DataCacheSnapshot,
481        _store: &SnapshotStore,
482    ) -> AnyResult<()> {
483        let _ = load_chunked_vec::<OwnedDataRow>;
484        anyhow::bail!(
485            "DataCache::restore_from_snapshot: W17-snapshot-resume \
486             surface — symmetric to `snapshot()`. The kinded \
487             `serializable_to_slot(sv, expected_kind, store)` inverse \
488             reconstructs row-storage parallel kind tracks from the \
489             persisted discriminator. Tracked as W17-snapshot-resume per \
490             docs/cluster-audits/phase-2d-playbook.md §3. ADR-006 \
491             §2.7.4 + §2.7.5.1.",
492        );
493    }
494}
495
496impl Drop for DataCache {
497    fn drop(&mut self) {
498        // Clean shutdown: abort all background tasks
499        let mut subscriptions = self.subscriptions.lock().unwrap();
500        for (_, handle) in subscriptions.drain() {
501            handle.abort();
502        }
503    }
504}
505
506#[cfg(test)]
507mod tests {
508    use super::*;
509    use crate::data::{DataQuery, NullAsyncProvider};
510    use crate::snapshot::SnapshotStore;
511    use std::sync::Arc;
512    use std::sync::atomic::{AtomicUsize, Ordering};
513    use std::time::{SystemTime, UNIX_EPOCH};
514
515    // Note: Full tests require a mock provider and tokio runtime
516    // These tests verify basic structure
517
518    #[test]
519    fn test_cache_key() {
520        let key1 = CacheKey::new("AAPL".to_string(), Timeframe::d1());
521        let key2 = CacheKey::new("AAPL".to_string(), Timeframe::d1());
522        let key3 = CacheKey::new("MSFT".to_string(), Timeframe::d1());
523
524        assert_eq!(key1, key2);
525        assert_ne!(key1, key3);
526    }
527
528    #[test]
529    fn test_cached_data() {
530        let df = DataFrame::new("TEST", Timeframe::d1());
531        let cached = CachedData::new(df);
532
533        assert_eq!(cached.row_count(), 0);
534        assert_eq!(cached.current_index, 0);
535    }
536
537    #[derive(Clone)]
538    struct TestAsyncProvider {
539        frames: Arc<HashMap<CacheKey, DataFrame>>,
540        load_calls: Arc<AtomicUsize>,
541    }
542
543    impl crate::data::AsyncDataProvider for TestAsyncProvider {
544        fn load<'a>(
545            &'a self,
546            query: &'a DataQuery,
547        ) -> std::pin::Pin<
548            Box<
549                dyn std::future::Future<Output = Result<DataFrame, crate::data::AsyncDataError>>
550                    + Send
551                    + 'a,
552            >,
553        > {
554            let key = CacheKey::new(query.id.clone(), query.timeframe);
555            let frames = self.frames.clone();
556            let calls = self.load_calls.clone();
557            Box::pin(async move {
558                calls.fetch_add(1, Ordering::SeqCst);
559                frames
560                    .get(&key)
561                    .cloned()
562                    .ok_or_else(|| crate::data::AsyncDataError::SymbolNotFound(query.id.clone()))
563            })
564        }
565
566        fn has_data(&self, symbol: &str, timeframe: &Timeframe) -> bool {
567            let key = CacheKey::new(symbol.to_string(), *timeframe);
568            self.frames.contains_key(&key)
569        }
570
571        fn symbols(&self) -> Vec<String> {
572            self.frames.keys().map(|k| k.id.clone()).collect()
573        }
574    }
575
576    // `test_data_cache_snapshot_roundtrip_no_refetch` deleted — see
577    // `DataCache::snapshot` doc comment. Phase 2c rebuilds the snapshot
578    // helpers and the test returns alongside.
579    #[allow(dead_code)]
580    fn _unused_test_imports(
581        _provider: TestAsyncProvider,
582        _df: DataFrame,
583        _query: DataQuery,
584        _kind: NullAsyncProvider,
585        _store: SnapshotStore,
586        _arc: Arc<()>,
587        _atomic: AtomicUsize,
588        _ordering: Ordering,
589    ) {
590        let _ = (SystemTime::UNIX_EPOCH, UNIX_EPOCH);
591    }
592
593    /// W17-snapshot-resume gate: `DataCache::snapshot` and
594    /// `DataCache::restore_from_snapshot` both return a structured
595    /// `anyhow::Error` carrying the W17 surface marker, never a
596    /// `todo!()` panic that would abort the host process.
597    #[test]
598    fn test_w17_data_cache_snapshot_returns_structured_error() {
599        let tmp = tempfile::tempdir().expect("tempdir");
600        let store = SnapshotStore::new(tmp.path()).expect("snapshot store");
601        let cache = DataCache::from_test_data(HashMap::new());
602
603        let result = cache.snapshot(&store);
604        let err = result.expect_err("expected Err, got Ok");
605        let msg = format!("{err}");
606        assert!(
607            msg.contains("W17-snapshot-resume surface"),
608            "missing W17 marker; got: {msg}"
609        );
610        assert!(
611            msg.contains("§2.7.4"),
612            "missing ADR-006 §2.7.4 cite; got: {msg}"
613        );
614    }
615}