1use crate::cache::MemoryCache;
8use crate::event_bus::EventBus;
9use crate::executor::{self, Context, GraphInfo};
10use crate::filter_library::FilterLibrary;
11use crate::runner::Runner;
12use crate::runner::Transport;
13use somatize_compiler::{CompileMode, CompileResult, compile};
14use somatize_core::cache::{CacheKey, CacheStore};
15use somatize_core::error::{Result, SomaError};
16use somatize_core::event::Event;
17use somatize_core::filter::FilterKind;
18use somatize_core::graph::Graph;
19use somatize_core::store::{DataRef, DataStore};
20use somatize_core::util::timestamp_id;
21use somatize_core::value::Value;
22use std::collections::HashMap;
23use std::sync::Arc;
24
25pub struct GraphSession {
37 graph: Graph,
38 library: FilterLibrary,
39 cache: Arc<dyn CacheStore>,
40 event_bus: Arc<EventBus>,
41 data_store: Option<Arc<dyn DataStore>>,
42 transport: Option<Arc<dyn Transport>>,
43 fitted: bool,
44}
45
46impl GraphSession {
47 pub fn new(graph: Graph, library: FilterLibrary) -> Self {
48 Self {
49 graph,
50 library,
51 cache: Arc::new(MemoryCache::default()),
52 event_bus: Arc::new(EventBus::new(256)),
53 data_store: None,
54 transport: None,
55 fitted: false,
56 }
57 }
58
59 pub fn with_cache(mut self, cache: Arc<dyn CacheStore>) -> Self {
60 self.cache = cache;
61 self
62 }
63
64 pub fn with_event_bus(mut self, bus: Arc<EventBus>) -> Self {
65 self.event_bus = bus;
66 self
67 }
68
69 pub fn with_data_store(mut self, store: Arc<dyn DataStore>) -> Self {
70 self.data_store = Some(store);
71 self
72 }
73
74 pub fn with_transport(mut self, transport: Arc<dyn Transport>) -> Self {
75 self.transport = Some(transport);
76 self
77 }
78
79 pub fn compile(&self, mode: CompileMode) -> Result<CompileResult> {
83 compile(&self.graph, &self.library, mode, Some(self.cache.as_ref()))
84 }
85
86 pub fn run(&mut self, mode: CompileMode) -> Result<HashMap<String, Value>> {
88 let CompileResult { plan, diagnostics } =
89 compile(&self.graph, &self.library, mode, Some(self.cache.as_ref()))?;
90
91 for diag in &diagnostics {
92 tracing::warn!("compile diagnostic: {:?}", diag);
93 }
94
95 let graph_info = GraphInfo::from_graph(&self.graph);
96 let mut ctx = Context::new(self.event_bus.clone(), timestamp_id("graph_run"))
97 .with_graph_info(graph_info);
98
99 if let Some(store) = &self.data_store {
100 ctx = ctx.with_data_store(store.clone());
101 }
102 if let Some(transport) = &self.transport {
103 ctx = ctx.with_transport(transport.clone());
104 }
105
106 executor::execute(&plan, &mut ctx, &self.library, self.cache.as_ref())?;
107
108 Ok(ctx
109 .store
110 .into_iter()
111 .filter_map(|(k, vv)| vv.as_value().cloned().map(|v| (k, v)))
112 .collect())
113 }
114
115 pub fn fit(&mut self, x: &Value, y: Option<&Value>) -> Result<HashMap<String, Value>> {
118 self.graph.validate()?;
119
120 let CompileResult { plan, .. } = compile(
121 &self.graph,
122 &self.library,
123 CompileMode::NoCache,
124 Some(self.cache.as_ref()),
125 )?;
126
127 let runner = crate::runner::LocalRunner;
128 let (_last_output, mut all_outputs) = runner.fit(
129 &plan,
130 &self.library,
131 self.cache.as_ref(),
132 &self.event_bus,
133 x,
134 y,
135 )?;
136
137 for (key, value) in &all_outputs {
139 if let Some(node_id) = key.strip_prefix("__state_") {
140 self.library.set_state(node_id, value.clone());
141 }
142 }
143
144 all_outputs.retain(|k, _| !k.starts_with("__state_"));
146
147 self.fitted = true;
148 Ok(all_outputs)
149 }
150
151 pub fn forward_with(
158 &self,
159 x: &Value,
160 strategy: &dyn crate::forward::ForwardStrategy,
161 ) -> Result<Value> {
162 strategy.forward(
163 &self.graph,
164 &self.library,
165 self.cache.as_ref(),
166 &self.event_bus,
167 self.data_store.as_ref(),
168 x,
169 )
170 }
171
172 pub fn forward(&self, x: &Value) -> Result<Value> {
174 self.forward_with(x, &crate::forward::Standard)
175 }
176
177 pub fn persist_states(&self) -> Result<DataRef> {
181 let store = self
182 .data_store
183 .as_ref()
184 .ok_or_else(|| SomaError::Execution {
185 node_id: "session".into(),
186 message: "persist_states requires a data store".into(),
187 })?;
188
189 let sorted = self.graph.topological_sort()?;
190 let mut states_map = serde_json::Map::new();
191 for node_id in &sorted {
192 if let Some(state) = self.library.get_state(node_id) {
193 let json = serde_json::to_value(&*state)
194 .map_err(|e| SomaError::Other(format!("state serialize: {e}")))?;
195 states_map.insert(node_id.to_string(), json);
196 }
197 }
198
199 let states_value = Value::json(serde_json::Value::Object(states_map));
200 let key = CacheKey::from_parts(&[b"graph_states", self.graph_config_hash().as_bytes()]);
201 store.put(&key, &states_value)
202 }
203
204 pub fn load_states(&mut self, data_ref: &DataRef) -> Result<()> {
206 let store = self
207 .data_store
208 .as_ref()
209 .ok_or_else(|| SomaError::Execution {
210 node_id: "session".into(),
211 message: "load_states requires a data store".into(),
212 })?;
213
214 let states_value = store.get(data_ref)?;
215 let states_json = states_value
216 .as_json()
217 .ok_or_else(|| SomaError::Other("persisted states must be JSON".into()))?;
218 let obj = states_json
219 .as_object()
220 .ok_or_else(|| SomaError::Other("persisted states must be a JSON object".into()))?;
221
222 for (node_id, json_val) in obj {
223 let value: Value = serde_json::from_value(json_val.clone())
224 .map_err(|e| SomaError::Other(format!("state deserialize: {e}")))?;
225 self.library.set_state(node_id.clone(), value);
226 }
227
228 self.fitted = true;
229 Ok(())
230 }
231
232 pub fn subscribe(&self) -> tokio::sync::broadcast::Receiver<Event> {
236 self.event_bus.subscribe()
237 }
238
239 pub fn event_bus(&self) -> &Arc<EventBus> {
241 &self.event_bus
242 }
243
244 pub fn is_fitted(&self) -> bool {
246 self.fitted
247 }
248
249 pub fn graph(&self) -> &Graph {
251 &self.graph
252 }
253
254 pub fn library(&self) -> &FilterLibrary {
256 &self.library
257 }
258
259 pub fn library_mut(&mut self) -> &mut FilterLibrary {
261 &mut self.library
262 }
263
264 fn graph_config_hash(&self) -> String {
267 let node_ids: Vec<&str> = self.graph.nodes.iter().map(|n| n.id.as_str()).collect();
268 node_ids.join(",")
269 }
270}
271
272pub fn graph_run(
276 graph: &Graph,
277 library: &FilterLibrary,
278 mode: CompileMode,
279 cache: &dyn CacheStore,
280) -> Result<HashMap<String, Value>> {
281 let CompileResult { plan, diagnostics } = compile(graph, library, mode, Some(cache))?;
282
283 for diag in &diagnostics {
284 tracing::warn!("compile diagnostic: {:?}", diag);
285 }
286
287 let bus = Arc::new(EventBus::new(256));
288 let graph_info = GraphInfo::from_graph(graph);
289
290 let mut ctx = Context::new(bus, timestamp_id("graph_run")).with_graph_info(graph_info);
291
292 executor::execute(&plan, &mut ctx, library, cache)?;
293
294 Ok(ctx
295 .store
296 .into_iter()
297 .filter_map(|(k, vv)| vv.as_value().cloned().map(|v| (k, v)))
298 .collect())
299}
300
301pub fn graph_fit(
303 graph: &Graph,
304 library: &FilterLibrary,
305 x: &Value,
306 y: Option<&Value>,
307 cache: &dyn CacheStore,
308) -> Result<HashMap<String, Value>> {
309 graph.validate()?;
310 let sorted = graph.topological_sort()?;
311 let graph_info = GraphInfo::from_graph(graph);
312
313 let bus = Arc::new(EventBus::new(256));
314 let run_id = timestamp_id("graph_fit");
315
316 let mut outputs: HashMap<String, Value> = HashMap::new();
317
318 let roots = graph.roots();
319 for root_id in &roots {
320 outputs.insert(format!("__input_{root_id}"), x.clone());
321 }
322
323 for node_id in &sorted {
324 let filter = library
325 .get(node_id)
326 .ok_or_else(|| SomaError::NodeNotFound(node_id.to_string()))?;
327
328 bus.emit(Event::NodeStarted {
329 run_id: run_id.clone(),
330 node_id: node_id.to_string(),
331 kind: filter.meta().kind,
332 });
333
334 let preds = graph_info.predecessors(node_id);
335 let input = match preds.len() {
336 0 => x.clone(),
337 1 => outputs.get(&preds[0]).cloned().unwrap_or_else(|| x.clone()),
338 _ => {
339 let mut merged = serde_json::Map::new();
340 for pred_id in preds {
341 if let Some(val) = outputs.get(pred_id.as_str()) {
342 let json_val = serde_json::to_value(val).unwrap_or(serde_json::Value::Null);
343 merged.insert(pred_id.clone(), json_val);
344 }
345 }
346 Value::json(serde_json::Value::Object(merged))
347 }
348 };
349
350 let meta = filter.meta();
351 let start = std::time::Instant::now();
352
353 let (state, output) = if meta.kind == FilterKind::Trainable {
354 let data_hash = CacheKey::hash_data(&serde_json::to_vec(&input).unwrap_or_default());
355 let state_key = CacheKey::for_state(&filter.config_hash(), &data_hash);
356
357 let state = if let Some(cached) = cache.get(&state_key)? {
358 cached
359 } else {
360 let s = filter.fit(&input, y)?;
361 cache.put(&state_key, &s)?;
362 s
363 };
364
365 let output = filter.forward(&input, &state)?;
366 (state, output)
367 } else {
368 let output = filter.forward(&input, &Value::Empty)?;
369 (Value::Empty, output)
370 };
371
372 let _ = state;
373
374 bus.emit(Event::NodeCompleted {
375 run_id: run_id.clone(),
376 node_id: node_id.to_string(),
377 duration: start.elapsed(),
378 output_summary: format!("{output}"),
379 });
380
381 outputs.insert(node_id.to_string(), output);
382 }
383
384 Ok(outputs)
385}
386
387pub fn graph_predict(
389 graph: &Graph,
390 library: &FilterLibrary,
391 x: &Value,
392 cache: &dyn CacheStore,
393) -> Result<Value> {
394 let CompileResult { plan, .. } = compile(graph, library, CompileMode::Inference, Some(cache))?;
395
396 let bus = Arc::new(EventBus::new(256));
397 let graph_info = GraphInfo::from_graph(graph);
398 let mut ctx = Context::new(bus, timestamp_id("graph_predict")).with_graph_info(graph_info);
399
400 let roots = graph.roots();
401 if roots.len() == 1 {
402 ctx.set(format!("__input_{}", roots[0]), x.clone());
403 }
404 ctx.set("__input__", x.clone());
405
406 executor::execute(&plan, &mut ctx, library, cache)?;
407
408 let leaves = graph.leaves();
409 let mut extract =
410 |id: &str| -> Option<Value> { ctx.store.remove(id).and_then(|vv| vv.as_value().cloned()) };
411
412 if let Some(leaf_id) = leaves.first() {
413 extract(leaf_id)
414 .ok_or_else(|| SomaError::Other(format!("leaf node '{leaf_id}' produced no output")))
415 } else {
416 ctx.execution_order
417 .last()
418 .and_then(|id| extract(id))
419 .ok_or_else(|| SomaError::Other("no output produced".into()))
420 }
421}
422
423#[cfg(test)]
424mod tests {
425 use super::*;
426 use crate::cache::MemoryCache;
427 use somatize_compiler::FilterRegistry;
428 use somatize_core::cache::CacheKey;
429 use somatize_core::error::Result;
430 use somatize_core::filter::{FilterKind, FilterMeta, StreamMode};
431 use somatize_core::graph::{Edge, Node};
432
433 struct DoublerFilter;
436 impl somatize_core::filter::Filter for DoublerFilter {
437 fn config_hash(&self) -> CacheKey {
438 CacheKey::from_parts(&[b"Doubler"])
439 }
440 fn fit(&self, _x: &Value, _y: Option<&Value>) -> Result<Value> {
441 Ok(Value::Empty)
442 }
443 fn forward(&self, x: &Value, _state: &Value) -> Result<Value> {
444 let (data, shape) = x
445 .as_tensor()
446 .ok_or(SomaError::Other("need tensor".into()))?;
447 Ok(Value::tensor(
448 data.iter().map(|v| v * 2.0).collect(),
449 shape.to_vec(),
450 ))
451 }
452 fn meta(&self) -> FilterMeta {
453 FilterMeta {
454 name: "Doubler".into(),
455 kind: FilterKind::Stateless,
456 cacheable: true,
457 differentiable: true,
458 stream_mode: StreamMode::FixedState,
459 distribution: somatize_core::filter::Distribution::Local,
460 input_schema: None,
461 output_schema: None,
462 }
463 }
464
465 fn as_any(&self) -> &dyn std::any::Any {
466 self
467 }
468 }
469
470 struct AdderFilter(f64);
471 impl somatize_core::filter::Filter for AdderFilter {
472 fn config_hash(&self) -> CacheKey {
473 CacheKey::from_parts(&[b"Adder", &self.0.to_le_bytes()])
474 }
475 fn fit(&self, _x: &Value, _y: Option<&Value>) -> Result<Value> {
476 Ok(Value::Empty)
477 }
478 fn forward(&self, x: &Value, _state: &Value) -> Result<Value> {
479 let (data, shape) = x
480 .as_tensor()
481 .ok_or(SomaError::Other("need tensor".into()))?;
482 Ok(Value::tensor(
483 data.iter().map(|v| v + self.0).collect(),
484 shape.to_vec(),
485 ))
486 }
487 fn meta(&self) -> FilterMeta {
488 FilterMeta {
489 name: "Adder".into(),
490 kind: FilterKind::Stateless,
491 cacheable: true,
492 differentiable: true,
493 stream_mode: StreamMode::FixedState,
494 distribution: somatize_core::filter::Distribution::Local,
495 input_schema: None,
496 output_schema: None,
497 }
498 }
499
500 fn as_any(&self) -> &dyn std::any::Any {
501 self
502 }
503 }
504
505 struct MeanFilter;
506 impl somatize_core::filter::Filter for MeanFilter {
507 fn config_hash(&self) -> CacheKey {
508 CacheKey::from_parts(&[b"Mean"])
509 }
510 fn fit(&self, x: &Value, _y: Option<&Value>) -> Result<Value> {
511 let (data, _) = x
512 .as_tensor()
513 .ok_or(SomaError::Other("need tensor".into()))?;
514 let mean = data.iter().sum::<f64>() / data.len() as f64;
515 Ok(Value::json(serde_json::json!({ "mean": mean })))
516 }
517 fn forward(&self, x: &Value, state: &Value) -> Result<Value> {
518 let (data, shape) = x
519 .as_tensor()
520 .ok_or(SomaError::Other("need tensor".into()))?;
521 let mean = state
522 .as_json()
523 .and_then(|j| j["mean"].as_f64())
524 .unwrap_or(0.0);
525 Ok(Value::tensor(
526 data.iter().map(|v| v - mean).collect(),
527 shape.to_vec(),
528 ))
529 }
530 fn meta(&self) -> FilterMeta {
531 FilterMeta {
532 name: "Mean".into(),
533 kind: FilterKind::Trainable,
534 cacheable: true,
535 differentiable: true,
536 stream_mode: StreamMode::FixedState,
537 distribution: somatize_core::filter::Distribution::Local,
538 input_schema: None,
539 output_schema: None,
540 }
541 }
542
543 fn as_any(&self) -> &dyn std::any::Any {
544 self
545 }
546 }
547
548 fn linear_graph(ids: &[&str]) -> Graph {
549 let mut g = Graph::new();
550 for &id in ids {
551 g.nodes.push(Node::new(id, id, id));
552 }
553 for (i, pair) in ids.windows(2).enumerate() {
554 g.edges.push(Edge::data(format!("e{i}"), pair[0], pair[1]));
555 }
556 g
557 }
558
559 #[test]
562 fn session_run_linear() {
563 let graph = linear_graph(&["double", "add"]);
564 let mut lib = FilterLibrary::new();
565 lib.register("double", Box::new(DoublerFilter));
566 lib.register("add", Box::new(AdderFilter(10.0)));
567
568 let mut session = GraphSession::new(graph, lib);
569
570 let cache = MemoryCache::default();
571 session = session.with_cache(Arc::new(cache));
572
573 let CompileResult { plan, .. } = session.compile(CompileMode::NoCache).unwrap();
575 let bus = Arc::new(EventBus::new(64));
576 let mut ctx =
577 Context::new(bus, "test").with_graph_info(GraphInfo::from_graph(session.graph()));
578 ctx.set("__input__", Value::tensor(vec![1.0, 2.0, 3.0], vec![3]));
579 executor::execute(&plan, &mut ctx, session.library(), &MemoryCache::default()).unwrap();
580
581 let outputs: HashMap<String, Value> = ctx
582 .store
583 .into_iter()
584 .filter_map(|(k, vv)| vv.as_value().cloned().map(|v| (k, v)))
585 .collect();
586
587 let result = outputs.get("add").unwrap();
588 let (data, _) = result.as_tensor().unwrap();
589 assert_eq!(data, &[12.0, 14.0, 16.0]);
590 }
591
592 #[test]
593 fn session_fit_and_forward() {
594 let graph = linear_graph(&["mean", "double"]);
595 let mut lib = FilterLibrary::new();
596 lib.register("mean", Box::new(MeanFilter));
597 lib.register("double", Box::new(DoublerFilter));
598
599 let mut session = GraphSession::new(graph, lib);
600
601 let x = Value::tensor(vec![10.0, 20.0, 30.0], vec![3]);
602 let outputs = session.fit(&x, None).unwrap();
603
604 let result = outputs.get("double").unwrap();
607 let (data, _) = result.as_tensor().unwrap();
608 assert_eq!(data, &[-20.0, 0.0, 20.0]);
609
610 assert!(session.is_fitted());
611 }
612
613 #[test]
614 fn session_compile_diagnostics() {
615 let graph = linear_graph(&["double"]);
616 let mut lib = FilterLibrary::new();
617 lib.register("double", Box::new(DoublerFilter));
618
619 let session = GraphSession::new(graph, lib);
620 let result = session.compile(CompileMode::NoCache).unwrap();
621 assert!(result.plan.node_count() > 0);
622 }
623
624 #[test]
627 fn graph_run_linear() {
628 let graph = linear_graph(&["double", "add"]);
629 let mut lib = FilterLibrary::new();
630 lib.register("double", Box::new(DoublerFilter));
631 lib.register("add", Box::new(AdderFilter(10.0)));
632
633 let cache = MemoryCache::default();
634
635 let outputs = {
636 let CompileResult { plan, .. } =
637 compile(&graph, &lib, CompileMode::NoCache, None).unwrap();
638 let bus = Arc::new(EventBus::new(64));
639 let mut ctx = Context::new(bus, "test").with_graph_info(GraphInfo::from_graph(&graph));
640 ctx.set("__input__", Value::tensor(vec![1.0, 2.0, 3.0], vec![3]));
641 executor::execute(&plan, &mut ctx, &lib, &cache).unwrap();
642 ctx.store
643 .into_iter()
644 .filter_map(|(k, vv)| vv.as_value().cloned().map(|v| (k, v)))
645 .collect::<HashMap<String, Value>>()
646 };
647
648 let result = outputs.get("add").unwrap();
649 let (data, _) = result.as_tensor().unwrap();
650 assert_eq!(data, &[12.0, 14.0, 16.0]);
651 }
652
653 #[test]
654 fn graph_run_diamond() {
655 let mut graph = Graph::new();
656 graph.nodes.push(Node::new("double", "Double", "double"));
657 graph.nodes.push(Node::new("add", "Add", "add"));
658 graph.nodes.push(Node::new("merge", "Merge", "merge"));
659 graph.edges.push(Edge::data("e1", "double", "merge"));
660 graph.edges.push(Edge::data("e2", "add", "merge"));
661
662 let mut lib = FilterLibrary::new();
663 lib.register("double", Box::new(DoublerFilter));
664 lib.register("add", Box::new(AdderFilter(100.0)));
665
666 struct MergeFilter;
667 impl somatize_core::filter::Filter for MergeFilter {
668 fn config_hash(&self) -> CacheKey {
669 CacheKey::from_parts(&[b"Merge"])
670 }
671 fn fit(&self, _x: &Value, _y: Option<&Value>) -> Result<Value> {
672 Ok(Value::Empty)
673 }
674 fn forward(&self, x: &Value, _state: &Value) -> Result<Value> {
675 Ok(x.clone())
676 }
677 fn meta(&self) -> FilterMeta {
678 FilterMeta {
679 name: "Merge".into(),
680 kind: FilterKind::Stateless,
681 cacheable: true,
682 differentiable: false,
683 stream_mode: StreamMode::FixedState,
684 distribution: somatize_core::filter::Distribution::Local,
685 input_schema: None,
686 output_schema: None,
687 }
688 }
689
690 fn as_any(&self) -> &dyn std::any::Any {
691 self
692 }
693 }
694 lib.register("merge", Box::new(MergeFilter));
695
696 let cache = MemoryCache::default();
697 let CompileResult { plan, .. } = compile(&graph, &lib, CompileMode::NoCache, None).unwrap();
698
699 let bus = Arc::new(EventBus::new(64));
700 let mut ctx = Context::new(bus, "test").with_graph_info(GraphInfo::from_graph(&graph));
701 ctx.set("__input__", Value::tensor(vec![5.0], vec![1]));
702 executor::execute(&plan, &mut ctx, &lib, &cache).unwrap();
703
704 let merge_output = ctx.get("merge").unwrap();
705 assert!(
706 merge_output.as_json().is_some(),
707 "merge should receive JSON from multiple predecessors"
708 );
709 }
710
711 #[test]
712 fn graph_fit_trainable() {
713 let graph = linear_graph(&["mean", "double"]);
714 let mut lib = FilterLibrary::new();
715 lib.register("mean", Box::new(MeanFilter));
716 lib.register("double", Box::new(DoublerFilter));
717
718 let cache = MemoryCache::default();
719 let x = Value::tensor(vec![10.0, 20.0, 30.0], vec![3]);
720
721 let outputs = graph_fit(&graph, &lib, &x, None, &cache).unwrap();
722
723 let result = outputs.get("double").unwrap();
724 let (data, _) = result.as_tensor().unwrap();
725 assert_eq!(data, &[-20.0, 0.0, 20.0]);
726
727 assert!(!cache.is_empty());
728 }
729
730 #[test]
731 fn filter_library_registry_compat() {
732 let mut lib = FilterLibrary::new();
733 lib.register("a", Box::new(DoublerFilter));
734
735 let registry: &dyn FilterRegistry = &lib;
736 assert!(registry.meta("a").is_some());
737 assert_eq!(registry.meta("a").unwrap().name, "Doubler");
738 assert!(registry.config_hash("a").is_some());
739 assert!(registry.meta("b").is_none());
740 }
741}