1#[cfg(feature = "s3")]
7pub mod s3;
8
9#[cfg(feature = "s3")]
10pub use s3::S3DataStore;
11
12#[cfg(feature = "zarr")]
13pub mod zarr;
14
15#[cfg(feature = "zarr")]
16pub use zarr::ZarrStore;
17
18use crate::cache::CacheKey;
19use crate::error::{Result, SomaError};
20use crate::value::Value;
21use serde::{Deserialize, Serialize};
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct StoreMeta {
26 pub total_rows: usize,
28 pub shape_tail: Vec<usize>,
30 pub dtype: String,
32}
33
34impl StoreMeta {
35 pub fn from_value(value: &Value) -> Self {
37 match value {
38 Value::Tensor { shape, .. } => Self {
39 total_rows: shape.first().copied().unwrap_or(0),
40 shape_tail: shape.get(1..).unwrap_or_default().to_vec(),
41 dtype: "tensor".into(),
42 },
43 Value::Json(_) => Self {
44 total_rows: 1,
45 shape_tail: vec![],
46 dtype: "json".into(),
47 },
48 Value::Bytes(b) => Self {
49 total_rows: b.len(),
50 shape_tail: vec![],
51 dtype: "bytes".into(),
52 },
53 Value::Empty => Self {
54 total_rows: 0,
55 shape_tail: vec![],
56 dtype: "empty".into(),
57 },
58 }
59 }
60}
61
62pub fn slice_tensor_rows(value: &Value, start: usize, len: usize) -> Result<Value> {
64 match value {
65 Value::Tensor { values, shape } => {
66 if shape.is_empty() {
67 return Err(SomaError::DataStore("cannot slice scalar tensor".into()));
68 }
69 let cols: usize = shape[1..].iter().product::<usize>().max(1);
70 let row_start = start * cols;
71 let row_end = (start + len) * cols;
72 if row_end > values.len() {
73 return Err(SomaError::DataStore(format!(
74 "row range {start}..{} out of bounds (total rows: {})",
75 start + len,
76 shape[0]
77 )));
78 }
79 let mut new_shape = shape.clone();
80 new_shape[0] = len;
81 Ok(Value::tensor(
82 values[row_start..row_end].to_vec(),
83 new_shape,
84 ))
85 }
86 _ => Err(SomaError::DataStore(
87 "get_rows only works on Tensor values".into(),
88 )),
89 }
90}
91
92#[derive(Debug, Clone, Serialize, Deserialize)]
95#[serde(tag = "type")]
96#[non_exhaustive]
97pub enum DataRef {
98 Local { path: String },
100 S3 {
102 bucket: String,
103 key: String,
104 region: Option<String>,
105 },
106 Cached { cache_key: CacheKey },
108 Stream {
110 endpoint: String,
111 format: StreamFormat,
112 },
113 Inline { value: Value },
115 Zarr {
117 bucket: String,
118 array_path: String,
120 region: Option<String>,
121 },
122}
123
124#[derive(Debug, Clone, Serialize, Deserialize, Default)]
126#[serde(rename_all = "snake_case")]
127#[non_exhaustive]
128pub enum StreamFormat {
129 #[default]
130 JsonLines,
131 Csv,
132 Arrow,
133 Protobuf,
134}
135
136#[derive(Debug, Clone, Serialize, Deserialize)]
138#[serde(tag = "type")]
139#[non_exhaustive]
140pub enum StorageConfig {
141 #[serde(rename = "local")]
143 Local { base_path: String },
144 #[serde(rename = "s3")]
146 S3 {
147 bucket: String,
148 prefix: String,
149 region: Option<String>,
150 endpoint: Option<String>,
151 },
152 #[serde(rename = "zarr")]
154 Zarr {
155 bucket: String,
156 prefix: String,
157 region: Option<String>,
158 endpoint: Option<String>,
159 chunk_rows: usize,
161 },
162}
163
164impl Default for StorageConfig {
165 fn default() -> Self {
166 Self::Local {
167 base_path: "/tmp/soma-data".to_string(),
168 }
169 }
170}
171
172pub trait DataStore: Send + Sync {
177 fn put(&self, key: &CacheKey, data: &Value) -> Result<DataRef>;
179
180 fn get(&self, data_ref: &DataRef) -> Result<Value>;
182
183 fn exists(&self, data_ref: &DataRef) -> Result<bool>;
185
186 fn remove(&self, data_ref: &DataRef) -> Result<()>;
188
189 fn config(&self) -> &StorageConfig;
191
192 fn get_rows(&self, data_ref: &DataRef, start: usize, len: usize) -> Result<Value> {
196 let value = self.get(data_ref)?;
197 slice_tensor_rows(&value, start, len)
198 }
199
200 fn meta(&self, data_ref: &DataRef) -> Result<StoreMeta> {
203 let value = self.get(data_ref)?;
204 Ok(StoreMeta::from_value(&value))
205 }
206}
207
208pub struct LocalDataStore {
210 config: StorageConfig,
211 base_path: std::path::PathBuf,
212}
213
214impl LocalDataStore {
215 pub fn new(base_path: impl Into<std::path::PathBuf>) -> Self {
216 let base = base_path.into();
217 std::fs::create_dir_all(&base).ok();
218 Self {
219 config: StorageConfig::Local {
220 base_path: base.to_string_lossy().to_string(),
221 },
222 base_path: base,
223 }
224 }
225}
226
227impl DataStore for LocalDataStore {
228 fn put(&self, key: &CacheKey, data: &Value) -> Result<DataRef> {
229 let path = self.base_path.join(key.to_hex());
230 let bytes = serde_json::to_vec(data)
231 .map_err(|e| crate::error::SomaError::DataStore(e.to_string()))?;
232 std::fs::write(&path, &bytes)
233 .map_err(|e| crate::error::SomaError::DataStore(e.to_string()))?;
234 Ok(DataRef::Local {
235 path: path.to_string_lossy().to_string(),
236 })
237 }
238
239 fn get(&self, data_ref: &DataRef) -> Result<Value> {
240 match data_ref {
241 DataRef::Local { path } => {
242 let bytes = std::fs::read(path)
243 .map_err(|e| crate::error::SomaError::DataStore(e.to_string()))?;
244 serde_json::from_slice(&bytes)
245 .map_err(|e| crate::error::SomaError::DataStore(e.to_string()))
246 }
247 DataRef::Cached { cache_key } => {
248 let path = self.base_path.join(cache_key.to_hex());
249 let bytes = std::fs::read(&path)
250 .map_err(|e| crate::error::SomaError::DataStore(e.to_string()))?;
251 serde_json::from_slice(&bytes)
252 .map_err(|e| crate::error::SomaError::DataStore(e.to_string()))
253 }
254 DataRef::Inline { value } => Ok(value.clone()),
255 _ => Err(crate::error::SomaError::DataStore(
256 "Cannot get non-local DataRef from LocalDataStore".into(),
257 )),
258 }
259 }
260
261 fn exists(&self, data_ref: &DataRef) -> Result<bool> {
262 match data_ref {
263 DataRef::Local { path } => Ok(std::path::Path::new(path).exists()),
264 DataRef::Cached { cache_key } => Ok(self.base_path.join(cache_key.to_hex()).exists()),
265 DataRef::Inline { .. } => Ok(true),
266 _ => Ok(false),
267 }
268 }
269
270 fn remove(&self, data_ref: &DataRef) -> Result<()> {
271 if let DataRef::Local { path } = data_ref {
272 std::fs::remove_file(path).ok();
273 }
274 Ok(())
275 }
276
277 fn config(&self) -> &StorageConfig {
278 &self.config
279 }
280}
281
282pub struct StreamCache {
292 states: std::collections::HashMap<String, (CacheKey, Value)>,
294 chunk_cache: std::collections::HashMap<CacheKey, Value>,
296 max_chunks: usize,
298 pub hits: u64,
300 pub misses: u64,
301}
302
303impl StreamCache {
304 pub fn new(max_chunks: usize) -> Self {
305 Self {
306 states: std::collections::HashMap::new(),
307 chunk_cache: std::collections::HashMap::new(),
308 max_chunks,
309 hits: 0,
310 misses: 0,
311 }
312 }
313
314 pub fn load_state(&mut self, filter_id: &str, state_key: CacheKey, state: Value) {
316 self.states
317 .insert(filter_id.to_string(), (state_key, state));
318 }
319
320 pub fn get_state(&self, filter_id: &str) -> Option<&Value> {
322 self.states.get(filter_id).map(|(_, v)| v)
323 }
324
325 pub fn get_chunk(&mut self, chunk_key: &CacheKey) -> Option<&Value> {
328 if let Some(v) = self.chunk_cache.get(chunk_key) {
329 self.hits += 1;
330 Some(v)
331 } else {
332 self.misses += 1;
333 None
334 }
335 }
336
337 pub fn put_chunk(&mut self, chunk_key: CacheKey, value: Value) {
339 if self.chunk_cache.len() >= self.max_chunks {
340 if let Some(k) = self.chunk_cache.keys().next().cloned() {
342 self.chunk_cache.remove(&k);
343 }
344 }
345 self.chunk_cache.insert(chunk_key, value);
346 }
347
348 pub fn hit_rate(&self) -> f64 {
350 let total = self.hits + self.misses;
351 if total == 0 {
352 0.0
353 } else {
354 self.hits as f64 / total as f64
355 }
356 }
357}
358
359#[cfg(test)]
360mod tests {
361 use super::*;
362
363 #[test]
364 fn local_data_store_roundtrip() {
365 let dir = std::env::temp_dir().join("soma-ds-test");
366 let store = LocalDataStore::new(&dir);
367
368 let key = CacheKey::hash_data(b"test_data");
369 let value = Value::tensor(vec![1.0, 2.0, 3.0], vec![3]);
370
371 let data_ref = store.put(&key, &value).unwrap();
372 assert!(store.exists(&data_ref).unwrap());
373
374 let retrieved = store.get(&data_ref).unwrap();
375 let (data, _) = retrieved.as_tensor().unwrap();
376 assert_eq!(data, &[1.0, 2.0, 3.0]);
377
378 store.remove(&data_ref).unwrap();
379 assert!(!store.exists(&data_ref).unwrap());
380
381 let _ = std::fs::remove_dir_all(&dir);
382 }
383
384 #[test]
385 fn inline_data_ref() {
386 let dir = std::env::temp_dir().join("soma-ds-test-inline");
387 let store = LocalDataStore::new(&dir);
388
389 let data_ref = DataRef::Inline {
390 value: Value::tensor(vec![42.0], vec![1]),
391 };
392
393 assert!(store.exists(&data_ref).unwrap());
394 let v = store.get(&data_ref).unwrap();
395 let (data, _) = v.as_tensor().unwrap();
396 assert_eq!(data, &[42.0]);
397
398 let _ = std::fs::remove_dir_all(&dir);
399 }
400
401 #[test]
402 fn stream_cache_basics() {
403 let mut cache = StreamCache::new(100);
404
405 let state = Value::tensor(vec![0.0, 1.0], vec![2]);
406 let state_key = CacheKey::hash_data(b"state_001");
407 cache.load_state("normalize", state_key, state.clone());
408
409 assert!(cache.get_state("normalize").is_some());
410 assert!(cache.get_state("unknown").is_none());
411 }
412
413 #[test]
414 fn stream_cache_chunks() {
415 let mut cache = StreamCache::new(3);
416
417 let k1 = CacheKey::hash_data(b"chunk_1");
418 let k2 = CacheKey::hash_data(b"chunk_2");
419 let k3 = CacheKey::hash_data(b"chunk_3");
420 let k4 = CacheKey::hash_data(b"chunk_4");
421
422 cache.put_chunk(k1.clone(), Value::tensor(vec![1.0], vec![1]));
423 cache.put_chunk(k2.clone(), Value::tensor(vec![2.0], vec![1]));
424 cache.put_chunk(k3.clone(), Value::tensor(vec![3.0], vec![1]));
425
426 assert!(cache.get_chunk(&k1).is_some());
428 assert!(cache.get_chunk(&k2).is_some());
429 assert!(cache.get_chunk(&k3).is_some());
430 assert_eq!(cache.hits, 3);
431
432 cache.put_chunk(k4.clone(), Value::tensor(vec![4.0], vec![1]));
434 assert!(cache.get_chunk(&k4).is_some());
435
436 assert!(cache.hit_rate() > 0.0);
437 }
438
439 #[test]
440 fn storage_config_serde() {
441 let s3 = StorageConfig::S3 {
442 bucket: "my-lab".into(),
443 prefix: "experiments/".into(),
444 region: Some("eu-west-1".into()),
445 endpoint: None,
446 };
447 let json = serde_json::to_string(&s3).unwrap();
448 assert!(json.contains("my-lab"));
449
450 let local = StorageConfig::Local {
451 base_path: "/data".into(),
452 };
453 let json = serde_json::to_string(&local).unwrap();
454 assert!(json.contains("/data"));
455 }
456
457 #[test]
458 fn data_ref_serde() {
459 let refs = vec![
460 DataRef::Local {
461 path: "/tmp/x".into(),
462 },
463 DataRef::S3 {
464 bucket: "b".into(),
465 key: "k".into(),
466 region: None,
467 },
468 DataRef::Cached {
469 cache_key: CacheKey::hash_data(b"x"),
470 },
471 DataRef::Inline {
472 value: Value::Empty,
473 },
474 DataRef::Zarr {
475 bucket: "b".into(),
476 array_path: "data/abc".into(),
477 region: None,
478 },
479 ];
480 for r in &refs {
481 let json = serde_json::to_string(r).unwrap();
482 let _: DataRef = serde_json::from_str(&json).unwrap();
483 }
484 }
485
486 #[test]
487 fn slice_tensor_rows_basic() {
488 let v = Value::tensor(
490 vec![
491 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
492 ],
493 vec![4, 3],
494 );
495 let sliced = slice_tensor_rows(&v, 1, 2).unwrap();
497 let (data, shape) = sliced.as_tensor().unwrap();
498 assert_eq!(shape, &[2, 3]);
499 assert_eq!(data, &[4.0, 5.0, 6.0, 7.0, 8.0, 9.0]);
500 }
501
502 #[test]
503 fn slice_tensor_rows_single() {
504 let v = Value::tensor(vec![10.0, 20.0, 30.0], vec![3]);
505 let sliced = slice_tensor_rows(&v, 1, 1).unwrap();
506 let (data, shape) = sliced.as_tensor().unwrap();
507 assert_eq!(shape, &[1]);
508 assert_eq!(data, &[20.0]);
509 }
510
511 #[test]
512 fn slice_tensor_rows_out_of_bounds() {
513 let v = Value::tensor(vec![1.0, 2.0, 3.0], vec![3]);
514 assert!(slice_tensor_rows(&v, 2, 5).is_err());
515 }
516
517 #[test]
518 fn store_meta_from_tensor() {
519 let v = Value::tensor(vec![0.0; 12], vec![4, 3]);
520 let meta = StoreMeta::from_value(&v);
521 assert_eq!(meta.total_rows, 4);
522 assert_eq!(meta.shape_tail, vec![3]);
523 assert_eq!(meta.dtype, "tensor");
524 }
525
526 #[test]
527 fn store_meta_from_json() {
528 let v = Value::json(serde_json::json!({"a": 1}));
529 let meta = StoreMeta::from_value(&v);
530 assert_eq!(meta.dtype, "json");
531 assert_eq!(meta.total_rows, 1);
532 }
533
534 #[test]
535 fn default_get_rows_on_local_store() {
536 let dir = std::env::temp_dir().join("soma-ds-test-getrows");
537 let store = LocalDataStore::new(&dir);
538
539 let key = CacheKey::hash_data(b"rows_test");
540 let value = Value::tensor(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![3, 2]);
541 let data_ref = store.put(&key, &value).unwrap();
542
543 let sliced = store.get_rows(&data_ref, 1, 2).unwrap();
545 let (data, shape) = sliced.as_tensor().unwrap();
546 assert_eq!(shape, &[2, 2]);
547 assert_eq!(data, &[3.0, 4.0, 5.0, 6.0]);
548
549 let _ = std::fs::remove_dir_all(&dir);
550 }
551
552 #[test]
553 fn default_meta_on_local_store() {
554 let dir = std::env::temp_dir().join("soma-ds-test-meta");
555 let store = LocalDataStore::new(&dir);
556
557 let key = CacheKey::hash_data(b"meta_test");
558 let value = Value::tensor(vec![0.0; 20], vec![5, 4]);
559 let data_ref = store.put(&key, &value).unwrap();
560
561 let meta = store.meta(&data_ref).unwrap();
562 assert_eq!(meta.total_rows, 5);
563 assert_eq!(meta.shape_tail, vec![4]);
564 assert_eq!(meta.dtype, "tensor");
565
566 let _ = std::fs::remove_dir_all(&dir);
567 }
568
569 #[test]
570 fn zarr_storage_config_serde() {
571 let zarr = StorageConfig::Zarr {
572 bucket: "soma-research".into(),
573 prefix: "data/".into(),
574 region: None,
575 endpoint: Some("s3.eu-central-003.backblazeb2.com".into()),
576 chunk_rows: 1024,
577 };
578 let json = serde_json::to_string(&zarr).unwrap();
579 assert!(json.contains("soma-research"));
580 assert!(json.contains("1024"));
581 let _: StorageConfig = serde_json::from_str(&json).unwrap();
582 }
583}