Skip to main content

somatize_runtime/
forward.rs

1//! Forward execution strategies for [`GraphSession`].
2//!
3//! Each strategy defines HOW input data flows through a compiled graph:
4//! - [`Standard`] — full input at once, with inference caching
5//! - [`Stream`] — chunked input through [`StreamExecutor`], respecting StreamMode
6//! - [`Batched`] — rows from a [`DataStore`], batch by batch (memory-bounded)
7
8use crate::event_bus::EventBus;
9use crate::filter_library::FilterLibrary;
10use crate::runner::Runner;
11use somatize_compiler::{CompileMode, CompileResult, compile, compile_stream};
12use somatize_core::cache::CacheStore;
13use somatize_core::error::{Result, SomaError};
14use somatize_core::graph::Graph;
15use somatize_core::store::{DataRef, DataStore};
16use somatize_core::value::Value;
17use std::sync::Arc;
18
19/// How a forward pass feeds data through the compiled graph.
20pub trait ForwardStrategy {
21    /// Execute a forward pass, returning the final output.
22    fn forward(
23        &self,
24        graph: &Graph,
25        library: &FilterLibrary,
26        cache: &dyn CacheStore,
27        event_bus: &Arc<EventBus>,
28        data_store: Option<&Arc<dyn DataStore>>,
29        x: &Value,
30    ) -> Result<Value>;
31}
32
33/// Standard forward: full input at once with inference caching.
34pub struct Standard;
35
36impl ForwardStrategy for Standard {
37    fn forward(
38        &self,
39        graph: &Graph,
40        library: &FilterLibrary,
41        cache: &dyn CacheStore,
42        event_bus: &Arc<EventBus>,
43        _data_store: Option<&Arc<dyn DataStore>>,
44        x: &Value,
45    ) -> Result<Value> {
46        let CompileResult { plan, .. } =
47            compile(graph, library, CompileMode::Inference, Some(cache))?;
48
49        let runner = crate::runner::LocalRunner;
50        runner.forward(&plan, library, cache, event_bus, x)
51    }
52}
53
54/// Streaming forward: chunk input and process through StreamExecutor.
55/// Each filter's StreamMode (FixedState/Evolving/Barrier) defines its per-chunk contract.
56pub struct Stream {
57    pub chunk_size: usize,
58}
59
60impl ForwardStrategy for Stream {
61    fn forward(
62        &self,
63        graph: &Graph,
64        library: &FilterLibrary,
65        cache: &dyn CacheStore,
66        event_bus: &Arc<EventBus>,
67        _data_store: Option<&Arc<dyn DataStore>>,
68        x: &Value,
69    ) -> Result<Value> {
70        let CompileResult { plan, .. } = compile_stream(graph, library, self.chunk_size)?;
71
72        let runner = crate::runner::LocalRunner;
73        runner.forward(&plan, library, cache, event_bus, x)
74    }
75}
76
77/// Batched forward: read rows from a DataStore in fixed-size batches.
78/// Keeps memory bounded — only one batch is materialized at a time.
79pub struct Batched<'a> {
80    pub data_ref: &'a DataRef,
81    pub batch_size: usize,
82}
83
84impl ForwardStrategy for Batched<'_> {
85    fn forward(
86        &self,
87        graph: &Graph,
88        library: &FilterLibrary,
89        cache: &dyn CacheStore,
90        event_bus: &Arc<EventBus>,
91        data_store: Option<&Arc<dyn DataStore>>,
92        _x: &Value,
93    ) -> Result<Value> {
94        let store = data_store.ok_or_else(|| SomaError::Execution {
95            node_id: "session".into(),
96            message: "Batched strategy requires a data store (use with_data_store)".into(),
97        })?;
98
99        let meta = store.meta(self.data_ref)?;
100        let total_rows = meta.total_rows;
101        if total_rows == 0 {
102            return Ok(Value::Empty);
103        }
104
105        // Compile once, reuse for each batch.
106        let CompileResult { plan, .. } =
107            compile(graph, library, CompileMode::Inference, Some(cache))?;
108        let runner = crate::runner::LocalRunner;
109
110        let mut all_values: Vec<f64> = Vec::new();
111        let mut result_shape: Option<Vec<usize>> = None;
112        let mut rows_processed = 0;
113
114        while rows_processed < total_rows {
115            let batch_len = self.batch_size.min(total_rows - rows_processed);
116            let batch = store.get_rows(self.data_ref, rows_processed, batch_len)?;
117            let output = runner.forward(&plan, library, cache, event_bus, &batch)?;
118
119            if let Value::Tensor { values, shape } = &output {
120                if result_shape.is_none() {
121                    result_shape = Some(shape.clone());
122                }
123                all_values.extend_from_slice(values.as_slice());
124            } else {
125                return Ok(output);
126            }
127
128            rows_processed += batch_len;
129        }
130
131        match result_shape {
132            Some(mut shape) => {
133                shape[0] = total_rows;
134                Ok(Value::tensor(all_values, shape))
135            }
136            None => Ok(Value::Empty),
137        }
138    }
139}
140
141#[cfg(test)]
142mod tests {
143    use super::*;
144    use crate::cache::MemoryCache;
145    use crate::filter_library::FilterLibrary;
146    use somatize_core::cache::CacheKey;
147    use somatize_core::error::Result as SomaResult;
148    use somatize_core::filter::{Distribution, Filter, FilterKind, FilterMeta, StreamMode};
149    use somatize_core::graph::{Graph, Node};
150
151    struct DoublerFilter;
152    impl Filter for DoublerFilter {
153        fn config_hash(&self) -> CacheKey {
154            CacheKey::from_parts(&[b"Doubler"])
155        }
156        fn fit(&self, _x: &Value, _y: Option<&Value>) -> SomaResult<Value> {
157            Ok(Value::Empty)
158        }
159        fn forward(&self, x: &Value, _state: &Value) -> SomaResult<Value> {
160            match x {
161                Value::Tensor { values, shape } => {
162                    let doubled: Vec<f64> = values.iter().map(|v| v * 2.0).collect();
163                    Ok(Value::tensor(doubled, shape.clone()))
164                }
165                _ => Ok(x.clone()),
166            }
167        }
168        fn meta(&self) -> FilterMeta {
169            FilterMeta {
170                name: "Doubler".into(),
171                kind: FilterKind::Stateless,
172                cacheable: false,
173                differentiable: false,
174                stream_mode: StreamMode::FixedState,
175                distribution: Distribution::Local,
176                input_schema: None,
177                output_schema: None,
178            }
179        }
180        fn as_any(&self) -> &dyn std::any::Any {
181            self
182        }
183    }
184
185    fn make_session() -> (Graph, FilterLibrary, Arc<dyn CacheStore>, Arc<EventBus>) {
186        let mut graph = Graph::new();
187        graph.nodes.push(Node::new("double", "Double", "double"));
188
189        let mut library = FilterLibrary::new();
190        library.register("double", Box::new(DoublerFilter));
191
192        let cache: Arc<dyn CacheStore> = Arc::new(MemoryCache::default());
193        let bus = Arc::new(EventBus::new(64));
194        (graph, library, cache, bus)
195    }
196
197    #[test]
198    fn standard_forward() {
199        let (graph, library, cache, bus) = make_session();
200        let input = Value::tensor(vec![1.0, 2.0, 3.0], vec![3]);
201
202        let result = Standard
203            .forward(&graph, &library, cache.as_ref(), &bus, None, &input)
204            .unwrap();
205        let (data, _) = result.as_tensor().unwrap();
206        assert_eq!(data, &[2.0, 4.0, 6.0]);
207    }
208
209    #[test]
210    fn stream_forward() {
211        let (graph, library, cache, bus) = make_session();
212        let input = Value::tensor(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![6]);
213
214        let result = Stream { chunk_size: 2 }
215            .forward(&graph, &library, cache.as_ref(), &bus, None, &input)
216            .unwrap();
217        let (data, shape) = result.as_tensor().unwrap();
218        assert_eq!(data, &[2.0, 4.0, 6.0, 8.0, 10.0, 12.0]);
219        assert_eq!(shape, &[6]);
220    }
221
222    #[test]
223    fn stream_matches_standard() {
224        let (graph, library, cache, bus) = make_session();
225        let input = Value::tensor(vec![1.0, 2.0, 3.0, 4.0], vec![4]);
226
227        let standard = Standard
228            .forward(&graph, &library, cache.as_ref(), &bus, None, &input)
229            .unwrap();
230        let streamed = Stream { chunk_size: 2 }
231            .forward(&graph, &library, cache.as_ref(), &bus, None, &input)
232            .unwrap();
233        assert_eq!(standard, streamed);
234    }
235}