Skip to main content

somatize_core/store/
mod.rs

1//! Data Store: abstraction for moving data between workers.
2//!
3//! Separates WHERE data lives from HOW it's processed.
4//! Workers use DataRef to reference data without materializing it.
5
6/// Maximum payload size for inline WebSocket transport.
7/// Payloads above this threshold are uploaded via HTTP bulk or DataStore.
8pub const INLINE_THRESHOLD_BYTES: usize = 10 * 1024 * 1024; // 10 MB
9
10#[cfg(feature = "s3")]
11pub mod s3;
12
13#[cfg(feature = "s3")]
14pub use s3::S3DataStore;
15
16#[cfg(feature = "zarr")]
17pub mod zarr;
18
19#[cfg(feature = "zarr")]
20pub use zarr::ZarrStore;
21
22use crate::cache::CacheKey;
23use crate::error::{Result, SomaError};
24use crate::value::Value;
25use serde::{Deserialize, Serialize};
26
27/// Metadata about a stored value, queryable without loading data.
28#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct StoreMeta {
30    /// Total number of rows (shape[0] for tensors, 1 for scalar types).
31    pub total_rows: usize,
32    /// Remaining shape dimensions after the row axis (shape[1..] for tensors).
33    pub shape_tail: Vec<usize>,
34    /// Type tag: "tensor", "json", "bytes", or "empty".
35    pub dtype: String,
36}
37
38impl StoreMeta {
39    /// Build metadata from an in-memory Value.
40    pub fn from_value(value: &Value) -> Self {
41        match value {
42            Value::Tensor { shape, .. } => Self {
43                total_rows: shape.first().copied().unwrap_or(0),
44                shape_tail: shape.get(1..).unwrap_or_default().to_vec(),
45                dtype: "tensor".into(),
46            },
47            Value::Json(_) => Self {
48                total_rows: 1,
49                shape_tail: vec![],
50                dtype: "json".into(),
51            },
52            Value::Bytes(b) | Value::Object(b) => Self {
53                total_rows: b.len(),
54                shape_tail: vec![],
55                dtype: "bytes".into(),
56            },
57            Value::Empty => Self {
58                total_rows: 0,
59                shape_tail: vec![],
60                dtype: "empty".into(),
61            },
62        }
63    }
64}
65
66/// Slice rows `[start..start+len)` from a tensor value.
67pub fn slice_tensor_rows(value: &Value, start: usize, len: usize) -> Result<Value> {
68    match value {
69        Value::Tensor { values, shape } => {
70            if shape.is_empty() {
71                return Err(SomaError::DataStore("cannot slice scalar tensor".into()));
72            }
73            let cols: usize = shape[1..].iter().product::<usize>().max(1);
74            let row_start = start * cols;
75            let row_end = (start + len) * cols;
76            if row_end > values.len() {
77                return Err(SomaError::DataStore(format!(
78                    "row range {start}..{} out of bounds (total rows: {})",
79                    start + len,
80                    shape[0]
81                )));
82            }
83            let mut new_shape = shape.clone();
84            new_shape[0] = len;
85            Ok(Value::tensor(
86                values[row_start..row_end].to_vec(),
87                new_shape,
88            ))
89        }
90        _ => Err(SomaError::DataStore(
91            "get_rows only works on Tensor values".into(),
92        )),
93    }
94}
95
96/// A reference to data that may live in different places.
97/// Workers exchange DataRefs instead of raw data.
98#[derive(Debug, Clone, Serialize, Deserialize)]
99#[serde(tag = "type")]
100#[non_exhaustive]
101pub enum DataRef {
102    /// Data in local filesystem
103    Local { path: String },
104    /// Data in S3-compatible object storage
105    S3 {
106        bucket: String,
107        key: String,
108        region: Option<String>,
109    },
110    /// Data in Soma cache (content-addressable)
111    Cached { cache_key: CacheKey },
112    /// Data available as a stream endpoint
113    Stream {
114        endpoint: String,
115        format: StreamFormat,
116    },
117    /// Data materialized inline (small values only)
118    Inline { value: Value },
119    /// Data stored as a Zarr v3 array in object storage (chunked tensors).
120    Zarr {
121        bucket: String,
122        /// Root path of the Zarr array (contains zarr.json + chunk objects).
123        array_path: String,
124        region: Option<String>,
125    },
126}
127
128/// Stream data format.
129#[derive(Debug, Clone, Serialize, Deserialize, Default)]
130#[serde(rename_all = "snake_case")]
131#[non_exhaustive]
132pub enum StreamFormat {
133    #[default]
134    JsonLines,
135    Csv,
136    Arrow,
137    Protobuf,
138}
139
140/// Storage configuration for an investigation/pipeline.
141#[derive(Debug, Clone, Serialize, Deserialize)]
142#[serde(tag = "type")]
143#[non_exhaustive]
144pub enum StorageConfig {
145    /// Local filesystem (NFS, mounted volume)
146    #[serde(rename = "local")]
147    Local { base_path: String },
148    /// S3-compatible object storage
149    #[serde(rename = "s3")]
150    S3 {
151        bucket: String,
152        prefix: String,
153        region: Option<String>,
154        endpoint: Option<String>,
155    },
156    /// Zarr v3 chunked storage on S3-compatible backend.
157    #[serde(rename = "zarr")]
158    Zarr {
159        bucket: String,
160        prefix: String,
161        region: Option<String>,
162        endpoint: Option<String>,
163        /// Rows per chunk (first dimension).
164        chunk_rows: usize,
165    },
166}
167
168impl Default for StorageConfig {
169    fn default() -> Self {
170        Self::Local {
171            base_path: "/tmp/soma-data".to_string(),
172        }
173    }
174}
175
176/// The DataStore trait: put/get/stream data across workers.
177///
178/// Unlike CacheStore (which stores Values by CacheKey),
179/// DataStore moves data between locations and supports streaming.
180pub trait DataStore: Send + Sync {
181    /// Store data and return a reference to it.
182    fn put(&self, key: &CacheKey, data: &Value) -> Result<DataRef>;
183
184    /// Retrieve data from a reference.
185    fn get(&self, data_ref: &DataRef) -> Result<Value>;
186
187    /// Check if data exists at a reference.
188    fn exists(&self, data_ref: &DataRef) -> Result<bool>;
189
190    /// Delete data at a reference.
191    fn remove(&self, data_ref: &DataRef) -> Result<()>;
192
193    /// Get the storage config.
194    fn config(&self) -> &StorageConfig;
195
196    /// Read a range of rows `[start..start+len)` from a tensor.
197    /// Returns a `Value::Tensor` with `shape[0] == len`.
198    /// Default impl downloads the full value and slices in memory.
199    fn get_rows(&self, data_ref: &DataRef, start: usize, len: usize) -> Result<Value> {
200        let value = self.get(data_ref)?;
201        slice_tensor_rows(&value, start, len)
202    }
203
204    /// Get metadata about a stored value without reading the data.
205    /// Default impl downloads the full value to extract metadata.
206    fn meta(&self, data_ref: &DataRef) -> Result<StoreMeta> {
207        let value = self.get(data_ref)?;
208        Ok(StoreMeta::from_value(&value))
209    }
210}
211
212/// Local filesystem data store.
213pub struct LocalDataStore {
214    config: StorageConfig,
215    base_path: std::path::PathBuf,
216}
217
218impl LocalDataStore {
219    pub fn new(base_path: impl Into<std::path::PathBuf>) -> Self {
220        let base = base_path.into();
221        std::fs::create_dir_all(&base).ok();
222        Self {
223            config: StorageConfig::Local {
224                base_path: base.to_string_lossy().to_string(),
225            },
226            base_path: base,
227        }
228    }
229}
230
231impl DataStore for LocalDataStore {
232    fn put(&self, key: &CacheKey, data: &Value) -> Result<DataRef> {
233        let path = self.base_path.join(key.to_hex());
234        let bytes = serde_json::to_vec(data)
235            .map_err(|e| crate::error::SomaError::DataStore(e.to_string()))?;
236        std::fs::write(&path, &bytes)
237            .map_err(|e| crate::error::SomaError::DataStore(e.to_string()))?;
238        Ok(DataRef::Local {
239            path: path.to_string_lossy().to_string(),
240        })
241    }
242
243    fn get(&self, data_ref: &DataRef) -> Result<Value> {
244        match data_ref {
245            DataRef::Local { path } => {
246                let bytes = std::fs::read(path)
247                    .map_err(|e| crate::error::SomaError::DataStore(e.to_string()))?;
248                serde_json::from_slice(&bytes)
249                    .map_err(|e| crate::error::SomaError::DataStore(e.to_string()))
250            }
251            DataRef::Cached { cache_key } => {
252                let path = self.base_path.join(cache_key.to_hex());
253                let bytes = std::fs::read(&path)
254                    .map_err(|e| crate::error::SomaError::DataStore(e.to_string()))?;
255                serde_json::from_slice(&bytes)
256                    .map_err(|e| crate::error::SomaError::DataStore(e.to_string()))
257            }
258            DataRef::Inline { value } => Ok(value.clone()),
259            _ => Err(crate::error::SomaError::DataStore(
260                "Cannot get non-local DataRef from LocalDataStore".into(),
261            )),
262        }
263    }
264
265    fn exists(&self, data_ref: &DataRef) -> Result<bool> {
266        match data_ref {
267            DataRef::Local { path } => Ok(std::path::Path::new(path).exists()),
268            DataRef::Cached { cache_key } => Ok(self.base_path.join(cache_key.to_hex()).exists()),
269            DataRef::Inline { .. } => Ok(true),
270            _ => Ok(false),
271        }
272    }
273
274    fn remove(&self, data_ref: &DataRef) -> Result<()> {
275        if let DataRef::Local { path } = data_ref {
276            std::fs::remove_file(path).ok();
277        }
278        Ok(())
279    }
280
281    fn config(&self) -> &StorageConfig {
282        &self.config
283    }
284}
285
286/// Stream-aware cache for inference pipelines.
287///
288/// Key insight: during inference, the filter STATE is fixed (from training).
289/// Only the DATA changes. So we cache:
290/// 1. Filter states (from training) — keyed by config_hash + training_data_hash
291/// 2. Chunk results — keyed by config_hash + state_hash + chunk_hash
292///
293/// This means: if the same chunk passes through the same filter with the
294/// same trained state, the result is returned from cache instantly.
295pub struct StreamCache {
296    /// State cache: filter_id → (state_key, cached state)
297    states: std::collections::HashMap<String, (CacheKey, Value)>,
298    /// Chunk result cache: LRU of chunk results
299    chunk_cache: std::collections::HashMap<CacheKey, Value>,
300    /// Max cached chunks (LRU eviction)
301    max_chunks: usize,
302    /// Stats
303    pub hits: u64,
304    pub misses: u64,
305}
306
307impl StreamCache {
308    pub fn new(max_chunks: usize) -> Self {
309        Self {
310            states: std::collections::HashMap::new(),
311            chunk_cache: std::collections::HashMap::new(),
312            max_chunks,
313            hits: 0,
314            misses: 0,
315        }
316    }
317
318    /// Load a filter's trained state into the stream cache.
319    pub fn load_state(&mut self, filter_id: &str, state_key: CacheKey, state: Value) {
320        self.states
321            .insert(filter_id.to_string(), (state_key, state));
322    }
323
324    /// Get a filter's cached state (for forward() during inference).
325    pub fn get_state(&self, filter_id: &str) -> Option<&Value> {
326        self.states.get(filter_id).map(|(_, v)| v)
327    }
328
329    /// Try to get a cached chunk result.
330    /// chunk_key = hash(config_hash + state_hash + chunk_data_hash)
331    pub fn get_chunk(&mut self, chunk_key: &CacheKey) -> Option<&Value> {
332        if let Some(v) = self.chunk_cache.get(chunk_key) {
333            self.hits += 1;
334            Some(v)
335        } else {
336            self.misses += 1;
337            None
338        }
339    }
340
341    /// Cache a chunk result.
342    pub fn put_chunk(&mut self, chunk_key: CacheKey, value: Value) {
343        if self.chunk_cache.len() >= self.max_chunks {
344            // Simple eviction: remove first entry (not true LRU, but fast)
345            if let Some(k) = self.chunk_cache.keys().next().cloned() {
346                self.chunk_cache.remove(&k);
347            }
348        }
349        self.chunk_cache.insert(chunk_key, value);
350    }
351
352    /// Cache hit rate.
353    pub fn hit_rate(&self) -> f64 {
354        let total = self.hits + self.misses;
355        if total == 0 {
356            0.0
357        } else {
358            self.hits as f64 / total as f64
359        }
360    }
361}
362
363#[cfg(test)]
364mod tests {
365    use super::*;
366
367    #[test]
368    fn local_data_store_roundtrip() {
369        let dir = std::env::temp_dir().join("soma-ds-test");
370        let store = LocalDataStore::new(&dir);
371
372        let key = CacheKey::hash_data(b"test_data");
373        let value = Value::tensor(vec![1.0, 2.0, 3.0], vec![3]);
374
375        let data_ref = store.put(&key, &value).unwrap();
376        assert!(store.exists(&data_ref).unwrap());
377
378        let retrieved = store.get(&data_ref).unwrap();
379        let (data, _) = retrieved.as_tensor().unwrap();
380        assert_eq!(data, &[1.0, 2.0, 3.0]);
381
382        store.remove(&data_ref).unwrap();
383        assert!(!store.exists(&data_ref).unwrap());
384
385        let _ = std::fs::remove_dir_all(&dir);
386    }
387
388    #[test]
389    fn inline_data_ref() {
390        let dir = std::env::temp_dir().join("soma-ds-test-inline");
391        let store = LocalDataStore::new(&dir);
392
393        let data_ref = DataRef::Inline {
394            value: Value::tensor(vec![42.0], vec![1]),
395        };
396
397        assert!(store.exists(&data_ref).unwrap());
398        let v = store.get(&data_ref).unwrap();
399        let (data, _) = v.as_tensor().unwrap();
400        assert_eq!(data, &[42.0]);
401
402        let _ = std::fs::remove_dir_all(&dir);
403    }
404
405    #[test]
406    fn stream_cache_basics() {
407        let mut cache = StreamCache::new(100);
408
409        let state = Value::tensor(vec![0.0, 1.0], vec![2]);
410        let state_key = CacheKey::hash_data(b"state_001");
411        cache.load_state("normalize", state_key, state.clone());
412
413        assert!(cache.get_state("normalize").is_some());
414        assert!(cache.get_state("unknown").is_none());
415    }
416
417    #[test]
418    fn stream_cache_chunks() {
419        let mut cache = StreamCache::new(3);
420
421        let k1 = CacheKey::hash_data(b"chunk_1");
422        let k2 = CacheKey::hash_data(b"chunk_2");
423        let k3 = CacheKey::hash_data(b"chunk_3");
424        let k4 = CacheKey::hash_data(b"chunk_4");
425
426        cache.put_chunk(k1.clone(), Value::tensor(vec![1.0], vec![1]));
427        cache.put_chunk(k2.clone(), Value::tensor(vec![2.0], vec![1]));
428        cache.put_chunk(k3.clone(), Value::tensor(vec![3.0], vec![1]));
429
430        // All 3 should be cached
431        assert!(cache.get_chunk(&k1).is_some());
432        assert!(cache.get_chunk(&k2).is_some());
433        assert!(cache.get_chunk(&k3).is_some());
434        assert_eq!(cache.hits, 3);
435
436        // Adding k4 should evict one (max_chunks = 3)
437        cache.put_chunk(k4.clone(), Value::tensor(vec![4.0], vec![1]));
438        assert!(cache.get_chunk(&k4).is_some());
439
440        assert!(cache.hit_rate() > 0.0);
441    }
442
443    #[test]
444    fn storage_config_serde() {
445        let s3 = StorageConfig::S3 {
446            bucket: "my-lab".into(),
447            prefix: "experiments/".into(),
448            region: Some("eu-west-1".into()),
449            endpoint: None,
450        };
451        let json = serde_json::to_string(&s3).unwrap();
452        assert!(json.contains("my-lab"));
453
454        let local = StorageConfig::Local {
455            base_path: "/data".into(),
456        };
457        let json = serde_json::to_string(&local).unwrap();
458        assert!(json.contains("/data"));
459    }
460
461    #[test]
462    fn data_ref_serde() {
463        let refs = vec![
464            DataRef::Local {
465                path: "/tmp/x".into(),
466            },
467            DataRef::S3 {
468                bucket: "b".into(),
469                key: "k".into(),
470                region: None,
471            },
472            DataRef::Cached {
473                cache_key: CacheKey::hash_data(b"x"),
474            },
475            DataRef::Inline {
476                value: Value::Empty,
477            },
478            DataRef::Zarr {
479                bucket: "b".into(),
480                array_path: "data/abc".into(),
481                region: None,
482            },
483        ];
484        for r in &refs {
485            let json = serde_json::to_string(r).unwrap();
486            let _: DataRef = serde_json::from_str(&json).unwrap();
487        }
488    }
489
490    #[test]
491    fn slice_tensor_rows_basic() {
492        // 4 rows × 3 cols
493        let v = Value::tensor(
494            vec![
495                1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
496            ],
497            vec![4, 3],
498        );
499        // Rows 1..3 → [[4,5,6], [7,8,9]]
500        let sliced = slice_tensor_rows(&v, 1, 2).unwrap();
501        let (data, shape) = sliced.as_tensor().unwrap();
502        assert_eq!(shape, &[2, 3]);
503        assert_eq!(data, &[4.0, 5.0, 6.0, 7.0, 8.0, 9.0]);
504    }
505
506    #[test]
507    fn slice_tensor_rows_single() {
508        let v = Value::tensor(vec![10.0, 20.0, 30.0], vec![3]);
509        let sliced = slice_tensor_rows(&v, 1, 1).unwrap();
510        let (data, shape) = sliced.as_tensor().unwrap();
511        assert_eq!(shape, &[1]);
512        assert_eq!(data, &[20.0]);
513    }
514
515    #[test]
516    fn slice_tensor_rows_out_of_bounds() {
517        let v = Value::tensor(vec![1.0, 2.0, 3.0], vec![3]);
518        assert!(slice_tensor_rows(&v, 2, 5).is_err());
519    }
520
521    #[test]
522    fn store_meta_from_tensor() {
523        let v = Value::tensor(vec![0.0; 12], vec![4, 3]);
524        let meta = StoreMeta::from_value(&v);
525        assert_eq!(meta.total_rows, 4);
526        assert_eq!(meta.shape_tail, vec![3]);
527        assert_eq!(meta.dtype, "tensor");
528    }
529
530    #[test]
531    fn store_meta_from_json() {
532        let v = Value::json(serde_json::json!({"a": 1}));
533        let meta = StoreMeta::from_value(&v);
534        assert_eq!(meta.dtype, "json");
535        assert_eq!(meta.total_rows, 1);
536    }
537
538    #[test]
539    fn default_get_rows_on_local_store() {
540        let dir = std::env::temp_dir().join("soma-ds-test-getrows");
541        let store = LocalDataStore::new(&dir);
542
543        let key = CacheKey::hash_data(b"rows_test");
544        let value = Value::tensor(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![3, 2]);
545        let data_ref = store.put(&key, &value).unwrap();
546
547        // Read rows 1..2 via default impl (full get + slice)
548        let sliced = store.get_rows(&data_ref, 1, 2).unwrap();
549        let (data, shape) = sliced.as_tensor().unwrap();
550        assert_eq!(shape, &[2, 2]);
551        assert_eq!(data, &[3.0, 4.0, 5.0, 6.0]);
552
553        let _ = std::fs::remove_dir_all(&dir);
554    }
555
556    #[test]
557    fn default_meta_on_local_store() {
558        let dir = std::env::temp_dir().join("soma-ds-test-meta");
559        let store = LocalDataStore::new(&dir);
560
561        let key = CacheKey::hash_data(b"meta_test");
562        let value = Value::tensor(vec![0.0; 20], vec![5, 4]);
563        let data_ref = store.put(&key, &value).unwrap();
564
565        let meta = store.meta(&data_ref).unwrap();
566        assert_eq!(meta.total_rows, 5);
567        assert_eq!(meta.shape_tail, vec![4]);
568        assert_eq!(meta.dtype, "tensor");
569
570        let _ = std::fs::remove_dir_all(&dir);
571    }
572
573    #[test]
574    fn zarr_storage_config_serde() {
575        let zarr = StorageConfig::Zarr {
576            bucket: "soma-research".into(),
577            prefix: "data/".into(),
578            region: None,
579            endpoint: Some("s3.eu-central-003.backblazeb2.com".into()),
580            chunk_rows: 1024,
581        };
582        let json = serde_json::to_string(&zarr).unwrap();
583        assert!(json.contains("soma-research"));
584        assert!(json.contains("1024"));
585        let _: StorageConfig = serde_json::from_str(&json).unwrap();
586    }
587}