Skip to main content

somatize_core/store/
mod.rs

1//! Data Store: abstraction for moving data between workers.
2//!
3//! Separates WHERE data lives from HOW it's processed.
4//! Workers use DataRef to reference data without materializing it.
5
6#[cfg(feature = "s3")]
7pub mod s3;
8
9#[cfg(feature = "s3")]
10pub use s3::S3DataStore;
11
12#[cfg(feature = "zarr")]
13pub mod zarr;
14
15#[cfg(feature = "zarr")]
16pub use zarr::ZarrStore;
17
18use crate::cache::CacheKey;
19use crate::error::{Result, SomaError};
20use crate::value::Value;
21use serde::{Deserialize, Serialize};
22
23/// Metadata about a stored value, queryable without loading data.
24#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct StoreMeta {
26    /// Total number of rows (shape[0] for tensors, 1 for scalar types).
27    pub total_rows: usize,
28    /// Remaining shape dimensions after the row axis (shape[1..] for tensors).
29    pub shape_tail: Vec<usize>,
30    /// Type tag: "tensor", "json", "bytes", or "empty".
31    pub dtype: String,
32}
33
34impl StoreMeta {
35    /// Build metadata from an in-memory Value.
36    pub fn from_value(value: &Value) -> Self {
37        match value {
38            Value::Tensor { shape, .. } => Self {
39                total_rows: shape.first().copied().unwrap_or(0),
40                shape_tail: shape.get(1..).unwrap_or_default().to_vec(),
41                dtype: "tensor".into(),
42            },
43            Value::Json(_) => Self {
44                total_rows: 1,
45                shape_tail: vec![],
46                dtype: "json".into(),
47            },
48            Value::Bytes(b) => Self {
49                total_rows: b.len(),
50                shape_tail: vec![],
51                dtype: "bytes".into(),
52            },
53            Value::Empty => Self {
54                total_rows: 0,
55                shape_tail: vec![],
56                dtype: "empty".into(),
57            },
58        }
59    }
60}
61
62/// Slice rows `[start..start+len)` from a tensor value.
63pub fn slice_tensor_rows(value: &Value, start: usize, len: usize) -> Result<Value> {
64    match value {
65        Value::Tensor { values, shape } => {
66            if shape.is_empty() {
67                return Err(SomaError::DataStore("cannot slice scalar tensor".into()));
68            }
69            let cols: usize = shape[1..].iter().product::<usize>().max(1);
70            let row_start = start * cols;
71            let row_end = (start + len) * cols;
72            if row_end > values.len() {
73                return Err(SomaError::DataStore(format!(
74                    "row range {start}..{} out of bounds (total rows: {})",
75                    start + len,
76                    shape[0]
77                )));
78            }
79            let mut new_shape = shape.clone();
80            new_shape[0] = len;
81            Ok(Value::tensor(
82                values[row_start..row_end].to_vec(),
83                new_shape,
84            ))
85        }
86        _ => Err(SomaError::DataStore(
87            "get_rows only works on Tensor values".into(),
88        )),
89    }
90}
91
92/// A reference to data that may live in different places.
93/// Workers exchange DataRefs instead of raw data.
94#[derive(Debug, Clone, Serialize, Deserialize)]
95#[serde(tag = "type")]
96#[non_exhaustive]
97pub enum DataRef {
98    /// Data in local filesystem
99    Local { path: String },
100    /// Data in S3-compatible object storage
101    S3 {
102        bucket: String,
103        key: String,
104        region: Option<String>,
105    },
106    /// Data in Soma cache (content-addressable)
107    Cached { cache_key: CacheKey },
108    /// Data available as a stream endpoint
109    Stream {
110        endpoint: String,
111        format: StreamFormat,
112    },
113    /// Data materialized inline (small values only)
114    Inline { value: Value },
115    /// Data stored as a Zarr v3 array in object storage (chunked tensors).
116    Zarr {
117        bucket: String,
118        /// Root path of the Zarr array (contains zarr.json + chunk objects).
119        array_path: String,
120        region: Option<String>,
121    },
122}
123
124/// Stream data format.
125#[derive(Debug, Clone, Serialize, Deserialize, Default)]
126#[serde(rename_all = "snake_case")]
127#[non_exhaustive]
128pub enum StreamFormat {
129    #[default]
130    JsonLines,
131    Csv,
132    Arrow,
133    Protobuf,
134}
135
136/// Storage configuration for an investigation/pipeline.
137#[derive(Debug, Clone, Serialize, Deserialize)]
138#[serde(tag = "type")]
139#[non_exhaustive]
140pub enum StorageConfig {
141    /// Local filesystem (NFS, mounted volume)
142    #[serde(rename = "local")]
143    Local { base_path: String },
144    /// S3-compatible object storage
145    #[serde(rename = "s3")]
146    S3 {
147        bucket: String,
148        prefix: String,
149        region: Option<String>,
150        endpoint: Option<String>,
151    },
152    /// Zarr v3 chunked storage on S3-compatible backend.
153    #[serde(rename = "zarr")]
154    Zarr {
155        bucket: String,
156        prefix: String,
157        region: Option<String>,
158        endpoint: Option<String>,
159        /// Rows per chunk (first dimension).
160        chunk_rows: usize,
161    },
162}
163
164impl Default for StorageConfig {
165    fn default() -> Self {
166        Self::Local {
167            base_path: "/tmp/soma-data".to_string(),
168        }
169    }
170}
171
172/// The DataStore trait: put/get/stream data across workers.
173///
174/// Unlike CacheStore (which stores Values by CacheKey),
175/// DataStore moves data between locations and supports streaming.
176pub trait DataStore: Send + Sync {
177    /// Store data and return a reference to it.
178    fn put(&self, key: &CacheKey, data: &Value) -> Result<DataRef>;
179
180    /// Retrieve data from a reference.
181    fn get(&self, data_ref: &DataRef) -> Result<Value>;
182
183    /// Check if data exists at a reference.
184    fn exists(&self, data_ref: &DataRef) -> Result<bool>;
185
186    /// Delete data at a reference.
187    fn remove(&self, data_ref: &DataRef) -> Result<()>;
188
189    /// Get the storage config.
190    fn config(&self) -> &StorageConfig;
191
192    /// Read a range of rows `[start..start+len)` from a tensor.
193    /// Returns a `Value::Tensor` with `shape[0] == len`.
194    /// Default impl downloads the full value and slices in memory.
195    fn get_rows(&self, data_ref: &DataRef, start: usize, len: usize) -> Result<Value> {
196        let value = self.get(data_ref)?;
197        slice_tensor_rows(&value, start, len)
198    }
199
200    /// Get metadata about a stored value without reading the data.
201    /// Default impl downloads the full value to extract metadata.
202    fn meta(&self, data_ref: &DataRef) -> Result<StoreMeta> {
203        let value = self.get(data_ref)?;
204        Ok(StoreMeta::from_value(&value))
205    }
206}
207
208/// Local filesystem data store.
209pub struct LocalDataStore {
210    config: StorageConfig,
211    base_path: std::path::PathBuf,
212}
213
214impl LocalDataStore {
215    pub fn new(base_path: impl Into<std::path::PathBuf>) -> Self {
216        let base = base_path.into();
217        std::fs::create_dir_all(&base).ok();
218        Self {
219            config: StorageConfig::Local {
220                base_path: base.to_string_lossy().to_string(),
221            },
222            base_path: base,
223        }
224    }
225}
226
227impl DataStore for LocalDataStore {
228    fn put(&self, key: &CacheKey, data: &Value) -> Result<DataRef> {
229        let path = self.base_path.join(key.to_hex());
230        let bytes = serde_json::to_vec(data)
231            .map_err(|e| crate::error::SomaError::DataStore(e.to_string()))?;
232        std::fs::write(&path, &bytes)
233            .map_err(|e| crate::error::SomaError::DataStore(e.to_string()))?;
234        Ok(DataRef::Local {
235            path: path.to_string_lossy().to_string(),
236        })
237    }
238
239    fn get(&self, data_ref: &DataRef) -> Result<Value> {
240        match data_ref {
241            DataRef::Local { path } => {
242                let bytes = std::fs::read(path)
243                    .map_err(|e| crate::error::SomaError::DataStore(e.to_string()))?;
244                serde_json::from_slice(&bytes)
245                    .map_err(|e| crate::error::SomaError::DataStore(e.to_string()))
246            }
247            DataRef::Cached { cache_key } => {
248                let path = self.base_path.join(cache_key.to_hex());
249                let bytes = std::fs::read(&path)
250                    .map_err(|e| crate::error::SomaError::DataStore(e.to_string()))?;
251                serde_json::from_slice(&bytes)
252                    .map_err(|e| crate::error::SomaError::DataStore(e.to_string()))
253            }
254            DataRef::Inline { value } => Ok(value.clone()),
255            _ => Err(crate::error::SomaError::DataStore(
256                "Cannot get non-local DataRef from LocalDataStore".into(),
257            )),
258        }
259    }
260
261    fn exists(&self, data_ref: &DataRef) -> Result<bool> {
262        match data_ref {
263            DataRef::Local { path } => Ok(std::path::Path::new(path).exists()),
264            DataRef::Cached { cache_key } => Ok(self.base_path.join(cache_key.to_hex()).exists()),
265            DataRef::Inline { .. } => Ok(true),
266            _ => Ok(false),
267        }
268    }
269
270    fn remove(&self, data_ref: &DataRef) -> Result<()> {
271        if let DataRef::Local { path } = data_ref {
272            std::fs::remove_file(path).ok();
273        }
274        Ok(())
275    }
276
277    fn config(&self) -> &StorageConfig {
278        &self.config
279    }
280}
281
282/// Stream-aware cache for inference pipelines.
283///
284/// Key insight: during inference, the filter STATE is fixed (from training).
285/// Only the DATA changes. So we cache:
286/// 1. Filter states (from training) — keyed by config_hash + training_data_hash
287/// 2. Chunk results — keyed by config_hash + state_hash + chunk_hash
288///
289/// This means: if the same chunk passes through the same filter with the
290/// same trained state, the result is returned from cache instantly.
291pub struct StreamCache {
292    /// State cache: filter_id → (state_key, cached state)
293    states: std::collections::HashMap<String, (CacheKey, Value)>,
294    /// Chunk result cache: LRU of chunk results
295    chunk_cache: std::collections::HashMap<CacheKey, Value>,
296    /// Max cached chunks (LRU eviction)
297    max_chunks: usize,
298    /// Stats
299    pub hits: u64,
300    pub misses: u64,
301}
302
303impl StreamCache {
304    pub fn new(max_chunks: usize) -> Self {
305        Self {
306            states: std::collections::HashMap::new(),
307            chunk_cache: std::collections::HashMap::new(),
308            max_chunks,
309            hits: 0,
310            misses: 0,
311        }
312    }
313
314    /// Load a filter's trained state into the stream cache.
315    pub fn load_state(&mut self, filter_id: &str, state_key: CacheKey, state: Value) {
316        self.states
317            .insert(filter_id.to_string(), (state_key, state));
318    }
319
320    /// Get a filter's cached state (for forward() during inference).
321    pub fn get_state(&self, filter_id: &str) -> Option<&Value> {
322        self.states.get(filter_id).map(|(_, v)| v)
323    }
324
325    /// Try to get a cached chunk result.
326    /// chunk_key = hash(config_hash + state_hash + chunk_data_hash)
327    pub fn get_chunk(&mut self, chunk_key: &CacheKey) -> Option<&Value> {
328        if let Some(v) = self.chunk_cache.get(chunk_key) {
329            self.hits += 1;
330            Some(v)
331        } else {
332            self.misses += 1;
333            None
334        }
335    }
336
337    /// Cache a chunk result.
338    pub fn put_chunk(&mut self, chunk_key: CacheKey, value: Value) {
339        if self.chunk_cache.len() >= self.max_chunks {
340            // Simple eviction: remove first entry (not true LRU, but fast)
341            if let Some(k) = self.chunk_cache.keys().next().cloned() {
342                self.chunk_cache.remove(&k);
343            }
344        }
345        self.chunk_cache.insert(chunk_key, value);
346    }
347
348    /// Cache hit rate.
349    pub fn hit_rate(&self) -> f64 {
350        let total = self.hits + self.misses;
351        if total == 0 {
352            0.0
353        } else {
354            self.hits as f64 / total as f64
355        }
356    }
357}
358
359#[cfg(test)]
360mod tests {
361    use super::*;
362
363    #[test]
364    fn local_data_store_roundtrip() {
365        let dir = std::env::temp_dir().join("soma-ds-test");
366        let store = LocalDataStore::new(&dir);
367
368        let key = CacheKey::hash_data(b"test_data");
369        let value = Value::tensor(vec![1.0, 2.0, 3.0], vec![3]);
370
371        let data_ref = store.put(&key, &value).unwrap();
372        assert!(store.exists(&data_ref).unwrap());
373
374        let retrieved = store.get(&data_ref).unwrap();
375        let (data, _) = retrieved.as_tensor().unwrap();
376        assert_eq!(data, &[1.0, 2.0, 3.0]);
377
378        store.remove(&data_ref).unwrap();
379        assert!(!store.exists(&data_ref).unwrap());
380
381        let _ = std::fs::remove_dir_all(&dir);
382    }
383
384    #[test]
385    fn inline_data_ref() {
386        let dir = std::env::temp_dir().join("soma-ds-test-inline");
387        let store = LocalDataStore::new(&dir);
388
389        let data_ref = DataRef::Inline {
390            value: Value::tensor(vec![42.0], vec![1]),
391        };
392
393        assert!(store.exists(&data_ref).unwrap());
394        let v = store.get(&data_ref).unwrap();
395        let (data, _) = v.as_tensor().unwrap();
396        assert_eq!(data, &[42.0]);
397
398        let _ = std::fs::remove_dir_all(&dir);
399    }
400
401    #[test]
402    fn stream_cache_basics() {
403        let mut cache = StreamCache::new(100);
404
405        let state = Value::tensor(vec![0.0, 1.0], vec![2]);
406        let state_key = CacheKey::hash_data(b"state_001");
407        cache.load_state("normalize", state_key, state.clone());
408
409        assert!(cache.get_state("normalize").is_some());
410        assert!(cache.get_state("unknown").is_none());
411    }
412
413    #[test]
414    fn stream_cache_chunks() {
415        let mut cache = StreamCache::new(3);
416
417        let k1 = CacheKey::hash_data(b"chunk_1");
418        let k2 = CacheKey::hash_data(b"chunk_2");
419        let k3 = CacheKey::hash_data(b"chunk_3");
420        let k4 = CacheKey::hash_data(b"chunk_4");
421
422        cache.put_chunk(k1.clone(), Value::tensor(vec![1.0], vec![1]));
423        cache.put_chunk(k2.clone(), Value::tensor(vec![2.0], vec![1]));
424        cache.put_chunk(k3.clone(), Value::tensor(vec![3.0], vec![1]));
425
426        // All 3 should be cached
427        assert!(cache.get_chunk(&k1).is_some());
428        assert!(cache.get_chunk(&k2).is_some());
429        assert!(cache.get_chunk(&k3).is_some());
430        assert_eq!(cache.hits, 3);
431
432        // Adding k4 should evict one (max_chunks = 3)
433        cache.put_chunk(k4.clone(), Value::tensor(vec![4.0], vec![1]));
434        assert!(cache.get_chunk(&k4).is_some());
435
436        assert!(cache.hit_rate() > 0.0);
437    }
438
439    #[test]
440    fn storage_config_serde() {
441        let s3 = StorageConfig::S3 {
442            bucket: "my-lab".into(),
443            prefix: "experiments/".into(),
444            region: Some("eu-west-1".into()),
445            endpoint: None,
446        };
447        let json = serde_json::to_string(&s3).unwrap();
448        assert!(json.contains("my-lab"));
449
450        let local = StorageConfig::Local {
451            base_path: "/data".into(),
452        };
453        let json = serde_json::to_string(&local).unwrap();
454        assert!(json.contains("/data"));
455    }
456
457    #[test]
458    fn data_ref_serde() {
459        let refs = vec![
460            DataRef::Local {
461                path: "/tmp/x".into(),
462            },
463            DataRef::S3 {
464                bucket: "b".into(),
465                key: "k".into(),
466                region: None,
467            },
468            DataRef::Cached {
469                cache_key: CacheKey::hash_data(b"x"),
470            },
471            DataRef::Inline {
472                value: Value::Empty,
473            },
474            DataRef::Zarr {
475                bucket: "b".into(),
476                array_path: "data/abc".into(),
477                region: None,
478            },
479        ];
480        for r in &refs {
481            let json = serde_json::to_string(r).unwrap();
482            let _: DataRef = serde_json::from_str(&json).unwrap();
483        }
484    }
485
486    #[test]
487    fn slice_tensor_rows_basic() {
488        // 4 rows × 3 cols
489        let v = Value::tensor(
490            vec![
491                1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
492            ],
493            vec![4, 3],
494        );
495        // Rows 1..3 → [[4,5,6], [7,8,9]]
496        let sliced = slice_tensor_rows(&v, 1, 2).unwrap();
497        let (data, shape) = sliced.as_tensor().unwrap();
498        assert_eq!(shape, &[2, 3]);
499        assert_eq!(data, &[4.0, 5.0, 6.0, 7.0, 8.0, 9.0]);
500    }
501
502    #[test]
503    fn slice_tensor_rows_single() {
504        let v = Value::tensor(vec![10.0, 20.0, 30.0], vec![3]);
505        let sliced = slice_tensor_rows(&v, 1, 1).unwrap();
506        let (data, shape) = sliced.as_tensor().unwrap();
507        assert_eq!(shape, &[1]);
508        assert_eq!(data, &[20.0]);
509    }
510
511    #[test]
512    fn slice_tensor_rows_out_of_bounds() {
513        let v = Value::tensor(vec![1.0, 2.0, 3.0], vec![3]);
514        assert!(slice_tensor_rows(&v, 2, 5).is_err());
515    }
516
517    #[test]
518    fn store_meta_from_tensor() {
519        let v = Value::tensor(vec![0.0; 12], vec![4, 3]);
520        let meta = StoreMeta::from_value(&v);
521        assert_eq!(meta.total_rows, 4);
522        assert_eq!(meta.shape_tail, vec![3]);
523        assert_eq!(meta.dtype, "tensor");
524    }
525
526    #[test]
527    fn store_meta_from_json() {
528        let v = Value::json(serde_json::json!({"a": 1}));
529        let meta = StoreMeta::from_value(&v);
530        assert_eq!(meta.dtype, "json");
531        assert_eq!(meta.total_rows, 1);
532    }
533
534    #[test]
535    fn default_get_rows_on_local_store() {
536        let dir = std::env::temp_dir().join("soma-ds-test-getrows");
537        let store = LocalDataStore::new(&dir);
538
539        let key = CacheKey::hash_data(b"rows_test");
540        let value = Value::tensor(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![3, 2]);
541        let data_ref = store.put(&key, &value).unwrap();
542
543        // Read rows 1..2 via default impl (full get + slice)
544        let sliced = store.get_rows(&data_ref, 1, 2).unwrap();
545        let (data, shape) = sliced.as_tensor().unwrap();
546        assert_eq!(shape, &[2, 2]);
547        assert_eq!(data, &[3.0, 4.0, 5.0, 6.0]);
548
549        let _ = std::fs::remove_dir_all(&dir);
550    }
551
552    #[test]
553    fn default_meta_on_local_store() {
554        let dir = std::env::temp_dir().join("soma-ds-test-meta");
555        let store = LocalDataStore::new(&dir);
556
557        let key = CacheKey::hash_data(b"meta_test");
558        let value = Value::tensor(vec![0.0; 20], vec![5, 4]);
559        let data_ref = store.put(&key, &value).unwrap();
560
561        let meta = store.meta(&data_ref).unwrap();
562        assert_eq!(meta.total_rows, 5);
563        assert_eq!(meta.shape_tail, vec![4]);
564        assert_eq!(meta.dtype, "tensor");
565
566        let _ = std::fs::remove_dir_all(&dir);
567    }
568
569    #[test]
570    fn zarr_storage_config_serde() {
571        let zarr = StorageConfig::Zarr {
572            bucket: "soma-research".into(),
573            prefix: "data/".into(),
574            region: None,
575            endpoint: Some("s3.eu-central-003.backblazeb2.com".into()),
576            chunk_rows: 1024,
577        };
578        let json = serde_json::to_string(&zarr).unwrap();
579        assert!(json.contains("soma-research"));
580        assert!(json.contains("1024"));
581        let _: StorageConfig = serde_json::from_str(&json).unwrap();
582    }
583}