Skip to main content

somatize_runtime/
graph_session.rs

1//! Graph session — the primary orchestrator for Graph → Compile → Execute.
2//!
3//! [`GraphSession`] binds a [`Graph`] with its [`FilterLibrary`], cache,
4//! event bus, and optional distributed components into a single object
5//! that can compile, fit, and execute.
6
7use crate::cache::MemoryCache;
8use crate::event_bus::EventBus;
9use crate::executor::{self, Context, GraphInfo, RemoteExecutor};
10use crate::filter_library::FilterLibrary;
11use somatize_compiler::{CompileMode, CompileResult, compile};
12use somatize_core::cache::{CacheKey, CacheStore};
13use somatize_core::error::{Result, SomaError};
14use somatize_core::event::Event;
15use somatize_core::filter::FilterKind;
16use somatize_core::graph::Graph;
17use somatize_core::store::{DataRef, DataStore};
18use somatize_core::util::timestamp_id;
19use somatize_core::value::Value;
20use std::collections::HashMap;
21use std::sync::Arc;
22
23/// The primary orchestrator: Graph + filters + cache + events.
24///
25/// ```ignore
26/// let mut lib = FilterLibrary::new();
27/// lib.register("scaler", Box::new(MyScaler::new()));
28/// lib.register("model", Box::new(MyModel::new()));
29///
30/// let mut session = GraphSession::new(graph, lib);
31/// session.fit(&train_x, Some(&train_y))?;
32/// let output = session.forward(&test_x)?;
33/// ```
34pub struct GraphSession {
35    graph: Graph,
36    library: FilterLibrary,
37    cache: Arc<dyn CacheStore>,
38    event_bus: Arc<EventBus>,
39    data_store: Option<Arc<dyn DataStore>>,
40    remote_executor: Option<Arc<dyn RemoteExecutor>>,
41    fitted: bool,
42}
43
44impl GraphSession {
45    pub fn new(graph: Graph, library: FilterLibrary) -> Self {
46        Self {
47            graph,
48            library,
49            cache: Arc::new(MemoryCache::default()),
50            event_bus: Arc::new(EventBus::new(256)),
51            data_store: None,
52            remote_executor: None,
53            fitted: false,
54        }
55    }
56
57    pub fn with_cache(mut self, cache: Arc<dyn CacheStore>) -> Self {
58        self.cache = cache;
59        self
60    }
61
62    pub fn with_event_bus(mut self, bus: Arc<EventBus>) -> Self {
63        self.event_bus = bus;
64        self
65    }
66
67    pub fn with_data_store(mut self, store: Arc<dyn DataStore>) -> Self {
68        self.data_store = Some(store);
69        self
70    }
71
72    pub fn with_remote_executor(mut self, executor: Arc<dyn RemoteExecutor>) -> Self {
73        self.remote_executor = Some(executor);
74        self
75    }
76
77    // ── Core operations ──
78
79    /// Compile the graph and return diagnostics without executing.
80    pub fn compile(&self, mode: CompileMode) -> Result<CompileResult> {
81        compile(&self.graph, &self.library, mode, Some(self.cache.as_ref()))
82    }
83
84    /// Compile and execute the graph, returning all node outputs.
85    pub fn run(&mut self, mode: CompileMode) -> Result<HashMap<String, Value>> {
86        let CompileResult { plan, diagnostics } =
87            compile(&self.graph, &self.library, mode, Some(self.cache.as_ref()))?;
88
89        for diag in &diagnostics {
90            tracing::warn!("compile diagnostic: {:?}", diag);
91        }
92
93        let graph_info = GraphInfo::from_graph(&self.graph);
94        let mut ctx = Context::new(self.event_bus.clone(), timestamp_id("graph_run"))
95            .with_graph_info(graph_info);
96
97        if let Some(store) = &self.data_store {
98            ctx = ctx.with_data_store(store.clone());
99        }
100        if let Some(remote) = &self.remote_executor {
101            ctx = ctx.with_remote_executor(remote.clone());
102        }
103
104        executor::execute(&plan, &mut ctx, &self.library, self.cache.as_ref())?;
105
106        Ok(ctx
107            .store
108            .into_iter()
109            .filter_map(|(k, vv)| vv.as_value().cloned().map(|v| (k, v)))
110            .collect())
111    }
112
113    /// Fit all trainable filters in topological order.
114    ///
115    /// For each node in topo order:
116    /// 1. Resolve input from predecessors
117    /// 2. If trainable: call `fit(input, y)`, cache state, then `forward(input, state)`
118    /// 3. If stateless: call `forward(input, Empty)`
119    /// 4. Store output for downstream nodes
120    pub fn fit(&mut self, x: &Value, y: Option<&Value>) -> Result<HashMap<String, Value>> {
121        self.graph.validate()?;
122        let sorted = self.graph.topological_sort()?;
123        let graph_info = GraphInfo::from_graph(&self.graph);
124
125        let run_id = timestamp_id("graph_fit");
126        let mut outputs: HashMap<String, Value> = HashMap::new();
127
128        // Set initial input for root nodes
129        let roots = self.graph.roots();
130        for root_id in &roots {
131            outputs.insert(format!("__input_{root_id}"), x.clone());
132        }
133
134        for node_id in &sorted {
135            let filter = self
136                .library
137                .get(node_id)
138                .ok_or_else(|| SomaError::NodeNotFound(node_id.to_string()))?;
139
140            self.event_bus.emit(Event::NodeStarted {
141                run_id: run_id.clone(),
142                node_id: node_id.to_string(),
143                kind: filter.meta().kind,
144            });
145
146            // Resolve input from predecessors
147            let preds = graph_info.predecessors(node_id);
148            let input = match preds.len() {
149                0 => x.clone(),
150                1 => outputs.get(&preds[0]).cloned().unwrap_or_else(|| x.clone()),
151                _ => {
152                    let mut merged = serde_json::Map::new();
153                    for pred_id in preds {
154                        if let Some(val) = outputs.get(pred_id.as_str()) {
155                            let json_val =
156                                serde_json::to_value(val).unwrap_or(serde_json::Value::Null);
157                            merged.insert(pred_id.clone(), json_val);
158                        }
159                    }
160                    Value::Json(serde_json::Value::Object(merged))
161                }
162            };
163
164            let meta = filter.meta();
165            let start = std::time::Instant::now();
166
167            // Fit trainable filters, forward all
168            let (state, output) = if meta.kind == FilterKind::Trainable {
169                let data_hash =
170                    CacheKey::hash_data(&serde_json::to_vec(&input).unwrap_or_default());
171                let state_key = CacheKey::for_state(&filter.config_hash(), &data_hash);
172
173                let state = if let Some(cached) = self.cache.get(&state_key)? {
174                    cached
175                } else {
176                    let s = filter.fit(&input, y)?;
177                    self.cache.put(&state_key, &s)?;
178                    s
179                };
180
181                let output = filter.forward(&input, &state)?;
182                // Store state in library for later use by forward()
183                self.library.set_state(node_id.to_string(), state.clone());
184                (state, output)
185            } else {
186                let output = filter.forward(&input, &Value::Empty)?;
187                (Value::Empty, output)
188            };
189
190            let _ = state; // state already cached above
191
192            self.event_bus.emit(Event::NodeCompleted {
193                run_id: run_id.clone(),
194                node_id: node_id.to_string(),
195                duration: start.elapsed(),
196                output_summary: format!("{output}"),
197            });
198
199            outputs.insert(node_id.to_string(), output);
200        }
201
202        self.fitted = true;
203        Ok(outputs)
204    }
205
206    /// Forward pass through the fitted graph (inference).
207    ///
208    /// Compiles in Inference mode and executes. Trainable filters
209    /// use states cached during `fit()`.
210    pub fn forward(&self, x: &Value) -> Result<Value> {
211        let CompileResult { plan, .. } = compile(
212            &self.graph,
213            &self.library,
214            CompileMode::Inference,
215            Some(self.cache.as_ref()),
216        )?;
217
218        let graph_info = GraphInfo::from_graph(&self.graph);
219        let mut ctx = Context::new(self.event_bus.clone(), timestamp_id("graph_forward"))
220            .with_graph_info(graph_info);
221
222        if let Some(store) = &self.data_store {
223            ctx = ctx.with_data_store(store.clone());
224        }
225        if let Some(remote) = &self.remote_executor {
226            ctx = ctx.with_remote_executor(remote.clone());
227        }
228
229        // Set input for root nodes
230        let roots = self.graph.roots();
231        if roots.len() == 1 {
232            ctx.set(format!("__input_{}", roots[0]), x.clone());
233        }
234        ctx.set("__input__", x.clone());
235
236        executor::execute(&plan, &mut ctx, &self.library, self.cache.as_ref())?;
237
238        // Return the last leaf's output
239        let leaves = self.graph.leaves();
240        let mut extract = |id: &str| -> Option<Value> {
241            ctx.store.remove(id).and_then(|vv| vv.as_value().cloned())
242        };
243
244        if let Some(leaf_id) = leaves.first() {
245            extract(leaf_id).ok_or_else(|| {
246                SomaError::Other(format!("leaf node '{leaf_id}' produced no output"))
247            })
248        } else {
249            ctx.execution_order
250                .last()
251                .and_then(|id| extract(id))
252                .ok_or_else(|| SomaError::Other("no output produced".into()))
253        }
254    }
255
256    /// Forward pass in batches from a DataStore reference.
257    pub fn forward_batched(&self, data_ref: &DataRef, batch_size: usize) -> Result<Value> {
258        let store = self
259            .data_store
260            .as_ref()
261            .ok_or_else(|| SomaError::Execution {
262                node_id: "session".into(),
263                message: "forward_batched requires a data store (use with_data_store)".into(),
264            })?;
265
266        let meta = store.meta(data_ref)?;
267        let total_rows = meta.total_rows;
268        if total_rows == 0 {
269            return Ok(Value::Empty);
270        }
271
272        let mut all_values: Vec<f64> = Vec::new();
273        let mut result_shape: Option<Vec<usize>> = None;
274        let mut rows_processed = 0;
275
276        while rows_processed < total_rows {
277            let batch_len = batch_size.min(total_rows - rows_processed);
278            let batch = store.get_rows(data_ref, rows_processed, batch_len)?;
279            let output = self.forward(&batch)?;
280
281            if let Value::Tensor { values, shape } = &output {
282                if result_shape.is_none() {
283                    result_shape = Some(shape.clone());
284                }
285                all_values.extend_from_slice(values);
286            } else {
287                return Ok(output);
288            }
289
290            rows_processed += batch_len;
291        }
292
293        match result_shape {
294            Some(mut shape) => {
295                shape[0] = total_rows;
296                Ok(Value::tensor(all_values, shape))
297            }
298            None => Ok(Value::Empty),
299        }
300    }
301
302    // ── State persistence ──
303
304    /// Persist all trained states to the data store.
305    pub fn persist_states(&self) -> Result<DataRef> {
306        let store = self
307            .data_store
308            .as_ref()
309            .ok_or_else(|| SomaError::Execution {
310                node_id: "session".into(),
311                message: "persist_states requires a data store".into(),
312            })?;
313
314        let sorted = self.graph.topological_sort()?;
315        let mut states_map = serde_json::Map::new();
316        for node_id in &sorted {
317            if let Some(state) = self.library.get_state(node_id) {
318                let json = serde_json::to_value(state)
319                    .map_err(|e| SomaError::Other(format!("state serialize: {e}")))?;
320                states_map.insert(node_id.to_string(), json);
321            }
322        }
323
324        let states_value = Value::Json(serde_json::Value::Object(states_map));
325        let key = CacheKey::from_parts(&[b"graph_states", self.graph_config_hash().as_bytes()]);
326        store.put(&key, &states_value)
327    }
328
329    /// Load previously persisted states from a data store reference.
330    pub fn load_states(&mut self, data_ref: &DataRef) -> Result<()> {
331        let store = self
332            .data_store
333            .as_ref()
334            .ok_or_else(|| SomaError::Execution {
335                node_id: "session".into(),
336                message: "load_states requires a data store".into(),
337            })?;
338
339        let states_value = store.get(data_ref)?;
340        let states_json = states_value
341            .as_json()
342            .ok_or_else(|| SomaError::Other("persisted states must be JSON".into()))?;
343        let obj = states_json
344            .as_object()
345            .ok_or_else(|| SomaError::Other("persisted states must be a JSON object".into()))?;
346
347        for (node_id, json_val) in obj {
348            let value: Value = serde_json::from_value(json_val.clone())
349                .map_err(|e| SomaError::Other(format!("state deserialize: {e}")))?;
350            self.library.set_state(node_id.clone(), value);
351        }
352
353        self.fitted = true;
354        Ok(())
355    }
356
357    // ── Observability ──
358
359    /// Subscribe to execution events.
360    pub fn subscribe(&self) -> tokio::sync::broadcast::Receiver<Event> {
361        self.event_bus.subscribe()
362    }
363
364    /// Access the event bus directly.
365    pub fn event_bus(&self) -> &Arc<EventBus> {
366        &self.event_bus
367    }
368
369    /// Whether the session has been fitted.
370    pub fn is_fitted(&self) -> bool {
371        self.fitted
372    }
373
374    /// Access the graph.
375    pub fn graph(&self) -> &Graph {
376        &self.graph
377    }
378
379    /// Access the filter library.
380    pub fn library(&self) -> &FilterLibrary {
381        &self.library
382    }
383
384    /// Mutable access to the filter library (for registering filters after creation).
385    pub fn library_mut(&mut self) -> &mut FilterLibrary {
386        &mut self.library
387    }
388
389    // ── Private helpers ──
390
391    fn graph_config_hash(&self) -> String {
392        let node_ids: Vec<&str> = self.graph.nodes.iter().map(|n| n.id.as_str()).collect();
393        node_ids.join(",")
394    }
395}
396
397// ── Convenience free functions (backward compat) ──
398
399/// Compile and execute a graph, returning all node outputs.
400pub fn graph_run(
401    graph: &Graph,
402    library: &FilterLibrary,
403    mode: CompileMode,
404    cache: &dyn CacheStore,
405) -> Result<HashMap<String, Value>> {
406    let CompileResult { plan, diagnostics } = compile(graph, library, mode, Some(cache))?;
407
408    for diag in &diagnostics {
409        tracing::warn!("compile diagnostic: {:?}", diag);
410    }
411
412    let bus = Arc::new(EventBus::new(256));
413    let graph_info = GraphInfo::from_graph(graph);
414
415    let mut ctx = Context::new(bus, timestamp_id("graph_run")).with_graph_info(graph_info);
416
417    executor::execute(&plan, &mut ctx, library, cache)?;
418
419    Ok(ctx
420        .store
421        .into_iter()
422        .filter_map(|(k, vv)| vv.as_value().cloned().map(|v| (k, v)))
423        .collect())
424}
425
426/// Fit all trainable filters in topological order.
427pub fn graph_fit(
428    graph: &Graph,
429    library: &FilterLibrary,
430    x: &Value,
431    y: Option<&Value>,
432    cache: &dyn CacheStore,
433) -> Result<HashMap<String, Value>> {
434    graph.validate()?;
435    let sorted = graph.topological_sort()?;
436    let graph_info = GraphInfo::from_graph(graph);
437
438    let bus = Arc::new(EventBus::new(256));
439    let run_id = timestamp_id("graph_fit");
440
441    let mut outputs: HashMap<String, Value> = HashMap::new();
442
443    let roots = graph.roots();
444    for root_id in &roots {
445        outputs.insert(format!("__input_{root_id}"), x.clone());
446    }
447
448    for node_id in &sorted {
449        let filter = library
450            .get(node_id)
451            .ok_or_else(|| SomaError::NodeNotFound(node_id.to_string()))?;
452
453        bus.emit(Event::NodeStarted {
454            run_id: run_id.clone(),
455            node_id: node_id.to_string(),
456            kind: filter.meta().kind,
457        });
458
459        let preds = graph_info.predecessors(node_id);
460        let input = match preds.len() {
461            0 => x.clone(),
462            1 => outputs.get(&preds[0]).cloned().unwrap_or_else(|| x.clone()),
463            _ => {
464                let mut merged = serde_json::Map::new();
465                for pred_id in preds {
466                    if let Some(val) = outputs.get(pred_id.as_str()) {
467                        let json_val = serde_json::to_value(val).unwrap_or(serde_json::Value::Null);
468                        merged.insert(pred_id.clone(), json_val);
469                    }
470                }
471                Value::Json(serde_json::Value::Object(merged))
472            }
473        };
474
475        let meta = filter.meta();
476        let start = std::time::Instant::now();
477
478        let (state, output) = if meta.kind == FilterKind::Trainable {
479            let data_hash = CacheKey::hash_data(&serde_json::to_vec(&input).unwrap_or_default());
480            let state_key = CacheKey::for_state(&filter.config_hash(), &data_hash);
481
482            let state = if let Some(cached) = cache.get(&state_key)? {
483                cached
484            } else {
485                let s = filter.fit(&input, y)?;
486                cache.put(&state_key, &s)?;
487                s
488            };
489
490            let output = filter.forward(&input, &state)?;
491            (state, output)
492        } else {
493            let output = filter.forward(&input, &Value::Empty)?;
494            (Value::Empty, output)
495        };
496
497        let _ = state;
498
499        bus.emit(Event::NodeCompleted {
500            run_id: run_id.clone(),
501            node_id: node_id.to_string(),
502            duration: start.elapsed(),
503            output_summary: format!("{output}"),
504        });
505
506        outputs.insert(node_id.to_string(), output);
507    }
508
509    Ok(outputs)
510}
511
512/// Compile in Inference mode and execute, returning the last leaf's output.
513pub fn graph_predict(
514    graph: &Graph,
515    library: &FilterLibrary,
516    x: &Value,
517    cache: &dyn CacheStore,
518) -> Result<Value> {
519    let CompileResult { plan, .. } = compile(graph, library, CompileMode::Inference, Some(cache))?;
520
521    let bus = Arc::new(EventBus::new(256));
522    let graph_info = GraphInfo::from_graph(graph);
523    let mut ctx = Context::new(bus, timestamp_id("graph_predict")).with_graph_info(graph_info);
524
525    let roots = graph.roots();
526    if roots.len() == 1 {
527        ctx.set(format!("__input_{}", roots[0]), x.clone());
528    }
529    ctx.set("__input__", x.clone());
530
531    executor::execute(&plan, &mut ctx, library, cache)?;
532
533    let leaves = graph.leaves();
534    let mut extract =
535        |id: &str| -> Option<Value> { ctx.store.remove(id).and_then(|vv| vv.as_value().cloned()) };
536
537    if let Some(leaf_id) = leaves.first() {
538        extract(leaf_id)
539            .ok_or_else(|| SomaError::Other(format!("leaf node '{leaf_id}' produced no output")))
540    } else {
541        ctx.execution_order
542            .last()
543            .and_then(|id| extract(id))
544            .ok_or_else(|| SomaError::Other("no output produced".into()))
545    }
546}
547
548#[cfg(test)]
549mod tests {
550    use super::*;
551    use crate::cache::MemoryCache;
552    use somatize_compiler::FilterRegistry;
553    use somatize_core::cache::CacheKey;
554    use somatize_core::error::Result;
555    use somatize_core::filter::{FilterKind, FilterMeta, StreamMode};
556    use somatize_core::graph::{Edge, Node};
557
558    // ── Test filters ──
559
560    struct DoublerFilter;
561    impl somatize_core::filter::Filter for DoublerFilter {
562        fn config_hash(&self) -> CacheKey {
563            CacheKey::from_parts(&[b"Doubler"])
564        }
565        fn fit(&self, _x: &Value, _y: Option<&Value>) -> Result<Value> {
566            Ok(Value::Empty)
567        }
568        fn forward(&self, x: &Value, _state: &Value) -> Result<Value> {
569            let (data, shape) = x
570                .as_tensor()
571                .ok_or(SomaError::Other("need tensor".into()))?;
572            Ok(Value::tensor(
573                data.iter().map(|v| v * 2.0).collect(),
574                shape.to_vec(),
575            ))
576        }
577        fn meta(&self) -> FilterMeta {
578            FilterMeta {
579                name: "Doubler".into(),
580                kind: FilterKind::Stateless,
581                cacheable: true,
582                differentiable: true,
583                stream_mode: StreamMode::FixedState,
584                distribution: somatize_core::filter::Distribution::Local,
585                input_schema: None,
586                output_schema: None,
587            }
588        }
589
590        fn as_any(&self) -> &dyn std::any::Any {
591            self
592        }
593    }
594
595    struct AdderFilter(f64);
596    impl somatize_core::filter::Filter for AdderFilter {
597        fn config_hash(&self) -> CacheKey {
598            CacheKey::from_parts(&[b"Adder", &self.0.to_le_bytes()])
599        }
600        fn fit(&self, _x: &Value, _y: Option<&Value>) -> Result<Value> {
601            Ok(Value::Empty)
602        }
603        fn forward(&self, x: &Value, _state: &Value) -> Result<Value> {
604            let (data, shape) = x
605                .as_tensor()
606                .ok_or(SomaError::Other("need tensor".into()))?;
607            Ok(Value::tensor(
608                data.iter().map(|v| v + self.0).collect(),
609                shape.to_vec(),
610            ))
611        }
612        fn meta(&self) -> FilterMeta {
613            FilterMeta {
614                name: "Adder".into(),
615                kind: FilterKind::Stateless,
616                cacheable: true,
617                differentiable: true,
618                stream_mode: StreamMode::FixedState,
619                distribution: somatize_core::filter::Distribution::Local,
620                input_schema: None,
621                output_schema: None,
622            }
623        }
624
625        fn as_any(&self) -> &dyn std::any::Any {
626            self
627        }
628    }
629
630    struct MeanFilter;
631    impl somatize_core::filter::Filter for MeanFilter {
632        fn config_hash(&self) -> CacheKey {
633            CacheKey::from_parts(&[b"Mean"])
634        }
635        fn fit(&self, x: &Value, _y: Option<&Value>) -> Result<Value> {
636            let (data, _) = x
637                .as_tensor()
638                .ok_or(SomaError::Other("need tensor".into()))?;
639            let mean = data.iter().sum::<f64>() / data.len() as f64;
640            Ok(Value::json(serde_json::json!({ "mean": mean })))
641        }
642        fn forward(&self, x: &Value, state: &Value) -> Result<Value> {
643            let (data, shape) = x
644                .as_tensor()
645                .ok_or(SomaError::Other("need tensor".into()))?;
646            let mean = state
647                .as_json()
648                .and_then(|j| j["mean"].as_f64())
649                .unwrap_or(0.0);
650            Ok(Value::tensor(
651                data.iter().map(|v| v - mean).collect(),
652                shape.to_vec(),
653            ))
654        }
655        fn meta(&self) -> FilterMeta {
656            FilterMeta {
657                name: "Mean".into(),
658                kind: FilterKind::Trainable,
659                cacheable: true,
660                differentiable: true,
661                stream_mode: StreamMode::FixedState,
662                distribution: somatize_core::filter::Distribution::Local,
663                input_schema: None,
664                output_schema: None,
665            }
666        }
667
668        fn as_any(&self) -> &dyn std::any::Any {
669            self
670        }
671    }
672
673    fn linear_graph(ids: &[&str]) -> Graph {
674        let mut g = Graph::new();
675        for &id in ids {
676            g.nodes.push(Node::new(id, id, id));
677        }
678        for (i, pair) in ids.windows(2).enumerate() {
679            g.edges.push(Edge::data(format!("e{i}"), pair[0], pair[1]));
680        }
681        g
682    }
683
684    // ── GraphSession tests ──
685
686    #[test]
687    fn session_run_linear() {
688        let graph = linear_graph(&["double", "add"]);
689        let mut lib = FilterLibrary::new();
690        lib.register("double", Box::new(DoublerFilter));
691        lib.register("add", Box::new(AdderFilter(10.0)));
692
693        let mut session = GraphSession::new(graph, lib);
694
695        let cache = MemoryCache::default();
696        session = session.with_cache(Arc::new(cache));
697
698        // Manual compile + execute via run
699        let CompileResult { plan, .. } = session.compile(CompileMode::NoCache).unwrap();
700        let bus = Arc::new(EventBus::new(64));
701        let mut ctx =
702            Context::new(bus, "test").with_graph_info(GraphInfo::from_graph(session.graph()));
703        ctx.set("__input__", Value::tensor(vec![1.0, 2.0, 3.0], vec![3]));
704        executor::execute(&plan, &mut ctx, session.library(), &MemoryCache::default()).unwrap();
705
706        let outputs: HashMap<String, Value> = ctx
707            .store
708            .into_iter()
709            .filter_map(|(k, vv)| vv.as_value().cloned().map(|v| (k, v)))
710            .collect();
711
712        let result = outputs.get("add").unwrap();
713        let (data, _) = result.as_tensor().unwrap();
714        assert_eq!(data, &[12.0, 14.0, 16.0]);
715    }
716
717    #[test]
718    fn session_fit_and_forward() {
719        let graph = linear_graph(&["mean", "double"]);
720        let mut lib = FilterLibrary::new();
721        lib.register("mean", Box::new(MeanFilter));
722        lib.register("double", Box::new(DoublerFilter));
723
724        let mut session = GraphSession::new(graph, lib);
725
726        let x = Value::tensor(vec![10.0, 20.0, 30.0], vec![3]);
727        let outputs = session.fit(&x, None).unwrap();
728
729        // mean: fit learns mean=20, forward: [10-20, 20-20, 30-20] = [-10, 0, 10]
730        // double: [-10, 0, 10] → [-20, 0, 20]
731        let result = outputs.get("double").unwrap();
732        let (data, _) = result.as_tensor().unwrap();
733        assert_eq!(data, &[-20.0, 0.0, 20.0]);
734
735        assert!(session.is_fitted());
736    }
737
738    #[test]
739    fn session_compile_diagnostics() {
740        let graph = linear_graph(&["double"]);
741        let mut lib = FilterLibrary::new();
742        lib.register("double", Box::new(DoublerFilter));
743
744        let session = GraphSession::new(graph, lib);
745        let result = session.compile(CompileMode::NoCache).unwrap();
746        assert!(result.plan.node_count() > 0);
747    }
748
749    // ── Free function tests (backward compat) ──
750
751    #[test]
752    fn graph_run_linear() {
753        let graph = linear_graph(&["double", "add"]);
754        let mut lib = FilterLibrary::new();
755        lib.register("double", Box::new(DoublerFilter));
756        lib.register("add", Box::new(AdderFilter(10.0)));
757
758        let cache = MemoryCache::default();
759
760        let outputs = {
761            let CompileResult { plan, .. } =
762                compile(&graph, &lib, CompileMode::NoCache, None).unwrap();
763            let bus = Arc::new(EventBus::new(64));
764            let mut ctx = Context::new(bus, "test").with_graph_info(GraphInfo::from_graph(&graph));
765            ctx.set("__input__", Value::tensor(vec![1.0, 2.0, 3.0], vec![3]));
766            executor::execute(&plan, &mut ctx, &lib, &cache).unwrap();
767            ctx.store
768                .into_iter()
769                .filter_map(|(k, vv)| vv.as_value().cloned().map(|v| (k, v)))
770                .collect::<HashMap<String, Value>>()
771        };
772
773        let result = outputs.get("add").unwrap();
774        let (data, _) = result.as_tensor().unwrap();
775        assert_eq!(data, &[12.0, 14.0, 16.0]);
776    }
777
778    #[test]
779    fn graph_run_diamond() {
780        let mut graph = Graph::new();
781        graph.nodes.push(Node::new("double", "Double", "double"));
782        graph.nodes.push(Node::new("add", "Add", "add"));
783        graph.nodes.push(Node::new("merge", "Merge", "merge"));
784        graph.edges.push(Edge::data("e1", "double", "merge"));
785        graph.edges.push(Edge::data("e2", "add", "merge"));
786
787        let mut lib = FilterLibrary::new();
788        lib.register("double", Box::new(DoublerFilter));
789        lib.register("add", Box::new(AdderFilter(100.0)));
790
791        struct MergeFilter;
792        impl somatize_core::filter::Filter for MergeFilter {
793            fn config_hash(&self) -> CacheKey {
794                CacheKey::from_parts(&[b"Merge"])
795            }
796            fn fit(&self, _x: &Value, _y: Option<&Value>) -> Result<Value> {
797                Ok(Value::Empty)
798            }
799            fn forward(&self, x: &Value, _state: &Value) -> Result<Value> {
800                Ok(x.clone())
801            }
802            fn meta(&self) -> FilterMeta {
803                FilterMeta {
804                    name: "Merge".into(),
805                    kind: FilterKind::Stateless,
806                    cacheable: true,
807                    differentiable: false,
808                    stream_mode: StreamMode::FixedState,
809                    distribution: somatize_core::filter::Distribution::Local,
810                    input_schema: None,
811                    output_schema: None,
812                }
813            }
814
815            fn as_any(&self) -> &dyn std::any::Any {
816                self
817            }
818        }
819        lib.register("merge", Box::new(MergeFilter));
820
821        let cache = MemoryCache::default();
822        let CompileResult { plan, .. } = compile(&graph, &lib, CompileMode::NoCache, None).unwrap();
823
824        let bus = Arc::new(EventBus::new(64));
825        let mut ctx = Context::new(bus, "test").with_graph_info(GraphInfo::from_graph(&graph));
826        ctx.set("__input__", Value::tensor(vec![5.0], vec![1]));
827        executor::execute(&plan, &mut ctx, &lib, &cache).unwrap();
828
829        let merge_output = ctx.get("merge").unwrap();
830        assert!(
831            merge_output.as_json().is_some(),
832            "merge should receive JSON from multiple predecessors"
833        );
834    }
835
836    #[test]
837    fn graph_fit_trainable() {
838        let graph = linear_graph(&["mean", "double"]);
839        let mut lib = FilterLibrary::new();
840        lib.register("mean", Box::new(MeanFilter));
841        lib.register("double", Box::new(DoublerFilter));
842
843        let cache = MemoryCache::default();
844        let x = Value::tensor(vec![10.0, 20.0, 30.0], vec![3]);
845
846        let outputs = graph_fit(&graph, &lib, &x, None, &cache).unwrap();
847
848        let result = outputs.get("double").unwrap();
849        let (data, _) = result.as_tensor().unwrap();
850        assert_eq!(data, &[-20.0, 0.0, 20.0]);
851
852        assert!(!cache.is_empty());
853    }
854
855    #[test]
856    fn filter_library_registry_compat() {
857        let mut lib = FilterLibrary::new();
858        lib.register("a", Box::new(DoublerFilter));
859
860        let registry: &dyn FilterRegistry = &lib;
861        assert!(registry.meta("a").is_some());
862        assert_eq!(registry.meta("a").unwrap().name, "Doubler");
863        assert!(registry.config_hash("a").is_some());
864        assert!(registry.meta("b").is_none());
865    }
866}