1use crate::event_bus::EventBus;
9use crate::filter_library::FilterLibrary;
10use crate::runner::Runner;
11use somatize_compiler::{CompileMode, CompileResult, compile, compile_stream};
12use somatize_core::cache::CacheStore;
13use somatize_core::error::{Result, SomaError};
14use somatize_core::graph::Graph;
15use somatize_core::store::{DataRef, DataStore};
16use somatize_core::value::Value;
17use std::sync::Arc;
18
19pub trait ForwardStrategy {
21 fn forward(
23 &self,
24 graph: &Graph,
25 library: &FilterLibrary,
26 cache: &dyn CacheStore,
27 event_bus: &Arc<EventBus>,
28 data_store: Option<&Arc<dyn DataStore>>,
29 x: &Value,
30 ) -> Result<Value>;
31}
32
33pub struct Standard;
35
36impl ForwardStrategy for Standard {
37 fn forward(
38 &self,
39 graph: &Graph,
40 library: &FilterLibrary,
41 cache: &dyn CacheStore,
42 event_bus: &Arc<EventBus>,
43 _data_store: Option<&Arc<dyn DataStore>>,
44 x: &Value,
45 ) -> Result<Value> {
46 let CompileResult { plan, .. } =
47 compile(graph, library, CompileMode::Inference, Some(cache))?;
48
49 let runner = crate::runner::LocalRunner;
50 runner.forward(&plan, library, cache, event_bus, x)
51 }
52}
53
54pub struct Stream {
57 pub chunk_size: usize,
58}
59
60impl ForwardStrategy for Stream {
61 fn forward(
62 &self,
63 graph: &Graph,
64 library: &FilterLibrary,
65 cache: &dyn CacheStore,
66 event_bus: &Arc<EventBus>,
67 _data_store: Option<&Arc<dyn DataStore>>,
68 x: &Value,
69 ) -> Result<Value> {
70 let CompileResult { plan, .. } = compile_stream(graph, library, self.chunk_size)?;
71
72 let runner = crate::runner::LocalRunner;
73 runner.forward(&plan, library, cache, event_bus, x)
74 }
75}
76
77pub struct Batched<'a> {
80 pub data_ref: &'a DataRef,
81 pub batch_size: usize,
82}
83
84impl ForwardStrategy for Batched<'_> {
85 fn forward(
86 &self,
87 graph: &Graph,
88 library: &FilterLibrary,
89 cache: &dyn CacheStore,
90 event_bus: &Arc<EventBus>,
91 data_store: Option<&Arc<dyn DataStore>>,
92 _x: &Value,
93 ) -> Result<Value> {
94 let store = data_store.ok_or_else(|| SomaError::Execution {
95 node_id: "session".into(),
96 message: "Batched strategy requires a data store (use with_data_store)".into(),
97 })?;
98
99 let meta = store.meta(self.data_ref)?;
100 let total_rows = meta.total_rows;
101 if total_rows == 0 {
102 return Ok(Value::Empty);
103 }
104
105 let CompileResult { plan, .. } =
107 compile(graph, library, CompileMode::Inference, Some(cache))?;
108 let runner = crate::runner::LocalRunner;
109
110 let mut all_values: Vec<f64> = Vec::new();
111 let mut result_shape: Option<Vec<usize>> = None;
112 let mut rows_processed = 0;
113
114 while rows_processed < total_rows {
115 let batch_len = self.batch_size.min(total_rows - rows_processed);
116 let batch = store.get_rows(self.data_ref, rows_processed, batch_len)?;
117 let output = runner.forward(&plan, library, cache, event_bus, &batch)?;
118
119 if let Value::Tensor { values, shape } = &output {
120 if result_shape.is_none() {
121 result_shape = Some(shape.clone());
122 }
123 all_values.extend_from_slice(values);
124 } else {
125 return Ok(output);
126 }
127
128 rows_processed += batch_len;
129 }
130
131 match result_shape {
132 Some(mut shape) => {
133 shape[0] = total_rows;
134 Ok(Value::tensor(all_values, shape))
135 }
136 None => Ok(Value::Empty),
137 }
138 }
139}
140
141#[cfg(test)]
142mod tests {
143 use super::*;
144 use crate::cache::MemoryCache;
145 use crate::filter_library::FilterLibrary;
146 use somatize_core::cache::CacheKey;
147 use somatize_core::error::Result as SomaResult;
148 use somatize_core::filter::{Distribution, Filter, FilterKind, FilterMeta, StreamMode};
149 use somatize_core::graph::{Graph, Node};
150
151 struct DoublerFilter;
152 impl Filter for DoublerFilter {
153 fn config_hash(&self) -> CacheKey {
154 CacheKey::from_parts(&[b"Doubler"])
155 }
156 fn fit(&self, _x: &Value, _y: Option<&Value>) -> SomaResult<Value> {
157 Ok(Value::Empty)
158 }
159 fn forward(&self, x: &Value, _state: &Value) -> SomaResult<Value> {
160 match x {
161 Value::Tensor { values, shape } => {
162 let doubled: Vec<f64> = values.iter().map(|v| v * 2.0).collect();
163 Ok(Value::tensor(doubled, shape.clone()))
164 }
165 _ => Ok(x.clone()),
166 }
167 }
168 fn meta(&self) -> FilterMeta {
169 FilterMeta {
170 name: "Doubler".into(),
171 kind: FilterKind::Stateless,
172 cacheable: false,
173 differentiable: false,
174 stream_mode: StreamMode::FixedState,
175 distribution: Distribution::Local,
176 input_schema: None,
177 output_schema: None,
178 }
179 }
180 fn as_any(&self) -> &dyn std::any::Any {
181 self
182 }
183 }
184
185 fn make_session() -> (Graph, FilterLibrary, Arc<dyn CacheStore>, Arc<EventBus>) {
186 let mut graph = Graph::new();
187 graph.nodes.push(Node::new("double", "Double", "double"));
188
189 let mut library = FilterLibrary::new();
190 library.register("double", Box::new(DoublerFilter));
191
192 let cache: Arc<dyn CacheStore> = Arc::new(MemoryCache::default());
193 let bus = Arc::new(EventBus::new(64));
194 (graph, library, cache, bus)
195 }
196
197 #[test]
198 fn standard_forward() {
199 let (graph, library, cache, bus) = make_session();
200 let input = Value::tensor(vec![1.0, 2.0, 3.0], vec![3]);
201
202 let result = Standard
203 .forward(&graph, &library, cache.as_ref(), &bus, None, &input)
204 .unwrap();
205 let (data, _) = result.as_tensor().unwrap();
206 assert_eq!(data, &[2.0, 4.0, 6.0]);
207 }
208
209 #[test]
210 fn stream_forward() {
211 let (graph, library, cache, bus) = make_session();
212 let input = Value::tensor(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![6]);
213
214 let result = Stream { chunk_size: 2 }
215 .forward(&graph, &library, cache.as_ref(), &bus, None, &input)
216 .unwrap();
217 let (data, shape) = result.as_tensor().unwrap();
218 assert_eq!(data, &[2.0, 4.0, 6.0, 8.0, 10.0, 12.0]);
219 assert_eq!(shape, &[6]);
220 }
221
222 #[test]
223 fn stream_matches_standard() {
224 let (graph, library, cache, bus) = make_session();
225 let input = Value::tensor(vec![1.0, 2.0, 3.0, 4.0], vec![4]);
226
227 let standard = Standard
228 .forward(&graph, &library, cache.as_ref(), &bus, None, &input)
229 .unwrap();
230 let streamed = Stream { chunk_size: 2 }
231 .forward(&graph, &library, cache.as_ref(), &bus, None, &input)
232 .unwrap();
233 assert_eq!(standard, streamed);
234 }
235}