1use super::{
9 DataFrame, DataQuery, OwnedDataRow, SharedAsyncProvider, Timeframe,
10 async_provider::AsyncDataError,
11};
12use crate::snapshot::{
13 CacheKeySnapshot, CachedDataSnapshot, DEFAULT_CHUNK_LEN, DataCacheSnapshot, LiveBufferSnapshot,
14 SnapshotStore, load_chunked_vec, store_chunked_vec,
15};
16use anyhow::Result as AnyResult;
17use std::collections::HashMap;
18use std::sync::{Arc, Mutex, RwLock};
19use tokio::task::JoinHandle;
20
21#[derive(Debug, Clone, Hash, Eq, PartialEq)]
23pub struct CacheKey {
24 pub id: String,
26 pub timeframe: Timeframe,
28}
29
30impl CacheKey {
31 pub fn new(id: String, timeframe: Timeframe) -> Self {
33 Self { id, timeframe }
34 }
35}
36
37#[derive(Debug, Clone)]
39pub struct CachedData {
40 pub historical: DataFrame,
42 pub current_index: usize,
44}
45
46impl CachedData {
47 pub fn new(historical: DataFrame) -> Self {
49 Self {
50 historical,
51 current_index: 0,
52 }
53 }
54
55 pub fn row_count(&self) -> usize {
57 self.historical.row_count()
58 }
59}
60
61#[derive(Clone)]
75pub struct DataCache {
76 provider: SharedAsyncProvider,
78
79 historical: Arc<RwLock<HashMap<CacheKey, CachedData>>>,
82
83 live_buffer: Arc<RwLock<HashMap<CacheKey, Vec<OwnedDataRow>>>>,
86
87 subscriptions: Arc<Mutex<HashMap<CacheKey, JoinHandle<()>>>>,
91
92 runtime: tokio::runtime::Handle,
94}
95
96impl DataCache {
97 pub fn new(provider: SharedAsyncProvider, runtime: tokio::runtime::Handle) -> Self {
104 Self {
105 provider,
106 historical: Arc::new(RwLock::new(HashMap::new())),
107 live_buffer: Arc::new(RwLock::new(HashMap::new())),
108 subscriptions: Arc::new(Mutex::new(HashMap::new())),
109 runtime,
110 }
111 }
112
113 #[cfg(test)]
117 pub(crate) fn from_test_data(data: HashMap<CacheKey, DataFrame>) -> Self {
118 let historical: HashMap<CacheKey, CachedData> = data
119 .into_iter()
120 .map(|(k, df)| (k, CachedData::new(df)))
121 .collect();
122 let rt = tokio::runtime::Builder::new_current_thread()
123 .enable_all()
124 .build()
125 .expect("test tokio runtime");
126 Self {
127 provider: Arc::new(super::async_provider::NullAsyncProvider),
128 historical: Arc::new(RwLock::new(historical)),
129 live_buffer: Arc::new(RwLock::new(HashMap::new())),
130 subscriptions: Arc::new(Mutex::new(HashMap::new())),
131 runtime: rt.handle().clone(),
132 }
133 }
134
135 pub async fn prefetch(&self, queries: Vec<DataQuery>) -> Result<(), AsyncDataError> {
158 use futures::future::join_all;
159
160 let futures: Vec<_> = queries
162 .iter()
163 .map(|q| {
164 let provider = self.provider.clone();
165 let query = q.clone();
166 async move {
167 let df = provider.load(&query).await?;
168 Ok::<_, AsyncDataError>((query, df))
169 }
170 })
171 .collect();
172
173 let results = join_all(futures).await;
174
175 let mut historical = self.historical.write().unwrap();
177 for result in results {
178 let (query, df) = result?;
179 let key = CacheKey::new(query.id.clone(), query.timeframe);
180 historical.insert(key, CachedData::new(df));
181 }
182
183 Ok(())
184 }
185
186 pub fn get_row(&self, id: &str, timeframe: &Timeframe, index: usize) -> Option<OwnedDataRow> {
201 let key = CacheKey::new(id.to_string(), *timeframe);
202
203 let historical = self.historical.read().unwrap();
204 historical.get(&key).and_then(|cached| {
205 let hist_len = cached.row_count();
206
207 if index < hist_len {
209 if let Some(row) = cached.historical.get_row(index) {
210 return OwnedDataRow::from_data_row(&row);
211 }
212 }
213
214 if let Ok(live) = self.live_buffer.read() {
216 let live_index = index.saturating_sub(hist_len);
217 if let Some(live_rows) = live.get(&key) {
218 return live_rows.get(live_index).cloned();
219 }
220 }
221
222 None
223 })
224 }
225
226 pub fn get_row_range(
239 &self,
240 id: &str,
241 timeframe: &Timeframe,
242 start: usize,
243 end: usize,
244 ) -> Vec<OwnedDataRow> {
245 let key = CacheKey::new(id.to_string(), *timeframe);
246 let mut rows = Vec::new();
247
248 let historical = self.historical.read().unwrap();
249 if let Some(cached) = historical.get(&key) {
250 let hist_len = cached.row_count();
251
252 for i in start..end.min(hist_len) {
254 if let Some(row) = cached.historical.get_row(i) {
255 if let Some(owned) = OwnedDataRow::from_data_row(&row) {
256 rows.push(owned);
257 }
258 }
259 }
260
261 if end > hist_len {
263 if let Ok(live) = self.live_buffer.read() {
264 if let Some(live_rows) = live.get(&key) {
265 let live_start = start.saturating_sub(hist_len);
266 let live_end = end - hist_len;
267 for row in live_rows
268 .iter()
269 .skip(live_start)
270 .take(live_end.saturating_sub(live_start))
271 {
272 rows.push(row.clone());
273 }
274 }
275 }
276 }
277 }
278
279 rows
280 }
281
282 pub fn subscribe_live(&self, id: &str, timeframe: &Timeframe) -> Result<(), AsyncDataError> {
297 let key = CacheKey::new(id.to_string(), *timeframe);
298
299 {
301 let subscriptions = self.subscriptions.lock().unwrap();
302 if subscriptions.contains_key(&key) {
303 return Ok(());
304 }
305 }
306
307 let mut rx = self.provider.subscribe(id, timeframe)?;
308 let live_buffer = self.live_buffer.clone();
309 let key_clone = key.clone();
310
311 let handle = self.runtime.spawn(async move {
313 while let Some(df) = rx.recv().await {
314 if let Ok(mut buffer) = live_buffer.write() {
316 let rows = buffer.entry(key_clone.clone()).or_insert_with(Vec::new);
317
318 for i in 0..df.row_count() {
319 if let Some(row) = df.get_row(i) {
320 if let Some(owned) = OwnedDataRow::from_data_row(&row) {
321 rows.push(owned);
322 }
323 }
324 }
325 }
326 }
327 });
328
329 let mut subscriptions = self.subscriptions.lock().unwrap();
330 subscriptions.insert(key, handle);
331 Ok(())
332 }
333
334 pub fn unsubscribe_live(&self, symbol: &str, timeframe: &Timeframe) {
343 let key = CacheKey::new(symbol.to_string(), *timeframe);
344
345 let mut subscriptions = self.subscriptions.lock().unwrap();
346 if let Some(handle) = subscriptions.remove(&key) {
347 handle.abort();
348 }
349
350 let _ = self.provider.unsubscribe(symbol, timeframe);
352
353 if let Ok(mut buffer) = self.live_buffer.write() {
355 buffer.remove(&key);
356 }
357 }
358
359 pub fn row_count(&self, id: &str, timeframe: &Timeframe) -> usize {
370 let key = CacheKey::new(id.to_string(), *timeframe);
371
372 let historical = self.historical.read().unwrap();
373 let hist_count = historical.get(&key).map(|c| c.row_count()).unwrap_or(0);
374
375 let live_count = self
376 .live_buffer
377 .read()
378 .ok()
379 .and_then(|b| b.get(&key).map(|v| v.len()))
380 .unwrap_or(0);
381
382 hist_count + live_count
383 }
384
385 pub fn has_cached(&self, symbol: &str, timeframe: &Timeframe) -> bool {
396 let key = CacheKey::new(symbol.to_string(), *timeframe);
397 let historical = self.historical.read().unwrap();
398 historical.contains_key(&key)
399 }
400
401 pub fn cached_keys(&self) -> Vec<(String, Timeframe)> {
407 let historical = self.historical.read().unwrap();
408 historical
409 .keys()
410 .map(|k| (k.id.clone(), k.timeframe))
411 .collect()
412 }
413
414 pub fn clear(&self) {
418 let mut subscriptions = self.subscriptions.lock().unwrap();
420 for (_, handle) in subscriptions.drain() {
421 handle.abort();
422 }
423 drop(subscriptions);
424
425 let mut historical = self.historical.write().unwrap();
427 historical.clear();
428 drop(historical);
429
430 if let Ok(mut buffer) = self.live_buffer.write() {
432 buffer.clear();
433 }
434 }
435
436 pub fn provider(&self) -> SharedAsyncProvider {
440 self.provider.clone()
441 }
442
443 pub fn snapshot(&self, store: &SnapshotStore) -> AnyResult<DataCacheSnapshot> {
453 let _ = (
454 store,
455 &self.historical,
456 &self.live_buffer,
457 DEFAULT_CHUNK_LEN,
458 );
459 let _: Option<CacheKeySnapshot> = None;
460 let _: Option<CachedDataSnapshot> = None;
461 let _: Option<LiveBufferSnapshot> = None;
462 let _ = store_chunked_vec::<u8>;
463 anyhow::bail!(
464 "DataCache::snapshot: W17-snapshot-resume surface — \
465 DataFrame / cached-row (de)serializers were deleted alongside \
466 the kind-threaded `slot_to_serializable` rebuild. The kinded \
467 replacement uses `store_chunked_vec` over the parallel \
468 (bits, NativeKind) per-row track. Tracked as \
469 W17-snapshot-resume per docs/cluster-audits/phase-2d-playbook.md §3. \
470 ADR-006 §2.7.4 (snapshot serialization deferral) + §2.7.5.1 \
471 (post-proof wire-format shape for new HeapKinds).",
472 );
473 }
474
475 pub fn restore_from_snapshot(
479 &self,
480 _snapshot: DataCacheSnapshot,
481 _store: &SnapshotStore,
482 ) -> AnyResult<()> {
483 let _ = load_chunked_vec::<OwnedDataRow>;
484 anyhow::bail!(
485 "DataCache::restore_from_snapshot: W17-snapshot-resume \
486 surface — symmetric to `snapshot()`. The kinded \
487 `serializable_to_slot(sv, expected_kind, store)` inverse \
488 reconstructs row-storage parallel kind tracks from the \
489 persisted discriminator. Tracked as W17-snapshot-resume per \
490 docs/cluster-audits/phase-2d-playbook.md §3. ADR-006 \
491 §2.7.4 + §2.7.5.1.",
492 );
493 }
494}
495
496impl Drop for DataCache {
497 fn drop(&mut self) {
498 let mut subscriptions = self.subscriptions.lock().unwrap();
500 for (_, handle) in subscriptions.drain() {
501 handle.abort();
502 }
503 }
504}
505
506#[cfg(test)]
507mod tests {
508 use super::*;
509 use crate::data::{DataQuery, NullAsyncProvider};
510 use crate::snapshot::SnapshotStore;
511 use std::sync::Arc;
512 use std::sync::atomic::{AtomicUsize, Ordering};
513 use std::time::{SystemTime, UNIX_EPOCH};
514
515 #[test]
519 fn test_cache_key() {
520 let key1 = CacheKey::new("AAPL".to_string(), Timeframe::d1());
521 let key2 = CacheKey::new("AAPL".to_string(), Timeframe::d1());
522 let key3 = CacheKey::new("MSFT".to_string(), Timeframe::d1());
523
524 assert_eq!(key1, key2);
525 assert_ne!(key1, key3);
526 }
527
528 #[test]
529 fn test_cached_data() {
530 let df = DataFrame::new("TEST", Timeframe::d1());
531 let cached = CachedData::new(df);
532
533 assert_eq!(cached.row_count(), 0);
534 assert_eq!(cached.current_index, 0);
535 }
536
537 #[derive(Clone)]
538 struct TestAsyncProvider {
539 frames: Arc<HashMap<CacheKey, DataFrame>>,
540 load_calls: Arc<AtomicUsize>,
541 }
542
543 impl crate::data::AsyncDataProvider for TestAsyncProvider {
544 fn load<'a>(
545 &'a self,
546 query: &'a DataQuery,
547 ) -> std::pin::Pin<
548 Box<
549 dyn std::future::Future<Output = Result<DataFrame, crate::data::AsyncDataError>>
550 + Send
551 + 'a,
552 >,
553 > {
554 let key = CacheKey::new(query.id.clone(), query.timeframe);
555 let frames = self.frames.clone();
556 let calls = self.load_calls.clone();
557 Box::pin(async move {
558 calls.fetch_add(1, Ordering::SeqCst);
559 frames
560 .get(&key)
561 .cloned()
562 .ok_or_else(|| crate::data::AsyncDataError::SymbolNotFound(query.id.clone()))
563 })
564 }
565
566 fn has_data(&self, symbol: &str, timeframe: &Timeframe) -> bool {
567 let key = CacheKey::new(symbol.to_string(), *timeframe);
568 self.frames.contains_key(&key)
569 }
570
571 fn symbols(&self) -> Vec<String> {
572 self.frames.keys().map(|k| k.id.clone()).collect()
573 }
574 }
575
576 #[allow(dead_code)]
580 fn _unused_test_imports(
581 _provider: TestAsyncProvider,
582 _df: DataFrame,
583 _query: DataQuery,
584 _kind: NullAsyncProvider,
585 _store: SnapshotStore,
586 _arc: Arc<()>,
587 _atomic: AtomicUsize,
588 _ordering: Ordering,
589 ) {
590 let _ = (SystemTime::UNIX_EPOCH, UNIX_EPOCH);
591 }
592
593 #[test]
598 fn test_w17_data_cache_snapshot_returns_structured_error() {
599 let tmp = tempfile::tempdir().expect("tempdir");
600 let store = SnapshotStore::new(tmp.path()).expect("snapshot store");
601 let cache = DataCache::from_test_data(HashMap::new());
602
603 let result = cache.snapshot(&store);
604 let err = result.expect_err("expected Err, got Ok");
605 let msg = format!("{err}");
606 assert!(
607 msg.contains("W17-snapshot-resume surface"),
608 "missing W17 marker; got: {msg}"
609 );
610 assert!(
611 msg.contains("§2.7.4"),
612 "missing ADR-006 §2.7.4 cite; got: {msg}"
613 );
614 }
615}