1use crate::plan::ExecutionPlan;
7use somatize_core::cache::{CacheKey, CacheStore};
8use somatize_core::error::Result;
9use somatize_core::filter::{Filter, FilterMeta};
10use somatize_core::graph::{Graph, NodeId};
11use std::collections::{HashMap, HashSet};
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum CompileMode {
16 Inference,
18 Differentiable,
20 NoCache,
22}
23
24#[derive(Debug, Clone)]
26pub struct Diagnostic {
27 pub node_id: NodeId,
28 pub level: DiagnosticLevel,
29 pub message: String,
30}
31
32#[derive(Debug, Clone, Copy, PartialEq, Eq)]
33pub enum DiagnosticLevel {
34 Warning,
35 Info,
36}
37
38pub struct CompileResult {
40 pub plan: ExecutionPlan,
41 pub diagnostics: Vec<Diagnostic>,
42}
43
44pub trait FilterRegistry: Send + Sync {
48 fn meta(&self, node_id: &str) -> Option<FilterMeta>;
49 fn config_hash(&self, node_id: &str) -> Option<CacheKey>;
50}
51
52pub struct SimpleFilterRegistry {
54 entries: HashMap<String, (FilterMeta, CacheKey)>,
55}
56
57impl SimpleFilterRegistry {
58 pub fn new() -> Self {
59 Self {
60 entries: HashMap::new(),
61 }
62 }
63
64 pub fn register(&mut self, node_id: impl Into<String>, filter: &dyn Filter) {
65 let id = node_id.into();
66 self.entries
67 .insert(id, (filter.meta(), filter.config_hash()));
68 }
69
70 pub fn register_meta(
71 &mut self,
72 node_id: impl Into<String>,
73 meta: FilterMeta,
74 config_hash: CacheKey,
75 ) {
76 self.entries.insert(node_id.into(), (meta, config_hash));
77 }
78}
79
80impl Default for SimpleFilterRegistry {
81 fn default() -> Self {
82 Self::new()
83 }
84}
85
86impl FilterRegistry for SimpleFilterRegistry {
87 fn meta(&self, node_id: &str) -> Option<FilterMeta> {
88 self.entries.get(node_id).map(|(m, _)| m.clone())
89 }
90
91 fn config_hash(&self, node_id: &str) -> Option<CacheKey> {
92 self.entries.get(node_id).map(|(_, h)| h.clone())
93 }
94}
95
96pub struct Compiler<'a> {
98 graph: &'a Graph,
99 registry: &'a dyn FilterRegistry,
100 mode: CompileMode,
101 diagnostics: Vec<Diagnostic>,
102}
103
104impl<'a> Compiler<'a> {
105 pub fn new(graph: &'a Graph, registry: &'a dyn FilterRegistry, mode: CompileMode) -> Self {
106 Self {
107 graph,
108 registry,
109 mode,
110 diagnostics: Vec::new(),
111 }
112 }
113
114 pub fn compile(mut self, cache: Option<&dyn CacheStore>) -> Result<CompileResult> {
116 self.graph.validate()?;
117
118 let sorted = self.graph.topological_sort()?;
119
120 if sorted.is_empty() {
121 return Ok(CompileResult {
122 plan: ExecutionPlan::Empty,
123 diagnostics: self.diagnostics,
124 });
125 }
126
127 self.check_gradient_flow(&sorted);
129
130 self.validate_schemas(&sorted);
132
133 let plan = self.build_plan(&sorted);
135
136 let plan = if let Some(cache) = cache {
138 self.resolve_cache(plan, cache, &sorted)?
139 } else {
140 plan
141 };
142
143 let plan = self.resolve_distribution(plan);
145
146 let plan = plan.simplify();
147
148 Ok(CompileResult {
149 plan,
150 diagnostics: self.diagnostics,
151 })
152 }
153
154 fn build_plan(&self, sorted: &[&str]) -> ExecutionPlan {
156 let levels = self.compute_levels(sorted);
158
159 let mut plan_steps: Vec<ExecutionPlan> = Vec::new();
160
161 for level in &levels {
162 if level.len() == 1 {
163 plan_steps.push(self.plan_for_node(level[0]));
164 } else {
165 let branches: Vec<ExecutionPlan> =
166 level.iter().map(|id| self.plan_for_node(id)).collect();
167 plan_steps.push(ExecutionPlan::Parallel(branches));
168 }
169 }
170
171 if plan_steps.len() == 1 {
172 plan_steps.into_iter().next().unwrap()
173 } else {
174 ExecutionPlan::Sequence(plan_steps)
175 }
176 }
177
178 fn plan_for_node(&self, node_id: &str) -> ExecutionPlan {
180 use somatize_core::graph::NodeKind;
181
182 let node = match self.graph.node(node_id) {
183 Some(n) => n,
184 None => {
185 return ExecutionPlan::Execute {
186 node_id: node_id.to_string(),
187 };
188 }
189 };
190
191 match &node.kind {
192 NodeKind::Filter { .. } => ExecutionPlan::Execute {
193 node_id: node_id.to_string(),
194 },
195
196 NodeKind::SubGraph { graph } => {
197 let inner_compiler = Compiler::new(graph, self.registry, self.mode);
199 match inner_compiler.compile(None) {
200 Ok(result) => result.plan,
201 Err(_) => ExecutionPlan::Execute {
202 node_id: node_id.to_string(),
203 },
204 }
205 }
206
207 NodeKind::Loop { max_iterations } => {
208 let successors = self.graph.successors(node_id);
211 let body = if successors.len() == 1 {
212 self.plan_for_node(successors[0])
213 } else if successors.len() > 1 {
214 let branches: Vec<ExecutionPlan> =
215 successors.iter().map(|id| self.plan_for_node(id)).collect();
216 ExecutionPlan::Parallel(branches)
217 } else {
218 ExecutionPlan::Empty
219 };
220 ExecutionPlan::Loop {
221 node_id: node_id.to_string(),
222 body: Box::new(body),
223 max_iterations: *max_iterations,
224 }
225 }
226
227 NodeKind::Branch => {
228 let arms: Vec<(String, ExecutionPlan)> = self
230 .graph
231 .edges
232 .iter()
233 .filter(|e| e.source == node_id)
234 .map(|e| {
235 let label = e.label.clone().unwrap_or_else(|| e.target.clone());
236 let plan = self.plan_for_node(&e.target);
237 (label, plan)
238 })
239 .collect();
240 ExecutionPlan::Branch {
241 node_id: node_id.to_string(),
242 arms,
243 }
244 }
245
246 _ => ExecutionPlan::Execute {
247 node_id: node_id.to_string(),
248 },
249 }
250 }
251
252 fn compute_levels<'b>(&self, sorted: &[&'b str]) -> Vec<Vec<&'b str>> {
255 let mut node_level: HashMap<&str, usize> = HashMap::new();
256 let mut max_level: usize = 0;
257
258 for &node in sorted {
259 let preds = self.graph.predecessors(node);
260 let level = if preds.is_empty() {
261 0
262 } else {
263 preds
264 .iter()
265 .map(|p| node_level.get(p).copied().unwrap_or(0) + 1)
266 .max()
267 .unwrap_or(0)
268 };
269 node_level.insert(node, level);
270 if level > max_level {
271 max_level = level;
272 }
273 }
274
275 let mut levels: Vec<Vec<&str>> = vec![Vec::new(); max_level + 1];
276 for &node in sorted {
277 let level = node_level[node];
278 levels[level].push(node);
279 }
280
281 levels.retain(|l| !l.is_empty());
283 levels
284 }
285
286 fn resolve_cache(
289 &self,
290 plan: ExecutionPlan,
291 cache: &dyn CacheStore,
292 sorted: &[&str],
293 ) -> Result<ExecutionPlan> {
294 if self.mode == CompileMode::NoCache {
295 return Ok(plan);
296 }
297
298 let mut node_keys: HashMap<String, CacheKey> = HashMap::new();
301 let mut cached_nodes: HashSet<String> = HashSet::new();
302
303 for &node_id in sorted {
304 let config_hash = match self.registry.config_hash(node_id) {
305 Some(h) => h,
306 None => continue, };
308
309 let meta = self.registry.meta(node_id);
310 let cacheable = meta.as_ref().is_some_and(|m| m.cacheable);
311
312 let can_cache = cacheable && self.mode == CompileMode::Inference;
315
316 let pred_ids = self.graph.predecessors(node_id);
318 let mut key_parts: Vec<Vec<u8>> = vec![config_hash.0.to_vec()];
319 for pred in &pred_ids {
320 if let Some(pred_key) = node_keys.get(*pred) {
321 key_parts.push(pred_key.0.to_vec());
322 } else {
323 debug_assert!(
326 false,
327 "predecessor `{pred}` of `{node_id}` not in node_keys - \
328 topological order may be broken"
329 );
330 }
331 }
332 let parts_refs: Vec<&[u8]> = key_parts.iter().map(|p| p.as_slice()).collect();
333 let key = CacheKey::from_parts(&parts_refs);
334 node_keys.insert(node_id.to_string(), key.clone());
335
336 if can_cache {
338 if cache.exists(&key)? {
342 cached_nodes.insert(node_id.to_string());
343 }
344 }
345 }
346
347 Ok(self.apply_cache_to_plan(plan, &cached_nodes, &node_keys))
349 }
350
351 fn apply_cache_to_plan(
352 &self,
353 plan: ExecutionPlan,
354 cached: &HashSet<String>,
355 keys: &HashMap<String, CacheKey>,
356 ) -> ExecutionPlan {
357 match plan {
358 ExecutionPlan::Execute { ref node_id } => {
359 if cached.contains(node_id)
360 && let Some(key) = keys.get(node_id)
361 {
362 return ExecutionPlan::Cached {
363 node_id: node_id.clone(),
364 key: key.clone(),
365 };
366 }
367 plan
368 }
369 ExecutionPlan::Sequence(steps) => ExecutionPlan::Sequence(
370 steps
371 .into_iter()
372 .map(|s| self.apply_cache_to_plan(s, cached, keys))
373 .collect(),
374 ),
375 ExecutionPlan::Parallel(branches) => ExecutionPlan::Parallel(
376 branches
377 .into_iter()
378 .map(|b| self.apply_cache_to_plan(b, cached, keys))
379 .collect(),
380 ),
381 other => other,
382 }
383 }
384
385 fn resolve_distribution(&self, plan: ExecutionPlan) -> ExecutionPlan {
387 match plan {
388 ExecutionPlan::Execute { ref node_id } => {
389 if let Some(meta) = self.registry.meta(node_id) {
390 match &meta.distribution {
391 somatize_core::filter::Distribution::Remote(target) => {
392 ExecutionPlan::Remote {
393 node_id: node_id.clone(),
394 target: target.clone(),
395 plan: Box::new(plan),
396 }
397 }
398 _ => plan,
399 }
400 } else {
401 plan
402 }
403 }
404 ExecutionPlan::Sequence(steps) => ExecutionPlan::Sequence(
405 steps
406 .into_iter()
407 .map(|s| self.resolve_distribution(s))
408 .collect(),
409 ),
410 ExecutionPlan::Parallel(branches) => ExecutionPlan::Parallel(
411 branches
412 .into_iter()
413 .map(|b| self.resolve_distribution(b))
414 .collect(),
415 ),
416 other => other,
417 }
418 }
419
420 fn validate_schemas(&mut self, sorted: &[&str]) {
426 for &node_id in sorted {
427 let input_schema = self
428 .registry
429 .meta(node_id)
430 .and_then(|m| m.input_schema.clone());
431
432 let Some(expected_input) = input_schema else {
434 continue;
435 };
436
437 for pred_id in self.graph.predecessors(node_id) {
439 let pred_output = self
440 .registry
441 .meta(pred_id)
442 .and_then(|m| m.output_schema.clone());
443
444 let Some(actual_output) = pred_output else {
445 continue; };
447
448 if !actual_output.is_compatible_with(&expected_input) {
449 self.diagnostics.push(Diagnostic {
450 node_id: node_id.to_string(),
451 level: DiagnosticLevel::Warning,
452 message: format!(
453 "schema mismatch: `{pred_id}` outputs {actual_output} \
454 but `{node_id}` expects {expected_input}",
455 ),
456 });
457 }
458 }
459 }
460 }
461
462 fn check_gradient_flow(&mut self, sorted: &[&str]) {
468 let mut gradient_flows = true;
469
470 for &node_id in sorted {
471 if let Some(meta) = self.registry.meta(node_id) {
472 if gradient_flows && !meta.differentiable {
473 self.diagnostics.push(Diagnostic {
474 node_id: node_id.to_string(),
475 level: DiagnosticLevel::Warning,
476 message: format!(
477 "gradient flow interrupted at `{}` ({:?}). \
478 Gradients from upstream will not reach downstream filters \
479 through this node.",
480 node_id, meta.kind,
481 ),
482 });
483 gradient_flows = false;
484 } else if !gradient_flows && meta.differentiable {
485 gradient_flows = true;
488 }
489 }
490 }
491 }
492}
493
494pub fn compile(
496 graph: &Graph,
497 registry: &dyn FilterRegistry,
498 mode: CompileMode,
499 cache: Option<&dyn CacheStore>,
500) -> Result<CompileResult> {
501 Compiler::new(graph, registry, mode).compile(cache)
502}
503
504#[cfg(test)]
505mod tests {
506 use super::*;
507 use somatize_core::cache::EntryMeta;
508 use somatize_core::error::SomaError;
509 use somatize_core::filter::{FilterKind, StreamMode};
510 use somatize_core::graph::{Edge, Graph, Node, linear_pipeline};
511 use somatize_core::value::Value;
512 use std::sync::Mutex;
513
514 struct MockCacheStore {
517 entries: Mutex<HashSet<CacheKey>>,
518 }
519
520 impl MockCacheStore {
521 fn new() -> Self {
522 Self {
523 entries: Mutex::new(HashSet::new()),
524 }
525 }
526
527 fn insert(&self, key: CacheKey) {
528 self.entries.lock().unwrap().insert(key);
529 }
530 }
531
532 impl CacheStore for MockCacheStore {
533 fn get(&self, _key: &CacheKey) -> Result<Option<Value>> {
534 Ok(None)
535 }
536 fn put(&self, _key: &CacheKey, _value: &Value) -> Result<()> {
537 Ok(())
538 }
539 fn exists(&self, key: &CacheKey) -> Result<bool> {
540 Ok(self.entries.lock().unwrap().contains(key))
541 }
542 fn remove(&self, _key: &CacheKey) -> Result<()> {
543 Ok(())
544 }
545 fn metadata(&self, _key: &CacheKey) -> Result<Option<EntryMeta>> {
546 Ok(None)
547 }
548 }
549
550 fn make_meta(kind: FilterKind, differentiable: bool) -> FilterMeta {
553 FilterMeta {
554 name: "test".into(),
555 kind,
556 cacheable: true,
557 differentiable,
558 stream_mode: StreamMode::FixedState,
559 distribution: somatize_core::filter::Distribution::Local,
560 input_schema: None,
561 output_schema: None,
562 }
563 }
564
565 fn register_nodes(registry: &mut SimpleFilterRegistry, ids: &[&str], meta: FilterMeta) {
566 for (i, id) in ids.iter().enumerate() {
567 let hash = CacheKey::from_parts(&[id.as_bytes(), &[i as u8]]);
568 registry.register_meta(*id, meta.clone(), hash);
569 }
570 }
571
572 #[test]
575 fn compile_empty_graph() {
576 let graph = Graph::new();
577 let registry = SimpleFilterRegistry::new();
578 let result = compile(&graph, ®istry, CompileMode::Inference, None).unwrap();
579 assert!(matches!(result.plan, ExecutionPlan::Empty));
580 }
581
582 #[test]
583 fn compile_single_node() {
584 let mut graph = Graph::new();
585 graph.add_node(Node::new("a", "A", "F"));
586 let mut registry = SimpleFilterRegistry::new();
587 register_nodes(
588 &mut registry,
589 &["a"],
590 make_meta(FilterKind::Trainable, true),
591 );
592
593 let result = compile(&graph, ®istry, CompileMode::Inference, None).unwrap();
594 assert!(matches!(result.plan, ExecutionPlan::Execute { .. }));
595 }
596
597 #[test]
598 fn compile_linear_pipeline_produces_sequence() {
599 let graph = linear_pipeline(vec![
600 Node::new("a", "Scaler", "F"),
601 Node::new("b", "PCA", "F"),
602 Node::new("c", "SVM", "F"),
603 ]);
604 let mut registry = SimpleFilterRegistry::new();
605 register_nodes(
606 &mut registry,
607 &["a", "b", "c"],
608 make_meta(FilterKind::Trainable, true),
609 );
610
611 let result = compile(&graph, ®istry, CompileMode::Inference, None).unwrap();
612
613 if let ExecutionPlan::Sequence(steps) = &result.plan {
614 assert_eq!(steps.len(), 3);
615 assert!(
616 steps
617 .iter()
618 .all(|s| matches!(s, ExecutionPlan::Execute { .. }))
619 );
620 } else {
621 panic!("expected Sequence, got: {:?}", result.plan);
622 }
623 }
624
625 #[test]
626 fn compile_diamond_detects_parallelism() {
627 let mut graph = Graph::new();
628 graph.add_node(Node::new("root", "Root", "F"));
629 graph.add_node(Node::new("b1", "B1", "F"));
630 graph.add_node(Node::new("b2", "B2", "F"));
631 graph.add_node(Node::new("merge", "Merge", "F"));
632 graph.add_edge(Edge::data("e1", "root", "b1"));
633 graph.add_edge(Edge::data("e2", "root", "b2"));
634 graph.add_edge(Edge::data("e3", "b1", "merge"));
635 graph.add_edge(Edge::data("e4", "b2", "merge"));
636
637 let mut registry = SimpleFilterRegistry::new();
638 register_nodes(
639 &mut registry,
640 &["root", "b1", "b2", "merge"],
641 make_meta(FilterKind::Trainable, true),
642 );
643
644 let result = compile(&graph, ®istry, CompileMode::Inference, None).unwrap();
645
646 if let ExecutionPlan::Sequence(steps) = &result.plan {
648 assert_eq!(steps.len(), 3);
649 assert!(matches!(&steps[0], ExecutionPlan::Execute { node_id } if node_id == "root"));
650 assert!(matches!(&steps[1], ExecutionPlan::Parallel(branches) if branches.len() == 2));
651 assert!(matches!(&steps[2], ExecutionPlan::Execute { node_id } if node_id == "merge"));
652 } else {
653 panic!("expected Sequence, got: {:?}", result.plan);
654 }
655 }
656
657 #[test]
658 fn compile_independent_roots_parallel() {
659 let mut graph = Graph::new();
660 graph.add_node(Node::new("a", "A", "F"));
661 graph.add_node(Node::new("b", "B", "F"));
662 let mut registry = SimpleFilterRegistry::new();
665 register_nodes(
666 &mut registry,
667 &["a", "b"],
668 make_meta(FilterKind::Trainable, true),
669 );
670
671 let result = compile(&graph, ®istry, CompileMode::Inference, None).unwrap();
672
673 assert!(matches!(result.plan, ExecutionPlan::Parallel(_)));
675 }
676
677 #[test]
678 fn cache_resolution_replaces_cached_nodes() {
679 let graph = linear_pipeline(vec![
680 Node::new("a", "Scaler", "F"),
681 Node::new("b", "PCA", "F"),
682 Node::new("c", "SVM", "F"),
683 ]);
684
685 let mut registry = SimpleFilterRegistry::new();
686 register_nodes(
687 &mut registry,
688 &["a", "b", "c"],
689 make_meta(FilterKind::Trainable, true),
690 );
691
692 let a_config = registry.config_hash("a").unwrap();
694 let a_cache_key = CacheKey::from_parts(&[&a_config.0]);
695
696 let cache = MockCacheStore::new();
697 cache.insert(a_cache_key);
698
699 let result = compile(&graph, ®istry, CompileMode::Inference, Some(&cache)).unwrap();
700
701 if let ExecutionPlan::Sequence(steps) = &result.plan {
702 assert!(
703 matches!(&steps[0], ExecutionPlan::Cached { node_id, .. } if node_id == "a"),
704 "first node should be cached, got: {:?}",
705 steps[0]
706 );
707 assert!(matches!(&steps[1], ExecutionPlan::Execute { .. }));
708 assert!(matches!(&steps[2], ExecutionPlan::Execute { .. }));
709 } else {
710 panic!("expected Sequence, got: {:?}", result.plan);
711 }
712 }
713
714 #[test]
715 fn cascade_invalidation_different_config_changes_keys() {
716 let mut reg1 = SimpleFilterRegistry::new();
718 reg1.register_meta(
719 "a",
720 make_meta(FilterKind::Trainable, true),
721 CacheKey::hash_data(b"scaler_v1"),
722 );
723 reg1.register_meta(
724 "b",
725 make_meta(FilterKind::Trainable, true),
726 CacheKey::hash_data(b"pca_v1"),
727 );
728
729 let mut reg2 = SimpleFilterRegistry::new();
731 reg2.register_meta(
732 "a",
733 make_meta(FilterKind::Trainable, true),
734 CacheKey::hash_data(b"scaler_v2"), );
736 reg2.register_meta(
737 "b",
738 make_meta(FilterKind::Trainable, true),
739 CacheKey::hash_data(b"pca_v1"), );
741
742 let a_key_v1 = CacheKey::from_parts(&[&CacheKey::hash_data(b"scaler_v1").0]);
747 let b_key_v1 = CacheKey::from_parts(&[&CacheKey::hash_data(b"pca_v1").0, &a_key_v1.0]);
748
749 let a_key_v2 = CacheKey::from_parts(&[&CacheKey::hash_data(b"scaler_v2").0]);
750 let b_key_v2 = CacheKey::from_parts(&[&CacheKey::hash_data(b"pca_v1").0, &a_key_v2.0]);
751
752 assert_ne!(a_key_v1, a_key_v2);
754 assert_ne!(b_key_v1, b_key_v2);
756 }
757
758 #[test]
759 fn no_cache_mode_skips_all_caching() {
760 let graph = linear_pipeline(vec![Node::new("a", "A", "F"), Node::new("b", "B", "F")]);
761
762 let mut registry = SimpleFilterRegistry::new();
763 register_nodes(
764 &mut registry,
765 &["a", "b"],
766 make_meta(FilterKind::Trainable, true),
767 );
768
769 let a_config = registry.config_hash("a").unwrap();
771 let a_key = CacheKey::from_parts(&[&a_config.0]);
772 let cache = MockCacheStore::new();
773 cache.insert(a_key);
774
775 let result = compile(&graph, ®istry, CompileMode::NoCache, Some(&cache)).unwrap();
776
777 assert_eq!(result.plan.cached_count(), 0);
779 }
780
781 #[test]
782 fn differentiable_mode_skips_output_caching() {
783 let graph = linear_pipeline(vec![Node::new("a", "A", "F"), Node::new("b", "B", "F")]);
784
785 let mut registry = SimpleFilterRegistry::new();
786 register_nodes(
787 &mut registry,
788 &["a", "b"],
789 make_meta(FilterKind::Trainable, true),
790 );
791
792 let a_config = registry.config_hash("a").unwrap();
793 let a_key = CacheKey::from_parts(&[&a_config.0]);
794 let cache = MockCacheStore::new();
795 cache.insert(a_key);
796
797 let result = compile(&graph, ®istry, CompileMode::Differentiable, Some(&cache)).unwrap();
798
799 assert_eq!(result.plan.cached_count(), 0);
801 }
802
803 #[test]
804 fn gradient_flow_diagnostic_on_opaque() {
805 let graph = linear_pipeline(vec![
806 Node::new("scaler", "Scaler", "F"),
807 Node::new("tree", "DecisionTree", "F"),
808 Node::new("linear", "Linear", "F"),
809 ]);
810
811 let mut registry = SimpleFilterRegistry::new();
812 registry.register_meta(
813 "scaler",
814 make_meta(FilterKind::Trainable, true),
815 CacheKey::hash_data(b"s"),
816 );
817 registry.register_meta(
818 "tree",
819 make_meta(FilterKind::Opaque, false), CacheKey::hash_data(b"t"),
821 );
822 registry.register_meta(
823 "linear",
824 make_meta(FilterKind::Trainable, true),
825 CacheKey::hash_data(b"l"),
826 );
827
828 let result = compile(&graph, ®istry, CompileMode::Inference, None).unwrap();
829
830 assert_eq!(result.diagnostics.len(), 1);
831 assert_eq!(result.diagnostics[0].node_id, "tree");
832 assert_eq!(result.diagnostics[0].level, DiagnosticLevel::Warning);
833 assert!(
834 result.diagnostics[0]
835 .message
836 .contains("gradient flow interrupted")
837 );
838 }
839
840 #[test]
841 fn no_diagnostic_when_all_differentiable() {
842 let graph = linear_pipeline(vec![Node::new("a", "A", "F"), Node::new("b", "B", "F")]);
843
844 let mut registry = SimpleFilterRegistry::new();
845 register_nodes(
846 &mut registry,
847 &["a", "b"],
848 make_meta(FilterKind::Trainable, true),
849 );
850
851 let result = compile(&graph, ®istry, CompileMode::Inference, None).unwrap();
852 assert!(result.diagnostics.is_empty());
853 }
854
855 #[test]
856 fn compile_rejects_cycle() {
857 let mut graph = Graph::new();
858 graph.add_node(Node::new("a", "A", "F"));
859 graph.add_node(Node::new("b", "B", "F"));
860 graph.add_edge(Edge::data("e1", "a", "b"));
861 graph.add_edge(Edge::data("e2", "b", "a"));
862
863 let registry = SimpleFilterRegistry::new();
864 let result = compile(&graph, ®istry, CompileMode::Inference, None);
865 assert!(matches!(result, Err(SomaError::CycleDetected)));
866 }
867
868 #[test]
869 fn plan_summary_is_accurate() {
870 let mut graph = Graph::new();
871 graph.add_node(Node::new("root", "Root", "F"));
872 graph.add_node(Node::new("b1", "B1", "F"));
873 graph.add_node(Node::new("b2", "B2", "F"));
874 graph.add_node(Node::new("end", "End", "F"));
875 graph.add_edge(Edge::data("e1", "root", "b1"));
876 graph.add_edge(Edge::data("e2", "root", "b2"));
877 graph.add_edge(Edge::data("e3", "b1", "end"));
878 graph.add_edge(Edge::data("e4", "b2", "end"));
879
880 let mut registry = SimpleFilterRegistry::new();
881 register_nodes(
882 &mut registry,
883 &["root", "b1", "b2", "end"],
884 make_meta(FilterKind::Trainable, true),
885 );
886
887 let result = compile(&graph, ®istry, CompileMode::Inference, None).unwrap();
888 let summary = result.plan.summary();
889 assert_eq!(summary.total_nodes, 4);
890 assert_eq!(summary.parallel_branches, 2);
891 }
892
893 #[test]
894 fn distribution_wraps_remote_nodes() {
895 let graph = linear_pipeline(vec![
896 Node::new("preprocess", "Preprocess", "F"),
897 Node::new("gpu_train", "GpuTrain", "F"),
898 Node::new("evaluate", "Evaluate", "F"),
899 ]);
900
901 let mut registry = SimpleFilterRegistry::new();
902 registry.register_meta(
904 "preprocess",
905 make_meta(FilterKind::Trainable, true),
906 CacheKey::hash_data(b"pre"),
907 );
908 let mut gpu_meta = make_meta(FilterKind::Trainable, true);
910 gpu_meta.distribution = somatize_core::filter::Distribution::Remote(
911 somatize_core::filter::RemoteTarget::Tag("gpu".into()),
912 );
913 registry.register_meta("gpu_train", gpu_meta, CacheKey::hash_data(b"gpu"));
914 registry.register_meta(
916 "evaluate",
917 make_meta(FilterKind::Trainable, true),
918 CacheKey::hash_data(b"eval"),
919 );
920
921 let result = compile(&graph, ®istry, CompileMode::Inference, None).unwrap();
922
923 if let ExecutionPlan::Sequence(steps) = &result.plan {
925 assert_eq!(steps.len(), 3);
926 assert!(
927 matches!(&steps[0], ExecutionPlan::Execute { node_id } if node_id == "preprocess")
928 );
929 assert!(
930 matches!(&steps[1], ExecutionPlan::Remote { node_id, target, .. }
931 if node_id == "gpu_train"
932 && *target == somatize_core::filter::RemoteTarget::Tag("gpu".into())
933 ),
934 "expected Remote, got: {:?}",
935 steps[1]
936 );
937 assert!(
938 matches!(&steps[2], ExecutionPlan::Execute { node_id } if node_id == "evaluate")
939 );
940 } else {
941 panic!("expected Sequence, got: {:?}", result.plan);
942 }
943 }
944
945 #[test]
946 fn local_distribution_not_wrapped() {
947 let graph = linear_pipeline(vec![Node::new("a", "A", "F"), Node::new("b", "B", "F")]);
948
949 let mut registry = SimpleFilterRegistry::new();
950 register_nodes(
951 &mut registry,
952 &["a", "b"],
953 make_meta(FilterKind::Trainable, true),
954 );
955
956 let result = compile(&graph, ®istry, CompileMode::Inference, None).unwrap();
957
958 let ids = result.plan.node_ids();
960 assert_eq!(ids.len(), 2);
961 if let ExecutionPlan::Sequence(steps) = &result.plan {
963 assert!(
964 steps
965 .iter()
966 .all(|s| matches!(s, ExecutionPlan::Execute { .. }))
967 );
968 }
969 }
970}