1use somatize_core::cache::{CacheKey, CacheStore};
8use somatize_core::error::{Result, SomaError};
9use somatize_core::filter::{Filter, StreamMode};
10use somatize_core::value::Value;
11use std::sync::Arc;
12
13pub struct FittedFilter {
15 pub name: String,
16 pub filter: Arc<dyn Filter>,
17 pub state: Value,
18}
19
20pub struct StreamExecutor {
27 filters: Vec<FittedFilter>,
28 cache: Option<Arc<dyn CacheStore>>,
29 barrier_buffers: Vec<Vec<Value>>,
31 evolving_states: Vec<Option<Value>>,
33 chunk_count: usize,
35}
36
37impl StreamExecutor {
38 pub fn new(filters: Vec<FittedFilter>) -> Self {
39 let n = filters.len();
40 Self {
41 filters,
42 cache: None,
43 barrier_buffers: vec![Vec::new(); n],
44 evolving_states: vec![None; n],
45 chunk_count: 0,
46 }
47 }
48
49 pub fn with_cache(mut self, cache: Arc<dyn CacheStore>) -> Self {
50 self.cache = Some(cache);
51 self
52 }
53
54 pub fn process_chunk(&mut self, chunk: Value) -> Result<Option<Value>> {
58 let mut current = chunk;
59 self.chunk_count += 1;
60
61 let n = self.filters.len();
62 for i in 0..n {
63 let mode = self.filters[i].filter.meta().stream_mode;
64
65 match mode {
66 StreamMode::FixedState => {
67 current = self.process_fixed_state(i, ¤t)?;
68 }
69 StreamMode::Evolving { checkpoint_every } => {
70 current = self.process_evolving(i, ¤t, checkpoint_every)?;
71 }
72 StreamMode::Barrier => {
73 self.barrier_buffers[i].push(current);
74 return Ok(None);
75 }
76 _ => {
77 current = self.process_fixed_state(i, ¤t)?;
78 }
79 }
80 }
81
82 Ok(Some(current))
83 }
84
85 pub fn flush(&mut self) -> Result<Option<Value>> {
89 let mut current: Option<Value> = None;
90 let n = self.filters.len();
91
92 for i in 0..n {
93 let mode = self.filters[i].filter.meta().stream_mode;
94
95 if mode == StreamMode::Barrier && !self.barrier_buffers[i].is_empty() {
96 let materialized = self.materialize_buffer(i)?;
97 let result = self.filters[i]
98 .filter
99 .forward(&materialized, &self.filters[i].state)?;
100 self.barrier_buffers[i].clear();
101 current = Some(result);
102 } else if let Some(val) = current.take() {
103 let result = self.filters[i]
104 .filter
105 .forward(&val, &self.filters[i].state)?;
106 current = Some(result);
107 }
108 }
109
110 Ok(current)
111 }
112
113 pub fn process_all(&mut self, chunks: Vec<Value>) -> Result<Vec<Value>> {
115 let mut outputs = Vec::new();
116
117 for chunk in chunks {
118 if let Some(output) = self.process_chunk(chunk)? {
119 outputs.push(output);
120 }
121 }
122
123 if let Some(flushed) = self.flush()? {
125 outputs.push(flushed);
126 }
127
128 Ok(outputs)
129 }
130
131 pub fn chunks_processed(&self) -> usize {
133 self.chunk_count
134 }
135
136 fn process_fixed_state(&self, filter_idx: usize, input: &Value) -> Result<Value> {
137 let fitted = &self.filters[filter_idx];
138
139 if let Some(cache) = &self.cache {
141 let chunk_hash = CacheKey::hash_data(&serde_json::to_vec(input).unwrap_or_default());
142 let cache_key = CacheKey::for_output(
143 &fitted.filter.config_hash(),
144 &CacheKey::hash_data(&serde_json::to_vec(&fitted.state).unwrap_or_default()),
145 &chunk_hash,
146 );
147 if let Some(cached) = cache.get(&cache_key)? {
148 return Ok(cached);
149 }
150 let result = fitted.filter.forward(input, &fitted.state)?;
151 let _ = cache.put(&cache_key, &result);
152 return Ok(result);
153 }
154
155 fitted.filter.forward(input, &fitted.state)
156 }
157
158 fn process_evolving(
159 &mut self,
160 filter_idx: usize,
161 input: &Value,
162 checkpoint_every: usize,
163 ) -> Result<Value> {
164 let fitted = &self.filters[filter_idx];
165
166 let state = self.evolving_states[filter_idx]
168 .as_ref()
169 .unwrap_or(&fitted.state);
170
171 let result = fitted.filter.forward(input, state)?;
172
173 self.evolving_states[filter_idx] = Some(result.clone());
176
177 if checkpoint_every > 0
179 && self.chunk_count.is_multiple_of(checkpoint_every)
180 && let Some(cache) = &self.cache
181 {
182 let checkpoint_key = CacheKey::from_parts(&[
183 b"checkpoint",
184 fitted.name.as_bytes(),
185 &(self.chunk_count as u64).to_le_bytes(),
186 ]);
187 let _ = cache.put(&checkpoint_key, &result);
188 }
189
190 Ok(result)
191 }
192
193 fn materialize_buffer(&self, filter_idx: usize) -> Result<Value> {
194 let buffer = &self.barrier_buffers[filter_idx];
195 if buffer.is_empty() {
196 return Ok(Value::Empty);
197 }
198
199 let mut all_data = Vec::new();
201 let mut total_rows = 0;
202 let mut cols = 0;
203
204 for chunk in buffer {
205 match chunk {
206 Value::Tensor { values, shape } => {
207 all_data.extend(values);
208 if shape.len() == 1 {
209 total_rows += shape[0];
210 cols = 1;
211 } else if shape.len() >= 2 {
212 total_rows += shape[0];
213 cols = shape[1];
214 }
215 }
216 _ => {
217 return Err(SomaError::Other(
218 "barrier buffer contains non-tensor values".into(),
219 ));
220 }
221 }
222 }
223
224 if cols <= 1 {
225 Ok(Value::tensor(all_data, vec![total_rows]))
226 } else {
227 Ok(Value::tensor(all_data, vec![total_rows, cols]))
228 }
229 }
230}
231
232#[cfg(test)]
233mod tests {
234 use super::*;
235 use somatize_core::cache::CacheKey;
236 use somatize_core::filter::{FilterKind, FilterMeta};
237
238 struct DoubleChunk;
241 impl Filter for DoubleChunk {
242 fn config_hash(&self) -> CacheKey {
243 CacheKey::from_parts(&[b"DoubleChunk"])
244 }
245 fn fit(&self, _: &Value, _: Option<&Value>) -> Result<Value> {
246 Ok(Value::Empty)
247 }
248 fn forward(&self, x: &Value, _: &Value) -> Result<Value> {
249 match x {
250 Value::Tensor { values, shape } => Ok(Value::tensor(
251 values.iter().map(|v| v * 2.0).collect(),
252 shape.clone(),
253 )),
254 _ => Ok(x.clone()),
255 }
256 }
257 fn meta(&self) -> FilterMeta {
258 FilterMeta {
259 name: "DoubleChunk".into(),
260 kind: FilterKind::Stateless,
261 cacheable: true,
262 differentiable: true,
263 stream_mode: StreamMode::FixedState,
264 distribution: somatize_core::filter::Distribution::Local,
265 input_schema: None,
266 output_schema: None,
267 }
268 }
269 }
270
271 struct Accumulator;
272 impl Filter for Accumulator {
273 fn config_hash(&self) -> CacheKey {
274 CacheKey::from_parts(&[b"Accumulator"])
275 }
276 fn fit(&self, _: &Value, _: Option<&Value>) -> Result<Value> {
277 Ok(Value::Empty)
278 }
279 fn forward(&self, x: &Value, _: &Value) -> Result<Value> {
280 match x {
282 Value::Tensor { values, shape: _ } => {
283 let mean = values.iter().sum::<f64>() / values.len() as f64;
284 Ok(Value::tensor(vec![mean], vec![1]))
285 }
286 _ => Ok(x.clone()),
287 }
288 }
289 fn meta(&self) -> FilterMeta {
290 FilterMeta {
291 name: "Accumulator".into(),
292 kind: FilterKind::Trainable,
293 cacheable: false,
294 differentiable: false,
295 stream_mode: StreamMode::Barrier,
296 distribution: somatize_core::filter::Distribution::Local,
297 input_schema: None,
298 output_schema: None,
299 }
300 }
301 }
302
303 struct RunningSum;
304 impl Filter for RunningSum {
305 fn config_hash(&self) -> CacheKey {
306 CacheKey::from_parts(&[b"RunningSum"])
307 }
308 fn fit(&self, _: &Value, _: Option<&Value>) -> Result<Value> {
309 Ok(Value::tensor(vec![0.0], vec![1]))
310 }
311 fn forward(&self, x: &Value, state: &Value) -> Result<Value> {
312 let x_val = x.as_tensor().map(|(d, _)| d[0]).unwrap_or(0.0);
313 let s_val = state.as_tensor().map(|(d, _)| d[0]).unwrap_or(0.0);
314 Ok(Value::tensor(vec![x_val + s_val], vec![1]))
315 }
316 fn meta(&self) -> FilterMeta {
317 FilterMeta {
318 name: "RunningSum".into(),
319 kind: FilterKind::Trainable,
320 cacheable: false,
321 differentiable: false,
322 stream_mode: StreamMode::Evolving {
323 checkpoint_every: 3,
324 },
325 distribution: somatize_core::filter::Distribution::Local,
326 input_schema: None,
327 output_schema: None,
328 }
329 }
330 }
331
332 #[test]
335 fn fixed_state_processes_each_chunk() {
336 let mut executor = StreamExecutor::new(vec![FittedFilter {
337 name: "double".into(),
338 filter: Arc::new(DoubleChunk),
339 state: Value::Empty,
340 }]);
341
342 let chunks = vec![
343 Value::tensor(vec![1.0, 2.0], vec![2]),
344 Value::tensor(vec![3.0, 4.0], vec![2]),
345 Value::tensor(vec![5.0], vec![1]),
346 ];
347
348 let outputs = executor.process_all(chunks).unwrap();
349 assert_eq!(outputs.len(), 3);
350
351 let (d0, _) = outputs[0].as_tensor().unwrap();
352 assert_eq!(d0, &[2.0, 4.0]);
353 let (d1, _) = outputs[1].as_tensor().unwrap();
354 assert_eq!(d1, &[6.0, 8.0]);
355 let (d2, _) = outputs[2].as_tensor().unwrap();
356 assert_eq!(d2, &[10.0]);
357 }
358
359 #[test]
360 fn barrier_accumulates_then_flushes() {
361 let mut executor = StreamExecutor::new(vec![FittedFilter {
362 name: "acc".into(),
363 filter: Arc::new(Accumulator),
364 state: Value::Empty,
365 }]);
366
367 assert!(
369 executor
370 .process_chunk(Value::tensor(vec![1.0, 2.0], vec![2]))
371 .unwrap()
372 .is_none()
373 );
374 assert!(
375 executor
376 .process_chunk(Value::tensor(vec![3.0, 4.0], vec![2]))
377 .unwrap()
378 .is_none()
379 );
380 assert!(
381 executor
382 .process_chunk(Value::tensor(vec![5.0, 6.0], vec![2]))
383 .unwrap()
384 .is_none()
385 );
386
387 let result = executor.flush().unwrap().unwrap();
389 let (data, _) = result.as_tensor().unwrap();
390 assert!((data[0] - 3.5).abs() < 0.01); }
392
393 #[test]
394 fn evolving_state_accumulates() {
395 let mut executor = StreamExecutor::new(vec![FittedFilter {
396 name: "sum".into(),
397 filter: Arc::new(RunningSum),
398 state: Value::tensor(vec![0.0], vec![1]), }]);
400
401 let r1 = executor
402 .process_chunk(Value::tensor(vec![5.0], vec![1]))
403 .unwrap()
404 .unwrap();
405 assert_eq!(r1.as_tensor().unwrap().0, &[5.0]); let r2 = executor
408 .process_chunk(Value::tensor(vec![3.0], vec![1]))
409 .unwrap()
410 .unwrap();
411 assert_eq!(r2.as_tensor().unwrap().0, &[8.0]); let r3 = executor
414 .process_chunk(Value::tensor(vec![2.0], vec![1]))
415 .unwrap()
416 .unwrap();
417 assert_eq!(r3.as_tensor().unwrap().0, &[10.0]); }
419
420 #[test]
421 fn mixed_pipeline_fixed_then_barrier() {
422 let mut executor = StreamExecutor::new(vec![
423 FittedFilter {
424 name: "double".into(),
425 filter: Arc::new(DoubleChunk),
426 state: Value::Empty,
427 },
428 FittedFilter {
429 name: "acc".into(),
430 filter: Arc::new(Accumulator),
431 state: Value::Empty,
432 },
433 ]);
434
435 let chunks = vec![
436 Value::tensor(vec![1.0], vec![1]),
437 Value::tensor(vec![2.0], vec![1]),
438 Value::tensor(vec![3.0], vec![1]),
439 ];
440
441 let outputs = executor.process_all(chunks).unwrap();
442 assert_eq!(outputs.len(), 1);
445 let (data, _) = outputs[0].as_tensor().unwrap();
446 assert!((data[0] - 4.0).abs() < 0.01);
447 }
448
449 #[test]
450 fn fixed_state_with_cache() {
451 let cache = Arc::new(crate::MemoryCache::default());
452 let mut executor = StreamExecutor::new(vec![FittedFilter {
453 name: "double".into(),
454 filter: Arc::new(DoubleChunk),
455 state: Value::Empty,
456 }])
457 .with_cache(cache.clone());
458
459 let chunk = Value::tensor(vec![7.0], vec![1]);
460
461 let r1 = executor.process_chunk(chunk.clone()).unwrap().unwrap();
463 assert_eq!(r1.as_tensor().unwrap().0, &[14.0]);
464 assert!(!cache.is_empty()); let r2 = executor.process_chunk(chunk).unwrap().unwrap();
468 assert_eq!(r2.as_tensor().unwrap().0, &[14.0]);
469 }
470
471 #[test]
472 fn chunks_processed_counter() {
473 let mut executor = StreamExecutor::new(vec![FittedFilter {
474 name: "double".into(),
475 filter: Arc::new(DoubleChunk),
476 state: Value::Empty,
477 }]);
478
479 assert_eq!(executor.chunks_processed(), 0);
480 executor
481 .process_chunk(Value::tensor(vec![1.0], vec![1]))
482 .unwrap();
483 assert_eq!(executor.chunks_processed(), 1);
484 executor
485 .process_chunk(Value::tensor(vec![2.0], vec![1]))
486 .unwrap();
487 assert_eq!(executor.chunks_processed(), 2);
488 }
489
490 #[test]
491 fn empty_stream() {
492 let mut executor = StreamExecutor::new(vec![FittedFilter {
493 name: "double".into(),
494 filter: Arc::new(DoubleChunk),
495 state: Value::Empty,
496 }]);
497
498 let outputs = executor.process_all(vec![]).unwrap();
499 assert!(outputs.is_empty());
500 }
501}