Skip to main content

somatize_runtime/executors/
stream.rs

1//! Streaming executor — processes data in chunks through fitted filters.
2//!
3//! Respects each filter's [`StreamMode`]: FixedState processes chunks
4//! independently, Evolving updates state per chunk with checkpoints,
5//! Barrier accumulates all chunks before processing.
6
7use somatize_core::cache::{CacheKey, CacheStore};
8use somatize_core::error::{Result, SomaError};
9use somatize_core::filter::{Filter, StreamMode};
10use somatize_core::value::Value;
11use std::sync::Arc;
12
13/// A fitted filter with its learned state, ready for streaming.
14pub struct FittedFilter {
15    pub name: String,
16    pub filter: Arc<dyn Filter>,
17    pub state: Arc<Value>,
18}
19
20/// Per-filter streaming state — one per filter in the pipeline.
21struct FilterStreamState {
22    /// Accumulated chunks for Barrier mode.
23    barrier_buffer: Vec<Value>,
24    /// Evolving state (mutated per chunk).
25    evolving_state: Option<Value>,
26}
27
28/// Processes a stream of chunks through a sequence of fitted filters.
29///
30/// Each filter's StreamMode defines its contract:
31/// - FixedState: each chunk processed independently, cacheable per chunk
32/// - Evolving: state mutates with each chunk, periodic checkpoints
33/// - Barrier: accumulates all chunks, processes as batch on flush
34pub struct StreamExecutor {
35    filters: Vec<FittedFilter>,
36    cache: Option<Arc<dyn CacheStore>>,
37    states: Vec<FilterStreamState>,
38    chunk_count: usize,
39}
40
41impl StreamExecutor {
42    pub fn new(filters: Vec<FittedFilter>) -> Self {
43        let n = filters.len();
44        Self {
45            filters,
46            cache: None,
47            states: (0..n)
48                .map(|_| FilterStreamState {
49                    barrier_buffer: Vec::new(),
50                    evolving_state: None,
51                })
52                .collect(),
53            chunk_count: 0,
54        }
55    }
56
57    pub fn with_cache(mut self, cache: Arc<dyn CacheStore>) -> Self {
58        self.cache = Some(cache);
59        self
60    }
61
62    /// Process a single chunk through the pipeline.
63    /// Returns the output chunk, or None if a Barrier filter is still accumulating.
64    pub fn process_chunk(&mut self, chunk: Value) -> Result<Option<Value>> {
65        let mut current = chunk;
66        self.chunk_count += 1;
67
68        for i in 0..self.filters.len() {
69            let mode = self.filters[i].filter.meta().stream_mode;
70            match process_by_mode(
71                &mode,
72                &self.filters[i],
73                &current,
74                &mut self.states[i],
75                self.cache.as_deref(),
76                self.chunk_count,
77            )? {
78                ChunkResult::Output(val) => current = val,
79                ChunkResult::Buffered => return Ok(None),
80            }
81        }
82
83        Ok(Some(current))
84    }
85
86    /// Flush barrier filters and process remaining data as batch.
87    pub fn flush(&mut self) -> Result<Option<Value>> {
88        let mut current: Option<Value> = None;
89
90        for i in 0..self.filters.len() {
91            let mode = self.filters[i].filter.meta().stream_mode;
92            if let Some(val) = flush_by_mode(&mode, &self.filters[i], &mut self.states[i])? {
93                current = Some(val);
94            } else if let Some(val) = current.take() {
95                current = Some(
96                    self.filters[i]
97                        .filter
98                        .forward(&val, &self.filters[i].state)?,
99                );
100            }
101        }
102
103        Ok(current)
104    }
105
106    /// Process multiple chunks and collect outputs.
107    pub fn process_all(&mut self, chunks: Vec<Value>) -> Result<Vec<Value>> {
108        let mut outputs = Vec::new();
109        for chunk in chunks {
110            if let Some(output) = self.process_chunk(chunk)? {
111                outputs.push(output);
112            }
113        }
114        if let Some(flushed) = self.flush()? {
115            outputs.push(flushed);
116        }
117        Ok(outputs)
118    }
119
120    /// Number of chunks processed so far.
121    pub fn chunks_processed(&self) -> usize {
122        self.chunk_count
123    }
124}
125
126/// Result of processing a chunk through one filter.
127enum ChunkResult {
128    /// Filter produced output — pass to next filter.
129    Output(Value),
130    /// Filter is buffering (Barrier) — no output yet.
131    Buffered,
132}
133
134// ── StreamMode dispatch ──
135
136/// Process one chunk according to the stream mode.
137fn process_by_mode(
138    mode: &StreamMode,
139    fitted: &FittedFilter,
140    input: &Value,
141    state: &mut FilterStreamState,
142    cache: Option<&dyn CacheStore>,
143    chunk_count: usize,
144) -> Result<ChunkResult> {
145    match mode {
146        StreamMode::FixedState => {
147            let result = forward_cached(fitted, input, cache)?;
148            Ok(ChunkResult::Output(result))
149        }
150        StreamMode::Evolving { checkpoint_every } => {
151            let default_state: &Value = &fitted.state;
152            let filter_state = state.evolving_state.as_ref().unwrap_or(default_state);
153            let result = fitted.filter.forward(input, filter_state)?;
154            state.evolving_state = Some(result.clone());
155
156            if *checkpoint_every > 0
157                && chunk_count.is_multiple_of(*checkpoint_every)
158                && let Some(c) = cache
159            {
160                let key = CacheKey::from_parts(&[
161                    b"checkpoint",
162                    fitted.name.as_bytes(),
163                    &(chunk_count as u64).to_le_bytes(),
164                ]);
165                let _ = c.put(&key, &result);
166            }
167            Ok(ChunkResult::Output(result))
168        }
169        StreamMode::Barrier => {
170            state.barrier_buffer.push(input.clone());
171            Ok(ChunkResult::Buffered)
172        }
173        _ => {
174            // Default: treat as FixedState
175            let result = forward_cached(fitted, input, cache)?;
176            Ok(ChunkResult::Output(result))
177        }
178    }
179}
180
181/// Flush a filter by mode. Only Barrier has work to do.
182fn flush_by_mode(
183    mode: &StreamMode,
184    fitted: &FittedFilter,
185    state: &mut FilterStreamState,
186) -> Result<Option<Value>> {
187    match mode {
188        StreamMode::Barrier if !state.barrier_buffer.is_empty() => {
189            let materialized = materialize_buffer(&state.barrier_buffer)?;
190            state.barrier_buffer.clear();
191            let result = fitted.filter.forward(&materialized, &fitted.state)?;
192            Ok(Some(result))
193        }
194        _ => Ok(None),
195    }
196}
197
198/// Forward with optional cache lookup.
199fn forward_cached(
200    fitted: &FittedFilter,
201    input: &Value,
202    cache: Option<&dyn CacheStore>,
203) -> Result<Value> {
204    if let Some(c) = cache {
205        let chunk_hash = CacheKey::hash_data(&serde_json::to_vec(input).unwrap_or_default());
206        let state_hash =
207            CacheKey::hash_data(&serde_json::to_vec(&fitted.state).unwrap_or_default());
208        let cache_key =
209            CacheKey::for_output(&fitted.filter.config_hash(), &state_hash, &chunk_hash);
210        if let Some(cached) = c.get(&cache_key)? {
211            return Ok(cached);
212        }
213        let result = fitted.filter.forward(input, &fitted.state)?;
214        let _ = c.put(&cache_key, &result);
215        return Ok(result);
216    }
217    fitted.filter.forward(input, &fitted.state)
218}
219
220/// Concatenate tensor chunks along first dimension.
221pub fn materialize_buffer(buffer: &[Value]) -> Result<Value> {
222    if buffer.is_empty() {
223        return Ok(Value::Empty);
224    }
225    let mut all_data = Vec::new();
226    let mut total_rows = 0;
227    let mut cols = 0;
228
229    for chunk in buffer {
230        match chunk {
231            Value::Tensor { values, shape } => {
232                all_data.extend(values.iter());
233                if shape.len() == 1 {
234                    total_rows += shape[0];
235                    cols = 1;
236                } else if shape.len() >= 2 {
237                    total_rows += shape[0];
238                    cols = shape[1];
239                }
240            }
241            _ => {
242                return Err(SomaError::Other(
243                    "barrier buffer contains non-tensor values".into(),
244                ));
245            }
246        }
247    }
248
249    if cols <= 1 {
250        Ok(Value::tensor(all_data, vec![total_rows]))
251    } else {
252        Ok(Value::tensor(all_data, vec![total_rows, cols]))
253    }
254}
255
256#[cfg(test)]
257mod tests {
258    use super::*;
259    use somatize_core::cache::CacheKey;
260    use somatize_core::error::Result as SomaResult;
261    use somatize_core::filter::{Distribution, FilterKind, FilterMeta};
262
263    struct DoubleChunk;
264
265    impl Filter for DoubleChunk {
266        fn config_hash(&self) -> CacheKey {
267            CacheKey::from_parts(&[b"DoubleChunk"])
268        }
269        fn fit(&self, _x: &Value, _y: Option<&Value>) -> SomaResult<Value> {
270            Ok(Value::Empty)
271        }
272        fn forward(&self, x: &Value, _state: &Value) -> SomaResult<Value> {
273            if let Value::Tensor { values, shape } = x {
274                let doubled: Vec<f64> = values.iter().map(|v| v * 2.0).collect();
275                Ok(Value::tensor(doubled, shape.clone()))
276            } else {
277                Ok(x.clone())
278            }
279        }
280        fn meta(&self) -> FilterMeta {
281            FilterMeta {
282                name: "DoubleChunk".into(),
283                kind: FilterKind::Stateless,
284                cacheable: true,
285                differentiable: false,
286                stream_mode: StreamMode::FixedState,
287                distribution: Distribution::Local,
288                input_schema: None,
289                output_schema: None,
290            }
291        }
292        fn as_any(&self) -> &dyn std::any::Any {
293            self
294        }
295    }
296
297    struct Accumulator;
298
299    impl Filter for Accumulator {
300        fn config_hash(&self) -> CacheKey {
301            CacheKey::from_parts(&[b"Accumulator"])
302        }
303        fn fit(&self, _x: &Value, _y: Option<&Value>) -> SomaResult<Value> {
304            Ok(Value::Empty)
305        }
306        fn forward(&self, x: &Value, _state: &Value) -> SomaResult<Value> {
307            Ok(x.clone())
308        }
309        fn meta(&self) -> FilterMeta {
310            FilterMeta {
311                name: "Accumulator".into(),
312                kind: FilterKind::Stateless,
313                cacheable: false,
314                differentiable: false,
315                stream_mode: StreamMode::Barrier,
316                distribution: Distribution::Local,
317                input_schema: None,
318                output_schema: None,
319            }
320        }
321        fn as_any(&self) -> &dyn std::any::Any {
322            self
323        }
324    }
325
326    struct RunningSum;
327
328    impl Filter for RunningSum {
329        fn config_hash(&self) -> CacheKey {
330            CacheKey::from_parts(&[b"RunningSum"])
331        }
332        fn fit(&self, _x: &Value, _y: Option<&Value>) -> SomaResult<Value> {
333            Ok(Value::tensor(vec![0.0], vec![1]))
334        }
335        fn forward(&self, x: &Value, state: &Value) -> SomaResult<Value> {
336            let x_sum: f64 = match x {
337                Value::Tensor { values, .. } => values.iter().sum(),
338                _ => 0.0,
339            };
340            let state_sum: f64 = match state {
341                Value::Tensor { values, .. } => values.first().copied().unwrap_or(0.0),
342                _ => 0.0,
343            };
344            Ok(Value::tensor(vec![x_sum + state_sum], vec![1]))
345        }
346        fn meta(&self) -> FilterMeta {
347            FilterMeta {
348                name: "RunningSum".into(),
349                kind: FilterKind::Trainable,
350                cacheable: false,
351                differentiable: false,
352                stream_mode: StreamMode::Evolving {
353                    checkpoint_every: 2,
354                },
355                distribution: Distribution::Local,
356                input_schema: None,
357                output_schema: None,
358            }
359        }
360        fn as_any(&self) -> &dyn std::any::Any {
361            self
362        }
363    }
364
365    fn make_fitted(filter: impl Filter + 'static, state: Value) -> FittedFilter {
366        let name = filter.meta().name.clone();
367        FittedFilter {
368            name,
369            filter: Arc::new(filter),
370            state: Arc::new(state),
371        }
372    }
373
374    #[test]
375    fn fixed_state_processes_each_chunk() {
376        let f = make_fitted(DoubleChunk, Value::Empty);
377        let mut exec = StreamExecutor::new(vec![f]);
378
379        let out1 = exec
380            .process_chunk(Value::tensor(vec![1.0, 2.0], vec![2]))
381            .unwrap();
382        assert_eq!(out1, Some(Value::tensor(vec![2.0, 4.0], vec![2])));
383
384        let out2 = exec
385            .process_chunk(Value::tensor(vec![3.0], vec![1]))
386            .unwrap();
387        assert_eq!(out2, Some(Value::tensor(vec![6.0], vec![1])));
388    }
389
390    #[test]
391    fn barrier_accumulates_then_flushes() {
392        let f = make_fitted(Accumulator, Value::Empty);
393        let mut exec = StreamExecutor::new(vec![f]);
394
395        let r1 = exec
396            .process_chunk(Value::tensor(vec![1.0, 2.0], vec![2]))
397            .unwrap();
398        assert_eq!(r1, None);
399
400        let r2 = exec
401            .process_chunk(Value::tensor(vec![3.0, 4.0], vec![2]))
402            .unwrap();
403        assert_eq!(r2, None);
404
405        let flushed = exec.flush().unwrap().unwrap();
406        let (data, shape) = flushed.as_tensor().unwrap();
407        assert_eq!(data, &[1.0, 2.0, 3.0, 4.0]);
408        assert_eq!(shape, &[4]);
409    }
410
411    #[test]
412    fn evolving_state_accumulates() {
413        let f = make_fitted(RunningSum, Value::tensor(vec![0.0], vec![1]));
414        let mut exec = StreamExecutor::new(vec![f]);
415
416        let r1 = exec
417            .process_chunk(Value::tensor(vec![10.0], vec![1]))
418            .unwrap()
419            .unwrap();
420        let (d1, _) = r1.as_tensor().unwrap();
421        assert_eq!(d1, &[10.0]);
422
423        let r2 = exec
424            .process_chunk(Value::tensor(vec![5.0], vec![1]))
425            .unwrap()
426            .unwrap();
427        let (d2, _) = r2.as_tensor().unwrap();
428        assert_eq!(d2, &[15.0]); // 10 + 5
429    }
430
431    #[test]
432    fn mixed_pipeline_fixed_then_barrier() {
433        let f1 = make_fitted(DoubleChunk, Value::Empty);
434        let f2 = make_fitted(Accumulator, Value::Empty);
435        let mut exec = StreamExecutor::new(vec![f1, f2]);
436
437        let r1 = exec
438            .process_chunk(Value::tensor(vec![1.0], vec![1]))
439            .unwrap();
440        assert_eq!(r1, None); // barrier
441
442        let r2 = exec
443            .process_chunk(Value::tensor(vec![2.0], vec![1]))
444            .unwrap();
445        assert_eq!(r2, None);
446
447        let flushed = exec.flush().unwrap().unwrap();
448        let (data, _) = flushed.as_tensor().unwrap();
449        assert_eq!(data, &[2.0, 4.0]); // doubled then accumulated
450    }
451
452    #[test]
453    fn process_all_combines_chunks() {
454        let f = make_fitted(DoubleChunk, Value::Empty);
455        let mut exec = StreamExecutor::new(vec![f]);
456
457        let outputs = exec
458            .process_all(vec![
459                Value::tensor(vec![1.0], vec![1]),
460                Value::tensor(vec![2.0], vec![1]),
461                Value::tensor(vec![3.0], vec![1]),
462            ])
463            .unwrap();
464
465        assert_eq!(outputs.len(), 3);
466        let (d, _) = outputs[0].as_tensor().unwrap();
467        assert_eq!(d, &[2.0]);
468    }
469
470    #[test]
471    fn fixed_state_with_cache() {
472        let f = make_fitted(DoubleChunk, Value::Empty);
473        let cache = Arc::new(crate::MemoryCache::default());
474        let mut exec = StreamExecutor::new(vec![f]).with_cache(cache);
475
476        let r1 = exec
477            .process_chunk(Value::tensor(vec![5.0], vec![1]))
478            .unwrap()
479            .unwrap();
480        // Second call with same input should hit cache
481        let r2 = exec
482            .process_chunk(Value::tensor(vec![5.0], vec![1]))
483            .unwrap()
484            .unwrap();
485        assert_eq!(r1, r2);
486    }
487}