1use somatize_core::cache::{CacheKey, CacheStore};
8use somatize_core::error::{Result, SomaError};
9use somatize_core::filter::{Filter, StreamMode};
10use somatize_core::value::Value;
11use std::sync::Arc;
12
13pub struct FittedFilter {
15 pub name: String,
16 pub filter: Arc<dyn Filter>,
17 pub state: Arc<Value>,
18}
19
20struct FilterStreamState {
22 barrier_buffer: Vec<Value>,
24 evolving_state: Option<Value>,
26}
27
28pub struct StreamExecutor {
35 filters: Vec<FittedFilter>,
36 cache: Option<Arc<dyn CacheStore>>,
37 states: Vec<FilterStreamState>,
38 chunk_count: usize,
39}
40
41impl StreamExecutor {
42 pub fn new(filters: Vec<FittedFilter>) -> Self {
43 let n = filters.len();
44 Self {
45 filters,
46 cache: None,
47 states: (0..n)
48 .map(|_| FilterStreamState {
49 barrier_buffer: Vec::new(),
50 evolving_state: None,
51 })
52 .collect(),
53 chunk_count: 0,
54 }
55 }
56
57 pub fn with_cache(mut self, cache: Arc<dyn CacheStore>) -> Self {
58 self.cache = Some(cache);
59 self
60 }
61
62 pub fn process_chunk(&mut self, chunk: Value) -> Result<Option<Value>> {
65 let mut current = chunk;
66 self.chunk_count += 1;
67
68 for i in 0..self.filters.len() {
69 let mode = self.filters[i].filter.meta().stream_mode;
70 match process_by_mode(
71 &mode,
72 &self.filters[i],
73 ¤t,
74 &mut self.states[i],
75 self.cache.as_deref(),
76 self.chunk_count,
77 )? {
78 ChunkResult::Output(val) => current = val,
79 ChunkResult::Buffered => return Ok(None),
80 }
81 }
82
83 Ok(Some(current))
84 }
85
86 pub fn flush(&mut self) -> Result<Option<Value>> {
88 let mut current: Option<Value> = None;
89
90 for i in 0..self.filters.len() {
91 let mode = self.filters[i].filter.meta().stream_mode;
92 if let Some(val) = flush_by_mode(&mode, &self.filters[i], &mut self.states[i])? {
93 current = Some(val);
94 } else if let Some(val) = current.take() {
95 current = Some(
96 self.filters[i]
97 .filter
98 .forward(&val, &self.filters[i].state)?,
99 );
100 }
101 }
102
103 Ok(current)
104 }
105
106 pub fn process_all(&mut self, chunks: Vec<Value>) -> Result<Vec<Value>> {
108 let mut outputs = Vec::new();
109 for chunk in chunks {
110 if let Some(output) = self.process_chunk(chunk)? {
111 outputs.push(output);
112 }
113 }
114 if let Some(flushed) = self.flush()? {
115 outputs.push(flushed);
116 }
117 Ok(outputs)
118 }
119
120 pub fn chunks_processed(&self) -> usize {
122 self.chunk_count
123 }
124}
125
126enum ChunkResult {
128 Output(Value),
130 Buffered,
132}
133
134fn process_by_mode(
138 mode: &StreamMode,
139 fitted: &FittedFilter,
140 input: &Value,
141 state: &mut FilterStreamState,
142 cache: Option<&dyn CacheStore>,
143 chunk_count: usize,
144) -> Result<ChunkResult> {
145 match mode {
146 StreamMode::FixedState => {
147 let result = forward_cached(fitted, input, cache)?;
148 Ok(ChunkResult::Output(result))
149 }
150 StreamMode::Evolving { checkpoint_every } => {
151 let default_state: &Value = &fitted.state;
152 let filter_state = state.evolving_state.as_ref().unwrap_or(default_state);
153 let result = fitted.filter.forward(input, filter_state)?;
154 state.evolving_state = Some(result.clone());
155
156 if *checkpoint_every > 0
157 && chunk_count.is_multiple_of(*checkpoint_every)
158 && let Some(c) = cache
159 {
160 let key = CacheKey::from_parts(&[
161 b"checkpoint",
162 fitted.name.as_bytes(),
163 &(chunk_count as u64).to_le_bytes(),
164 ]);
165 let _ = c.put(&key, &result);
166 }
167 Ok(ChunkResult::Output(result))
168 }
169 StreamMode::Barrier => {
170 state.barrier_buffer.push(input.clone());
171 Ok(ChunkResult::Buffered)
172 }
173 _ => {
174 let result = forward_cached(fitted, input, cache)?;
176 Ok(ChunkResult::Output(result))
177 }
178 }
179}
180
181fn flush_by_mode(
183 mode: &StreamMode,
184 fitted: &FittedFilter,
185 state: &mut FilterStreamState,
186) -> Result<Option<Value>> {
187 match mode {
188 StreamMode::Barrier if !state.barrier_buffer.is_empty() => {
189 let materialized = materialize_buffer(&state.barrier_buffer)?;
190 state.barrier_buffer.clear();
191 let result = fitted.filter.forward(&materialized, &fitted.state)?;
192 Ok(Some(result))
193 }
194 _ => Ok(None),
195 }
196}
197
198fn forward_cached(
200 fitted: &FittedFilter,
201 input: &Value,
202 cache: Option<&dyn CacheStore>,
203) -> Result<Value> {
204 if let Some(c) = cache {
205 let chunk_hash = CacheKey::hash_data(&serde_json::to_vec(input).unwrap_or_default());
206 let state_hash =
207 CacheKey::hash_data(&serde_json::to_vec(&fitted.state).unwrap_or_default());
208 let cache_key =
209 CacheKey::for_output(&fitted.filter.config_hash(), &state_hash, &chunk_hash);
210 if let Some(cached) = c.get(&cache_key)? {
211 return Ok(cached);
212 }
213 let result = fitted.filter.forward(input, &fitted.state)?;
214 let _ = c.put(&cache_key, &result);
215 return Ok(result);
216 }
217 fitted.filter.forward(input, &fitted.state)
218}
219
220pub fn materialize_buffer(buffer: &[Value]) -> Result<Value> {
222 if buffer.is_empty() {
223 return Ok(Value::Empty);
224 }
225 let mut all_data = Vec::new();
226 let mut total_rows = 0;
227 let mut cols = 0;
228
229 for chunk in buffer {
230 match chunk {
231 Value::Tensor { values, shape } => {
232 all_data.extend(values.iter());
233 if shape.len() == 1 {
234 total_rows += shape[0];
235 cols = 1;
236 } else if shape.len() >= 2 {
237 total_rows += shape[0];
238 cols = shape[1];
239 }
240 }
241 _ => {
242 return Err(SomaError::Other(
243 "barrier buffer contains non-tensor values".into(),
244 ));
245 }
246 }
247 }
248
249 if cols <= 1 {
250 Ok(Value::tensor(all_data, vec![total_rows]))
251 } else {
252 Ok(Value::tensor(all_data, vec![total_rows, cols]))
253 }
254}
255
256#[cfg(test)]
257mod tests {
258 use super::*;
259 use somatize_core::cache::CacheKey;
260 use somatize_core::error::Result as SomaResult;
261 use somatize_core::filter::{Distribution, FilterKind, FilterMeta};
262
263 struct DoubleChunk;
264
265 impl Filter for DoubleChunk {
266 fn config_hash(&self) -> CacheKey {
267 CacheKey::from_parts(&[b"DoubleChunk"])
268 }
269 fn fit(&self, _x: &Value, _y: Option<&Value>) -> SomaResult<Value> {
270 Ok(Value::Empty)
271 }
272 fn forward(&self, x: &Value, _state: &Value) -> SomaResult<Value> {
273 if let Value::Tensor { values, shape } = x {
274 let doubled: Vec<f64> = values.iter().map(|v| v * 2.0).collect();
275 Ok(Value::tensor(doubled, shape.clone()))
276 } else {
277 Ok(x.clone())
278 }
279 }
280 fn meta(&self) -> FilterMeta {
281 FilterMeta {
282 name: "DoubleChunk".into(),
283 kind: FilterKind::Stateless,
284 cacheable: true,
285 differentiable: false,
286 stream_mode: StreamMode::FixedState,
287 distribution: Distribution::Local,
288 input_schema: None,
289 output_schema: None,
290 }
291 }
292 fn as_any(&self) -> &dyn std::any::Any {
293 self
294 }
295 }
296
297 struct Accumulator;
298
299 impl Filter for Accumulator {
300 fn config_hash(&self) -> CacheKey {
301 CacheKey::from_parts(&[b"Accumulator"])
302 }
303 fn fit(&self, _x: &Value, _y: Option<&Value>) -> SomaResult<Value> {
304 Ok(Value::Empty)
305 }
306 fn forward(&self, x: &Value, _state: &Value) -> SomaResult<Value> {
307 Ok(x.clone())
308 }
309 fn meta(&self) -> FilterMeta {
310 FilterMeta {
311 name: "Accumulator".into(),
312 kind: FilterKind::Stateless,
313 cacheable: false,
314 differentiable: false,
315 stream_mode: StreamMode::Barrier,
316 distribution: Distribution::Local,
317 input_schema: None,
318 output_schema: None,
319 }
320 }
321 fn as_any(&self) -> &dyn std::any::Any {
322 self
323 }
324 }
325
326 struct RunningSum;
327
328 impl Filter for RunningSum {
329 fn config_hash(&self) -> CacheKey {
330 CacheKey::from_parts(&[b"RunningSum"])
331 }
332 fn fit(&self, _x: &Value, _y: Option<&Value>) -> SomaResult<Value> {
333 Ok(Value::tensor(vec![0.0], vec![1]))
334 }
335 fn forward(&self, x: &Value, state: &Value) -> SomaResult<Value> {
336 let x_sum: f64 = match x {
337 Value::Tensor { values, .. } => values.iter().sum(),
338 _ => 0.0,
339 };
340 let state_sum: f64 = match state {
341 Value::Tensor { values, .. } => values.first().copied().unwrap_or(0.0),
342 _ => 0.0,
343 };
344 Ok(Value::tensor(vec![x_sum + state_sum], vec![1]))
345 }
346 fn meta(&self) -> FilterMeta {
347 FilterMeta {
348 name: "RunningSum".into(),
349 kind: FilterKind::Trainable,
350 cacheable: false,
351 differentiable: false,
352 stream_mode: StreamMode::Evolving {
353 checkpoint_every: 2,
354 },
355 distribution: Distribution::Local,
356 input_schema: None,
357 output_schema: None,
358 }
359 }
360 fn as_any(&self) -> &dyn std::any::Any {
361 self
362 }
363 }
364
365 fn make_fitted(filter: impl Filter + 'static, state: Value) -> FittedFilter {
366 let name = filter.meta().name.clone();
367 FittedFilter {
368 name,
369 filter: Arc::new(filter),
370 state: Arc::new(state),
371 }
372 }
373
374 #[test]
375 fn fixed_state_processes_each_chunk() {
376 let f = make_fitted(DoubleChunk, Value::Empty);
377 let mut exec = StreamExecutor::new(vec![f]);
378
379 let out1 = exec
380 .process_chunk(Value::tensor(vec![1.0, 2.0], vec![2]))
381 .unwrap();
382 assert_eq!(out1, Some(Value::tensor(vec![2.0, 4.0], vec![2])));
383
384 let out2 = exec
385 .process_chunk(Value::tensor(vec![3.0], vec![1]))
386 .unwrap();
387 assert_eq!(out2, Some(Value::tensor(vec![6.0], vec![1])));
388 }
389
390 #[test]
391 fn barrier_accumulates_then_flushes() {
392 let f = make_fitted(Accumulator, Value::Empty);
393 let mut exec = StreamExecutor::new(vec![f]);
394
395 let r1 = exec
396 .process_chunk(Value::tensor(vec![1.0, 2.0], vec![2]))
397 .unwrap();
398 assert_eq!(r1, None);
399
400 let r2 = exec
401 .process_chunk(Value::tensor(vec![3.0, 4.0], vec![2]))
402 .unwrap();
403 assert_eq!(r2, None);
404
405 let flushed = exec.flush().unwrap().unwrap();
406 let (data, shape) = flushed.as_tensor().unwrap();
407 assert_eq!(data, &[1.0, 2.0, 3.0, 4.0]);
408 assert_eq!(shape, &[4]);
409 }
410
411 #[test]
412 fn evolving_state_accumulates() {
413 let f = make_fitted(RunningSum, Value::tensor(vec![0.0], vec![1]));
414 let mut exec = StreamExecutor::new(vec![f]);
415
416 let r1 = exec
417 .process_chunk(Value::tensor(vec![10.0], vec![1]))
418 .unwrap()
419 .unwrap();
420 let (d1, _) = r1.as_tensor().unwrap();
421 assert_eq!(d1, &[10.0]);
422
423 let r2 = exec
424 .process_chunk(Value::tensor(vec![5.0], vec![1]))
425 .unwrap()
426 .unwrap();
427 let (d2, _) = r2.as_tensor().unwrap();
428 assert_eq!(d2, &[15.0]); }
430
431 #[test]
432 fn mixed_pipeline_fixed_then_barrier() {
433 let f1 = make_fitted(DoubleChunk, Value::Empty);
434 let f2 = make_fitted(Accumulator, Value::Empty);
435 let mut exec = StreamExecutor::new(vec![f1, f2]);
436
437 let r1 = exec
438 .process_chunk(Value::tensor(vec![1.0], vec![1]))
439 .unwrap();
440 assert_eq!(r1, None); let r2 = exec
443 .process_chunk(Value::tensor(vec![2.0], vec![1]))
444 .unwrap();
445 assert_eq!(r2, None);
446
447 let flushed = exec.flush().unwrap().unwrap();
448 let (data, _) = flushed.as_tensor().unwrap();
449 assert_eq!(data, &[2.0, 4.0]); }
451
452 #[test]
453 fn process_all_combines_chunks() {
454 let f = make_fitted(DoubleChunk, Value::Empty);
455 let mut exec = StreamExecutor::new(vec![f]);
456
457 let outputs = exec
458 .process_all(vec![
459 Value::tensor(vec![1.0], vec![1]),
460 Value::tensor(vec![2.0], vec![1]),
461 Value::tensor(vec![3.0], vec![1]),
462 ])
463 .unwrap();
464
465 assert_eq!(outputs.len(), 3);
466 let (d, _) = outputs[0].as_tensor().unwrap();
467 assert_eq!(d, &[2.0]);
468 }
469
470 #[test]
471 fn fixed_state_with_cache() {
472 let f = make_fitted(DoubleChunk, Value::Empty);
473 let cache = Arc::new(crate::MemoryCache::default());
474 let mut exec = StreamExecutor::new(vec![f]).with_cache(cache);
475
476 let r1 = exec
477 .process_chunk(Value::tensor(vec![5.0], vec![1]))
478 .unwrap()
479 .unwrap();
480 let r2 = exec
482 .process_chunk(Value::tensor(vec![5.0], vec![1]))
483 .unwrap()
484 .unwrap();
485 assert_eq!(r1, r2);
486 }
487}