1use crate::plan::ExecutionPlan;
7use somatize_core::cache::{CacheKey, CacheStore};
8use somatize_core::error::Result;
9use somatize_core::filter::{Filter, FilterMeta};
10use somatize_core::graph::{Graph, NodeId};
11use std::collections::{HashMap, HashSet};
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum CompileMode {
16 Inference,
18 Differentiable,
20 NoCache,
22}
23
24#[derive(Debug, Clone)]
26pub struct Diagnostic {
27 pub node_id: NodeId,
28 pub level: DiagnosticLevel,
29 pub message: String,
30}
31
32#[derive(Debug, Clone, Copy, PartialEq, Eq)]
33pub enum DiagnosticLevel {
34 Warning,
35 Info,
36}
37
38pub struct CompileResult {
40 pub plan: ExecutionPlan,
41 pub diagnostics: Vec<Diagnostic>,
42}
43
44pub trait FilterRegistry: Send + Sync {
48 fn meta(&self, node_id: &str) -> Option<FilterMeta>;
49 fn config_hash(&self, node_id: &str) -> Option<CacheKey>;
50}
51
52pub struct SimpleFilterRegistry {
54 entries: HashMap<String, (FilterMeta, CacheKey)>,
55}
56
57impl SimpleFilterRegistry {
58 pub fn new() -> Self {
59 Self {
60 entries: HashMap::new(),
61 }
62 }
63
64 pub fn register(&mut self, node_id: impl Into<String>, filter: &dyn Filter) {
65 let id = node_id.into();
66 self.entries
67 .insert(id, (filter.meta(), filter.config_hash()));
68 }
69
70 pub fn register_meta(
71 &mut self,
72 node_id: impl Into<String>,
73 meta: FilterMeta,
74 config_hash: CacheKey,
75 ) {
76 self.entries.insert(node_id.into(), (meta, config_hash));
77 }
78}
79
80impl Default for SimpleFilterRegistry {
81 fn default() -> Self {
82 Self::new()
83 }
84}
85
86impl FilterRegistry for SimpleFilterRegistry {
87 fn meta(&self, node_id: &str) -> Option<FilterMeta> {
88 self.entries.get(node_id).map(|(m, _)| m.clone())
89 }
90
91 fn config_hash(&self, node_id: &str) -> Option<CacheKey> {
92 self.entries.get(node_id).map(|(_, h)| h.clone())
93 }
94}
95
96pub struct Compiler<'a> {
98 graph: &'a Graph,
99 registry: &'a dyn FilterRegistry,
100 mode: CompileMode,
101 diagnostics: Vec<Diagnostic>,
102}
103
104impl<'a> Compiler<'a> {
105 pub fn new(graph: &'a Graph, registry: &'a dyn FilterRegistry, mode: CompileMode) -> Self {
106 Self {
107 graph,
108 registry,
109 mode,
110 diagnostics: Vec::new(),
111 }
112 }
113
114 pub fn compile(mut self, cache: Option<&dyn CacheStore>) -> Result<CompileResult> {
116 self.graph.validate()?;
117
118 let sorted = self.graph.topological_sort()?;
119
120 if sorted.is_empty() {
121 return Ok(CompileResult {
122 plan: ExecutionPlan::Empty,
123 diagnostics: self.diagnostics,
124 });
125 }
126
127 self.check_gradient_flow(&sorted);
129
130 self.validate_schemas(&sorted);
132
133 let plan = self.build_plan(&sorted);
135
136 let plan = if let Some(cache) = cache {
138 self.resolve_cache(plan, cache, &sorted)?
139 } else {
140 plan
141 };
142
143 let plan = self.resolve_distribution(plan);
145
146 let plan = self.collapse_differentiable(plan);
148
149 let plan = plan.simplify();
150
151 Ok(CompileResult {
152 plan,
153 diagnostics: self.diagnostics,
154 })
155 }
156
157 fn build_plan(&self, sorted: &[&str]) -> ExecutionPlan {
159 let levels = self.compute_levels(sorted);
161
162 let mut plan_steps: Vec<ExecutionPlan> = Vec::new();
163
164 for level in &levels {
165 if level.len() == 1 {
166 plan_steps.push(self.plan_for_node(level[0]));
167 } else {
168 let branches: Vec<ExecutionPlan> =
169 level.iter().map(|id| self.plan_for_node(id)).collect();
170 plan_steps.push(ExecutionPlan::Parallel(branches));
171 }
172 }
173
174 if plan_steps.len() == 1 {
175 plan_steps.into_iter().next().unwrap()
176 } else {
177 ExecutionPlan::Sequence(plan_steps)
178 }
179 }
180
181 fn plan_for_node(&self, node_id: &str) -> ExecutionPlan {
183 use somatize_core::graph::NodeKind;
184
185 let node = match self.graph.node(node_id) {
186 Some(n) => n,
187 None => {
188 return ExecutionPlan::Execute {
189 node_id: node_id.to_string(),
190 };
191 }
192 };
193
194 match &node.kind {
195 NodeKind::Filter { .. } => ExecutionPlan::Execute {
196 node_id: node_id.to_string(),
197 },
198
199 NodeKind::SubGraph { graph } => {
200 let inner_compiler = Compiler::new(graph, self.registry, self.mode);
202 match inner_compiler.compile(None) {
203 Ok(result) => result.plan,
204 Err(_) => ExecutionPlan::Execute {
205 node_id: node_id.to_string(),
206 },
207 }
208 }
209
210 NodeKind::Loop { max_iterations } => {
211 let successors = self.graph.successors(node_id);
214 let body = if successors.len() == 1 {
215 self.plan_for_node(successors[0])
216 } else if successors.len() > 1 {
217 let branches: Vec<ExecutionPlan> =
218 successors.iter().map(|id| self.plan_for_node(id)).collect();
219 ExecutionPlan::Parallel(branches)
220 } else {
221 ExecutionPlan::Empty
222 };
223 ExecutionPlan::Loop {
224 node_id: node_id.to_string(),
225 body: Box::new(body),
226 max_iterations: *max_iterations,
227 }
228 }
229
230 NodeKind::Branch => {
231 let arms: Vec<(String, ExecutionPlan)> = self
233 .graph
234 .edges
235 .iter()
236 .filter(|e| e.source == node_id)
237 .map(|e| {
238 let label = e.label.clone().unwrap_or_else(|| e.target.clone());
239 let plan = self.plan_for_node(&e.target);
240 (label, plan)
241 })
242 .collect();
243 ExecutionPlan::Branch {
244 node_id: node_id.to_string(),
245 arms,
246 }
247 }
248
249 _ => ExecutionPlan::Execute {
250 node_id: node_id.to_string(),
251 },
252 }
253 }
254
255 fn compute_levels<'b>(&self, sorted: &[&'b str]) -> Vec<Vec<&'b str>> {
258 let mut node_level: HashMap<&str, usize> = HashMap::new();
259 let mut max_level: usize = 0;
260
261 for &node in sorted {
262 let preds = self.graph.predecessors(node);
263 let level = if preds.is_empty() {
264 0
265 } else {
266 preds
267 .iter()
268 .map(|p| node_level.get(p).copied().unwrap_or(0) + 1)
269 .max()
270 .unwrap_or(0)
271 };
272 node_level.insert(node, level);
273 if level > max_level {
274 max_level = level;
275 }
276 }
277
278 let mut levels: Vec<Vec<&str>> = vec![Vec::new(); max_level + 1];
279 for &node in sorted {
280 let level = node_level[node];
281 levels[level].push(node);
282 }
283
284 levels.retain(|l| !l.is_empty());
286 levels
287 }
288
289 fn resolve_cache(
292 &self,
293 plan: ExecutionPlan,
294 cache: &dyn CacheStore,
295 sorted: &[&str],
296 ) -> Result<ExecutionPlan> {
297 if self.mode == CompileMode::NoCache {
298 return Ok(plan);
299 }
300
301 let mut node_keys: HashMap<String, CacheKey> = HashMap::new();
304 let mut cached_nodes: HashSet<String> = HashSet::new();
305
306 for &node_id in sorted {
307 let config_hash = match self.registry.config_hash(node_id) {
308 Some(h) => h,
309 None => continue, };
311
312 let meta = self.registry.meta(node_id);
313 let cacheable = meta.as_ref().is_some_and(|m| m.cacheable);
314
315 let can_cache = cacheable && self.mode == CompileMode::Inference;
318
319 let pred_ids = self.graph.predecessors(node_id);
321 let mut key_parts: Vec<Vec<u8>> = vec![config_hash.0.to_vec()];
322 for pred in &pred_ids {
323 if let Some(pred_key) = node_keys.get(*pred) {
324 key_parts.push(pred_key.0.to_vec());
325 } else {
326 debug_assert!(
329 false,
330 "predecessor `{pred}` of `{node_id}` not in node_keys - \
331 topological order may be broken"
332 );
333 }
334 }
335 let parts_refs: Vec<&[u8]> = key_parts.iter().map(|p| p.as_slice()).collect();
336 let key = CacheKey::from_parts(&parts_refs);
337 node_keys.insert(node_id.to_string(), key.clone());
338
339 if can_cache {
341 if cache.exists(&key)? {
345 cached_nodes.insert(node_id.to_string());
346 }
347 }
348 }
349
350 Ok(self.apply_cache_to_plan(plan, &cached_nodes, &node_keys))
352 }
353
354 fn apply_cache_to_plan(
355 &self,
356 plan: ExecutionPlan,
357 cached: &HashSet<String>,
358 keys: &HashMap<String, CacheKey>,
359 ) -> ExecutionPlan {
360 match plan {
361 ExecutionPlan::Execute { ref node_id } => {
362 if cached.contains(node_id)
363 && let Some(key) = keys.get(node_id)
364 {
365 return ExecutionPlan::Cached {
366 node_id: node_id.clone(),
367 key: key.clone(),
368 };
369 }
370 plan
371 }
372 ExecutionPlan::Sequence(steps) => ExecutionPlan::Sequence(
373 steps
374 .into_iter()
375 .map(|s| self.apply_cache_to_plan(s, cached, keys))
376 .collect(),
377 ),
378 ExecutionPlan::Parallel(branches) => ExecutionPlan::Parallel(
379 branches
380 .into_iter()
381 .map(|b| self.apply_cache_to_plan(b, cached, keys))
382 .collect(),
383 ),
384 other => other,
385 }
386 }
387
388 fn resolve_distribution(&self, plan: ExecutionPlan) -> ExecutionPlan {
390 match plan {
391 ExecutionPlan::Execute { ref node_id } => {
392 if let Some(meta) = self.registry.meta(node_id) {
393 match &meta.distribution {
394 somatize_core::filter::Distribution::Remote(target) => {
395 ExecutionPlan::Remote {
396 node_id: node_id.clone(),
397 target: target.clone(),
398 plan: Box::new(plan),
399 }
400 }
401 _ => plan,
402 }
403 } else {
404 plan
405 }
406 }
407 ExecutionPlan::Sequence(steps) => ExecutionPlan::Sequence(
408 steps
409 .into_iter()
410 .map(|s| self.resolve_distribution(s))
411 .collect(),
412 ),
413 ExecutionPlan::Parallel(branches) => ExecutionPlan::Parallel(
414 branches
415 .into_iter()
416 .map(|b| self.resolve_distribution(b))
417 .collect(),
418 ),
419 ExecutionPlan::Composite { ref node_ids } => {
420 let targets: Vec<_> = node_ids
424 .iter()
425 .filter_map(|nid| {
426 self.registry.meta(nid).and_then(|m| match &m.distribution {
427 somatize_core::filter::Distribution::Remote(t) => Some(t.clone()),
428 _ => None,
429 })
430 })
431 .collect();
432
433 if targets.len() == node_ids.len() && !targets.is_empty() {
434 let first_id = node_ids[0].clone();
435 ExecutionPlan::Remote {
436 node_id: first_id,
437 target: targets.into_iter().next().unwrap(),
438 plan: Box::new(plan),
439 }
440 } else {
441 plan
442 }
443 }
444 other => other,
445 }
446 }
447
448 fn collapse_differentiable(&self, plan: ExecutionPlan) -> ExecutionPlan {
453 match plan {
454 ExecutionPlan::Sequence(steps) => {
455 let mut result: Vec<ExecutionPlan> = Vec::new();
456 let mut diff_group: Vec<String> = Vec::new();
457
458 for step in steps {
459 if let ExecutionPlan::Execute { ref node_id } = step
460 && self
461 .registry
462 .meta(node_id)
463 .map(|m| m.differentiable)
464 .unwrap_or(false)
465 {
466 diff_group.push(node_id.clone());
467 continue;
468 }
469 Self::flush_diff_group(&mut diff_group, &mut result);
471 result.push(self.collapse_differentiable(step));
472 }
473 Self::flush_diff_group(&mut diff_group, &mut result);
474
475 if result.len() == 1 {
476 result.pop().unwrap()
477 } else {
478 ExecutionPlan::Sequence(result)
479 }
480 }
481 ExecutionPlan::Parallel(branches) => ExecutionPlan::Parallel(
482 branches
483 .into_iter()
484 .map(|b| self.collapse_differentiable(b))
485 .collect(),
486 ),
487 ExecutionPlan::Remote {
488 node_id,
489 target,
490 plan,
491 } => ExecutionPlan::Remote {
492 node_id,
493 target,
494 plan: Box::new(self.collapse_differentiable(*plan)),
495 },
496 other => other,
497 }
498 }
499
500 fn flush_diff_group(group: &mut Vec<String>, result: &mut Vec<ExecutionPlan>) {
501 if group.len() > 1 {
502 result.push(ExecutionPlan::Composite {
503 node_ids: std::mem::take(group),
504 });
505 } else if let Some(id) = group.pop() {
506 result.push(ExecutionPlan::Execute { node_id: id });
507 }
508 }
509
510 fn validate_schemas(&mut self, sorted: &[&str]) {
516 for &node_id in sorted {
517 let input_schema = self
518 .registry
519 .meta(node_id)
520 .and_then(|m| m.input_schema.clone());
521
522 let Some(expected_input) = input_schema else {
524 continue;
525 };
526
527 for pred_id in self.graph.predecessors(node_id) {
529 let pred_output = self
530 .registry
531 .meta(pred_id)
532 .and_then(|m| m.output_schema.clone());
533
534 let Some(actual_output) = pred_output else {
535 continue; };
537
538 if !actual_output.is_compatible_with(&expected_input) {
539 self.diagnostics.push(Diagnostic {
540 node_id: node_id.to_string(),
541 level: DiagnosticLevel::Warning,
542 message: format!(
543 "schema mismatch: `{pred_id}` outputs {actual_output} \
544 but `{node_id}` expects {expected_input}",
545 ),
546 });
547 }
548 }
549 }
550 }
551
552 fn check_gradient_flow(&mut self, sorted: &[&str]) {
558 let mut gradient_flows = true;
559
560 for &node_id in sorted {
561 if let Some(meta) = self.registry.meta(node_id) {
562 if gradient_flows && !meta.differentiable {
563 self.diagnostics.push(Diagnostic {
564 node_id: node_id.to_string(),
565 level: DiagnosticLevel::Warning,
566 message: format!(
567 "gradient flow interrupted at `{}` ({:?}). \
568 Gradients from upstream will not reach downstream filters \
569 through this node.",
570 node_id, meta.kind,
571 ),
572 });
573 gradient_flows = false;
574 } else if !gradient_flows && meta.differentiable {
575 gradient_flows = true;
578 }
579 }
580 }
581 }
582}
583
584pub fn compile(
586 graph: &Graph,
587 registry: &dyn FilterRegistry,
588 mode: CompileMode,
589 cache: Option<&dyn CacheStore>,
590) -> Result<CompileResult> {
591 Compiler::new(graph, registry, mode).compile(cache)
592}
593
594pub fn compile_stream(
600 graph: &Graph,
601 _registry: &dyn FilterRegistry,
602 chunk_size: usize,
603) -> Result<CompileResult> {
604 graph.validate()?;
605 let sorted = graph.topological_sort()?;
606
607 if sorted.is_empty() {
608 return Ok(CompileResult {
609 plan: ExecutionPlan::Empty,
610 diagnostics: Vec::new(),
611 });
612 }
613
614 let node_ids: Vec<NodeId> = sorted.into_iter().map(|s| s.to_string()).collect();
615 let plan = ExecutionPlan::Stream {
616 node_ids,
617 chunk_size,
618 };
619
620 Ok(CompileResult {
621 plan,
622 diagnostics: Vec::new(),
623 })
624}
625
626#[cfg(test)]
627mod tests {
628 use super::*;
629 use somatize_core::cache::EntryMeta;
630 use somatize_core::error::SomaError;
631 use somatize_core::filter::{FilterKind, StreamMode};
632 use somatize_core::graph::{Edge, Graph, Node, linear_pipeline};
633 use somatize_core::value::Value;
634 use std::sync::Mutex;
635
636 struct MockCacheStore {
639 entries: Mutex<HashSet<CacheKey>>,
640 }
641
642 impl MockCacheStore {
643 fn new() -> Self {
644 Self {
645 entries: Mutex::new(HashSet::new()),
646 }
647 }
648
649 fn insert(&self, key: CacheKey) {
650 self.entries.lock().unwrap().insert(key);
651 }
652 }
653
654 impl CacheStore for MockCacheStore {
655 fn get(&self, _key: &CacheKey) -> Result<Option<Value>> {
656 Ok(None)
657 }
658 fn put(&self, _key: &CacheKey, _value: &Value) -> Result<()> {
659 Ok(())
660 }
661 fn exists(&self, key: &CacheKey) -> Result<bool> {
662 Ok(self.entries.lock().unwrap().contains(key))
663 }
664 fn remove(&self, _key: &CacheKey) -> Result<()> {
665 Ok(())
666 }
667 fn metadata(&self, _key: &CacheKey) -> Result<Option<EntryMeta>> {
668 Ok(None)
669 }
670 }
671
672 fn make_meta(kind: FilterKind, differentiable: bool) -> FilterMeta {
675 FilterMeta {
676 name: "test".into(),
677 kind,
678 cacheable: true,
679 differentiable,
680 stream_mode: StreamMode::FixedState,
681 distribution: somatize_core::filter::Distribution::Local,
682 input_schema: None,
683 output_schema: None,
684 }
685 }
686
687 fn register_nodes(registry: &mut SimpleFilterRegistry, ids: &[&str], meta: FilterMeta) {
688 for (i, id) in ids.iter().enumerate() {
689 let hash = CacheKey::from_parts(&[id.as_bytes(), &[i as u8]]);
690 registry.register_meta(*id, meta.clone(), hash);
691 }
692 }
693
694 #[test]
697 fn compile_empty_graph() {
698 let graph = Graph::new();
699 let registry = SimpleFilterRegistry::new();
700 let result = compile(&graph, ®istry, CompileMode::Inference, None).unwrap();
701 assert!(matches!(result.plan, ExecutionPlan::Empty));
702 }
703
704 #[test]
705 fn compile_single_node() {
706 let mut graph = Graph::new();
707 graph.add_node(Node::new("a", "A", "F"));
708 let mut registry = SimpleFilterRegistry::new();
709 register_nodes(
710 &mut registry,
711 &["a"],
712 make_meta(FilterKind::Trainable, true),
713 );
714
715 let result = compile(&graph, ®istry, CompileMode::Inference, None).unwrap();
716 assert!(matches!(result.plan, ExecutionPlan::Execute { .. }));
717 }
718
719 #[test]
720 fn compile_linear_pipeline_produces_sequence() {
721 let graph = linear_pipeline(vec![
722 Node::new("a", "Scaler", "F"),
723 Node::new("b", "PCA", "F"),
724 Node::new("c", "SVM", "F"),
725 ]);
726 let mut registry = SimpleFilterRegistry::new();
727 register_nodes(
728 &mut registry,
729 &["a", "b", "c"],
730 make_meta(FilterKind::Trainable, true),
731 );
732
733 let result = compile(&graph, ®istry, CompileMode::Inference, None).unwrap();
734
735 if let ExecutionPlan::Composite { node_ids } = &result.plan {
737 assert_eq!(node_ids, &["a", "b", "c"]);
738 } else {
739 panic!("expected Composite, got: {:?}", result.plan);
740 }
741 }
742
743 #[test]
744 fn compile_diamond_detects_parallelism() {
745 let mut graph = Graph::new();
746 graph.add_node(Node::new("root", "Root", "F"));
747 graph.add_node(Node::new("b1", "B1", "F"));
748 graph.add_node(Node::new("b2", "B2", "F"));
749 graph.add_node(Node::new("merge", "Merge", "F"));
750 graph.add_edge(Edge::data("e1", "root", "b1"));
751 graph.add_edge(Edge::data("e2", "root", "b2"));
752 graph.add_edge(Edge::data("e3", "b1", "merge"));
753 graph.add_edge(Edge::data("e4", "b2", "merge"));
754
755 let mut registry = SimpleFilterRegistry::new();
756 register_nodes(
757 &mut registry,
758 &["root", "b1", "b2", "merge"],
759 make_meta(FilterKind::Trainable, true),
760 );
761
762 let result = compile(&graph, ®istry, CompileMode::Inference, None).unwrap();
763
764 if let ExecutionPlan::Sequence(steps) = &result.plan {
766 assert_eq!(steps.len(), 3);
767 assert!(matches!(&steps[0], ExecutionPlan::Execute { node_id } if node_id == "root"));
768 assert!(matches!(&steps[1], ExecutionPlan::Parallel(branches) if branches.len() == 2));
769 assert!(matches!(&steps[2], ExecutionPlan::Execute { node_id } if node_id == "merge"));
770 } else {
771 panic!("expected Sequence, got: {:?}", result.plan);
772 }
773 }
774
775 #[test]
776 fn compile_independent_roots_parallel() {
777 let mut graph = Graph::new();
778 graph.add_node(Node::new("a", "A", "F"));
779 graph.add_node(Node::new("b", "B", "F"));
780 let mut registry = SimpleFilterRegistry::new();
783 register_nodes(
784 &mut registry,
785 &["a", "b"],
786 make_meta(FilterKind::Trainable, true),
787 );
788
789 let result = compile(&graph, ®istry, CompileMode::Inference, None).unwrap();
790
791 assert!(matches!(result.plan, ExecutionPlan::Parallel(_)));
793 }
794
795 #[test]
796 fn cache_resolution_replaces_cached_nodes() {
797 let graph = linear_pipeline(vec![
798 Node::new("a", "Scaler", "F"),
799 Node::new("b", "PCA", "F"),
800 Node::new("c", "SVM", "F"),
801 ]);
802
803 let mut registry = SimpleFilterRegistry::new();
804 register_nodes(
805 &mut registry,
806 &["a", "b", "c"],
807 make_meta(FilterKind::Trainable, true),
808 );
809
810 let a_config = registry.config_hash("a").unwrap();
812 let a_cache_key = CacheKey::from_parts(&[&a_config.0]);
813
814 let cache = MockCacheStore::new();
815 cache.insert(a_cache_key);
816
817 let result = compile(&graph, ®istry, CompileMode::Inference, Some(&cache)).unwrap();
818
819 if let ExecutionPlan::Sequence(steps) = &result.plan {
821 assert!(
822 matches!(&steps[0], ExecutionPlan::Cached { node_id, .. } if node_id == "a"),
823 "first node should be cached, got: {:?}",
824 steps[0]
825 );
826 assert!(
827 matches!(&steps[1], ExecutionPlan::Composite { node_ids } if node_ids == &["b", "c"]),
828 "b+c should be Composite, got: {:?}",
829 steps[1]
830 );
831 } else {
832 panic!("expected Sequence, got: {:?}", result.plan);
833 }
834 }
835
836 #[test]
837 fn cascade_invalidation_different_config_changes_keys() {
838 let mut reg1 = SimpleFilterRegistry::new();
840 reg1.register_meta(
841 "a",
842 make_meta(FilterKind::Trainable, true),
843 CacheKey::hash_data(b"scaler_v1"),
844 );
845 reg1.register_meta(
846 "b",
847 make_meta(FilterKind::Trainable, true),
848 CacheKey::hash_data(b"pca_v1"),
849 );
850
851 let mut reg2 = SimpleFilterRegistry::new();
853 reg2.register_meta(
854 "a",
855 make_meta(FilterKind::Trainable, true),
856 CacheKey::hash_data(b"scaler_v2"), );
858 reg2.register_meta(
859 "b",
860 make_meta(FilterKind::Trainable, true),
861 CacheKey::hash_data(b"pca_v1"), );
863
864 let a_key_v1 = CacheKey::from_parts(&[&CacheKey::hash_data(b"scaler_v1").0]);
869 let b_key_v1 = CacheKey::from_parts(&[&CacheKey::hash_data(b"pca_v1").0, &a_key_v1.0]);
870
871 let a_key_v2 = CacheKey::from_parts(&[&CacheKey::hash_data(b"scaler_v2").0]);
872 let b_key_v2 = CacheKey::from_parts(&[&CacheKey::hash_data(b"pca_v1").0, &a_key_v2.0]);
873
874 assert_ne!(a_key_v1, a_key_v2);
876 assert_ne!(b_key_v1, b_key_v2);
878 }
879
880 #[test]
881 fn no_cache_mode_skips_all_caching() {
882 let graph = linear_pipeline(vec![Node::new("a", "A", "F"), Node::new("b", "B", "F")]);
883
884 let mut registry = SimpleFilterRegistry::new();
885 register_nodes(
886 &mut registry,
887 &["a", "b"],
888 make_meta(FilterKind::Trainable, true),
889 );
890
891 let a_config = registry.config_hash("a").unwrap();
893 let a_key = CacheKey::from_parts(&[&a_config.0]);
894 let cache = MockCacheStore::new();
895 cache.insert(a_key);
896
897 let result = compile(&graph, ®istry, CompileMode::NoCache, Some(&cache)).unwrap();
898
899 assert_eq!(result.plan.cached_count(), 0);
901 }
902
903 #[test]
904 fn differentiable_mode_skips_output_caching() {
905 let graph = linear_pipeline(vec![Node::new("a", "A", "F"), Node::new("b", "B", "F")]);
906
907 let mut registry = SimpleFilterRegistry::new();
908 register_nodes(
909 &mut registry,
910 &["a", "b"],
911 make_meta(FilterKind::Trainable, true),
912 );
913
914 let a_config = registry.config_hash("a").unwrap();
915 let a_key = CacheKey::from_parts(&[&a_config.0]);
916 let cache = MockCacheStore::new();
917 cache.insert(a_key);
918
919 let result = compile(&graph, ®istry, CompileMode::Differentiable, Some(&cache)).unwrap();
920
921 assert_eq!(result.plan.cached_count(), 0);
923 }
924
925 #[test]
926 fn gradient_flow_diagnostic_on_opaque() {
927 let graph = linear_pipeline(vec![
928 Node::new("scaler", "Scaler", "F"),
929 Node::new("tree", "DecisionTree", "F"),
930 Node::new("linear", "Linear", "F"),
931 ]);
932
933 let mut registry = SimpleFilterRegistry::new();
934 registry.register_meta(
935 "scaler",
936 make_meta(FilterKind::Trainable, true),
937 CacheKey::hash_data(b"s"),
938 );
939 registry.register_meta(
940 "tree",
941 make_meta(FilterKind::Opaque, false), CacheKey::hash_data(b"t"),
943 );
944 registry.register_meta(
945 "linear",
946 make_meta(FilterKind::Trainable, true),
947 CacheKey::hash_data(b"l"),
948 );
949
950 let result = compile(&graph, ®istry, CompileMode::Inference, None).unwrap();
951
952 assert_eq!(result.diagnostics.len(), 1);
953 assert_eq!(result.diagnostics[0].node_id, "tree");
954 assert_eq!(result.diagnostics[0].level, DiagnosticLevel::Warning);
955 assert!(
956 result.diagnostics[0]
957 .message
958 .contains("gradient flow interrupted")
959 );
960 }
961
962 #[test]
963 fn no_diagnostic_when_all_differentiable() {
964 let graph = linear_pipeline(vec![Node::new("a", "A", "F"), Node::new("b", "B", "F")]);
965
966 let mut registry = SimpleFilterRegistry::new();
967 register_nodes(
968 &mut registry,
969 &["a", "b"],
970 make_meta(FilterKind::Trainable, true),
971 );
972
973 let result = compile(&graph, ®istry, CompileMode::Inference, None).unwrap();
974 assert!(result.diagnostics.is_empty());
975 }
976
977 #[test]
978 fn compile_rejects_cycle() {
979 let mut graph = Graph::new();
980 graph.add_node(Node::new("a", "A", "F"));
981 graph.add_node(Node::new("b", "B", "F"));
982 graph.add_edge(Edge::data("e1", "a", "b"));
983 graph.add_edge(Edge::data("e2", "b", "a"));
984
985 let registry = SimpleFilterRegistry::new();
986 let result = compile(&graph, ®istry, CompileMode::Inference, None);
987 assert!(matches!(result, Err(SomaError::CycleDetected)));
988 }
989
990 #[test]
991 fn plan_summary_is_accurate() {
992 let mut graph = Graph::new();
993 graph.add_node(Node::new("root", "Root", "F"));
994 graph.add_node(Node::new("b1", "B1", "F"));
995 graph.add_node(Node::new("b2", "B2", "F"));
996 graph.add_node(Node::new("end", "End", "F"));
997 graph.add_edge(Edge::data("e1", "root", "b1"));
998 graph.add_edge(Edge::data("e2", "root", "b2"));
999 graph.add_edge(Edge::data("e3", "b1", "end"));
1000 graph.add_edge(Edge::data("e4", "b2", "end"));
1001
1002 let mut registry = SimpleFilterRegistry::new();
1003 register_nodes(
1004 &mut registry,
1005 &["root", "b1", "b2", "end"],
1006 make_meta(FilterKind::Trainable, true),
1007 );
1008
1009 let result = compile(&graph, ®istry, CompileMode::Inference, None).unwrap();
1010 let summary = result.plan.summary();
1011 assert_eq!(summary.total_nodes, 4);
1012 assert_eq!(summary.parallel_branches, 2);
1013 }
1014
1015 #[test]
1016 fn distribution_wraps_remote_nodes() {
1017 let graph = linear_pipeline(vec![
1018 Node::new("preprocess", "Preprocess", "F"),
1019 Node::new("gpu_train", "GpuTrain", "F"),
1020 Node::new("evaluate", "Evaluate", "F"),
1021 ]);
1022
1023 let mut registry = SimpleFilterRegistry::new();
1024 registry.register_meta(
1026 "preprocess",
1027 make_meta(FilterKind::Trainable, true),
1028 CacheKey::hash_data(b"pre"),
1029 );
1030 let mut gpu_meta = make_meta(FilterKind::Trainable, true);
1032 gpu_meta.distribution = somatize_core::filter::Distribution::Remote(
1033 somatize_core::filter::RemoteTarget::Tag("gpu".into()),
1034 );
1035 registry.register_meta("gpu_train", gpu_meta, CacheKey::hash_data(b"gpu"));
1036 registry.register_meta(
1038 "evaluate",
1039 make_meta(FilterKind::Trainable, true),
1040 CacheKey::hash_data(b"eval"),
1041 );
1042
1043 let result = compile(&graph, ®istry, CompileMode::Inference, None).unwrap();
1044
1045 if let ExecutionPlan::Sequence(steps) = &result.plan {
1047 assert_eq!(steps.len(), 3);
1048 assert!(
1049 matches!(&steps[0], ExecutionPlan::Execute { node_id } if node_id == "preprocess")
1050 );
1051 assert!(
1052 matches!(&steps[1], ExecutionPlan::Remote { node_id, target, .. }
1053 if node_id == "gpu_train"
1054 && *target == somatize_core::filter::RemoteTarget::Tag("gpu".into())
1055 ),
1056 "expected Remote, got: {:?}",
1057 steps[1]
1058 );
1059 assert!(
1060 matches!(&steps[2], ExecutionPlan::Execute { node_id } if node_id == "evaluate")
1061 );
1062 } else {
1063 panic!("expected Sequence, got: {:?}", result.plan);
1064 }
1065 }
1066
1067 #[test]
1068 fn local_distribution_not_wrapped() {
1069 let graph = linear_pipeline(vec![Node::new("a", "A", "F"), Node::new("b", "B", "F")]);
1070
1071 let mut registry = SimpleFilterRegistry::new();
1072 register_nodes(
1073 &mut registry,
1074 &["a", "b"],
1075 make_meta(FilterKind::Trainable, true),
1076 );
1077
1078 let result = compile(&graph, ®istry, CompileMode::Inference, None).unwrap();
1079
1080 let ids = result.plan.node_ids();
1082 assert_eq!(ids.len(), 2);
1083 if let ExecutionPlan::Sequence(steps) = &result.plan {
1085 assert!(
1086 steps
1087 .iter()
1088 .all(|s| matches!(s, ExecutionPlan::Execute { .. }))
1089 );
1090 }
1091 }
1092}