Skip to main content

somatize_runtime/
graph_session.rs

1//! Graph session — the primary orchestrator for Graph → Compile → Execute.
2//!
3//! [`GraphSession`] binds a [`Graph`] with its [`FilterLibrary`], cache,
4//! event bus, and optional distributed components into a single object
5//! that can compile, fit, and execute.
6
7use crate::cache::MemoryCache;
8use crate::event_bus::EventBus;
9use crate::executor::{self, Context, GraphInfo};
10use crate::filter_library::FilterLibrary;
11use crate::runner::Runner;
12use crate::runner::Transport;
13use somatize_compiler::{CompileMode, CompileResult, compile};
14use somatize_core::cache::{CacheKey, CacheStore};
15use somatize_core::error::{Result, SomaError};
16use somatize_core::event::Event;
17use somatize_core::filter::FilterKind;
18use somatize_core::graph::Graph;
19use somatize_core::store::{DataRef, DataStore};
20use somatize_core::util::timestamp_id;
21use somatize_core::value::Value;
22use std::collections::HashMap;
23use std::sync::Arc;
24
25/// The primary orchestrator: Graph + filters + cache + events.
26///
27/// ```ignore
28/// let mut lib = FilterLibrary::new();
29/// lib.register("scaler", Box::new(MyScaler::new()));
30/// lib.register("model", Box::new(MyModel::new()));
31///
32/// let mut session = GraphSession::new(graph, lib);
33/// session.fit(&train_x, Some(&train_y))?;
34/// let output = session.forward(&test_x)?;
35/// ```
36pub struct GraphSession {
37    graph: Graph,
38    library: FilterLibrary,
39    cache: Arc<dyn CacheStore>,
40    event_bus: Arc<EventBus>,
41    data_store: Option<Arc<dyn DataStore>>,
42    transport: Option<Arc<dyn Transport>>,
43    fitted: bool,
44}
45
46impl GraphSession {
47    pub fn new(graph: Graph, library: FilterLibrary) -> Self {
48        Self {
49            graph,
50            library,
51            cache: Arc::new(MemoryCache::default()),
52            event_bus: Arc::new(EventBus::new(256)),
53            data_store: None,
54            transport: None,
55            fitted: false,
56        }
57    }
58
59    pub fn with_cache(mut self, cache: Arc<dyn CacheStore>) -> Self {
60        self.cache = cache;
61        self
62    }
63
64    pub fn with_event_bus(mut self, bus: Arc<EventBus>) -> Self {
65        self.event_bus = bus;
66        self
67    }
68
69    pub fn with_data_store(mut self, store: Arc<dyn DataStore>) -> Self {
70        self.data_store = Some(store);
71        self
72    }
73
74    pub fn with_transport(mut self, transport: Arc<dyn Transport>) -> Self {
75        self.transport = Some(transport);
76        self
77    }
78
79    // ── Core operations ──
80
81    /// Compile the graph and return diagnostics without executing.
82    pub fn compile(&self, mode: CompileMode) -> Result<CompileResult> {
83        compile(&self.graph, &self.library, mode, Some(self.cache.as_ref()))
84    }
85
86    /// Compile and execute the graph, returning all node outputs.
87    pub fn run(&mut self, mode: CompileMode) -> Result<HashMap<String, Value>> {
88        let CompileResult { plan, diagnostics } =
89            compile(&self.graph, &self.library, mode, Some(self.cache.as_ref()))?;
90
91        for diag in &diagnostics {
92            tracing::warn!("compile diagnostic: {:?}", diag);
93        }
94
95        let graph_info = GraphInfo::from_graph(&self.graph);
96        let mut ctx = Context::new(self.event_bus.clone(), timestamp_id("graph_run"))
97            .with_graph_info(graph_info);
98
99        if let Some(store) = &self.data_store {
100            ctx = ctx.with_data_store(store.clone());
101        }
102        if let Some(transport) = &self.transport {
103            ctx = ctx.with_transport(transport.clone());
104        }
105
106        executor::execute(&plan, &mut ctx, &self.library, self.cache.as_ref())?;
107
108        Ok(ctx
109            .store
110            .into_iter()
111            .filter_map(|(k, vv)| vv.as_value().cloned().map(|v| (k, v)))
112            .collect())
113    }
114
115    /// Fit all trainable filters in topological order.
116    /// Delegates to LocalRunner — same execution path as remote workers.
117    pub fn fit(&mut self, x: &Value, y: Option<&Value>) -> Result<HashMap<String, Value>> {
118        self.graph.validate()?;
119
120        let CompileResult { plan, .. } = compile(
121            &self.graph,
122            &self.library,
123            CompileMode::NoCache,
124            Some(self.cache.as_ref()),
125        )?;
126
127        let runner = crate::runner::LocalRunner;
128        let (_last_output, mut all_outputs) = runner.fit(
129            &plan,
130            &self.library,
131            self.cache.as_ref(),
132            &self.event_bus,
133            x,
134            y,
135        )?;
136
137        // Store trained states from __state_ keys into FilterLibrary
138        for (key, value) in &all_outputs {
139            if let Some(node_id) = key.strip_prefix("__state_") {
140                self.library.set_state(node_id, value.clone());
141            }
142        }
143
144        // Remove __state_ keys from returned outputs (callers expect node IDs only)
145        all_outputs.retain(|k, _| !k.starts_with("__state_"));
146
147        self.fitted = true;
148        Ok(all_outputs)
149    }
150
151    /// Forward pass using the given strategy.
152    ///
153    /// Strategies define HOW data flows through the compiled graph:
154    /// - [`Standard`] — full input at once with inference caching (default)
155    /// - [`Stream`] — chunked input through StreamExecutor
156    /// - [`Batched`] — rows from DataStore, batch by batch
157    pub fn forward_with(
158        &self,
159        x: &Value,
160        strategy: &dyn crate::forward::ForwardStrategy,
161    ) -> Result<Value> {
162        strategy.forward(
163            &self.graph,
164            &self.library,
165            self.cache.as_ref(),
166            &self.event_bus,
167            self.data_store.as_ref(),
168            x,
169        )
170    }
171
172    /// Standard forward pass (shortcut for `forward_with(x, &Standard)`).
173    pub fn forward(&self, x: &Value) -> Result<Value> {
174        self.forward_with(x, &crate::forward::Standard)
175    }
176
177    // ── State persistence ──
178
179    /// Persist all trained states to the data store.
180    pub fn persist_states(&self) -> Result<DataRef> {
181        let store = self
182            .data_store
183            .as_ref()
184            .ok_or_else(|| SomaError::Execution {
185                node_id: "session".into(),
186                message: "persist_states requires a data store".into(),
187            })?;
188
189        let sorted = self.graph.topological_sort()?;
190        let mut states_map = serde_json::Map::new();
191        for node_id in &sorted {
192            if let Some(state) = self.library.get_state(node_id) {
193                let json = serde_json::to_value(&*state)
194                    .map_err(|e| SomaError::Other(format!("state serialize: {e}")))?;
195                states_map.insert(node_id.to_string(), json);
196            }
197        }
198
199        let states_value = Value::json(serde_json::Value::Object(states_map));
200        let key = CacheKey::from_parts(&[b"graph_states", self.graph_config_hash().as_bytes()]);
201        store.put(&key, &states_value)
202    }
203
204    /// Load previously persisted states from a data store reference.
205    pub fn load_states(&mut self, data_ref: &DataRef) -> Result<()> {
206        let store = self
207            .data_store
208            .as_ref()
209            .ok_or_else(|| SomaError::Execution {
210                node_id: "session".into(),
211                message: "load_states requires a data store".into(),
212            })?;
213
214        let states_value = store.get(data_ref)?;
215        let states_json = states_value
216            .as_json()
217            .ok_or_else(|| SomaError::Other("persisted states must be JSON".into()))?;
218        let obj = states_json
219            .as_object()
220            .ok_or_else(|| SomaError::Other("persisted states must be a JSON object".into()))?;
221
222        for (node_id, json_val) in obj {
223            let value: Value = serde_json::from_value(json_val.clone())
224                .map_err(|e| SomaError::Other(format!("state deserialize: {e}")))?;
225            self.library.set_state(node_id.clone(), value);
226        }
227
228        self.fitted = true;
229        Ok(())
230    }
231
232    // ── Observability ──
233
234    /// Subscribe to execution events.
235    pub fn subscribe(&self) -> tokio::sync::broadcast::Receiver<Event> {
236        self.event_bus.subscribe()
237    }
238
239    /// Access the event bus directly.
240    pub fn event_bus(&self) -> &Arc<EventBus> {
241        &self.event_bus
242    }
243
244    /// Whether the session has been fitted.
245    pub fn is_fitted(&self) -> bool {
246        self.fitted
247    }
248
249    /// Access the graph.
250    pub fn graph(&self) -> &Graph {
251        &self.graph
252    }
253
254    /// Access the filter library.
255    pub fn library(&self) -> &FilterLibrary {
256        &self.library
257    }
258
259    /// Mutable access to the filter library (for registering filters after creation).
260    pub fn library_mut(&mut self) -> &mut FilterLibrary {
261        &mut self.library
262    }
263
264    // ── Private helpers ──
265
266    fn graph_config_hash(&self) -> String {
267        let node_ids: Vec<&str> = self.graph.nodes.iter().map(|n| n.id.as_str()).collect();
268        node_ids.join(",")
269    }
270}
271
272// ── Convenience free functions (backward compat) ──
273
274/// Compile and execute a graph, returning all node outputs.
275pub fn graph_run(
276    graph: &Graph,
277    library: &FilterLibrary,
278    mode: CompileMode,
279    cache: &dyn CacheStore,
280) -> Result<HashMap<String, Value>> {
281    let CompileResult { plan, diagnostics } = compile(graph, library, mode, Some(cache))?;
282
283    for diag in &diagnostics {
284        tracing::warn!("compile diagnostic: {:?}", diag);
285    }
286
287    let bus = Arc::new(EventBus::new(256));
288    let graph_info = GraphInfo::from_graph(graph);
289
290    let mut ctx = Context::new(bus, timestamp_id("graph_run")).with_graph_info(graph_info);
291
292    executor::execute(&plan, &mut ctx, library, cache)?;
293
294    Ok(ctx
295        .store
296        .into_iter()
297        .filter_map(|(k, vv)| vv.as_value().cloned().map(|v| (k, v)))
298        .collect())
299}
300
301/// Fit all trainable filters in topological order.
302pub fn graph_fit(
303    graph: &Graph,
304    library: &FilterLibrary,
305    x: &Value,
306    y: Option<&Value>,
307    cache: &dyn CacheStore,
308) -> Result<HashMap<String, Value>> {
309    graph.validate()?;
310    let sorted = graph.topological_sort()?;
311    let graph_info = GraphInfo::from_graph(graph);
312
313    let bus = Arc::new(EventBus::new(256));
314    let run_id = timestamp_id("graph_fit");
315
316    let mut outputs: HashMap<String, Value> = HashMap::new();
317
318    let roots = graph.roots();
319    for root_id in &roots {
320        outputs.insert(format!("__input_{root_id}"), x.clone());
321    }
322
323    for node_id in &sorted {
324        let filter = library
325            .get(node_id)
326            .ok_or_else(|| SomaError::NodeNotFound(node_id.to_string()))?;
327
328        bus.emit(Event::NodeStarted {
329            run_id: run_id.clone(),
330            node_id: node_id.to_string(),
331            kind: filter.meta().kind,
332        });
333
334        let preds = graph_info.predecessors(node_id);
335        let input = match preds.len() {
336            0 => x.clone(),
337            1 => outputs.get(&preds[0]).cloned().unwrap_or_else(|| x.clone()),
338            _ => {
339                let mut merged = serde_json::Map::new();
340                for pred_id in preds {
341                    if let Some(val) = outputs.get(pred_id.as_str()) {
342                        let json_val = serde_json::to_value(val).unwrap_or(serde_json::Value::Null);
343                        merged.insert(pred_id.clone(), json_val);
344                    }
345                }
346                Value::json(serde_json::Value::Object(merged))
347            }
348        };
349
350        let meta = filter.meta();
351        let start = std::time::Instant::now();
352
353        let (state, output) = if meta.kind == FilterKind::Trainable {
354            let data_hash = CacheKey::hash_data(&serde_json::to_vec(&input).unwrap_or_default());
355            let state_key = CacheKey::for_state(&filter.config_hash(), &data_hash);
356
357            let state = if let Some(cached) = cache.get(&state_key)? {
358                cached
359            } else {
360                let s = filter.fit(&input, y)?;
361                cache.put(&state_key, &s)?;
362                s
363            };
364
365            let output = filter.forward(&input, &state)?;
366            (state, output)
367        } else {
368            let output = filter.forward(&input, &Value::Empty)?;
369            (Value::Empty, output)
370        };
371
372        let _ = state;
373
374        bus.emit(Event::NodeCompleted {
375            run_id: run_id.clone(),
376            node_id: node_id.to_string(),
377            duration: start.elapsed(),
378            output_summary: format!("{output}"),
379        });
380
381        outputs.insert(node_id.to_string(), output);
382    }
383
384    Ok(outputs)
385}
386
387/// Compile in Inference mode and execute, returning the last leaf's output.
388pub fn graph_predict(
389    graph: &Graph,
390    library: &FilterLibrary,
391    x: &Value,
392    cache: &dyn CacheStore,
393) -> Result<Value> {
394    let CompileResult { plan, .. } = compile(graph, library, CompileMode::Inference, Some(cache))?;
395
396    let bus = Arc::new(EventBus::new(256));
397    let graph_info = GraphInfo::from_graph(graph);
398    let mut ctx = Context::new(bus, timestamp_id("graph_predict")).with_graph_info(graph_info);
399
400    let roots = graph.roots();
401    if roots.len() == 1 {
402        ctx.set(format!("__input_{}", roots[0]), x.clone());
403    }
404    ctx.set("__input__", x.clone());
405
406    executor::execute(&plan, &mut ctx, library, cache)?;
407
408    let leaves = graph.leaves();
409    let mut extract =
410        |id: &str| -> Option<Value> { ctx.store.remove(id).and_then(|vv| vv.as_value().cloned()) };
411
412    if let Some(leaf_id) = leaves.first() {
413        extract(leaf_id)
414            .ok_or_else(|| SomaError::Other(format!("leaf node '{leaf_id}' produced no output")))
415    } else {
416        ctx.execution_order
417            .last()
418            .and_then(|id| extract(id))
419            .ok_or_else(|| SomaError::Other("no output produced".into()))
420    }
421}
422
423#[cfg(test)]
424mod tests {
425    use super::*;
426    use crate::cache::MemoryCache;
427    use somatize_compiler::FilterRegistry;
428    use somatize_core::cache::CacheKey;
429    use somatize_core::error::Result;
430    use somatize_core::filter::{FilterKind, FilterMeta, StreamMode};
431    use somatize_core::graph::{Edge, Node};
432
433    // ── Test filters ──
434
435    struct DoublerFilter;
436    impl somatize_core::filter::Filter for DoublerFilter {
437        fn config_hash(&self) -> CacheKey {
438            CacheKey::from_parts(&[b"Doubler"])
439        }
440        fn fit(&self, _x: &Value, _y: Option<&Value>) -> Result<Value> {
441            Ok(Value::Empty)
442        }
443        fn forward(&self, x: &Value, _state: &Value) -> Result<Value> {
444            let (data, shape) = x
445                .as_tensor()
446                .ok_or(SomaError::Other("need tensor".into()))?;
447            Ok(Value::tensor(
448                data.iter().map(|v| v * 2.0).collect(),
449                shape.to_vec(),
450            ))
451        }
452        fn meta(&self) -> FilterMeta {
453            FilterMeta {
454                name: "Doubler".into(),
455                kind: FilterKind::Stateless,
456                cacheable: true,
457                differentiable: true,
458                stream_mode: StreamMode::FixedState,
459                distribution: somatize_core::filter::Distribution::Local,
460                input_schema: None,
461                output_schema: None,
462            }
463        }
464
465        fn as_any(&self) -> &dyn std::any::Any {
466            self
467        }
468    }
469
470    struct AdderFilter(f64);
471    impl somatize_core::filter::Filter for AdderFilter {
472        fn config_hash(&self) -> CacheKey {
473            CacheKey::from_parts(&[b"Adder", &self.0.to_le_bytes()])
474        }
475        fn fit(&self, _x: &Value, _y: Option<&Value>) -> Result<Value> {
476            Ok(Value::Empty)
477        }
478        fn forward(&self, x: &Value, _state: &Value) -> Result<Value> {
479            let (data, shape) = x
480                .as_tensor()
481                .ok_or(SomaError::Other("need tensor".into()))?;
482            Ok(Value::tensor(
483                data.iter().map(|v| v + self.0).collect(),
484                shape.to_vec(),
485            ))
486        }
487        fn meta(&self) -> FilterMeta {
488            FilterMeta {
489                name: "Adder".into(),
490                kind: FilterKind::Stateless,
491                cacheable: true,
492                differentiable: true,
493                stream_mode: StreamMode::FixedState,
494                distribution: somatize_core::filter::Distribution::Local,
495                input_schema: None,
496                output_schema: None,
497            }
498        }
499
500        fn as_any(&self) -> &dyn std::any::Any {
501            self
502        }
503    }
504
505    struct MeanFilter;
506    impl somatize_core::filter::Filter for MeanFilter {
507        fn config_hash(&self) -> CacheKey {
508            CacheKey::from_parts(&[b"Mean"])
509        }
510        fn fit(&self, x: &Value, _y: Option<&Value>) -> Result<Value> {
511            let (data, _) = x
512                .as_tensor()
513                .ok_or(SomaError::Other("need tensor".into()))?;
514            let mean = data.iter().sum::<f64>() / data.len() as f64;
515            Ok(Value::json(serde_json::json!({ "mean": mean })))
516        }
517        fn forward(&self, x: &Value, state: &Value) -> Result<Value> {
518            let (data, shape) = x
519                .as_tensor()
520                .ok_or(SomaError::Other("need tensor".into()))?;
521            let mean = state
522                .as_json()
523                .and_then(|j| j["mean"].as_f64())
524                .unwrap_or(0.0);
525            Ok(Value::tensor(
526                data.iter().map(|v| v - mean).collect(),
527                shape.to_vec(),
528            ))
529        }
530        fn meta(&self) -> FilterMeta {
531            FilterMeta {
532                name: "Mean".into(),
533                kind: FilterKind::Trainable,
534                cacheable: true,
535                differentiable: true,
536                stream_mode: StreamMode::FixedState,
537                distribution: somatize_core::filter::Distribution::Local,
538                input_schema: None,
539                output_schema: None,
540            }
541        }
542
543        fn as_any(&self) -> &dyn std::any::Any {
544            self
545        }
546    }
547
548    fn linear_graph(ids: &[&str]) -> Graph {
549        let mut g = Graph::new();
550        for &id in ids {
551            g.nodes.push(Node::new(id, id, id));
552        }
553        for (i, pair) in ids.windows(2).enumerate() {
554            g.edges.push(Edge::data(format!("e{i}"), pair[0], pair[1]));
555        }
556        g
557    }
558
559    // ── GraphSession tests ──
560
561    #[test]
562    fn session_run_linear() {
563        let graph = linear_graph(&["double", "add"]);
564        let mut lib = FilterLibrary::new();
565        lib.register("double", Box::new(DoublerFilter));
566        lib.register("add", Box::new(AdderFilter(10.0)));
567
568        let mut session = GraphSession::new(graph, lib);
569
570        let cache = MemoryCache::default();
571        session = session.with_cache(Arc::new(cache));
572
573        // Manual compile + execute via run
574        let CompileResult { plan, .. } = session.compile(CompileMode::NoCache).unwrap();
575        let bus = Arc::new(EventBus::new(64));
576        let mut ctx =
577            Context::new(bus, "test").with_graph_info(GraphInfo::from_graph(session.graph()));
578        ctx.set("__input__", Value::tensor(vec![1.0, 2.0, 3.0], vec![3]));
579        executor::execute(&plan, &mut ctx, session.library(), &MemoryCache::default()).unwrap();
580
581        let outputs: HashMap<String, Value> = ctx
582            .store
583            .into_iter()
584            .filter_map(|(k, vv)| vv.as_value().cloned().map(|v| (k, v)))
585            .collect();
586
587        let result = outputs.get("add").unwrap();
588        let (data, _) = result.as_tensor().unwrap();
589        assert_eq!(data, &[12.0, 14.0, 16.0]);
590    }
591
592    #[test]
593    fn session_fit_and_forward() {
594        let graph = linear_graph(&["mean", "double"]);
595        let mut lib = FilterLibrary::new();
596        lib.register("mean", Box::new(MeanFilter));
597        lib.register("double", Box::new(DoublerFilter));
598
599        let mut session = GraphSession::new(graph, lib);
600
601        let x = Value::tensor(vec![10.0, 20.0, 30.0], vec![3]);
602        let outputs = session.fit(&x, None).unwrap();
603
604        // mean: fit learns mean=20, forward: [10-20, 20-20, 30-20] = [-10, 0, 10]
605        // double: [-10, 0, 10] → [-20, 0, 20]
606        let result = outputs.get("double").unwrap();
607        let (data, _) = result.as_tensor().unwrap();
608        assert_eq!(data, &[-20.0, 0.0, 20.0]);
609
610        assert!(session.is_fitted());
611    }
612
613    #[test]
614    fn session_compile_diagnostics() {
615        let graph = linear_graph(&["double"]);
616        let mut lib = FilterLibrary::new();
617        lib.register("double", Box::new(DoublerFilter));
618
619        let session = GraphSession::new(graph, lib);
620        let result = session.compile(CompileMode::NoCache).unwrap();
621        assert!(result.plan.node_count() > 0);
622    }
623
624    // ── Free function tests (backward compat) ──
625
626    #[test]
627    fn graph_run_linear() {
628        let graph = linear_graph(&["double", "add"]);
629        let mut lib = FilterLibrary::new();
630        lib.register("double", Box::new(DoublerFilter));
631        lib.register("add", Box::new(AdderFilter(10.0)));
632
633        let cache = MemoryCache::default();
634
635        let outputs = {
636            let CompileResult { plan, .. } =
637                compile(&graph, &lib, CompileMode::NoCache, None).unwrap();
638            let bus = Arc::new(EventBus::new(64));
639            let mut ctx = Context::new(bus, "test").with_graph_info(GraphInfo::from_graph(&graph));
640            ctx.set("__input__", Value::tensor(vec![1.0, 2.0, 3.0], vec![3]));
641            executor::execute(&plan, &mut ctx, &lib, &cache).unwrap();
642            ctx.store
643                .into_iter()
644                .filter_map(|(k, vv)| vv.as_value().cloned().map(|v| (k, v)))
645                .collect::<HashMap<String, Value>>()
646        };
647
648        let result = outputs.get("add").unwrap();
649        let (data, _) = result.as_tensor().unwrap();
650        assert_eq!(data, &[12.0, 14.0, 16.0]);
651    }
652
653    #[test]
654    fn graph_run_diamond() {
655        let mut graph = Graph::new();
656        graph.nodes.push(Node::new("double", "Double", "double"));
657        graph.nodes.push(Node::new("add", "Add", "add"));
658        graph.nodes.push(Node::new("merge", "Merge", "merge"));
659        graph.edges.push(Edge::data("e1", "double", "merge"));
660        graph.edges.push(Edge::data("e2", "add", "merge"));
661
662        let mut lib = FilterLibrary::new();
663        lib.register("double", Box::new(DoublerFilter));
664        lib.register("add", Box::new(AdderFilter(100.0)));
665
666        struct MergeFilter;
667        impl somatize_core::filter::Filter for MergeFilter {
668            fn config_hash(&self) -> CacheKey {
669                CacheKey::from_parts(&[b"Merge"])
670            }
671            fn fit(&self, _x: &Value, _y: Option<&Value>) -> Result<Value> {
672                Ok(Value::Empty)
673            }
674            fn forward(&self, x: &Value, _state: &Value) -> Result<Value> {
675                Ok(x.clone())
676            }
677            fn meta(&self) -> FilterMeta {
678                FilterMeta {
679                    name: "Merge".into(),
680                    kind: FilterKind::Stateless,
681                    cacheable: true,
682                    differentiable: false,
683                    stream_mode: StreamMode::FixedState,
684                    distribution: somatize_core::filter::Distribution::Local,
685                    input_schema: None,
686                    output_schema: None,
687                }
688            }
689
690            fn as_any(&self) -> &dyn std::any::Any {
691                self
692            }
693        }
694        lib.register("merge", Box::new(MergeFilter));
695
696        let cache = MemoryCache::default();
697        let CompileResult { plan, .. } = compile(&graph, &lib, CompileMode::NoCache, None).unwrap();
698
699        let bus = Arc::new(EventBus::new(64));
700        let mut ctx = Context::new(bus, "test").with_graph_info(GraphInfo::from_graph(&graph));
701        ctx.set("__input__", Value::tensor(vec![5.0], vec![1]));
702        executor::execute(&plan, &mut ctx, &lib, &cache).unwrap();
703
704        let merge_output = ctx.get("merge").unwrap();
705        assert!(
706            merge_output.as_json().is_some(),
707            "merge should receive JSON from multiple predecessors"
708        );
709    }
710
711    #[test]
712    fn graph_fit_trainable() {
713        let graph = linear_graph(&["mean", "double"]);
714        let mut lib = FilterLibrary::new();
715        lib.register("mean", Box::new(MeanFilter));
716        lib.register("double", Box::new(DoublerFilter));
717
718        let cache = MemoryCache::default();
719        let x = Value::tensor(vec![10.0, 20.0, 30.0], vec![3]);
720
721        let outputs = graph_fit(&graph, &lib, &x, None, &cache).unwrap();
722
723        let result = outputs.get("double").unwrap();
724        let (data, _) = result.as_tensor().unwrap();
725        assert_eq!(data, &[-20.0, 0.0, 20.0]);
726
727        assert!(!cache.is_empty());
728    }
729
730    #[test]
731    fn filter_library_registry_compat() {
732        let mut lib = FilterLibrary::new();
733        lib.register("a", Box::new(DoublerFilter));
734
735        let registry: &dyn FilterRegistry = &lib;
736        assert!(registry.meta("a").is_some());
737        assert_eq!(registry.meta("a").unwrap().name, "Doubler");
738        assert!(registry.config_hash("a").is_some());
739        assert!(registry.meta("b").is_none());
740    }
741}