Skip to main content

somatize_runtime/executors/
stream.rs

1//! Streaming executor — processes data in chunks through fitted filters.
2//!
3//! Respects each filter's [`StreamMode`]: FixedState processes chunks
4//! independently, Evolving updates state per chunk with checkpoints,
5//! Barrier accumulates all chunks before processing.
6
7use somatize_core::cache::{CacheKey, CacheStore};
8use somatize_core::error::{Result, SomaError};
9use somatize_core::filter::{Filter, StreamMode};
10use somatize_core::value::Value;
11use std::sync::Arc;
12
13/// A fitted filter with its learned state, ready for streaming.
14pub struct FittedFilter {
15    pub name: String,
16    pub filter: Arc<dyn Filter>,
17    pub state: Value,
18}
19
20/// Per-filter streaming state — one per filter in the pipeline.
21struct FilterStreamState {
22    /// Accumulated chunks for Barrier mode.
23    barrier_buffer: Vec<Value>,
24    /// Evolving state (mutated per chunk).
25    evolving_state: Option<Value>,
26}
27
28/// Processes a stream of chunks through a sequence of fitted filters.
29///
30/// Each filter's StreamMode defines its contract:
31/// - FixedState: each chunk processed independently, cacheable per chunk
32/// - Evolving: state mutates with each chunk, periodic checkpoints
33/// - Barrier: accumulates all chunks, processes as batch on flush
34pub struct StreamExecutor {
35    filters: Vec<FittedFilter>,
36    cache: Option<Arc<dyn CacheStore>>,
37    states: Vec<FilterStreamState>,
38    chunk_count: usize,
39}
40
41impl StreamExecutor {
42    pub fn new(filters: Vec<FittedFilter>) -> Self {
43        let n = filters.len();
44        Self {
45            filters,
46            cache: None,
47            states: (0..n)
48                .map(|_| FilterStreamState {
49                    barrier_buffer: Vec::new(),
50                    evolving_state: None,
51                })
52                .collect(),
53            chunk_count: 0,
54        }
55    }
56
57    pub fn with_cache(mut self, cache: Arc<dyn CacheStore>) -> Self {
58        self.cache = Some(cache);
59        self
60    }
61
62    /// Process a single chunk through the pipeline.
63    /// Returns the output chunk, or None if a Barrier filter is still accumulating.
64    pub fn process_chunk(&mut self, chunk: Value) -> Result<Option<Value>> {
65        let mut current = chunk;
66        self.chunk_count += 1;
67
68        for i in 0..self.filters.len() {
69            let mode = self.filters[i].filter.meta().stream_mode;
70            match process_by_mode(
71                &mode,
72                &self.filters[i],
73                &current,
74                &mut self.states[i],
75                self.cache.as_deref(),
76                self.chunk_count,
77            )? {
78                ChunkResult::Output(val) => current = val,
79                ChunkResult::Buffered => return Ok(None),
80            }
81        }
82
83        Ok(Some(current))
84    }
85
86    /// Flush barrier filters and process remaining data as batch.
87    pub fn flush(&mut self) -> Result<Option<Value>> {
88        let mut current: Option<Value> = None;
89
90        for i in 0..self.filters.len() {
91            let mode = self.filters[i].filter.meta().stream_mode;
92            if let Some(val) = flush_by_mode(&mode, &self.filters[i], &mut self.states[i])? {
93                current = Some(val);
94            } else if let Some(val) = current.take() {
95                current = Some(
96                    self.filters[i]
97                        .filter
98                        .forward(&val, &self.filters[i].state)?,
99                );
100            }
101        }
102
103        Ok(current)
104    }
105
106    /// Process multiple chunks and collect outputs.
107    pub fn process_all(&mut self, chunks: Vec<Value>) -> Result<Vec<Value>> {
108        let mut outputs = Vec::new();
109        for chunk in chunks {
110            if let Some(output) = self.process_chunk(chunk)? {
111                outputs.push(output);
112            }
113        }
114        if let Some(flushed) = self.flush()? {
115            outputs.push(flushed);
116        }
117        Ok(outputs)
118    }
119
120    /// Number of chunks processed so far.
121    pub fn chunks_processed(&self) -> usize {
122        self.chunk_count
123    }
124}
125
126/// Result of processing a chunk through one filter.
127enum ChunkResult {
128    /// Filter produced output — pass to next filter.
129    Output(Value),
130    /// Filter is buffering (Barrier) — no output yet.
131    Buffered,
132}
133
134// ── StreamMode dispatch ──
135
136/// Process one chunk according to the stream mode.
137fn process_by_mode(
138    mode: &StreamMode,
139    fitted: &FittedFilter,
140    input: &Value,
141    state: &mut FilterStreamState,
142    cache: Option<&dyn CacheStore>,
143    chunk_count: usize,
144) -> Result<ChunkResult> {
145    match mode {
146        StreamMode::FixedState => {
147            let result = forward_cached(fitted, input, cache)?;
148            Ok(ChunkResult::Output(result))
149        }
150        StreamMode::Evolving { checkpoint_every } => {
151            let filter_state = state.evolving_state.as_ref().unwrap_or(&fitted.state);
152            let result = fitted.filter.forward(input, filter_state)?;
153            state.evolving_state = Some(result.clone());
154
155            if *checkpoint_every > 0
156                && chunk_count.is_multiple_of(*checkpoint_every)
157                && let Some(c) = cache
158            {
159                let key = CacheKey::from_parts(&[
160                    b"checkpoint",
161                    fitted.name.as_bytes(),
162                    &(chunk_count as u64).to_le_bytes(),
163                ]);
164                let _ = c.put(&key, &result);
165            }
166            Ok(ChunkResult::Output(result))
167        }
168        StreamMode::Barrier => {
169            state.barrier_buffer.push(input.clone());
170            Ok(ChunkResult::Buffered)
171        }
172        _ => {
173            // Default: treat as FixedState
174            let result = forward_cached(fitted, input, cache)?;
175            Ok(ChunkResult::Output(result))
176        }
177    }
178}
179
180/// Flush a filter by mode. Only Barrier has work to do.
181fn flush_by_mode(
182    mode: &StreamMode,
183    fitted: &FittedFilter,
184    state: &mut FilterStreamState,
185) -> Result<Option<Value>> {
186    match mode {
187        StreamMode::Barrier if !state.barrier_buffer.is_empty() => {
188            let materialized = materialize_buffer(&state.barrier_buffer)?;
189            state.barrier_buffer.clear();
190            let result = fitted.filter.forward(&materialized, &fitted.state)?;
191            Ok(Some(result))
192        }
193        _ => Ok(None),
194    }
195}
196
197/// Forward with optional cache lookup.
198fn forward_cached(
199    fitted: &FittedFilter,
200    input: &Value,
201    cache: Option<&dyn CacheStore>,
202) -> Result<Value> {
203    if let Some(c) = cache {
204        let chunk_hash = CacheKey::hash_data(&serde_json::to_vec(input).unwrap_or_default());
205        let state_hash =
206            CacheKey::hash_data(&serde_json::to_vec(&fitted.state).unwrap_or_default());
207        let cache_key =
208            CacheKey::for_output(&fitted.filter.config_hash(), &state_hash, &chunk_hash);
209        if let Some(cached) = c.get(&cache_key)? {
210            return Ok(cached);
211        }
212        let result = fitted.filter.forward(input, &fitted.state)?;
213        let _ = c.put(&cache_key, &result);
214        return Ok(result);
215    }
216    fitted.filter.forward(input, &fitted.state)
217}
218
219/// Concatenate tensor chunks along first dimension.
220pub fn materialize_buffer(buffer: &[Value]) -> Result<Value> {
221    if buffer.is_empty() {
222        return Ok(Value::Empty);
223    }
224    let mut all_data = Vec::new();
225    let mut total_rows = 0;
226    let mut cols = 0;
227
228    for chunk in buffer {
229        match chunk {
230            Value::Tensor { values, shape } => {
231                all_data.extend(values);
232                if shape.len() == 1 {
233                    total_rows += shape[0];
234                    cols = 1;
235                } else if shape.len() >= 2 {
236                    total_rows += shape[0];
237                    cols = shape[1];
238                }
239            }
240            _ => {
241                return Err(SomaError::Other(
242                    "barrier buffer contains non-tensor values".into(),
243                ));
244            }
245        }
246    }
247
248    if cols <= 1 {
249        Ok(Value::tensor(all_data, vec![total_rows]))
250    } else {
251        Ok(Value::tensor(all_data, vec![total_rows, cols]))
252    }
253}
254
255#[cfg(test)]
256mod tests {
257    use super::*;
258    use somatize_core::cache::CacheKey;
259    use somatize_core::error::Result as SomaResult;
260    use somatize_core::filter::{Distribution, FilterKind, FilterMeta};
261
262    struct DoubleChunk;
263
264    impl Filter for DoubleChunk {
265        fn config_hash(&self) -> CacheKey {
266            CacheKey::from_parts(&[b"DoubleChunk"])
267        }
268        fn fit(&self, _x: &Value, _y: Option<&Value>) -> SomaResult<Value> {
269            Ok(Value::Empty)
270        }
271        fn forward(&self, x: &Value, _state: &Value) -> SomaResult<Value> {
272            if let Value::Tensor { values, shape } = x {
273                let doubled: Vec<f64> = values.iter().map(|v| v * 2.0).collect();
274                Ok(Value::tensor(doubled, shape.clone()))
275            } else {
276                Ok(x.clone())
277            }
278        }
279        fn meta(&self) -> FilterMeta {
280            FilterMeta {
281                name: "DoubleChunk".into(),
282                kind: FilterKind::Stateless,
283                cacheable: true,
284                differentiable: false,
285                stream_mode: StreamMode::FixedState,
286                distribution: Distribution::Local,
287                input_schema: None,
288                output_schema: None,
289            }
290        }
291        fn as_any(&self) -> &dyn std::any::Any {
292            self
293        }
294    }
295
296    struct Accumulator;
297
298    impl Filter for Accumulator {
299        fn config_hash(&self) -> CacheKey {
300            CacheKey::from_parts(&[b"Accumulator"])
301        }
302        fn fit(&self, _x: &Value, _y: Option<&Value>) -> SomaResult<Value> {
303            Ok(Value::Empty)
304        }
305        fn forward(&self, x: &Value, _state: &Value) -> SomaResult<Value> {
306            Ok(x.clone())
307        }
308        fn meta(&self) -> FilterMeta {
309            FilterMeta {
310                name: "Accumulator".into(),
311                kind: FilterKind::Stateless,
312                cacheable: false,
313                differentiable: false,
314                stream_mode: StreamMode::Barrier,
315                distribution: Distribution::Local,
316                input_schema: None,
317                output_schema: None,
318            }
319        }
320        fn as_any(&self) -> &dyn std::any::Any {
321            self
322        }
323    }
324
325    struct RunningSum;
326
327    impl Filter for RunningSum {
328        fn config_hash(&self) -> CacheKey {
329            CacheKey::from_parts(&[b"RunningSum"])
330        }
331        fn fit(&self, _x: &Value, _y: Option<&Value>) -> SomaResult<Value> {
332            Ok(Value::tensor(vec![0.0], vec![1]))
333        }
334        fn forward(&self, x: &Value, state: &Value) -> SomaResult<Value> {
335            let x_sum: f64 = match x {
336                Value::Tensor { values, .. } => values.iter().sum(),
337                _ => 0.0,
338            };
339            let state_sum: f64 = match state {
340                Value::Tensor { values, .. } => values.first().copied().unwrap_or(0.0),
341                _ => 0.0,
342            };
343            Ok(Value::tensor(vec![x_sum + state_sum], vec![1]))
344        }
345        fn meta(&self) -> FilterMeta {
346            FilterMeta {
347                name: "RunningSum".into(),
348                kind: FilterKind::Trainable,
349                cacheable: false,
350                differentiable: false,
351                stream_mode: StreamMode::Evolving {
352                    checkpoint_every: 2,
353                },
354                distribution: Distribution::Local,
355                input_schema: None,
356                output_schema: None,
357            }
358        }
359        fn as_any(&self) -> &dyn std::any::Any {
360            self
361        }
362    }
363
364    fn make_fitted(filter: impl Filter + 'static, state: Value) -> FittedFilter {
365        let name = filter.meta().name.clone();
366        FittedFilter {
367            name,
368            filter: Arc::new(filter),
369            state,
370        }
371    }
372
373    #[test]
374    fn fixed_state_processes_each_chunk() {
375        let f = make_fitted(DoubleChunk, Value::Empty);
376        let mut exec = StreamExecutor::new(vec![f]);
377
378        let out1 = exec
379            .process_chunk(Value::tensor(vec![1.0, 2.0], vec![2]))
380            .unwrap();
381        assert_eq!(out1, Some(Value::tensor(vec![2.0, 4.0], vec![2])));
382
383        let out2 = exec
384            .process_chunk(Value::tensor(vec![3.0], vec![1]))
385            .unwrap();
386        assert_eq!(out2, Some(Value::tensor(vec![6.0], vec![1])));
387    }
388
389    #[test]
390    fn barrier_accumulates_then_flushes() {
391        let f = make_fitted(Accumulator, Value::Empty);
392        let mut exec = StreamExecutor::new(vec![f]);
393
394        let r1 = exec
395            .process_chunk(Value::tensor(vec![1.0, 2.0], vec![2]))
396            .unwrap();
397        assert_eq!(r1, None);
398
399        let r2 = exec
400            .process_chunk(Value::tensor(vec![3.0, 4.0], vec![2]))
401            .unwrap();
402        assert_eq!(r2, None);
403
404        let flushed = exec.flush().unwrap().unwrap();
405        let (data, shape) = flushed.as_tensor().unwrap();
406        assert_eq!(data, &[1.0, 2.0, 3.0, 4.0]);
407        assert_eq!(shape, &[4]);
408    }
409
410    #[test]
411    fn evolving_state_accumulates() {
412        let f = make_fitted(RunningSum, Value::tensor(vec![0.0], vec![1]));
413        let mut exec = StreamExecutor::new(vec![f]);
414
415        let r1 = exec
416            .process_chunk(Value::tensor(vec![10.0], vec![1]))
417            .unwrap()
418            .unwrap();
419        let (d1, _) = r1.as_tensor().unwrap();
420        assert_eq!(d1, &[10.0]);
421
422        let r2 = exec
423            .process_chunk(Value::tensor(vec![5.0], vec![1]))
424            .unwrap()
425            .unwrap();
426        let (d2, _) = r2.as_tensor().unwrap();
427        assert_eq!(d2, &[15.0]); // 10 + 5
428    }
429
430    #[test]
431    fn mixed_pipeline_fixed_then_barrier() {
432        let f1 = make_fitted(DoubleChunk, Value::Empty);
433        let f2 = make_fitted(Accumulator, Value::Empty);
434        let mut exec = StreamExecutor::new(vec![f1, f2]);
435
436        let r1 = exec
437            .process_chunk(Value::tensor(vec![1.0], vec![1]))
438            .unwrap();
439        assert_eq!(r1, None); // barrier
440
441        let r2 = exec
442            .process_chunk(Value::tensor(vec![2.0], vec![1]))
443            .unwrap();
444        assert_eq!(r2, None);
445
446        let flushed = exec.flush().unwrap().unwrap();
447        let (data, _) = flushed.as_tensor().unwrap();
448        assert_eq!(data, &[2.0, 4.0]); // doubled then accumulated
449    }
450
451    #[test]
452    fn process_all_combines_chunks() {
453        let f = make_fitted(DoubleChunk, Value::Empty);
454        let mut exec = StreamExecutor::new(vec![f]);
455
456        let outputs = exec
457            .process_all(vec![
458                Value::tensor(vec![1.0], vec![1]),
459                Value::tensor(vec![2.0], vec![1]),
460                Value::tensor(vec![3.0], vec![1]),
461            ])
462            .unwrap();
463
464        assert_eq!(outputs.len(), 3);
465        let (d, _) = outputs[0].as_tensor().unwrap();
466        assert_eq!(d, &[2.0]);
467    }
468
469    #[test]
470    fn fixed_state_with_cache() {
471        let f = make_fitted(DoubleChunk, Value::Empty);
472        let cache = Arc::new(crate::MemoryCache::default());
473        let mut exec = StreamExecutor::new(vec![f]).with_cache(cache);
474
475        let r1 = exec
476            .process_chunk(Value::tensor(vec![5.0], vec![1]))
477            .unwrap()
478            .unwrap();
479        // Second call with same input should hit cache
480        let r2 = exec
481            .process_chunk(Value::tensor(vec![5.0], vec![1]))
482            .unwrap()
483            .unwrap();
484        assert_eq!(r1, r2);
485    }
486}