Skip to main content

somatize_runtime/
graph_session.rs

1//! Graph session — the primary orchestrator for Graph → Compile → Execute.
2//!
3//! [`GraphSession`] binds a [`Graph`] with its [`FilterLibrary`], cache,
4//! event bus, and optional distributed components into a single object
5//! that can compile, fit, and execute.
6
7use crate::cache::MemoryCache;
8use crate::event_bus::EventBus;
9use crate::executor::{self, Context, GraphInfo, RemoteExecutor};
10use crate::filter_library::FilterLibrary;
11use somatize_compiler::{CompileMode, CompileResult, compile};
12use somatize_core::cache::{CacheKey, CacheStore};
13use somatize_core::error::{Result, SomaError};
14use somatize_core::event::Event;
15use somatize_core::filter::FilterKind;
16use somatize_core::graph::Graph;
17use somatize_core::store::{DataRef, DataStore};
18use somatize_core::util::timestamp_id;
19use somatize_core::value::Value;
20use std::collections::HashMap;
21use std::sync::Arc;
22
23/// The primary orchestrator: Graph + filters + cache + events.
24///
25/// ```ignore
26/// let mut lib = FilterLibrary::new();
27/// lib.register("scaler", Box::new(MyScaler::new()));
28/// lib.register("model", Box::new(MyModel::new()));
29///
30/// let mut session = GraphSession::new(graph, lib);
31/// session.fit(&train_x, Some(&train_y))?;
32/// let output = session.forward(&test_x)?;
33/// ```
34pub struct GraphSession {
35    graph: Graph,
36    library: FilterLibrary,
37    cache: Arc<dyn CacheStore>,
38    event_bus: Arc<EventBus>,
39    data_store: Option<Arc<dyn DataStore>>,
40    remote_executor: Option<Arc<dyn RemoteExecutor>>,
41    fitted: bool,
42}
43
44impl GraphSession {
45    pub fn new(graph: Graph, library: FilterLibrary) -> Self {
46        Self {
47            graph,
48            library,
49            cache: Arc::new(MemoryCache::default()),
50            event_bus: Arc::new(EventBus::new(256)),
51            data_store: None,
52            remote_executor: None,
53            fitted: false,
54        }
55    }
56
57    pub fn with_cache(mut self, cache: Arc<dyn CacheStore>) -> Self {
58        self.cache = cache;
59        self
60    }
61
62    pub fn with_event_bus(mut self, bus: Arc<EventBus>) -> Self {
63        self.event_bus = bus;
64        self
65    }
66
67    pub fn with_data_store(mut self, store: Arc<dyn DataStore>) -> Self {
68        self.data_store = Some(store);
69        self
70    }
71
72    pub fn with_remote_executor(mut self, executor: Arc<dyn RemoteExecutor>) -> Self {
73        self.remote_executor = Some(executor);
74        self
75    }
76
77    // ── Core operations ──
78
79    /// Compile the graph and return diagnostics without executing.
80    pub fn compile(&self, mode: CompileMode) -> Result<CompileResult> {
81        compile(&self.graph, &self.library, mode, Some(self.cache.as_ref()))
82    }
83
84    /// Compile and execute the graph, returning all node outputs.
85    pub fn run(&mut self, mode: CompileMode) -> Result<HashMap<String, Value>> {
86        let CompileResult { plan, diagnostics } =
87            compile(&self.graph, &self.library, mode, Some(self.cache.as_ref()))?;
88
89        for diag in &diagnostics {
90            tracing::warn!("compile diagnostic: {:?}", diag);
91        }
92
93        let graph_info = GraphInfo::from_graph(&self.graph);
94        let mut ctx = Context::new(self.event_bus.clone(), timestamp_id("graph_run"))
95            .with_graph_info(graph_info);
96
97        if let Some(store) = &self.data_store {
98            ctx = ctx.with_data_store(store.clone());
99        }
100        if let Some(remote) = &self.remote_executor {
101            ctx = ctx.with_remote_executor(remote.clone());
102        }
103
104        executor::execute(&plan, &mut ctx, &self.library, self.cache.as_ref())?;
105
106        Ok(ctx
107            .store
108            .into_iter()
109            .filter_map(|(k, vv)| vv.as_value().cloned().map(|v| (k, v)))
110            .collect())
111    }
112
113    /// Fit all trainable filters in topological order.
114    ///
115    /// For each node in topo order:
116    /// 1. Resolve input from predecessors
117    /// 2. If trainable: call `fit(input, y)`, cache state, then `forward(input, state)`
118    /// 3. If stateless: call `forward(input, Empty)`
119    /// 4. Store output for downstream nodes
120    pub fn fit(&mut self, x: &Value, y: Option<&Value>) -> Result<HashMap<String, Value>> {
121        self.graph.validate()?;
122        let sorted = self.graph.topological_sort()?;
123        let graph_info = GraphInfo::from_graph(&self.graph);
124
125        let run_id = timestamp_id("graph_fit");
126        let mut outputs: HashMap<String, Value> = HashMap::new();
127
128        // Set initial input for root nodes
129        let roots = self.graph.roots();
130        for root_id in &roots {
131            outputs.insert(format!("__input_{root_id}"), x.clone());
132        }
133
134        for node_id in &sorted {
135            let filter = self
136                .library
137                .get(node_id)
138                .ok_or_else(|| SomaError::NodeNotFound(node_id.to_string()))?;
139
140            self.event_bus.emit(Event::NodeStarted {
141                run_id: run_id.clone(),
142                node_id: node_id.to_string(),
143                kind: filter.meta().kind,
144            });
145
146            // Resolve input from predecessors
147            let preds = graph_info.predecessors(node_id);
148            let input = match preds.len() {
149                0 => x.clone(),
150                1 => outputs.get(&preds[0]).cloned().unwrap_or_else(|| x.clone()),
151                _ => {
152                    let mut merged = serde_json::Map::new();
153                    for pred_id in preds {
154                        if let Some(val) = outputs.get(pred_id.as_str()) {
155                            let json_val =
156                                serde_json::to_value(val).unwrap_or(serde_json::Value::Null);
157                            merged.insert(pred_id.clone(), json_val);
158                        }
159                    }
160                    Value::Json(serde_json::Value::Object(merged))
161                }
162            };
163
164            let meta = filter.meta();
165            let start = std::time::Instant::now();
166
167            // Fit trainable filters, forward all
168            let (state, output) = if meta.kind == FilterKind::Trainable {
169                let data_hash =
170                    CacheKey::hash_data(&serde_json::to_vec(&input).unwrap_or_default());
171                let state_key = CacheKey::for_state(&filter.config_hash(), &data_hash);
172
173                let state = if let Some(cached) = self.cache.get(&state_key)? {
174                    cached
175                } else {
176                    let s = filter.fit(&input, y)?;
177                    self.cache.put(&state_key, &s)?;
178                    s
179                };
180
181                let output = filter.forward(&input, &state)?;
182                // Store state in library for later use by forward()
183                self.library.set_state(node_id.to_string(), state.clone());
184                (state, output)
185            } else {
186                let output = filter.forward(&input, &Value::Empty)?;
187                (Value::Empty, output)
188            };
189
190            let _ = state; // state already cached above
191
192            self.event_bus.emit(Event::NodeCompleted {
193                run_id: run_id.clone(),
194                node_id: node_id.to_string(),
195                duration: start.elapsed(),
196                output_summary: format!("{output}"),
197            });
198
199            outputs.insert(node_id.to_string(), output);
200        }
201
202        self.fitted = true;
203        Ok(outputs)
204    }
205
206    /// Forward pass through the fitted graph (inference).
207    ///
208    /// Compiles in Inference mode and executes. Trainable filters
209    /// use states cached during `fit()`.
210    pub fn forward(&self, x: &Value) -> Result<Value> {
211        let CompileResult { plan, .. } = compile(
212            &self.graph,
213            &self.library,
214            CompileMode::Inference,
215            Some(self.cache.as_ref()),
216        )?;
217
218        let graph_info = GraphInfo::from_graph(&self.graph);
219        let mut ctx = Context::new(self.event_bus.clone(), timestamp_id("graph_forward"))
220            .with_graph_info(graph_info);
221
222        if let Some(store) = &self.data_store {
223            ctx = ctx.with_data_store(store.clone());
224        }
225        if let Some(remote) = &self.remote_executor {
226            ctx = ctx.with_remote_executor(remote.clone());
227        }
228
229        // Set input for root nodes
230        let roots = self.graph.roots();
231        if roots.len() == 1 {
232            ctx.set(format!("__input_{}", roots[0]), x.clone());
233        }
234        ctx.set("__input__", x.clone());
235
236        executor::execute(&plan, &mut ctx, &self.library, self.cache.as_ref())?;
237
238        // Return the last leaf's output
239        let leaves = self.graph.leaves();
240        let mut extract = |id: &str| -> Option<Value> {
241            ctx.store.remove(id).and_then(|vv| vv.as_value().cloned())
242        };
243
244        if let Some(leaf_id) = leaves.first() {
245            extract(leaf_id).ok_or_else(|| {
246                SomaError::Other(format!("leaf node '{leaf_id}' produced no output"))
247            })
248        } else {
249            ctx.execution_order
250                .last()
251                .and_then(|id| extract(id))
252                .ok_or_else(|| SomaError::Other("no output produced".into()))
253        }
254    }
255
256    /// Forward pass in batches from a DataStore reference.
257    pub fn forward_batched(&self, data_ref: &DataRef, batch_size: usize) -> Result<Value> {
258        let store = self
259            .data_store
260            .as_ref()
261            .ok_or_else(|| SomaError::Execution {
262                node_id: "session".into(),
263                message: "forward_batched requires a data store (use with_data_store)".into(),
264            })?;
265
266        let meta = store.meta(data_ref)?;
267        let total_rows = meta.total_rows;
268        if total_rows == 0 {
269            return Ok(Value::Empty);
270        }
271
272        let mut all_values: Vec<f64> = Vec::new();
273        let mut result_shape: Option<Vec<usize>> = None;
274        let mut rows_processed = 0;
275
276        while rows_processed < total_rows {
277            let batch_len = batch_size.min(total_rows - rows_processed);
278            let batch = store.get_rows(data_ref, rows_processed, batch_len)?;
279            let output = self.forward(&batch)?;
280
281            if let Value::Tensor { values, shape } = &output {
282                if result_shape.is_none() {
283                    result_shape = Some(shape.clone());
284                }
285                all_values.extend_from_slice(values);
286            } else {
287                return Ok(output);
288            }
289
290            rows_processed += batch_len;
291        }
292
293        match result_shape {
294            Some(mut shape) => {
295                shape[0] = total_rows;
296                Ok(Value::tensor(all_values, shape))
297            }
298            None => Ok(Value::Empty),
299        }
300    }
301
302    // ── State persistence ──
303
304    /// Persist all trained states to the data store.
305    pub fn persist_states(&self) -> Result<DataRef> {
306        let store = self
307            .data_store
308            .as_ref()
309            .ok_or_else(|| SomaError::Execution {
310                node_id: "session".into(),
311                message: "persist_states requires a data store".into(),
312            })?;
313
314        let sorted = self.graph.topological_sort()?;
315        let mut states_map = serde_json::Map::new();
316        for node_id in &sorted {
317            if let Some(state) = self.library.get_state(node_id) {
318                let json = serde_json::to_value(state)
319                    .map_err(|e| SomaError::Other(format!("state serialize: {e}")))?;
320                states_map.insert(node_id.to_string(), json);
321            }
322        }
323
324        let states_value = Value::Json(serde_json::Value::Object(states_map));
325        let key = CacheKey::from_parts(&[b"graph_states", self.graph_config_hash().as_bytes()]);
326        store.put(&key, &states_value)
327    }
328
329    /// Load previously persisted states from a data store reference.
330    pub fn load_states(&mut self, data_ref: &DataRef) -> Result<()> {
331        let store = self
332            .data_store
333            .as_ref()
334            .ok_or_else(|| SomaError::Execution {
335                node_id: "session".into(),
336                message: "load_states requires a data store".into(),
337            })?;
338
339        let states_value = store.get(data_ref)?;
340        let states_json = states_value
341            .as_json()
342            .ok_or_else(|| SomaError::Other("persisted states must be JSON".into()))?;
343        let obj = states_json
344            .as_object()
345            .ok_or_else(|| SomaError::Other("persisted states must be a JSON object".into()))?;
346
347        for (node_id, json_val) in obj {
348            let value: Value = serde_json::from_value(json_val.clone())
349                .map_err(|e| SomaError::Other(format!("state deserialize: {e}")))?;
350            self.library.set_state(node_id.clone(), value);
351        }
352
353        self.fitted = true;
354        Ok(())
355    }
356
357    // ── Observability ──
358
359    /// Subscribe to execution events.
360    pub fn subscribe(&self) -> tokio::sync::broadcast::Receiver<Event> {
361        self.event_bus.subscribe()
362    }
363
364    /// Access the event bus directly.
365    pub fn event_bus(&self) -> &Arc<EventBus> {
366        &self.event_bus
367    }
368
369    /// Whether the session has been fitted.
370    pub fn is_fitted(&self) -> bool {
371        self.fitted
372    }
373
374    /// Access the graph.
375    pub fn graph(&self) -> &Graph {
376        &self.graph
377    }
378
379    /// Access the filter library.
380    pub fn library(&self) -> &FilterLibrary {
381        &self.library
382    }
383
384    /// Mutable access to the filter library (for registering filters after creation).
385    pub fn library_mut(&mut self) -> &mut FilterLibrary {
386        &mut self.library
387    }
388
389    // ── Private helpers ──
390
391    fn graph_config_hash(&self) -> String {
392        let node_ids: Vec<&str> = self.graph.nodes.iter().map(|n| n.id.as_str()).collect();
393        node_ids.join(",")
394    }
395}
396
397// ── Convenience free functions (backward compat) ──
398
399/// Compile and execute a graph, returning all node outputs.
400pub fn graph_run(
401    graph: &Graph,
402    library: &FilterLibrary,
403    mode: CompileMode,
404    cache: &dyn CacheStore,
405) -> Result<HashMap<String, Value>> {
406    let CompileResult { plan, diagnostics } = compile(graph, library, mode, Some(cache))?;
407
408    for diag in &diagnostics {
409        tracing::warn!("compile diagnostic: {:?}", diag);
410    }
411
412    let bus = Arc::new(EventBus::new(256));
413    let graph_info = GraphInfo::from_graph(graph);
414
415    let mut ctx = Context::new(bus, timestamp_id("graph_run")).with_graph_info(graph_info);
416
417    executor::execute(&plan, &mut ctx, library, cache)?;
418
419    Ok(ctx
420        .store
421        .into_iter()
422        .filter_map(|(k, vv)| vv.as_value().cloned().map(|v| (k, v)))
423        .collect())
424}
425
426/// Fit all trainable filters in topological order.
427pub fn graph_fit(
428    graph: &Graph,
429    library: &FilterLibrary,
430    x: &Value,
431    y: Option<&Value>,
432    cache: &dyn CacheStore,
433) -> Result<HashMap<String, Value>> {
434    graph.validate()?;
435    let sorted = graph.topological_sort()?;
436    let graph_info = GraphInfo::from_graph(graph);
437
438    let bus = Arc::new(EventBus::new(256));
439    let run_id = timestamp_id("graph_fit");
440
441    let mut outputs: HashMap<String, Value> = HashMap::new();
442
443    let roots = graph.roots();
444    for root_id in &roots {
445        outputs.insert(format!("__input_{root_id}"), x.clone());
446    }
447
448    for node_id in &sorted {
449        let filter = library
450            .get(node_id)
451            .ok_or_else(|| SomaError::NodeNotFound(node_id.to_string()))?;
452
453        bus.emit(Event::NodeStarted {
454            run_id: run_id.clone(),
455            node_id: node_id.to_string(),
456            kind: filter.meta().kind,
457        });
458
459        let preds = graph_info.predecessors(node_id);
460        let input = match preds.len() {
461            0 => x.clone(),
462            1 => outputs.get(&preds[0]).cloned().unwrap_or_else(|| x.clone()),
463            _ => {
464                let mut merged = serde_json::Map::new();
465                for pred_id in preds {
466                    if let Some(val) = outputs.get(pred_id.as_str()) {
467                        let json_val = serde_json::to_value(val).unwrap_or(serde_json::Value::Null);
468                        merged.insert(pred_id.clone(), json_val);
469                    }
470                }
471                Value::Json(serde_json::Value::Object(merged))
472            }
473        };
474
475        let meta = filter.meta();
476        let start = std::time::Instant::now();
477
478        let (state, output) = if meta.kind == FilterKind::Trainable {
479            let data_hash = CacheKey::hash_data(&serde_json::to_vec(&input).unwrap_or_default());
480            let state_key = CacheKey::for_state(&filter.config_hash(), &data_hash);
481
482            let state = if let Some(cached) = cache.get(&state_key)? {
483                cached
484            } else {
485                let s = filter.fit(&input, y)?;
486                cache.put(&state_key, &s)?;
487                s
488            };
489
490            let output = filter.forward(&input, &state)?;
491            (state, output)
492        } else {
493            let output = filter.forward(&input, &Value::Empty)?;
494            (Value::Empty, output)
495        };
496
497        let _ = state;
498
499        bus.emit(Event::NodeCompleted {
500            run_id: run_id.clone(),
501            node_id: node_id.to_string(),
502            duration: start.elapsed(),
503            output_summary: format!("{output}"),
504        });
505
506        outputs.insert(node_id.to_string(), output);
507    }
508
509    Ok(outputs)
510}
511
512/// Compile in Inference mode and execute, returning the last leaf's output.
513pub fn graph_predict(
514    graph: &Graph,
515    library: &FilterLibrary,
516    x: &Value,
517    cache: &dyn CacheStore,
518) -> Result<Value> {
519    let CompileResult { plan, .. } = compile(graph, library, CompileMode::Inference, Some(cache))?;
520
521    let bus = Arc::new(EventBus::new(256));
522    let graph_info = GraphInfo::from_graph(graph);
523    let mut ctx = Context::new(bus, timestamp_id("graph_predict")).with_graph_info(graph_info);
524
525    let roots = graph.roots();
526    if roots.len() == 1 {
527        ctx.set(format!("__input_{}", roots[0]), x.clone());
528    }
529    ctx.set("__input__", x.clone());
530
531    executor::execute(&plan, &mut ctx, library, cache)?;
532
533    let leaves = graph.leaves();
534    let mut extract =
535        |id: &str| -> Option<Value> { ctx.store.remove(id).and_then(|vv| vv.as_value().cloned()) };
536
537    if let Some(leaf_id) = leaves.first() {
538        extract(leaf_id)
539            .ok_or_else(|| SomaError::Other(format!("leaf node '{leaf_id}' produced no output")))
540    } else {
541        ctx.execution_order
542            .last()
543            .and_then(|id| extract(id))
544            .ok_or_else(|| SomaError::Other("no output produced".into()))
545    }
546}
547
548#[cfg(test)]
549mod tests {
550    use super::*;
551    use crate::cache::MemoryCache;
552    use somatize_compiler::FilterRegistry;
553    use somatize_core::cache::CacheKey;
554    use somatize_core::error::Result;
555    use somatize_core::filter::{FilterKind, FilterMeta, StreamMode};
556    use somatize_core::graph::{Edge, Node};
557
558    // ── Test filters ──
559
560    struct DoublerFilter;
561    impl somatize_core::filter::Filter for DoublerFilter {
562        fn config_hash(&self) -> CacheKey {
563            CacheKey::from_parts(&[b"Doubler"])
564        }
565        fn fit(&self, _x: &Value, _y: Option<&Value>) -> Result<Value> {
566            Ok(Value::Empty)
567        }
568        fn forward(&self, x: &Value, _state: &Value) -> Result<Value> {
569            let (data, shape) = x
570                .as_tensor()
571                .ok_or(SomaError::Other("need tensor".into()))?;
572            Ok(Value::tensor(
573                data.iter().map(|v| v * 2.0).collect(),
574                shape.to_vec(),
575            ))
576        }
577        fn meta(&self) -> FilterMeta {
578            FilterMeta {
579                name: "Doubler".into(),
580                kind: FilterKind::Stateless,
581                cacheable: true,
582                differentiable: true,
583                stream_mode: StreamMode::FixedState,
584                distribution: somatize_core::filter::Distribution::Local,
585                input_schema: None,
586                output_schema: None,
587            }
588        }
589    }
590
591    struct AdderFilter(f64);
592    impl somatize_core::filter::Filter for AdderFilter {
593        fn config_hash(&self) -> CacheKey {
594            CacheKey::from_parts(&[b"Adder", &self.0.to_le_bytes()])
595        }
596        fn fit(&self, _x: &Value, _y: Option<&Value>) -> Result<Value> {
597            Ok(Value::Empty)
598        }
599        fn forward(&self, x: &Value, _state: &Value) -> Result<Value> {
600            let (data, shape) = x
601                .as_tensor()
602                .ok_or(SomaError::Other("need tensor".into()))?;
603            Ok(Value::tensor(
604                data.iter().map(|v| v + self.0).collect(),
605                shape.to_vec(),
606            ))
607        }
608        fn meta(&self) -> FilterMeta {
609            FilterMeta {
610                name: "Adder".into(),
611                kind: FilterKind::Stateless,
612                cacheable: true,
613                differentiable: true,
614                stream_mode: StreamMode::FixedState,
615                distribution: somatize_core::filter::Distribution::Local,
616                input_schema: None,
617                output_schema: None,
618            }
619        }
620    }
621
622    struct MeanFilter;
623    impl somatize_core::filter::Filter for MeanFilter {
624        fn config_hash(&self) -> CacheKey {
625            CacheKey::from_parts(&[b"Mean"])
626        }
627        fn fit(&self, x: &Value, _y: Option<&Value>) -> Result<Value> {
628            let (data, _) = x
629                .as_tensor()
630                .ok_or(SomaError::Other("need tensor".into()))?;
631            let mean = data.iter().sum::<f64>() / data.len() as f64;
632            Ok(Value::json(serde_json::json!({ "mean": mean })))
633        }
634        fn forward(&self, x: &Value, state: &Value) -> Result<Value> {
635            let (data, shape) = x
636                .as_tensor()
637                .ok_or(SomaError::Other("need tensor".into()))?;
638            let mean = state
639                .as_json()
640                .and_then(|j| j["mean"].as_f64())
641                .unwrap_or(0.0);
642            Ok(Value::tensor(
643                data.iter().map(|v| v - mean).collect(),
644                shape.to_vec(),
645            ))
646        }
647        fn meta(&self) -> FilterMeta {
648            FilterMeta {
649                name: "Mean".into(),
650                kind: FilterKind::Trainable,
651                cacheable: true,
652                differentiable: true,
653                stream_mode: StreamMode::FixedState,
654                distribution: somatize_core::filter::Distribution::Local,
655                input_schema: None,
656                output_schema: None,
657            }
658        }
659    }
660
661    fn linear_graph(ids: &[&str]) -> Graph {
662        let mut g = Graph::new();
663        for &id in ids {
664            g.nodes.push(Node::new(id, id, id));
665        }
666        for (i, pair) in ids.windows(2).enumerate() {
667            g.edges.push(Edge::data(format!("e{i}"), pair[0], pair[1]));
668        }
669        g
670    }
671
672    // ── GraphSession tests ──
673
674    #[test]
675    fn session_run_linear() {
676        let graph = linear_graph(&["double", "add"]);
677        let mut lib = FilterLibrary::new();
678        lib.register("double", Box::new(DoublerFilter));
679        lib.register("add", Box::new(AdderFilter(10.0)));
680
681        let mut session = GraphSession::new(graph, lib);
682
683        let cache = MemoryCache::default();
684        session = session.with_cache(Arc::new(cache));
685
686        // Manual compile + execute via run
687        let CompileResult { plan, .. } = session.compile(CompileMode::NoCache).unwrap();
688        let bus = Arc::new(EventBus::new(64));
689        let mut ctx =
690            Context::new(bus, "test").with_graph_info(GraphInfo::from_graph(session.graph()));
691        ctx.set("__input__", Value::tensor(vec![1.0, 2.0, 3.0], vec![3]));
692        executor::execute(&plan, &mut ctx, session.library(), &MemoryCache::default()).unwrap();
693
694        let outputs: HashMap<String, Value> = ctx
695            .store
696            .into_iter()
697            .filter_map(|(k, vv)| vv.as_value().cloned().map(|v| (k, v)))
698            .collect();
699
700        let result = outputs.get("add").unwrap();
701        let (data, _) = result.as_tensor().unwrap();
702        assert_eq!(data, &[12.0, 14.0, 16.0]);
703    }
704
705    #[test]
706    fn session_fit_and_forward() {
707        let graph = linear_graph(&["mean", "double"]);
708        let mut lib = FilterLibrary::new();
709        lib.register("mean", Box::new(MeanFilter));
710        lib.register("double", Box::new(DoublerFilter));
711
712        let mut session = GraphSession::new(graph, lib);
713
714        let x = Value::tensor(vec![10.0, 20.0, 30.0], vec![3]);
715        let outputs = session.fit(&x, None).unwrap();
716
717        // mean: fit learns mean=20, forward: [10-20, 20-20, 30-20] = [-10, 0, 10]
718        // double: [-10, 0, 10] → [-20, 0, 20]
719        let result = outputs.get("double").unwrap();
720        let (data, _) = result.as_tensor().unwrap();
721        assert_eq!(data, &[-20.0, 0.0, 20.0]);
722
723        assert!(session.is_fitted());
724    }
725
726    #[test]
727    fn session_compile_diagnostics() {
728        let graph = linear_graph(&["double"]);
729        let mut lib = FilterLibrary::new();
730        lib.register("double", Box::new(DoublerFilter));
731
732        let session = GraphSession::new(graph, lib);
733        let result = session.compile(CompileMode::NoCache).unwrap();
734        assert!(result.plan.node_count() > 0);
735    }
736
737    // ── Free function tests (backward compat) ──
738
739    #[test]
740    fn graph_run_linear() {
741        let graph = linear_graph(&["double", "add"]);
742        let mut lib = FilterLibrary::new();
743        lib.register("double", Box::new(DoublerFilter));
744        lib.register("add", Box::new(AdderFilter(10.0)));
745
746        let cache = MemoryCache::default();
747
748        let outputs = {
749            let CompileResult { plan, .. } =
750                compile(&graph, &lib, CompileMode::NoCache, None).unwrap();
751            let bus = Arc::new(EventBus::new(64));
752            let mut ctx = Context::new(bus, "test").with_graph_info(GraphInfo::from_graph(&graph));
753            ctx.set("__input__", Value::tensor(vec![1.0, 2.0, 3.0], vec![3]));
754            executor::execute(&plan, &mut ctx, &lib, &cache).unwrap();
755            ctx.store
756                .into_iter()
757                .filter_map(|(k, vv)| vv.as_value().cloned().map(|v| (k, v)))
758                .collect::<HashMap<String, Value>>()
759        };
760
761        let result = outputs.get("add").unwrap();
762        let (data, _) = result.as_tensor().unwrap();
763        assert_eq!(data, &[12.0, 14.0, 16.0]);
764    }
765
766    #[test]
767    fn graph_run_diamond() {
768        let mut graph = Graph::new();
769        graph.nodes.push(Node::new("double", "Double", "double"));
770        graph.nodes.push(Node::new("add", "Add", "add"));
771        graph.nodes.push(Node::new("merge", "Merge", "merge"));
772        graph.edges.push(Edge::data("e1", "double", "merge"));
773        graph.edges.push(Edge::data("e2", "add", "merge"));
774
775        let mut lib = FilterLibrary::new();
776        lib.register("double", Box::new(DoublerFilter));
777        lib.register("add", Box::new(AdderFilter(100.0)));
778
779        struct MergeFilter;
780        impl somatize_core::filter::Filter for MergeFilter {
781            fn config_hash(&self) -> CacheKey {
782                CacheKey::from_parts(&[b"Merge"])
783            }
784            fn fit(&self, _x: &Value, _y: Option<&Value>) -> Result<Value> {
785                Ok(Value::Empty)
786            }
787            fn forward(&self, x: &Value, _state: &Value) -> Result<Value> {
788                Ok(x.clone())
789            }
790            fn meta(&self) -> FilterMeta {
791                FilterMeta {
792                    name: "Merge".into(),
793                    kind: FilterKind::Stateless,
794                    cacheable: true,
795                    differentiable: false,
796                    stream_mode: StreamMode::FixedState,
797                    distribution: somatize_core::filter::Distribution::Local,
798                    input_schema: None,
799                    output_schema: None,
800                }
801            }
802        }
803        lib.register("merge", Box::new(MergeFilter));
804
805        let cache = MemoryCache::default();
806        let CompileResult { plan, .. } = compile(&graph, &lib, CompileMode::NoCache, None).unwrap();
807
808        let bus = Arc::new(EventBus::new(64));
809        let mut ctx = Context::new(bus, "test").with_graph_info(GraphInfo::from_graph(&graph));
810        ctx.set("__input__", Value::tensor(vec![5.0], vec![1]));
811        executor::execute(&plan, &mut ctx, &lib, &cache).unwrap();
812
813        let merge_output = ctx.get("merge").unwrap();
814        assert!(
815            merge_output.as_json().is_some(),
816            "merge should receive JSON from multiple predecessors"
817        );
818    }
819
820    #[test]
821    fn graph_fit_trainable() {
822        let graph = linear_graph(&["mean", "double"]);
823        let mut lib = FilterLibrary::new();
824        lib.register("mean", Box::new(MeanFilter));
825        lib.register("double", Box::new(DoublerFilter));
826
827        let cache = MemoryCache::default();
828        let x = Value::tensor(vec![10.0, 20.0, 30.0], vec![3]);
829
830        let outputs = graph_fit(&graph, &lib, &x, None, &cache).unwrap();
831
832        let result = outputs.get("double").unwrap();
833        let (data, _) = result.as_tensor().unwrap();
834        assert_eq!(data, &[-20.0, 0.0, 20.0]);
835
836        assert!(!cache.is_empty());
837    }
838
839    #[test]
840    fn filter_library_registry_compat() {
841        let mut lib = FilterLibrary::new();
842        lib.register("a", Box::new(DoublerFilter));
843
844        let registry: &dyn FilterRegistry = &lib;
845        assert!(registry.meta("a").is_some());
846        assert_eq!(registry.meta("a").unwrap().name, "Doubler");
847        assert!(registry.config_hash("a").is_some());
848        assert!(registry.meta("b").is_none());
849    }
850}