Skip to main content

somatize_runtime/
graph_session.rs

1//! Graph session — the primary orchestrator for Graph → Compile → Execute.
2//!
3//! [`GraphSession`] binds a [`Graph`] with its [`FilterLibrary`], cache,
4//! event bus, and optional distributed components into a single object
5//! that can compile, fit, and execute.
6
7use crate::cache::MemoryCache;
8use crate::event_bus::EventBus;
9use crate::executor::{self, Context, GraphInfo, RemoteExecutor};
10use crate::filter_library::FilterLibrary;
11use crate::runner::Runner;
12use somatize_compiler::{CompileMode, CompileResult, compile};
13use somatize_core::cache::{CacheKey, CacheStore};
14use somatize_core::error::{Result, SomaError};
15use somatize_core::event::Event;
16use somatize_core::filter::FilterKind;
17use somatize_core::graph::Graph;
18use somatize_core::store::{DataRef, DataStore};
19use somatize_core::util::timestamp_id;
20use somatize_core::value::Value;
21use std::collections::HashMap;
22use std::sync::Arc;
23
24/// The primary orchestrator: Graph + filters + cache + events.
25///
26/// ```ignore
27/// let mut lib = FilterLibrary::new();
28/// lib.register("scaler", Box::new(MyScaler::new()));
29/// lib.register("model", Box::new(MyModel::new()));
30///
31/// let mut session = GraphSession::new(graph, lib);
32/// session.fit(&train_x, Some(&train_y))?;
33/// let output = session.forward(&test_x)?;
34/// ```
35pub struct GraphSession {
36    graph: Graph,
37    library: FilterLibrary,
38    cache: Arc<dyn CacheStore>,
39    event_bus: Arc<EventBus>,
40    data_store: Option<Arc<dyn DataStore>>,
41    remote_executor: Option<Arc<dyn RemoteExecutor>>,
42    fitted: bool,
43}
44
45impl GraphSession {
46    pub fn new(graph: Graph, library: FilterLibrary) -> Self {
47        Self {
48            graph,
49            library,
50            cache: Arc::new(MemoryCache::default()),
51            event_bus: Arc::new(EventBus::new(256)),
52            data_store: None,
53            remote_executor: None,
54            fitted: false,
55        }
56    }
57
58    pub fn with_cache(mut self, cache: Arc<dyn CacheStore>) -> Self {
59        self.cache = cache;
60        self
61    }
62
63    pub fn with_event_bus(mut self, bus: Arc<EventBus>) -> Self {
64        self.event_bus = bus;
65        self
66    }
67
68    pub fn with_data_store(mut self, store: Arc<dyn DataStore>) -> Self {
69        self.data_store = Some(store);
70        self
71    }
72
73    pub fn with_remote_executor(mut self, executor: Arc<dyn RemoteExecutor>) -> Self {
74        self.remote_executor = Some(executor);
75        self
76    }
77
78    // ── Core operations ──
79
80    /// Compile the graph and return diagnostics without executing.
81    pub fn compile(&self, mode: CompileMode) -> Result<CompileResult> {
82        compile(&self.graph, &self.library, mode, Some(self.cache.as_ref()))
83    }
84
85    /// Compile and execute the graph, returning all node outputs.
86    pub fn run(&mut self, mode: CompileMode) -> Result<HashMap<String, Value>> {
87        let CompileResult { plan, diagnostics } =
88            compile(&self.graph, &self.library, mode, Some(self.cache.as_ref()))?;
89
90        for diag in &diagnostics {
91            tracing::warn!("compile diagnostic: {:?}", diag);
92        }
93
94        let graph_info = GraphInfo::from_graph(&self.graph);
95        let mut ctx = Context::new(self.event_bus.clone(), timestamp_id("graph_run"))
96            .with_graph_info(graph_info);
97
98        if let Some(store) = &self.data_store {
99            ctx = ctx.with_data_store(store.clone());
100        }
101        if let Some(remote) = &self.remote_executor {
102            ctx = ctx.with_remote_executor(remote.clone());
103        }
104
105        executor::execute(&plan, &mut ctx, &self.library, self.cache.as_ref())?;
106
107        Ok(ctx
108            .store
109            .into_iter()
110            .filter_map(|(k, vv)| vv.as_value().cloned().map(|v| (k, v)))
111            .collect())
112    }
113
114    /// Fit all trainable filters in topological order.
115    /// Delegates to LocalRunner — same execution path as remote workers.
116    pub fn fit(&mut self, x: &Value, y: Option<&Value>) -> Result<HashMap<String, Value>> {
117        self.graph.validate()?;
118
119        let CompileResult { plan, .. } = compile(
120            &self.graph,
121            &self.library,
122            CompileMode::NoCache,
123            Some(self.cache.as_ref()),
124        )?;
125
126        let runner = crate::runner::LocalRunner;
127        let (_last_output, all_outputs) = runner.fit(
128            &plan,
129            &self.library,
130            self.cache.as_ref(),
131            &self.event_bus,
132            x,
133            y,
134        )?;
135
136        // Store trained states in library for subsequent forward() calls
137        // LocalRunner caches in CacheStore but can't mutate FilterLibrary
138        for node_id in plan.node_ids() {
139            if let Some(filter) = self.library.get(node_id)
140                && filter.meta().kind == FilterKind::Trainable
141                && all_outputs.contains_key(node_id)
142            {
143                let data_hash = CacheKey::hash_data(&serde_json::to_vec(x).unwrap_or_default());
144                let state_key = CacheKey::for_state(&filter.config_hash(), &data_hash);
145                if let Ok(Some(state)) = self.cache.get(&state_key) {
146                    self.library.set_state(node_id, state);
147                }
148            }
149        }
150
151        self.fitted = true;
152        Ok(all_outputs)
153    }
154
155    /// Forward pass through the fitted graph (inference).
156    ///
157    /// Compiles in Inference mode and executes. Trainable filters
158    /// use states cached during `fit()`.
159    /// Forward pass through the fitted graph (inference).
160    /// Delegates to LocalRunner — same execution path as remote workers.
161    pub fn forward(&self, x: &Value) -> Result<Value> {
162        let CompileResult { plan, .. } = compile(
163            &self.graph,
164            &self.library,
165            CompileMode::Inference,
166            Some(self.cache.as_ref()),
167        )?;
168
169        let runner = crate::runner::LocalRunner;
170        runner.forward(
171            &plan,
172            &self.library,
173            self.cache.as_ref(),
174            &self.event_bus,
175            x,
176        )
177    }
178
179    /// Forward pass in batches from a DataStore reference.
180    pub fn forward_batched(&self, data_ref: &DataRef, batch_size: usize) -> Result<Value> {
181        let store = self
182            .data_store
183            .as_ref()
184            .ok_or_else(|| SomaError::Execution {
185                node_id: "session".into(),
186                message: "forward_batched requires a data store (use with_data_store)".into(),
187            })?;
188
189        let meta = store.meta(data_ref)?;
190        let total_rows = meta.total_rows;
191        if total_rows == 0 {
192            return Ok(Value::Empty);
193        }
194
195        let mut all_values: Vec<f64> = Vec::new();
196        let mut result_shape: Option<Vec<usize>> = None;
197        let mut rows_processed = 0;
198
199        while rows_processed < total_rows {
200            let batch_len = batch_size.min(total_rows - rows_processed);
201            let batch = store.get_rows(data_ref, rows_processed, batch_len)?;
202            let output = self.forward(&batch)?;
203
204            if let Value::Tensor { values, shape } = &output {
205                if result_shape.is_none() {
206                    result_shape = Some(shape.clone());
207                }
208                all_values.extend_from_slice(values);
209            } else {
210                return Ok(output);
211            }
212
213            rows_processed += batch_len;
214        }
215
216        match result_shape {
217            Some(mut shape) => {
218                shape[0] = total_rows;
219                Ok(Value::tensor(all_values, shape))
220            }
221            None => Ok(Value::Empty),
222        }
223    }
224
225    // ── State persistence ──
226
227    /// Persist all trained states to the data store.
228    pub fn persist_states(&self) -> Result<DataRef> {
229        let store = self
230            .data_store
231            .as_ref()
232            .ok_or_else(|| SomaError::Execution {
233                node_id: "session".into(),
234                message: "persist_states requires a data store".into(),
235            })?;
236
237        let sorted = self.graph.topological_sort()?;
238        let mut states_map = serde_json::Map::new();
239        for node_id in &sorted {
240            if let Some(state) = self.library.get_state(node_id) {
241                let json = serde_json::to_value(state)
242                    .map_err(|e| SomaError::Other(format!("state serialize: {e}")))?;
243                states_map.insert(node_id.to_string(), json);
244            }
245        }
246
247        let states_value = Value::Json(serde_json::Value::Object(states_map));
248        let key = CacheKey::from_parts(&[b"graph_states", self.graph_config_hash().as_bytes()]);
249        store.put(&key, &states_value)
250    }
251
252    /// Load previously persisted states from a data store reference.
253    pub fn load_states(&mut self, data_ref: &DataRef) -> Result<()> {
254        let store = self
255            .data_store
256            .as_ref()
257            .ok_or_else(|| SomaError::Execution {
258                node_id: "session".into(),
259                message: "load_states requires a data store".into(),
260            })?;
261
262        let states_value = store.get(data_ref)?;
263        let states_json = states_value
264            .as_json()
265            .ok_or_else(|| SomaError::Other("persisted states must be JSON".into()))?;
266        let obj = states_json
267            .as_object()
268            .ok_or_else(|| SomaError::Other("persisted states must be a JSON object".into()))?;
269
270        for (node_id, json_val) in obj {
271            let value: Value = serde_json::from_value(json_val.clone())
272                .map_err(|e| SomaError::Other(format!("state deserialize: {e}")))?;
273            self.library.set_state(node_id.clone(), value);
274        }
275
276        self.fitted = true;
277        Ok(())
278    }
279
280    // ── Observability ──
281
282    /// Subscribe to execution events.
283    pub fn subscribe(&self) -> tokio::sync::broadcast::Receiver<Event> {
284        self.event_bus.subscribe()
285    }
286
287    /// Access the event bus directly.
288    pub fn event_bus(&self) -> &Arc<EventBus> {
289        &self.event_bus
290    }
291
292    /// Whether the session has been fitted.
293    pub fn is_fitted(&self) -> bool {
294        self.fitted
295    }
296
297    /// Access the graph.
298    pub fn graph(&self) -> &Graph {
299        &self.graph
300    }
301
302    /// Access the filter library.
303    pub fn library(&self) -> &FilterLibrary {
304        &self.library
305    }
306
307    /// Mutable access to the filter library (for registering filters after creation).
308    pub fn library_mut(&mut self) -> &mut FilterLibrary {
309        &mut self.library
310    }
311
312    // ── Private helpers ──
313
314    fn graph_config_hash(&self) -> String {
315        let node_ids: Vec<&str> = self.graph.nodes.iter().map(|n| n.id.as_str()).collect();
316        node_ids.join(",")
317    }
318}
319
320// ── Convenience free functions (backward compat) ──
321
322/// Compile and execute a graph, returning all node outputs.
323pub fn graph_run(
324    graph: &Graph,
325    library: &FilterLibrary,
326    mode: CompileMode,
327    cache: &dyn CacheStore,
328) -> Result<HashMap<String, Value>> {
329    let CompileResult { plan, diagnostics } = compile(graph, library, mode, Some(cache))?;
330
331    for diag in &diagnostics {
332        tracing::warn!("compile diagnostic: {:?}", diag);
333    }
334
335    let bus = Arc::new(EventBus::new(256));
336    let graph_info = GraphInfo::from_graph(graph);
337
338    let mut ctx = Context::new(bus, timestamp_id("graph_run")).with_graph_info(graph_info);
339
340    executor::execute(&plan, &mut ctx, library, cache)?;
341
342    Ok(ctx
343        .store
344        .into_iter()
345        .filter_map(|(k, vv)| vv.as_value().cloned().map(|v| (k, v)))
346        .collect())
347}
348
349/// Fit all trainable filters in topological order.
350pub fn graph_fit(
351    graph: &Graph,
352    library: &FilterLibrary,
353    x: &Value,
354    y: Option<&Value>,
355    cache: &dyn CacheStore,
356) -> Result<HashMap<String, Value>> {
357    graph.validate()?;
358    let sorted = graph.topological_sort()?;
359    let graph_info = GraphInfo::from_graph(graph);
360
361    let bus = Arc::new(EventBus::new(256));
362    let run_id = timestamp_id("graph_fit");
363
364    let mut outputs: HashMap<String, Value> = HashMap::new();
365
366    let roots = graph.roots();
367    for root_id in &roots {
368        outputs.insert(format!("__input_{root_id}"), x.clone());
369    }
370
371    for node_id in &sorted {
372        let filter = library
373            .get(node_id)
374            .ok_or_else(|| SomaError::NodeNotFound(node_id.to_string()))?;
375
376        bus.emit(Event::NodeStarted {
377            run_id: run_id.clone(),
378            node_id: node_id.to_string(),
379            kind: filter.meta().kind,
380        });
381
382        let preds = graph_info.predecessors(node_id);
383        let input = match preds.len() {
384            0 => x.clone(),
385            1 => outputs.get(&preds[0]).cloned().unwrap_or_else(|| x.clone()),
386            _ => {
387                let mut merged = serde_json::Map::new();
388                for pred_id in preds {
389                    if let Some(val) = outputs.get(pred_id.as_str()) {
390                        let json_val = serde_json::to_value(val).unwrap_or(serde_json::Value::Null);
391                        merged.insert(pred_id.clone(), json_val);
392                    }
393                }
394                Value::Json(serde_json::Value::Object(merged))
395            }
396        };
397
398        let meta = filter.meta();
399        let start = std::time::Instant::now();
400
401        let (state, output) = if meta.kind == FilterKind::Trainable {
402            let data_hash = CacheKey::hash_data(&serde_json::to_vec(&input).unwrap_or_default());
403            let state_key = CacheKey::for_state(&filter.config_hash(), &data_hash);
404
405            let state = if let Some(cached) = cache.get(&state_key)? {
406                cached
407            } else {
408                let s = filter.fit(&input, y)?;
409                cache.put(&state_key, &s)?;
410                s
411            };
412
413            let output = filter.forward(&input, &state)?;
414            (state, output)
415        } else {
416            let output = filter.forward(&input, &Value::Empty)?;
417            (Value::Empty, output)
418        };
419
420        let _ = state;
421
422        bus.emit(Event::NodeCompleted {
423            run_id: run_id.clone(),
424            node_id: node_id.to_string(),
425            duration: start.elapsed(),
426            output_summary: format!("{output}"),
427        });
428
429        outputs.insert(node_id.to_string(), output);
430    }
431
432    Ok(outputs)
433}
434
435/// Compile in Inference mode and execute, returning the last leaf's output.
436pub fn graph_predict(
437    graph: &Graph,
438    library: &FilterLibrary,
439    x: &Value,
440    cache: &dyn CacheStore,
441) -> Result<Value> {
442    let CompileResult { plan, .. } = compile(graph, library, CompileMode::Inference, Some(cache))?;
443
444    let bus = Arc::new(EventBus::new(256));
445    let graph_info = GraphInfo::from_graph(graph);
446    let mut ctx = Context::new(bus, timestamp_id("graph_predict")).with_graph_info(graph_info);
447
448    let roots = graph.roots();
449    if roots.len() == 1 {
450        ctx.set(format!("__input_{}", roots[0]), x.clone());
451    }
452    ctx.set("__input__", x.clone());
453
454    executor::execute(&plan, &mut ctx, library, cache)?;
455
456    let leaves = graph.leaves();
457    let mut extract =
458        |id: &str| -> Option<Value> { ctx.store.remove(id).and_then(|vv| vv.as_value().cloned()) };
459
460    if let Some(leaf_id) = leaves.first() {
461        extract(leaf_id)
462            .ok_or_else(|| SomaError::Other(format!("leaf node '{leaf_id}' produced no output")))
463    } else {
464        ctx.execution_order
465            .last()
466            .and_then(|id| extract(id))
467            .ok_or_else(|| SomaError::Other("no output produced".into()))
468    }
469}
470
471#[cfg(test)]
472mod tests {
473    use super::*;
474    use crate::cache::MemoryCache;
475    use somatize_compiler::FilterRegistry;
476    use somatize_core::cache::CacheKey;
477    use somatize_core::error::Result;
478    use somatize_core::filter::{FilterKind, FilterMeta, StreamMode};
479    use somatize_core::graph::{Edge, Node};
480
481    // ── Test filters ──
482
483    struct DoublerFilter;
484    impl somatize_core::filter::Filter for DoublerFilter {
485        fn config_hash(&self) -> CacheKey {
486            CacheKey::from_parts(&[b"Doubler"])
487        }
488        fn fit(&self, _x: &Value, _y: Option<&Value>) -> Result<Value> {
489            Ok(Value::Empty)
490        }
491        fn forward(&self, x: &Value, _state: &Value) -> Result<Value> {
492            let (data, shape) = x
493                .as_tensor()
494                .ok_or(SomaError::Other("need tensor".into()))?;
495            Ok(Value::tensor(
496                data.iter().map(|v| v * 2.0).collect(),
497                shape.to_vec(),
498            ))
499        }
500        fn meta(&self) -> FilterMeta {
501            FilterMeta {
502                name: "Doubler".into(),
503                kind: FilterKind::Stateless,
504                cacheable: true,
505                differentiable: true,
506                stream_mode: StreamMode::FixedState,
507                distribution: somatize_core::filter::Distribution::Local,
508                input_schema: None,
509                output_schema: None,
510            }
511        }
512
513        fn as_any(&self) -> &dyn std::any::Any {
514            self
515        }
516    }
517
518    struct AdderFilter(f64);
519    impl somatize_core::filter::Filter for AdderFilter {
520        fn config_hash(&self) -> CacheKey {
521            CacheKey::from_parts(&[b"Adder", &self.0.to_le_bytes()])
522        }
523        fn fit(&self, _x: &Value, _y: Option<&Value>) -> Result<Value> {
524            Ok(Value::Empty)
525        }
526        fn forward(&self, x: &Value, _state: &Value) -> Result<Value> {
527            let (data, shape) = x
528                .as_tensor()
529                .ok_or(SomaError::Other("need tensor".into()))?;
530            Ok(Value::tensor(
531                data.iter().map(|v| v + self.0).collect(),
532                shape.to_vec(),
533            ))
534        }
535        fn meta(&self) -> FilterMeta {
536            FilterMeta {
537                name: "Adder".into(),
538                kind: FilterKind::Stateless,
539                cacheable: true,
540                differentiable: true,
541                stream_mode: StreamMode::FixedState,
542                distribution: somatize_core::filter::Distribution::Local,
543                input_schema: None,
544                output_schema: None,
545            }
546        }
547
548        fn as_any(&self) -> &dyn std::any::Any {
549            self
550        }
551    }
552
553    struct MeanFilter;
554    impl somatize_core::filter::Filter for MeanFilter {
555        fn config_hash(&self) -> CacheKey {
556            CacheKey::from_parts(&[b"Mean"])
557        }
558        fn fit(&self, x: &Value, _y: Option<&Value>) -> Result<Value> {
559            let (data, _) = x
560                .as_tensor()
561                .ok_or(SomaError::Other("need tensor".into()))?;
562            let mean = data.iter().sum::<f64>() / data.len() as f64;
563            Ok(Value::json(serde_json::json!({ "mean": mean })))
564        }
565        fn forward(&self, x: &Value, state: &Value) -> Result<Value> {
566            let (data, shape) = x
567                .as_tensor()
568                .ok_or(SomaError::Other("need tensor".into()))?;
569            let mean = state
570                .as_json()
571                .and_then(|j| j["mean"].as_f64())
572                .unwrap_or(0.0);
573            Ok(Value::tensor(
574                data.iter().map(|v| v - mean).collect(),
575                shape.to_vec(),
576            ))
577        }
578        fn meta(&self) -> FilterMeta {
579            FilterMeta {
580                name: "Mean".into(),
581                kind: FilterKind::Trainable,
582                cacheable: true,
583                differentiable: true,
584                stream_mode: StreamMode::FixedState,
585                distribution: somatize_core::filter::Distribution::Local,
586                input_schema: None,
587                output_schema: None,
588            }
589        }
590
591        fn as_any(&self) -> &dyn std::any::Any {
592            self
593        }
594    }
595
596    fn linear_graph(ids: &[&str]) -> Graph {
597        let mut g = Graph::new();
598        for &id in ids {
599            g.nodes.push(Node::new(id, id, id));
600        }
601        for (i, pair) in ids.windows(2).enumerate() {
602            g.edges.push(Edge::data(format!("e{i}"), pair[0], pair[1]));
603        }
604        g
605    }
606
607    // ── GraphSession tests ──
608
609    #[test]
610    fn session_run_linear() {
611        let graph = linear_graph(&["double", "add"]);
612        let mut lib = FilterLibrary::new();
613        lib.register("double", Box::new(DoublerFilter));
614        lib.register("add", Box::new(AdderFilter(10.0)));
615
616        let mut session = GraphSession::new(graph, lib);
617
618        let cache = MemoryCache::default();
619        session = session.with_cache(Arc::new(cache));
620
621        // Manual compile + execute via run
622        let CompileResult { plan, .. } = session.compile(CompileMode::NoCache).unwrap();
623        let bus = Arc::new(EventBus::new(64));
624        let mut ctx =
625            Context::new(bus, "test").with_graph_info(GraphInfo::from_graph(session.graph()));
626        ctx.set("__input__", Value::tensor(vec![1.0, 2.0, 3.0], vec![3]));
627        executor::execute(&plan, &mut ctx, session.library(), &MemoryCache::default()).unwrap();
628
629        let outputs: HashMap<String, Value> = ctx
630            .store
631            .into_iter()
632            .filter_map(|(k, vv)| vv.as_value().cloned().map(|v| (k, v)))
633            .collect();
634
635        let result = outputs.get("add").unwrap();
636        let (data, _) = result.as_tensor().unwrap();
637        assert_eq!(data, &[12.0, 14.0, 16.0]);
638    }
639
640    #[test]
641    fn session_fit_and_forward() {
642        let graph = linear_graph(&["mean", "double"]);
643        let mut lib = FilterLibrary::new();
644        lib.register("mean", Box::new(MeanFilter));
645        lib.register("double", Box::new(DoublerFilter));
646
647        let mut session = GraphSession::new(graph, lib);
648
649        let x = Value::tensor(vec![10.0, 20.0, 30.0], vec![3]);
650        let outputs = session.fit(&x, None).unwrap();
651
652        // mean: fit learns mean=20, forward: [10-20, 20-20, 30-20] = [-10, 0, 10]
653        // double: [-10, 0, 10] → [-20, 0, 20]
654        let result = outputs.get("double").unwrap();
655        let (data, _) = result.as_tensor().unwrap();
656        assert_eq!(data, &[-20.0, 0.0, 20.0]);
657
658        assert!(session.is_fitted());
659    }
660
661    #[test]
662    fn session_compile_diagnostics() {
663        let graph = linear_graph(&["double"]);
664        let mut lib = FilterLibrary::new();
665        lib.register("double", Box::new(DoublerFilter));
666
667        let session = GraphSession::new(graph, lib);
668        let result = session.compile(CompileMode::NoCache).unwrap();
669        assert!(result.plan.node_count() > 0);
670    }
671
672    // ── Free function tests (backward compat) ──
673
674    #[test]
675    fn graph_run_linear() {
676        let graph = linear_graph(&["double", "add"]);
677        let mut lib = FilterLibrary::new();
678        lib.register("double", Box::new(DoublerFilter));
679        lib.register("add", Box::new(AdderFilter(10.0)));
680
681        let cache = MemoryCache::default();
682
683        let outputs = {
684            let CompileResult { plan, .. } =
685                compile(&graph, &lib, CompileMode::NoCache, None).unwrap();
686            let bus = Arc::new(EventBus::new(64));
687            let mut ctx = Context::new(bus, "test").with_graph_info(GraphInfo::from_graph(&graph));
688            ctx.set("__input__", Value::tensor(vec![1.0, 2.0, 3.0], vec![3]));
689            executor::execute(&plan, &mut ctx, &lib, &cache).unwrap();
690            ctx.store
691                .into_iter()
692                .filter_map(|(k, vv)| vv.as_value().cloned().map(|v| (k, v)))
693                .collect::<HashMap<String, Value>>()
694        };
695
696        let result = outputs.get("add").unwrap();
697        let (data, _) = result.as_tensor().unwrap();
698        assert_eq!(data, &[12.0, 14.0, 16.0]);
699    }
700
701    #[test]
702    fn graph_run_diamond() {
703        let mut graph = Graph::new();
704        graph.nodes.push(Node::new("double", "Double", "double"));
705        graph.nodes.push(Node::new("add", "Add", "add"));
706        graph.nodes.push(Node::new("merge", "Merge", "merge"));
707        graph.edges.push(Edge::data("e1", "double", "merge"));
708        graph.edges.push(Edge::data("e2", "add", "merge"));
709
710        let mut lib = FilterLibrary::new();
711        lib.register("double", Box::new(DoublerFilter));
712        lib.register("add", Box::new(AdderFilter(100.0)));
713
714        struct MergeFilter;
715        impl somatize_core::filter::Filter for MergeFilter {
716            fn config_hash(&self) -> CacheKey {
717                CacheKey::from_parts(&[b"Merge"])
718            }
719            fn fit(&self, _x: &Value, _y: Option<&Value>) -> Result<Value> {
720                Ok(Value::Empty)
721            }
722            fn forward(&self, x: &Value, _state: &Value) -> Result<Value> {
723                Ok(x.clone())
724            }
725            fn meta(&self) -> FilterMeta {
726                FilterMeta {
727                    name: "Merge".into(),
728                    kind: FilterKind::Stateless,
729                    cacheable: true,
730                    differentiable: false,
731                    stream_mode: StreamMode::FixedState,
732                    distribution: somatize_core::filter::Distribution::Local,
733                    input_schema: None,
734                    output_schema: None,
735                }
736            }
737
738            fn as_any(&self) -> &dyn std::any::Any {
739                self
740            }
741        }
742        lib.register("merge", Box::new(MergeFilter));
743
744        let cache = MemoryCache::default();
745        let CompileResult { plan, .. } = compile(&graph, &lib, CompileMode::NoCache, None).unwrap();
746
747        let bus = Arc::new(EventBus::new(64));
748        let mut ctx = Context::new(bus, "test").with_graph_info(GraphInfo::from_graph(&graph));
749        ctx.set("__input__", Value::tensor(vec![5.0], vec![1]));
750        executor::execute(&plan, &mut ctx, &lib, &cache).unwrap();
751
752        let merge_output = ctx.get("merge").unwrap();
753        assert!(
754            merge_output.as_json().is_some(),
755            "merge should receive JSON from multiple predecessors"
756        );
757    }
758
759    #[test]
760    fn graph_fit_trainable() {
761        let graph = linear_graph(&["mean", "double"]);
762        let mut lib = FilterLibrary::new();
763        lib.register("mean", Box::new(MeanFilter));
764        lib.register("double", Box::new(DoublerFilter));
765
766        let cache = MemoryCache::default();
767        let x = Value::tensor(vec![10.0, 20.0, 30.0], vec![3]);
768
769        let outputs = graph_fit(&graph, &lib, &x, None, &cache).unwrap();
770
771        let result = outputs.get("double").unwrap();
772        let (data, _) = result.as_tensor().unwrap();
773        assert_eq!(data, &[-20.0, 0.0, 20.0]);
774
775        assert!(!cache.is_empty());
776    }
777
778    #[test]
779    fn filter_library_registry_compat() {
780        let mut lib = FilterLibrary::new();
781        lib.register("a", Box::new(DoublerFilter));
782
783        let registry: &dyn FilterRegistry = &lib;
784        assert!(registry.meta("a").is_some());
785        assert_eq!(registry.meta("a").unwrap().name, "Doubler");
786        assert!(registry.config_hash("a").is_some());
787        assert!(registry.meta("b").is_none());
788    }
789}