1use somatize_core::cache::{CacheKey, CacheStore};
8use somatize_core::error::{Result, SomaError};
9use somatize_core::filter::{Filter, StreamMode};
10use somatize_core::value::Value;
11use std::sync::Arc;
12
13pub struct FittedFilter {
15 pub name: String,
16 pub filter: Arc<dyn Filter>,
17 pub state: Value,
18}
19
20struct FilterStreamState {
22 barrier_buffer: Vec<Value>,
24 evolving_state: Option<Value>,
26}
27
28pub struct StreamExecutor {
35 filters: Vec<FittedFilter>,
36 cache: Option<Arc<dyn CacheStore>>,
37 states: Vec<FilterStreamState>,
38 chunk_count: usize,
39}
40
41impl StreamExecutor {
42 pub fn new(filters: Vec<FittedFilter>) -> Self {
43 let n = filters.len();
44 Self {
45 filters,
46 cache: None,
47 states: (0..n)
48 .map(|_| FilterStreamState {
49 barrier_buffer: Vec::new(),
50 evolving_state: None,
51 })
52 .collect(),
53 chunk_count: 0,
54 }
55 }
56
57 pub fn with_cache(mut self, cache: Arc<dyn CacheStore>) -> Self {
58 self.cache = Some(cache);
59 self
60 }
61
62 pub fn process_chunk(&mut self, chunk: Value) -> Result<Option<Value>> {
65 let mut current = chunk;
66 self.chunk_count += 1;
67
68 for i in 0..self.filters.len() {
69 let mode = self.filters[i].filter.meta().stream_mode;
70 match process_by_mode(
71 &mode,
72 &self.filters[i],
73 ¤t,
74 &mut self.states[i],
75 self.cache.as_deref(),
76 self.chunk_count,
77 )? {
78 ChunkResult::Output(val) => current = val,
79 ChunkResult::Buffered => return Ok(None),
80 }
81 }
82
83 Ok(Some(current))
84 }
85
86 pub fn flush(&mut self) -> Result<Option<Value>> {
88 let mut current: Option<Value> = None;
89
90 for i in 0..self.filters.len() {
91 let mode = self.filters[i].filter.meta().stream_mode;
92 if let Some(val) = flush_by_mode(&mode, &self.filters[i], &mut self.states[i])? {
93 current = Some(val);
94 } else if let Some(val) = current.take() {
95 current = Some(
96 self.filters[i]
97 .filter
98 .forward(&val, &self.filters[i].state)?,
99 );
100 }
101 }
102
103 Ok(current)
104 }
105
106 pub fn process_all(&mut self, chunks: Vec<Value>) -> Result<Vec<Value>> {
108 let mut outputs = Vec::new();
109 for chunk in chunks {
110 if let Some(output) = self.process_chunk(chunk)? {
111 outputs.push(output);
112 }
113 }
114 if let Some(flushed) = self.flush()? {
115 outputs.push(flushed);
116 }
117 Ok(outputs)
118 }
119
120 pub fn chunks_processed(&self) -> usize {
122 self.chunk_count
123 }
124}
125
126enum ChunkResult {
128 Output(Value),
130 Buffered,
132}
133
134fn process_by_mode(
138 mode: &StreamMode,
139 fitted: &FittedFilter,
140 input: &Value,
141 state: &mut FilterStreamState,
142 cache: Option<&dyn CacheStore>,
143 chunk_count: usize,
144) -> Result<ChunkResult> {
145 match mode {
146 StreamMode::FixedState => {
147 let result = forward_cached(fitted, input, cache)?;
148 Ok(ChunkResult::Output(result))
149 }
150 StreamMode::Evolving { checkpoint_every } => {
151 let filter_state = state.evolving_state.as_ref().unwrap_or(&fitted.state);
152 let result = fitted.filter.forward(input, filter_state)?;
153 state.evolving_state = Some(result.clone());
154
155 if *checkpoint_every > 0
156 && chunk_count.is_multiple_of(*checkpoint_every)
157 && let Some(c) = cache
158 {
159 let key = CacheKey::from_parts(&[
160 b"checkpoint",
161 fitted.name.as_bytes(),
162 &(chunk_count as u64).to_le_bytes(),
163 ]);
164 let _ = c.put(&key, &result);
165 }
166 Ok(ChunkResult::Output(result))
167 }
168 StreamMode::Barrier => {
169 state.barrier_buffer.push(input.clone());
170 Ok(ChunkResult::Buffered)
171 }
172 _ => {
173 let result = forward_cached(fitted, input, cache)?;
175 Ok(ChunkResult::Output(result))
176 }
177 }
178}
179
180fn flush_by_mode(
182 mode: &StreamMode,
183 fitted: &FittedFilter,
184 state: &mut FilterStreamState,
185) -> Result<Option<Value>> {
186 match mode {
187 StreamMode::Barrier if !state.barrier_buffer.is_empty() => {
188 let materialized = materialize_buffer(&state.barrier_buffer)?;
189 state.barrier_buffer.clear();
190 let result = fitted.filter.forward(&materialized, &fitted.state)?;
191 Ok(Some(result))
192 }
193 _ => Ok(None),
194 }
195}
196
197fn forward_cached(
199 fitted: &FittedFilter,
200 input: &Value,
201 cache: Option<&dyn CacheStore>,
202) -> Result<Value> {
203 if let Some(c) = cache {
204 let chunk_hash = CacheKey::hash_data(&serde_json::to_vec(input).unwrap_or_default());
205 let state_hash =
206 CacheKey::hash_data(&serde_json::to_vec(&fitted.state).unwrap_or_default());
207 let cache_key =
208 CacheKey::for_output(&fitted.filter.config_hash(), &state_hash, &chunk_hash);
209 if let Some(cached) = c.get(&cache_key)? {
210 return Ok(cached);
211 }
212 let result = fitted.filter.forward(input, &fitted.state)?;
213 let _ = c.put(&cache_key, &result);
214 return Ok(result);
215 }
216 fitted.filter.forward(input, &fitted.state)
217}
218
219pub fn materialize_buffer(buffer: &[Value]) -> Result<Value> {
221 if buffer.is_empty() {
222 return Ok(Value::Empty);
223 }
224 let mut all_data = Vec::new();
225 let mut total_rows = 0;
226 let mut cols = 0;
227
228 for chunk in buffer {
229 match chunk {
230 Value::Tensor { values, shape } => {
231 all_data.extend(values);
232 if shape.len() == 1 {
233 total_rows += shape[0];
234 cols = 1;
235 } else if shape.len() >= 2 {
236 total_rows += shape[0];
237 cols = shape[1];
238 }
239 }
240 _ => {
241 return Err(SomaError::Other(
242 "barrier buffer contains non-tensor values".into(),
243 ));
244 }
245 }
246 }
247
248 if cols <= 1 {
249 Ok(Value::tensor(all_data, vec![total_rows]))
250 } else {
251 Ok(Value::tensor(all_data, vec![total_rows, cols]))
252 }
253}
254
255#[cfg(test)]
256mod tests {
257 use super::*;
258 use somatize_core::cache::CacheKey;
259 use somatize_core::error::Result as SomaResult;
260 use somatize_core::filter::{Distribution, FilterKind, FilterMeta};
261
262 struct DoubleChunk;
263
264 impl Filter for DoubleChunk {
265 fn config_hash(&self) -> CacheKey {
266 CacheKey::from_parts(&[b"DoubleChunk"])
267 }
268 fn fit(&self, _x: &Value, _y: Option<&Value>) -> SomaResult<Value> {
269 Ok(Value::Empty)
270 }
271 fn forward(&self, x: &Value, _state: &Value) -> SomaResult<Value> {
272 if let Value::Tensor { values, shape } = x {
273 let doubled: Vec<f64> = values.iter().map(|v| v * 2.0).collect();
274 Ok(Value::tensor(doubled, shape.clone()))
275 } else {
276 Ok(x.clone())
277 }
278 }
279 fn meta(&self) -> FilterMeta {
280 FilterMeta {
281 name: "DoubleChunk".into(),
282 kind: FilterKind::Stateless,
283 cacheable: true,
284 differentiable: false,
285 stream_mode: StreamMode::FixedState,
286 distribution: Distribution::Local,
287 input_schema: None,
288 output_schema: None,
289 }
290 }
291 fn as_any(&self) -> &dyn std::any::Any {
292 self
293 }
294 }
295
296 struct Accumulator;
297
298 impl Filter for Accumulator {
299 fn config_hash(&self) -> CacheKey {
300 CacheKey::from_parts(&[b"Accumulator"])
301 }
302 fn fit(&self, _x: &Value, _y: Option<&Value>) -> SomaResult<Value> {
303 Ok(Value::Empty)
304 }
305 fn forward(&self, x: &Value, _state: &Value) -> SomaResult<Value> {
306 Ok(x.clone())
307 }
308 fn meta(&self) -> FilterMeta {
309 FilterMeta {
310 name: "Accumulator".into(),
311 kind: FilterKind::Stateless,
312 cacheable: false,
313 differentiable: false,
314 stream_mode: StreamMode::Barrier,
315 distribution: Distribution::Local,
316 input_schema: None,
317 output_schema: None,
318 }
319 }
320 fn as_any(&self) -> &dyn std::any::Any {
321 self
322 }
323 }
324
325 struct RunningSum;
326
327 impl Filter for RunningSum {
328 fn config_hash(&self) -> CacheKey {
329 CacheKey::from_parts(&[b"RunningSum"])
330 }
331 fn fit(&self, _x: &Value, _y: Option<&Value>) -> SomaResult<Value> {
332 Ok(Value::tensor(vec![0.0], vec![1]))
333 }
334 fn forward(&self, x: &Value, state: &Value) -> SomaResult<Value> {
335 let x_sum: f64 = match x {
336 Value::Tensor { values, .. } => values.iter().sum(),
337 _ => 0.0,
338 };
339 let state_sum: f64 = match state {
340 Value::Tensor { values, .. } => values.first().copied().unwrap_or(0.0),
341 _ => 0.0,
342 };
343 Ok(Value::tensor(vec![x_sum + state_sum], vec![1]))
344 }
345 fn meta(&self) -> FilterMeta {
346 FilterMeta {
347 name: "RunningSum".into(),
348 kind: FilterKind::Trainable,
349 cacheable: false,
350 differentiable: false,
351 stream_mode: StreamMode::Evolving {
352 checkpoint_every: 2,
353 },
354 distribution: Distribution::Local,
355 input_schema: None,
356 output_schema: None,
357 }
358 }
359 fn as_any(&self) -> &dyn std::any::Any {
360 self
361 }
362 }
363
364 fn make_fitted(filter: impl Filter + 'static, state: Value) -> FittedFilter {
365 let name = filter.meta().name.clone();
366 FittedFilter {
367 name,
368 filter: Arc::new(filter),
369 state,
370 }
371 }
372
373 #[test]
374 fn fixed_state_processes_each_chunk() {
375 let f = make_fitted(DoubleChunk, Value::Empty);
376 let mut exec = StreamExecutor::new(vec![f]);
377
378 let out1 = exec
379 .process_chunk(Value::tensor(vec![1.0, 2.0], vec![2]))
380 .unwrap();
381 assert_eq!(out1, Some(Value::tensor(vec![2.0, 4.0], vec![2])));
382
383 let out2 = exec
384 .process_chunk(Value::tensor(vec![3.0], vec![1]))
385 .unwrap();
386 assert_eq!(out2, Some(Value::tensor(vec![6.0], vec![1])));
387 }
388
389 #[test]
390 fn barrier_accumulates_then_flushes() {
391 let f = make_fitted(Accumulator, Value::Empty);
392 let mut exec = StreamExecutor::new(vec![f]);
393
394 let r1 = exec
395 .process_chunk(Value::tensor(vec![1.0, 2.0], vec![2]))
396 .unwrap();
397 assert_eq!(r1, None);
398
399 let r2 = exec
400 .process_chunk(Value::tensor(vec![3.0, 4.0], vec![2]))
401 .unwrap();
402 assert_eq!(r2, None);
403
404 let flushed = exec.flush().unwrap().unwrap();
405 let (data, shape) = flushed.as_tensor().unwrap();
406 assert_eq!(data, &[1.0, 2.0, 3.0, 4.0]);
407 assert_eq!(shape, &[4]);
408 }
409
410 #[test]
411 fn evolving_state_accumulates() {
412 let f = make_fitted(RunningSum, Value::tensor(vec![0.0], vec![1]));
413 let mut exec = StreamExecutor::new(vec![f]);
414
415 let r1 = exec
416 .process_chunk(Value::tensor(vec![10.0], vec![1]))
417 .unwrap()
418 .unwrap();
419 let (d1, _) = r1.as_tensor().unwrap();
420 assert_eq!(d1, &[10.0]);
421
422 let r2 = exec
423 .process_chunk(Value::tensor(vec![5.0], vec![1]))
424 .unwrap()
425 .unwrap();
426 let (d2, _) = r2.as_tensor().unwrap();
427 assert_eq!(d2, &[15.0]); }
429
430 #[test]
431 fn mixed_pipeline_fixed_then_barrier() {
432 let f1 = make_fitted(DoubleChunk, Value::Empty);
433 let f2 = make_fitted(Accumulator, Value::Empty);
434 let mut exec = StreamExecutor::new(vec![f1, f2]);
435
436 let r1 = exec
437 .process_chunk(Value::tensor(vec![1.0], vec![1]))
438 .unwrap();
439 assert_eq!(r1, None); let r2 = exec
442 .process_chunk(Value::tensor(vec![2.0], vec![1]))
443 .unwrap();
444 assert_eq!(r2, None);
445
446 let flushed = exec.flush().unwrap().unwrap();
447 let (data, _) = flushed.as_tensor().unwrap();
448 assert_eq!(data, &[2.0, 4.0]); }
450
451 #[test]
452 fn process_all_combines_chunks() {
453 let f = make_fitted(DoubleChunk, Value::Empty);
454 let mut exec = StreamExecutor::new(vec![f]);
455
456 let outputs = exec
457 .process_all(vec![
458 Value::tensor(vec![1.0], vec![1]),
459 Value::tensor(vec![2.0], vec![1]),
460 Value::tensor(vec![3.0], vec![1]),
461 ])
462 .unwrap();
463
464 assert_eq!(outputs.len(), 3);
465 let (d, _) = outputs[0].as_tensor().unwrap();
466 assert_eq!(d, &[2.0]);
467 }
468
469 #[test]
470 fn fixed_state_with_cache() {
471 let f = make_fitted(DoubleChunk, Value::Empty);
472 let cache = Arc::new(crate::MemoryCache::default());
473 let mut exec = StreamExecutor::new(vec![f]).with_cache(cache);
474
475 let r1 = exec
476 .process_chunk(Value::tensor(vec![5.0], vec![1]))
477 .unwrap()
478 .unwrap();
479 let r2 = exec
481 .process_chunk(Value::tensor(vec![5.0], vec![1]))
482 .unwrap()
483 .unwrap();
484 assert_eq!(r1, r2);
485 }
486}