Skip to main content

somatize_runtime/
graph_session.rs

1//! Graph session — the primary orchestrator for Graph → Compile → Execute.
2//!
3//! [`GraphSession`] binds a [`Graph`] with its [`FilterLibrary`], cache,
4//! event bus, and optional distributed components into a single object
5//! that can compile, fit, and execute.
6
7use crate::cache::MemoryCache;
8use crate::event_bus::EventBus;
9use crate::executor::{self, Context, GraphInfo};
10use crate::filter_library::FilterLibrary;
11use crate::runner::Runner;
12use crate::runner::Transport;
13use somatize_compiler::{CompileMode, CompileResult, compile};
14use somatize_core::cache::{CacheKey, CacheStore};
15use somatize_core::error::{Result, SomaError};
16use somatize_core::event::Event;
17use somatize_core::filter::FilterKind;
18use somatize_core::graph::Graph;
19use somatize_core::store::{DataRef, DataStore};
20use somatize_core::util::timestamp_id;
21use somatize_core::value::Value;
22use std::collections::HashMap;
23use std::sync::Arc;
24
25/// The primary orchestrator: Graph + filters + cache + events.
26///
27/// ```ignore
28/// let mut lib = FilterLibrary::new();
29/// lib.register("scaler", Box::new(MyScaler::new()));
30/// lib.register("model", Box::new(MyModel::new()));
31///
32/// let mut session = GraphSession::new(graph, lib);
33/// session.fit(&train_x, Some(&train_y))?;
34/// let output = session.forward(&test_x)?;
35/// ```
36pub struct GraphSession {
37    graph: Graph,
38    library: FilterLibrary,
39    cache: Arc<dyn CacheStore>,
40    event_bus: Arc<EventBus>,
41    data_store: Option<Arc<dyn DataStore>>,
42    transport: Option<Arc<dyn Transport>>,
43    fitted: bool,
44}
45
46impl GraphSession {
47    pub fn new(graph: Graph, library: FilterLibrary) -> Self {
48        Self {
49            graph,
50            library,
51            cache: Arc::new(MemoryCache::default()),
52            event_bus: Arc::new(EventBus::new(256)),
53            data_store: None,
54            transport: None,
55            fitted: false,
56        }
57    }
58
59    pub fn with_cache(mut self, cache: Arc<dyn CacheStore>) -> Self {
60        self.cache = cache;
61        self
62    }
63
64    pub fn with_event_bus(mut self, bus: Arc<EventBus>) -> Self {
65        self.event_bus = bus;
66        self
67    }
68
69    pub fn with_data_store(mut self, store: Arc<dyn DataStore>) -> Self {
70        self.data_store = Some(store);
71        self
72    }
73
74    pub fn with_transport(mut self, transport: Arc<dyn Transport>) -> Self {
75        self.transport = Some(transport);
76        self
77    }
78
79    // ── Core operations ──
80
81    /// Compile the graph and return diagnostics without executing.
82    pub fn compile(&self, mode: CompileMode) -> Result<CompileResult> {
83        compile(&self.graph, &self.library, mode, Some(self.cache.as_ref()))
84    }
85
86    /// Compile and execute the graph, returning all node outputs.
87    pub fn run(&mut self, mode: CompileMode) -> Result<HashMap<String, Value>> {
88        let CompileResult { plan, diagnostics } =
89            compile(&self.graph, &self.library, mode, Some(self.cache.as_ref()))?;
90
91        for diag in &diagnostics {
92            tracing::warn!("compile diagnostic: {:?}", diag);
93        }
94
95        let graph_info = GraphInfo::from_graph(&self.graph);
96        let mut ctx = Context::new(self.event_bus.clone(), timestamp_id("graph_run"))
97            .with_graph_info(graph_info);
98
99        if let Some(store) = &self.data_store {
100            ctx = ctx.with_data_store(store.clone());
101        }
102        if let Some(transport) = &self.transport {
103            ctx = ctx.with_transport(transport.clone());
104        }
105
106        executor::execute(&plan, &mut ctx, &self.library, self.cache.as_ref())?;
107
108        Ok(ctx
109            .store
110            .into_iter()
111            .filter_map(|(k, vv)| vv.as_value().cloned().map(|v| (k, v)))
112            .collect())
113    }
114
115    /// Fit all trainable filters in topological order.
116    /// Delegates to LocalRunner — same execution path as remote workers.
117    pub fn fit(&mut self, x: &Value, y: Option<&Value>) -> Result<HashMap<String, Value>> {
118        self.graph.validate()?;
119
120        let CompileResult { plan, .. } = compile(
121            &self.graph,
122            &self.library,
123            CompileMode::NoCache,
124            Some(self.cache.as_ref()),
125        )?;
126
127        let runner = crate::runner::LocalRunner;
128        let (_last_output, mut all_outputs) = runner.fit(
129            &plan,
130            &self.library,
131            self.cache.as_ref(),
132            &self.event_bus,
133            x,
134            y,
135        )?;
136
137        // Store trained states from __state_ keys into FilterLibrary
138        for (key, value) in &all_outputs {
139            if let Some(node_id) = key.strip_prefix("__state_") {
140                self.library.set_state(node_id, value.clone());
141            }
142        }
143
144        // Remove __state_ keys from returned outputs (callers expect node IDs only)
145        all_outputs.retain(|k, _| !k.starts_with("__state_"));
146
147        self.fitted = true;
148        Ok(all_outputs)
149    }
150
151    /// Forward pass through the fitted graph (inference).
152    ///
153    /// Compiles in Inference mode and executes. Trainable filters
154    /// use states cached during `fit()`.
155    /// Forward pass through the fitted graph (inference).
156    /// Delegates to LocalRunner — same execution path as remote workers.
157    pub fn forward(&self, x: &Value) -> Result<Value> {
158        let CompileResult { plan, .. } = compile(
159            &self.graph,
160            &self.library,
161            CompileMode::Inference,
162            Some(self.cache.as_ref()),
163        )?;
164
165        let runner = crate::runner::LocalRunner;
166        runner.forward(
167            &plan,
168            &self.library,
169            self.cache.as_ref(),
170            &self.event_bus,
171            x,
172        )
173    }
174
175    /// Forward pass in batches from a DataStore reference.
176    pub fn forward_batched(&self, data_ref: &DataRef, batch_size: usize) -> Result<Value> {
177        let store = self
178            .data_store
179            .as_ref()
180            .ok_or_else(|| SomaError::Execution {
181                node_id: "session".into(),
182                message: "forward_batched requires a data store (use with_data_store)".into(),
183            })?;
184
185        let meta = store.meta(data_ref)?;
186        let total_rows = meta.total_rows;
187        if total_rows == 0 {
188            return Ok(Value::Empty);
189        }
190
191        let mut all_values: Vec<f64> = Vec::new();
192        let mut result_shape: Option<Vec<usize>> = None;
193        let mut rows_processed = 0;
194
195        while rows_processed < total_rows {
196            let batch_len = batch_size.min(total_rows - rows_processed);
197            let batch = store.get_rows(data_ref, rows_processed, batch_len)?;
198            let output = self.forward(&batch)?;
199
200            if let Value::Tensor { values, shape } = &output {
201                if result_shape.is_none() {
202                    result_shape = Some(shape.clone());
203                }
204                all_values.extend_from_slice(values);
205            } else {
206                return Ok(output);
207            }
208
209            rows_processed += batch_len;
210        }
211
212        match result_shape {
213            Some(mut shape) => {
214                shape[0] = total_rows;
215                Ok(Value::tensor(all_values, shape))
216            }
217            None => Ok(Value::Empty),
218        }
219    }
220
221    // ── State persistence ──
222
223    /// Persist all trained states to the data store.
224    pub fn persist_states(&self) -> Result<DataRef> {
225        let store = self
226            .data_store
227            .as_ref()
228            .ok_or_else(|| SomaError::Execution {
229                node_id: "session".into(),
230                message: "persist_states requires a data store".into(),
231            })?;
232
233        let sorted = self.graph.topological_sort()?;
234        let mut states_map = serde_json::Map::new();
235        for node_id in &sorted {
236            if let Some(state) = self.library.get_state(node_id) {
237                let json = serde_json::to_value(state)
238                    .map_err(|e| SomaError::Other(format!("state serialize: {e}")))?;
239                states_map.insert(node_id.to_string(), json);
240            }
241        }
242
243        let states_value = Value::Json(serde_json::Value::Object(states_map));
244        let key = CacheKey::from_parts(&[b"graph_states", self.graph_config_hash().as_bytes()]);
245        store.put(&key, &states_value)
246    }
247
248    /// Load previously persisted states from a data store reference.
249    pub fn load_states(&mut self, data_ref: &DataRef) -> Result<()> {
250        let store = self
251            .data_store
252            .as_ref()
253            .ok_or_else(|| SomaError::Execution {
254                node_id: "session".into(),
255                message: "load_states requires a data store".into(),
256            })?;
257
258        let states_value = store.get(data_ref)?;
259        let states_json = states_value
260            .as_json()
261            .ok_or_else(|| SomaError::Other("persisted states must be JSON".into()))?;
262        let obj = states_json
263            .as_object()
264            .ok_or_else(|| SomaError::Other("persisted states must be a JSON object".into()))?;
265
266        for (node_id, json_val) in obj {
267            let value: Value = serde_json::from_value(json_val.clone())
268                .map_err(|e| SomaError::Other(format!("state deserialize: {e}")))?;
269            self.library.set_state(node_id.clone(), value);
270        }
271
272        self.fitted = true;
273        Ok(())
274    }
275
276    // ── Observability ──
277
278    /// Subscribe to execution events.
279    pub fn subscribe(&self) -> tokio::sync::broadcast::Receiver<Event> {
280        self.event_bus.subscribe()
281    }
282
283    /// Access the event bus directly.
284    pub fn event_bus(&self) -> &Arc<EventBus> {
285        &self.event_bus
286    }
287
288    /// Whether the session has been fitted.
289    pub fn is_fitted(&self) -> bool {
290        self.fitted
291    }
292
293    /// Access the graph.
294    pub fn graph(&self) -> &Graph {
295        &self.graph
296    }
297
298    /// Access the filter library.
299    pub fn library(&self) -> &FilterLibrary {
300        &self.library
301    }
302
303    /// Mutable access to the filter library (for registering filters after creation).
304    pub fn library_mut(&mut self) -> &mut FilterLibrary {
305        &mut self.library
306    }
307
308    // ── Private helpers ──
309
310    fn graph_config_hash(&self) -> String {
311        let node_ids: Vec<&str> = self.graph.nodes.iter().map(|n| n.id.as_str()).collect();
312        node_ids.join(",")
313    }
314}
315
316// ── Convenience free functions (backward compat) ──
317
318/// Compile and execute a graph, returning all node outputs.
319pub fn graph_run(
320    graph: &Graph,
321    library: &FilterLibrary,
322    mode: CompileMode,
323    cache: &dyn CacheStore,
324) -> Result<HashMap<String, Value>> {
325    let CompileResult { plan, diagnostics } = compile(graph, library, mode, Some(cache))?;
326
327    for diag in &diagnostics {
328        tracing::warn!("compile diagnostic: {:?}", diag);
329    }
330
331    let bus = Arc::new(EventBus::new(256));
332    let graph_info = GraphInfo::from_graph(graph);
333
334    let mut ctx = Context::new(bus, timestamp_id("graph_run")).with_graph_info(graph_info);
335
336    executor::execute(&plan, &mut ctx, library, cache)?;
337
338    Ok(ctx
339        .store
340        .into_iter()
341        .filter_map(|(k, vv)| vv.as_value().cloned().map(|v| (k, v)))
342        .collect())
343}
344
345/// Fit all trainable filters in topological order.
346pub fn graph_fit(
347    graph: &Graph,
348    library: &FilterLibrary,
349    x: &Value,
350    y: Option<&Value>,
351    cache: &dyn CacheStore,
352) -> Result<HashMap<String, Value>> {
353    graph.validate()?;
354    let sorted = graph.topological_sort()?;
355    let graph_info = GraphInfo::from_graph(graph);
356
357    let bus = Arc::new(EventBus::new(256));
358    let run_id = timestamp_id("graph_fit");
359
360    let mut outputs: HashMap<String, Value> = HashMap::new();
361
362    let roots = graph.roots();
363    for root_id in &roots {
364        outputs.insert(format!("__input_{root_id}"), x.clone());
365    }
366
367    for node_id in &sorted {
368        let filter = library
369            .get(node_id)
370            .ok_or_else(|| SomaError::NodeNotFound(node_id.to_string()))?;
371
372        bus.emit(Event::NodeStarted {
373            run_id: run_id.clone(),
374            node_id: node_id.to_string(),
375            kind: filter.meta().kind,
376        });
377
378        let preds = graph_info.predecessors(node_id);
379        let input = match preds.len() {
380            0 => x.clone(),
381            1 => outputs.get(&preds[0]).cloned().unwrap_or_else(|| x.clone()),
382            _ => {
383                let mut merged = serde_json::Map::new();
384                for pred_id in preds {
385                    if let Some(val) = outputs.get(pred_id.as_str()) {
386                        let json_val = serde_json::to_value(val).unwrap_or(serde_json::Value::Null);
387                        merged.insert(pred_id.clone(), json_val);
388                    }
389                }
390                Value::Json(serde_json::Value::Object(merged))
391            }
392        };
393
394        let meta = filter.meta();
395        let start = std::time::Instant::now();
396
397        let (state, output) = if meta.kind == FilterKind::Trainable {
398            let data_hash = CacheKey::hash_data(&serde_json::to_vec(&input).unwrap_or_default());
399            let state_key = CacheKey::for_state(&filter.config_hash(), &data_hash);
400
401            let state = if let Some(cached) = cache.get(&state_key)? {
402                cached
403            } else {
404                let s = filter.fit(&input, y)?;
405                cache.put(&state_key, &s)?;
406                s
407            };
408
409            let output = filter.forward(&input, &state)?;
410            (state, output)
411        } else {
412            let output = filter.forward(&input, &Value::Empty)?;
413            (Value::Empty, output)
414        };
415
416        let _ = state;
417
418        bus.emit(Event::NodeCompleted {
419            run_id: run_id.clone(),
420            node_id: node_id.to_string(),
421            duration: start.elapsed(),
422            output_summary: format!("{output}"),
423        });
424
425        outputs.insert(node_id.to_string(), output);
426    }
427
428    Ok(outputs)
429}
430
431/// Compile in Inference mode and execute, returning the last leaf's output.
432pub fn graph_predict(
433    graph: &Graph,
434    library: &FilterLibrary,
435    x: &Value,
436    cache: &dyn CacheStore,
437) -> Result<Value> {
438    let CompileResult { plan, .. } = compile(graph, library, CompileMode::Inference, Some(cache))?;
439
440    let bus = Arc::new(EventBus::new(256));
441    let graph_info = GraphInfo::from_graph(graph);
442    let mut ctx = Context::new(bus, timestamp_id("graph_predict")).with_graph_info(graph_info);
443
444    let roots = graph.roots();
445    if roots.len() == 1 {
446        ctx.set(format!("__input_{}", roots[0]), x.clone());
447    }
448    ctx.set("__input__", x.clone());
449
450    executor::execute(&plan, &mut ctx, library, cache)?;
451
452    let leaves = graph.leaves();
453    let mut extract =
454        |id: &str| -> Option<Value> { ctx.store.remove(id).and_then(|vv| vv.as_value().cloned()) };
455
456    if let Some(leaf_id) = leaves.first() {
457        extract(leaf_id)
458            .ok_or_else(|| SomaError::Other(format!("leaf node '{leaf_id}' produced no output")))
459    } else {
460        ctx.execution_order
461            .last()
462            .and_then(|id| extract(id))
463            .ok_or_else(|| SomaError::Other("no output produced".into()))
464    }
465}
466
467#[cfg(test)]
468mod tests {
469    use super::*;
470    use crate::cache::MemoryCache;
471    use somatize_compiler::FilterRegistry;
472    use somatize_core::cache::CacheKey;
473    use somatize_core::error::Result;
474    use somatize_core::filter::{FilterKind, FilterMeta, StreamMode};
475    use somatize_core::graph::{Edge, Node};
476
477    // ── Test filters ──
478
479    struct DoublerFilter;
480    impl somatize_core::filter::Filter for DoublerFilter {
481        fn config_hash(&self) -> CacheKey {
482            CacheKey::from_parts(&[b"Doubler"])
483        }
484        fn fit(&self, _x: &Value, _y: Option<&Value>) -> Result<Value> {
485            Ok(Value::Empty)
486        }
487        fn forward(&self, x: &Value, _state: &Value) -> Result<Value> {
488            let (data, shape) = x
489                .as_tensor()
490                .ok_or(SomaError::Other("need tensor".into()))?;
491            Ok(Value::tensor(
492                data.iter().map(|v| v * 2.0).collect(),
493                shape.to_vec(),
494            ))
495        }
496        fn meta(&self) -> FilterMeta {
497            FilterMeta {
498                name: "Doubler".into(),
499                kind: FilterKind::Stateless,
500                cacheable: true,
501                differentiable: true,
502                stream_mode: StreamMode::FixedState,
503                distribution: somatize_core::filter::Distribution::Local,
504                input_schema: None,
505                output_schema: None,
506            }
507        }
508
509        fn as_any(&self) -> &dyn std::any::Any {
510            self
511        }
512    }
513
514    struct AdderFilter(f64);
515    impl somatize_core::filter::Filter for AdderFilter {
516        fn config_hash(&self) -> CacheKey {
517            CacheKey::from_parts(&[b"Adder", &self.0.to_le_bytes()])
518        }
519        fn fit(&self, _x: &Value, _y: Option<&Value>) -> Result<Value> {
520            Ok(Value::Empty)
521        }
522        fn forward(&self, x: &Value, _state: &Value) -> Result<Value> {
523            let (data, shape) = x
524                .as_tensor()
525                .ok_or(SomaError::Other("need tensor".into()))?;
526            Ok(Value::tensor(
527                data.iter().map(|v| v + self.0).collect(),
528                shape.to_vec(),
529            ))
530        }
531        fn meta(&self) -> FilterMeta {
532            FilterMeta {
533                name: "Adder".into(),
534                kind: FilterKind::Stateless,
535                cacheable: true,
536                differentiable: true,
537                stream_mode: StreamMode::FixedState,
538                distribution: somatize_core::filter::Distribution::Local,
539                input_schema: None,
540                output_schema: None,
541            }
542        }
543
544        fn as_any(&self) -> &dyn std::any::Any {
545            self
546        }
547    }
548
549    struct MeanFilter;
550    impl somatize_core::filter::Filter for MeanFilter {
551        fn config_hash(&self) -> CacheKey {
552            CacheKey::from_parts(&[b"Mean"])
553        }
554        fn fit(&self, x: &Value, _y: Option<&Value>) -> Result<Value> {
555            let (data, _) = x
556                .as_tensor()
557                .ok_or(SomaError::Other("need tensor".into()))?;
558            let mean = data.iter().sum::<f64>() / data.len() as f64;
559            Ok(Value::json(serde_json::json!({ "mean": mean })))
560        }
561        fn forward(&self, x: &Value, state: &Value) -> Result<Value> {
562            let (data, shape) = x
563                .as_tensor()
564                .ok_or(SomaError::Other("need tensor".into()))?;
565            let mean = state
566                .as_json()
567                .and_then(|j| j["mean"].as_f64())
568                .unwrap_or(0.0);
569            Ok(Value::tensor(
570                data.iter().map(|v| v - mean).collect(),
571                shape.to_vec(),
572            ))
573        }
574        fn meta(&self) -> FilterMeta {
575            FilterMeta {
576                name: "Mean".into(),
577                kind: FilterKind::Trainable,
578                cacheable: true,
579                differentiable: true,
580                stream_mode: StreamMode::FixedState,
581                distribution: somatize_core::filter::Distribution::Local,
582                input_schema: None,
583                output_schema: None,
584            }
585        }
586
587        fn as_any(&self) -> &dyn std::any::Any {
588            self
589        }
590    }
591
592    fn linear_graph(ids: &[&str]) -> Graph {
593        let mut g = Graph::new();
594        for &id in ids {
595            g.nodes.push(Node::new(id, id, id));
596        }
597        for (i, pair) in ids.windows(2).enumerate() {
598            g.edges.push(Edge::data(format!("e{i}"), pair[0], pair[1]));
599        }
600        g
601    }
602
603    // ── GraphSession tests ──
604
605    #[test]
606    fn session_run_linear() {
607        let graph = linear_graph(&["double", "add"]);
608        let mut lib = FilterLibrary::new();
609        lib.register("double", Box::new(DoublerFilter));
610        lib.register("add", Box::new(AdderFilter(10.0)));
611
612        let mut session = GraphSession::new(graph, lib);
613
614        let cache = MemoryCache::default();
615        session = session.with_cache(Arc::new(cache));
616
617        // Manual compile + execute via run
618        let CompileResult { plan, .. } = session.compile(CompileMode::NoCache).unwrap();
619        let bus = Arc::new(EventBus::new(64));
620        let mut ctx =
621            Context::new(bus, "test").with_graph_info(GraphInfo::from_graph(session.graph()));
622        ctx.set("__input__", Value::tensor(vec![1.0, 2.0, 3.0], vec![3]));
623        executor::execute(&plan, &mut ctx, session.library(), &MemoryCache::default()).unwrap();
624
625        let outputs: HashMap<String, Value> = ctx
626            .store
627            .into_iter()
628            .filter_map(|(k, vv)| vv.as_value().cloned().map(|v| (k, v)))
629            .collect();
630
631        let result = outputs.get("add").unwrap();
632        let (data, _) = result.as_tensor().unwrap();
633        assert_eq!(data, &[12.0, 14.0, 16.0]);
634    }
635
636    #[test]
637    fn session_fit_and_forward() {
638        let graph = linear_graph(&["mean", "double"]);
639        let mut lib = FilterLibrary::new();
640        lib.register("mean", Box::new(MeanFilter));
641        lib.register("double", Box::new(DoublerFilter));
642
643        let mut session = GraphSession::new(graph, lib);
644
645        let x = Value::tensor(vec![10.0, 20.0, 30.0], vec![3]);
646        let outputs = session.fit(&x, None).unwrap();
647
648        // mean: fit learns mean=20, forward: [10-20, 20-20, 30-20] = [-10, 0, 10]
649        // double: [-10, 0, 10] → [-20, 0, 20]
650        let result = outputs.get("double").unwrap();
651        let (data, _) = result.as_tensor().unwrap();
652        assert_eq!(data, &[-20.0, 0.0, 20.0]);
653
654        assert!(session.is_fitted());
655    }
656
657    #[test]
658    fn session_compile_diagnostics() {
659        let graph = linear_graph(&["double"]);
660        let mut lib = FilterLibrary::new();
661        lib.register("double", Box::new(DoublerFilter));
662
663        let session = GraphSession::new(graph, lib);
664        let result = session.compile(CompileMode::NoCache).unwrap();
665        assert!(result.plan.node_count() > 0);
666    }
667
668    // ── Free function tests (backward compat) ──
669
670    #[test]
671    fn graph_run_linear() {
672        let graph = linear_graph(&["double", "add"]);
673        let mut lib = FilterLibrary::new();
674        lib.register("double", Box::new(DoublerFilter));
675        lib.register("add", Box::new(AdderFilter(10.0)));
676
677        let cache = MemoryCache::default();
678
679        let outputs = {
680            let CompileResult { plan, .. } =
681                compile(&graph, &lib, CompileMode::NoCache, None).unwrap();
682            let bus = Arc::new(EventBus::new(64));
683            let mut ctx = Context::new(bus, "test").with_graph_info(GraphInfo::from_graph(&graph));
684            ctx.set("__input__", Value::tensor(vec![1.0, 2.0, 3.0], vec![3]));
685            executor::execute(&plan, &mut ctx, &lib, &cache).unwrap();
686            ctx.store
687                .into_iter()
688                .filter_map(|(k, vv)| vv.as_value().cloned().map(|v| (k, v)))
689                .collect::<HashMap<String, Value>>()
690        };
691
692        let result = outputs.get("add").unwrap();
693        let (data, _) = result.as_tensor().unwrap();
694        assert_eq!(data, &[12.0, 14.0, 16.0]);
695    }
696
697    #[test]
698    fn graph_run_diamond() {
699        let mut graph = Graph::new();
700        graph.nodes.push(Node::new("double", "Double", "double"));
701        graph.nodes.push(Node::new("add", "Add", "add"));
702        graph.nodes.push(Node::new("merge", "Merge", "merge"));
703        graph.edges.push(Edge::data("e1", "double", "merge"));
704        graph.edges.push(Edge::data("e2", "add", "merge"));
705
706        let mut lib = FilterLibrary::new();
707        lib.register("double", Box::new(DoublerFilter));
708        lib.register("add", Box::new(AdderFilter(100.0)));
709
710        struct MergeFilter;
711        impl somatize_core::filter::Filter for MergeFilter {
712            fn config_hash(&self) -> CacheKey {
713                CacheKey::from_parts(&[b"Merge"])
714            }
715            fn fit(&self, _x: &Value, _y: Option<&Value>) -> Result<Value> {
716                Ok(Value::Empty)
717            }
718            fn forward(&self, x: &Value, _state: &Value) -> Result<Value> {
719                Ok(x.clone())
720            }
721            fn meta(&self) -> FilterMeta {
722                FilterMeta {
723                    name: "Merge".into(),
724                    kind: FilterKind::Stateless,
725                    cacheable: true,
726                    differentiable: false,
727                    stream_mode: StreamMode::FixedState,
728                    distribution: somatize_core::filter::Distribution::Local,
729                    input_schema: None,
730                    output_schema: None,
731                }
732            }
733
734            fn as_any(&self) -> &dyn std::any::Any {
735                self
736            }
737        }
738        lib.register("merge", Box::new(MergeFilter));
739
740        let cache = MemoryCache::default();
741        let CompileResult { plan, .. } = compile(&graph, &lib, CompileMode::NoCache, None).unwrap();
742
743        let bus = Arc::new(EventBus::new(64));
744        let mut ctx = Context::new(bus, "test").with_graph_info(GraphInfo::from_graph(&graph));
745        ctx.set("__input__", Value::tensor(vec![5.0], vec![1]));
746        executor::execute(&plan, &mut ctx, &lib, &cache).unwrap();
747
748        let merge_output = ctx.get("merge").unwrap();
749        assert!(
750            merge_output.as_json().is_some(),
751            "merge should receive JSON from multiple predecessors"
752        );
753    }
754
755    #[test]
756    fn graph_fit_trainable() {
757        let graph = linear_graph(&["mean", "double"]);
758        let mut lib = FilterLibrary::new();
759        lib.register("mean", Box::new(MeanFilter));
760        lib.register("double", Box::new(DoublerFilter));
761
762        let cache = MemoryCache::default();
763        let x = Value::tensor(vec![10.0, 20.0, 30.0], vec![3]);
764
765        let outputs = graph_fit(&graph, &lib, &x, None, &cache).unwrap();
766
767        let result = outputs.get("double").unwrap();
768        let (data, _) = result.as_tensor().unwrap();
769        assert_eq!(data, &[-20.0, 0.0, 20.0]);
770
771        assert!(!cache.is_empty());
772    }
773
774    #[test]
775    fn filter_library_registry_compat() {
776        let mut lib = FilterLibrary::new();
777        lib.register("a", Box::new(DoublerFilter));
778
779        let registry: &dyn FilterRegistry = &lib;
780        assert!(registry.meta("a").is_some());
781        assert_eq!(registry.meta("a").unwrap().name, "Doubler");
782        assert!(registry.config_hash("a").is_some());
783        assert!(registry.meta("b").is_none());
784    }
785}