1use crate::error::{Result, SomaError};
7use crate::strategy::TrainingStrategy;
8use serde::{Deserialize, Serialize};
9use std::collections::{HashMap, HashSet};
10
11pub type NodeId = String;
16
17pub type EdgeId = String;
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
22#[serde(tag = "type")]
23#[non_exhaustive]
24pub enum NodeKind {
25 Filter { filter_name: String },
27 SubGraph { graph: Box<Graph> },
29 Loop { max_iterations: Option<usize> },
31 Branch,
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct Node {
38 pub id: NodeId,
39 pub label: String,
40 pub kind: NodeKind,
41 #[serde(default, skip_serializing_if = "Option::is_none")]
44 pub target: Option<String>,
45}
46
47impl Node {
48 pub fn new(
50 id: impl Into<String>,
51 label: impl Into<String>,
52 filter_name: impl Into<String>,
53 ) -> Self {
54 Self {
55 id: id.into(),
56 label: label.into(),
57 kind: NodeKind::Filter {
58 filter_name: filter_name.into(),
59 },
60 target: None,
61 }
62 }
63
64 pub fn filter_with_id(id: impl Into<String>, filter_name: impl Into<String>) -> Self {
66 let id = id.into();
67 Self {
68 label: id.clone(),
69 id,
70 kind: NodeKind::Filter {
71 filter_name: filter_name.into(),
72 },
73 target: None,
74 }
75 }
76
77 pub fn filter(filter_name: impl Into<String>) -> Self {
79 let name = filter_name.into();
80 Self {
81 id: name.clone(),
82 label: name.clone(),
83 kind: NodeKind::Filter { filter_name: name },
84 target: None,
85 }
86 }
87
88 pub fn subgraph(id: impl Into<String>, graph: Graph) -> Self {
90 let id = id.into();
91 Self {
92 id: id.clone(),
93 label: id,
94 kind: NodeKind::SubGraph {
95 graph: Box::new(graph),
96 },
97 target: None,
98 }
99 }
100
101 pub fn loop_node(id: impl Into<String>, max_iterations: Option<usize>) -> Self {
103 let id = id.into();
104 Self {
105 id: id.clone(),
106 label: id,
107 kind: NodeKind::Loop { max_iterations },
108 target: None,
109 }
110 }
111
112 pub fn branch(id: impl Into<String>) -> Self {
114 let id = id.into();
115 Self {
116 id: id.clone(),
117 label: id,
118 kind: NodeKind::Branch,
119 target: None,
120 }
121 }
122
123 pub fn with_target(mut self, target: impl Into<String>) -> Self {
125 self.target = Some(target.into());
126 self
127 }
128
129 pub fn is_local(&self) -> bool {
131 self.target.as_deref() == Some("local")
132 }
133
134 pub fn filter_name(&self) -> Option<&str> {
136 match &self.kind {
137 NodeKind::Filter { filter_name } => Some(filter_name),
138 _ => None,
139 }
140 }
141}
142
143#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
145pub enum EdgeKind {
146 Data,
148 Control,
150}
151
152#[derive(Debug, Clone, Serialize, Deserialize)]
154pub struct Edge {
155 pub id: EdgeId,
156 pub source: NodeId,
157 pub target: NodeId,
158 pub kind: EdgeKind,
159 pub label: Option<String>,
160}
161
162impl Edge {
163 pub fn data(
164 id: impl Into<String>,
165 source: impl Into<String>,
166 target: impl Into<String>,
167 ) -> Self {
168 Self {
169 id: id.into(),
170 source: source.into(),
171 target: target.into(),
172 kind: EdgeKind::Data,
173 label: None,
174 }
175 }
176
177 pub fn control(
178 id: impl Into<String>,
179 source: impl Into<String>,
180 target: impl Into<String>,
181 ) -> Self {
182 Self {
183 id: id.into(),
184 source: source.into(),
185 target: target.into(),
186 kind: EdgeKind::Control,
187 label: None,
188 }
189 }
190}
191
192#[derive(Debug, Clone, Serialize, Deserialize)]
194pub struct Graph {
195 pub nodes: Vec<Node>,
196 pub edges: Vec<Edge>,
197 #[serde(default, skip_serializing_if = "Option::is_none")]
200 pub training_strategy: Option<TrainingStrategy>,
201}
202
203impl Graph {
204 pub fn new() -> Self {
205 Self {
206 nodes: Vec::new(),
207 edges: Vec::new(),
208 training_strategy: None,
209 }
210 }
211
212 pub fn with_strategy(mut self, strategy: TrainingStrategy) -> Self {
214 self.training_strategy = Some(strategy);
215 self
216 }
217
218 pub fn set_strategy(&mut self, strategy: TrainingStrategy) {
220 self.training_strategy = Some(strategy);
221 }
222
223 pub fn effective_strategy(&self) -> &TrainingStrategy {
225 static LOCAL: TrainingStrategy = TrainingStrategy::Local;
226 self.training_strategy.as_ref().unwrap_or(&LOCAL)
227 }
228
229 pub fn add_node(&mut self, node: Node) {
230 self.nodes.push(node);
231 }
232
233 pub fn add_filter(&mut self, filter_name: impl Into<String>) -> &str {
236 let name = filter_name.into();
237 let id = if self.nodes.iter().any(|n| n.id == name) {
238 let mut i = 2;
239 loop {
240 let candidate = format!("{name}_{i}");
241 if !self.nodes.iter().any(|n| n.id == candidate) {
242 break candidate;
243 }
244 i += 1;
245 }
246 } else {
247 name.clone()
248 };
249 self.nodes.push(Node::filter_with_id(&id, &name));
250 &self.nodes.last().unwrap().id
251 }
252
253 pub fn add_edge(&mut self, edge: Edge) {
254 self.edges.push(edge);
255 }
256
257 pub fn connect(&mut self, source: impl Into<String>, target: impl Into<String>) {
259 let id = format!("e_{}", self.edges.len());
260 self.edges.push(Edge::data(id, source, target));
261 }
262
263 pub fn node(&self, id: &str) -> Option<&Node> {
265 self.nodes.iter().find(|n| n.id == id)
266 }
267
268 pub fn node_ids(&self) -> Vec<&str> {
270 self.nodes.iter().map(|n| n.id.as_str()).collect()
271 }
272
273 pub fn predecessors(&self, node_id: &str) -> Vec<&str> {
275 self.edges
276 .iter()
277 .filter(|e| e.target == node_id)
278 .map(|e| e.source.as_str())
279 .collect()
280 }
281
282 pub fn successors(&self, node_id: &str) -> Vec<&str> {
284 self.edges
285 .iter()
286 .filter(|e| e.source == node_id)
287 .map(|e| e.target.as_str())
288 .collect()
289 }
290
291 pub fn roots(&self) -> Vec<&str> {
293 let has_incoming: HashSet<&str> = self.edges.iter().map(|e| e.target.as_str()).collect();
294 self.nodes
295 .iter()
296 .filter(|n| !has_incoming.contains(n.id.as_str()))
297 .map(|n| n.id.as_str())
298 .collect()
299 }
300
301 pub fn leaves(&self) -> Vec<&str> {
303 let has_outgoing: HashSet<&str> = self.edges.iter().map(|e| e.source.as_str()).collect();
304 self.nodes
305 .iter()
306 .filter(|n| !has_outgoing.contains(n.id.as_str()))
307 .map(|n| n.id.as_str())
308 .collect()
309 }
310
311 fn in_degrees(&self) -> HashMap<&str, usize> {
313 let mut degrees: HashMap<&str, usize> =
314 self.nodes.iter().map(|n| (n.id.as_str(), 0)).collect();
315 for edge in &self.edges {
316 *degrees.entry(edge.target.as_str()).or_insert(0) += 1;
317 }
318 degrees
319 }
320
321 pub fn topological_sort(&self) -> Result<Vec<&str>> {
324 let mut in_deg = self.in_degrees();
325 let mut queue: Vec<&str> = in_deg
326 .iter()
327 .filter(|(_, deg)| **deg == 0)
328 .map(|(&id, _)| id)
329 .collect();
330 queue.sort(); let mut sorted = Vec::with_capacity(self.nodes.len());
333
334 while let Some(node) = queue.pop() {
335 sorted.push(node);
336 let mut next = Vec::new();
337 for succ in self.successors(node) {
338 if let Some(deg) = in_deg.get_mut(succ) {
339 *deg -= 1;
340 if *deg == 0 {
341 next.push(succ);
342 }
343 }
344 }
345 next.sort();
346 for n in next.into_iter().rev() {
348 queue.push(n);
349 }
350 }
351
352 if sorted.len() != self.nodes.len() {
353 return Err(SomaError::CycleDetected);
354 }
355
356 Ok(sorted)
357 }
358
359 pub fn validate(&self) -> Result<()> {
361 let mut seen = HashSet::new();
363 for node in &self.nodes {
364 if !seen.insert(&node.id) {
365 return Err(SomaError::Compilation(format!(
366 "duplicate node id: `{}`",
367 node.id
368 )));
369 }
370 }
371
372 let node_ids: HashSet<&str> = self.nodes.iter().map(|n| n.id.as_str()).collect();
374 for edge in &self.edges {
375 if !node_ids.contains(edge.source.as_str()) {
376 return Err(SomaError::NodeNotFound(edge.source.clone()));
377 }
378 if !node_ids.contains(edge.target.as_str()) {
379 return Err(SomaError::NodeNotFound(edge.target.clone()));
380 }
381 }
382
383 self.topological_sort()?;
385
386 for node in &self.nodes {
388 if let NodeKind::SubGraph { graph } = &node.kind {
389 graph.validate()?;
390 }
391 }
392
393 Ok(())
394 }
395}
396
397impl Graph {
400 pub fn to_mermaid(&self) -> String {
409 use std::fmt::Write;
410 let mut out = String::from("graph LR\n");
411 for node in &self.nodes {
412 let shape = match &node.kind {
413 NodeKind::Filter { .. } => format!(" {}[{}]", node.id, node.label),
414 NodeKind::SubGraph { .. } => format!(" {}[[{}]]", node.id, node.label),
415 NodeKind::Loop { max_iterations } => {
416 let label = match max_iterations {
417 Some(n) => format!("{} (max {})", node.label, n),
418 None => node.label.clone(),
419 };
420 format!(" {}(({}))", node.id, label)
421 }
422 NodeKind::Branch => format!(" {}{{{{{}}}}}", node.id, node.label),
423 };
424 let _ = writeln!(out, "{shape}");
425 }
426 for edge in &self.edges {
427 let arrow = match edge.kind {
428 EdgeKind::Data => "-->",
429 EdgeKind::Control => "-.->",
430 };
431 if let Some(label) = &edge.label {
432 let _ = writeln!(
433 out,
434 " {} {}|{}| {}",
435 edge.source, arrow, label, edge.target
436 );
437 } else {
438 let _ = writeln!(out, " {} {} {}", edge.source, arrow, edge.target);
439 }
440 }
441 out
442 }
443
444 pub fn to_graphviz(&self) -> String {
446 use std::fmt::Write;
447 let mut out = String::from("digraph G {\n rankdir=LR;\n");
448 for node in &self.nodes {
449 let shape = match &node.kind {
450 NodeKind::Filter { .. } => "box",
451 NodeKind::SubGraph { .. } => "doubleoctagon",
452 NodeKind::Loop { .. } => "ellipse",
453 NodeKind::Branch => "diamond",
454 };
455 let _ = writeln!(
456 out,
457 " \"{}\" [label=\"{}\" shape={}];",
458 node.id, node.label, shape
459 );
460 }
461 for edge in &self.edges {
462 let style = match edge.kind {
463 EdgeKind::Data => "",
464 EdgeKind::Control => " [style=dashed]",
465 };
466 let label = edge
467 .label
468 .as_ref()
469 .map(|l| format!(" [label=\"{l}\"]"))
470 .unwrap_or_default();
471 let attrs = if style.is_empty() && label.is_empty() {
472 String::new()
473 } else if label.is_empty() {
474 style.to_string()
475 } else {
476 label
477 };
478 let _ = writeln!(
479 out,
480 " \"{}\" -> \"{}\"{};",
481 edge.source, edge.target, attrs
482 );
483 }
484 out.push_str("}\n");
485 out
486 }
487
488 pub fn to_text(&self) -> String {
490 use std::fmt::Write;
491 let mut out = String::new();
492 let sorted = self.topological_sort().unwrap_or_default();
493 let total_nodes = self.nodes.len();
494 let total_edges = self.edges.len();
495 let _ = writeln!(out, "Graph ({total_nodes} nodes, {total_edges} edges)");
496
497 for (i, node_id) in sorted.iter().enumerate() {
498 let node = match self.node(node_id) {
499 Some(n) => n,
500 None => continue,
501 };
502 let is_last = i == sorted.len() - 1;
503 let prefix = if is_last { "└── " } else { "├── " };
504 let kind_tag = match &node.kind {
505 NodeKind::Filter { filter_name } => {
506 if filter_name == &node.id {
507 String::new()
508 } else {
509 format!(" ({})", filter_name)
510 }
511 }
512 NodeKind::SubGraph { graph } => {
513 format!(" [subgraph: {} nodes]", graph.nodes.len())
514 }
515 NodeKind::Loop { max_iterations } => match max_iterations {
516 Some(n) => format!(" [loop max={n}]"),
517 None => " [loop]".into(),
518 },
519 NodeKind::Branch => " [branch]".into(),
520 };
521 let preds = self.predecessors(node_id);
522 let pred_info = if preds.is_empty() {
523 String::new()
524 } else {
525 format!(" ← {}", preds.join(", "))
526 };
527 let _ = writeln!(out, "{prefix}{}{kind_tag}{pred_info}", node.id);
528 }
529 out
530 }
531}
532
533impl std::fmt::Display for Graph {
534 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
535 write!(f, "{}", self.to_text())
536 }
537}
538
539impl Default for Graph {
540 fn default() -> Self {
541 Self::new()
542 }
543}
544
545pub fn linear_pipeline(nodes: Vec<Node>) -> Graph {
547 let mut graph = Graph::new();
548 for (i, node) in nodes.iter().enumerate() {
549 graph.add_node(node.clone());
550 if i > 0 {
551 graph.add_edge(Edge::data(format!("e_{}", i), &nodes[i - 1].id, &node.id));
552 }
553 }
554 graph
555}
556
557#[cfg(test)]
558mod tests {
559 use super::*;
560
561 fn sample_linear_graph() -> Graph {
562 linear_pipeline(vec![
563 Node::new("a", "Scaler", "StandardScaler"),
564 Node::new("b", "PCA", "PCA"),
565 Node::new("c", "SVM", "SVM"),
566 ])
567 }
568
569 #[test]
570 fn linear_pipeline_structure() {
571 let g = sample_linear_graph();
572 assert_eq!(g.nodes.len(), 3);
573 assert_eq!(g.edges.len(), 2);
574 }
575
576 #[test]
577 fn roots_and_leaves() {
578 let g = sample_linear_graph();
579 assert_eq!(g.roots(), vec!["a"]);
580 assert_eq!(g.leaves(), vec!["c"]);
581 }
582
583 #[test]
584 fn predecessors_and_successors() {
585 let g = sample_linear_graph();
586 assert!(g.predecessors("a").is_empty());
587 assert_eq!(g.predecessors("b"), vec!["a"]);
588 assert_eq!(g.successors("a"), vec!["b"]);
589 assert_eq!(g.successors("b"), vec!["c"]);
590 assert!(g.successors("c").is_empty());
591 }
592
593 #[test]
594 fn topological_sort_linear() {
595 let g = sample_linear_graph();
596 let sorted = g.topological_sort().unwrap();
597 assert_eq!(sorted, vec!["a", "b", "c"]);
598 }
599
600 #[test]
601 fn topological_sort_parallel() {
602 let mut g = Graph::new();
603 g.add_node(Node::new("root", "Root", "Input"));
604 g.add_node(Node::new("b1", "Branch1", "F1"));
605 g.add_node(Node::new("b2", "Branch2", "F2"));
606 g.add_node(Node::new("merge", "Merge", "Merge"));
607 g.add_edge(Edge::data("e1", "root", "b1"));
608 g.add_edge(Edge::data("e2", "root", "b2"));
609 g.add_edge(Edge::data("e3", "b1", "merge"));
610 g.add_edge(Edge::data("e4", "b2", "merge"));
611
612 let sorted = g.topological_sort().unwrap();
613 assert_eq!(sorted[0], "root");
615 assert_eq!(sorted[3], "merge");
616 let middle: HashSet<&str> = sorted[1..3].iter().copied().collect();
618 assert!(middle.contains("b1"));
619 assert!(middle.contains("b2"));
620 }
621
622 #[test]
623 fn topological_sort_detects_cycle() {
624 let mut g = Graph::new();
625 g.add_node(Node::new("a", "A", "F"));
626 g.add_node(Node::new("b", "B", "F"));
627 g.add_edge(Edge::data("e1", "a", "b"));
628 g.add_edge(Edge::data("e2", "b", "a")); let result = g.topological_sort();
631 assert!(matches!(result, Err(SomaError::CycleDetected)));
632 }
633
634 #[test]
635 fn validate_accepts_valid_graph() {
636 let g = sample_linear_graph();
637 assert!(g.validate().is_ok());
638 }
639
640 #[test]
641 fn validate_rejects_duplicate_ids() {
642 let mut g = Graph::new();
643 g.add_node(Node::new("a", "A", "F"));
644 g.add_node(Node::new("a", "A2", "F"));
645 assert!(matches!(g.validate(), Err(SomaError::Compilation(_))));
646 }
647
648 #[test]
649 fn validate_rejects_missing_edge_target() {
650 let mut g = Graph::new();
651 g.add_node(Node::new("a", "A", "F"));
652 g.add_edge(Edge::data("e1", "a", "nonexistent"));
653 assert!(matches!(g.validate(), Err(SomaError::NodeNotFound(_))));
654 }
655
656 #[test]
657 fn graph_serde_roundtrip() {
658 let g = sample_linear_graph();
659 let json = serde_json::to_string(&g).unwrap();
660 let deserialized: Graph = serde_json::from_str(&json).unwrap();
661 assert_eq!(deserialized.nodes.len(), 3);
662 assert_eq!(deserialized.edges.len(), 2);
663 }
664
665 #[test]
666 fn empty_graph_is_valid() {
667 let g = Graph::new();
668 assert!(g.validate().is_ok());
669 assert!(g.topological_sort().unwrap().is_empty());
670 }
671
672 #[test]
673 fn single_node_graph() {
674 let mut g = Graph::new();
675 g.add_node(Node::new("solo", "Solo", "F"));
676 assert_eq!(g.roots(), vec!["solo"]);
677 assert_eq!(g.leaves(), vec!["solo"]);
678 assert_eq!(g.topological_sort().unwrap(), vec!["solo"]);
679 }
680
681 #[test]
684 fn node_filter_shorthand() {
685 let n = Node::filter("StandardScaler");
686 assert_eq!(n.id, "StandardScaler");
687 assert_eq!(n.filter_name(), Some("StandardScaler"));
688 }
689
690 #[test]
691 fn node_filter_with_id() {
692 let n = Node::filter_with_id("my_scaler", "StandardScaler");
693 assert_eq!(n.id, "my_scaler");
694 assert_eq!(n.filter_name(), Some("StandardScaler"));
695 }
696
697 #[test]
698 fn graph_add_filter_auto_names() {
699 let mut g = Graph::new();
700 g.add_filter("Scaler");
701 g.add_filter("PCA");
702 g.connect("Scaler", "PCA");
703
704 assert!(g.validate().is_ok());
705 assert_eq!(g.nodes.len(), 2);
706 assert_eq!(g.nodes[0].id, "Scaler");
707 assert_eq!(g.nodes[1].id, "PCA");
708 }
709
710 #[test]
711 fn graph_add_filter_deduplicates() {
712 let mut g = Graph::new();
713 g.add_filter("Scaler");
714 g.add_filter("Scaler"); assert_eq!(g.nodes.len(), 2);
717 assert_eq!(g.nodes[0].id, "Scaler");
718 assert_eq!(g.nodes[1].id, "Scaler_2");
719 }
720
721 #[test]
722 fn subgraph_node() {
723 let inner = linear_pipeline(vec![Node::new("a", "A", "F"), Node::new("b", "B", "F")]);
724
725 let mut outer = Graph::new();
726 outer.add_node(Node::new("input", "Input", "Input"));
727 outer.add_node(Node::subgraph("pipeline", inner));
728 outer.add_node(Node::new("output", "Output", "Output"));
729 outer.add_edge(Edge::data("e1", "input", "pipeline"));
730 outer.add_edge(Edge::data("e2", "pipeline", "output"));
731
732 assert!(outer.validate().is_ok());
733 assert_eq!(outer.nodes.len(), 3);
734
735 assert!(outer.node("pipeline").unwrap().filter_name().is_none());
737 }
738
739 #[test]
740 fn loop_and_branch_nodes() {
741 let mut g = Graph::new();
742 g.add_node(Node::loop_node("train_loop", Some(100)));
743 g.add_node(Node::branch("check_convergence"));
744 g.add_edge(Edge::data("e1", "train_loop", "check_convergence"));
745
746 assert!(g.validate().is_ok());
747 assert!(matches!(
748 g.node("train_loop").unwrap().kind,
749 NodeKind::Loop {
750 max_iterations: Some(100)
751 }
752 ));
753 assert!(matches!(
754 g.node("check_convergence").unwrap().kind,
755 NodeKind::Branch
756 ));
757 }
758
759 #[test]
762 fn to_mermaid_linear() {
763 let g = sample_linear_graph();
764 let m = g.to_mermaid();
765 assert!(m.starts_with("graph LR"));
766 assert!(m.contains("a[Scaler]"));
767 assert!(m.contains("b[PCA]"));
768 assert!(m.contains("c[SVM]"));
769 assert!(m.contains("a --> b"));
770 assert!(m.contains("b --> c"));
771 }
772
773 #[test]
774 fn to_mermaid_branch_and_loop() {
775 let mut g = Graph::new();
776 g.add_node(Node::loop_node("train", Some(100)));
777 g.add_node(Node::branch("check"));
778 g.add_edge(Edge::data("e1", "train", "check"));
779
780 let m = g.to_mermaid();
781 assert!(m.contains("train((train (max 100)))"));
782 assert!(m.contains("check{"));
783 assert!(m.contains("train --> check"));
784 }
785
786 #[test]
787 fn to_graphviz_output() {
788 let g = sample_linear_graph();
789 let dot = g.to_graphviz();
790 assert!(dot.starts_with("digraph G {"));
791 assert!(dot.contains("rankdir=LR"));
792 assert!(dot.contains("\"a\" [label=\"Scaler\" shape=box]"));
793 assert!(dot.contains("\"a\" -> \"b\""));
794 assert!(dot.ends_with("}\n"));
795 }
796
797 #[test]
798 fn to_text_output() {
799 let g = sample_linear_graph();
800 let text = g.to_text();
801 assert!(text.contains("Graph (3 nodes, 2 edges)"));
802 assert!(text.contains("a"));
803 assert!(text.contains("b"));
804 assert!(text.contains("c"));
805 assert!(text.contains("← a"));
806 }
807
808 #[test]
809 fn display_trait() {
810 let g = sample_linear_graph();
811 let s = format!("{g}");
812 assert!(s.contains("Graph (3 nodes"));
813 }
814
815 #[test]
816 fn node_kind_serde_roundtrip() {
817 let inner = linear_pipeline(vec![Node::new("x", "X", "F")]);
818 let nodes = vec![
819 Node::filter("Scaler"),
820 Node::subgraph("sub", inner),
821 Node::loop_node("loop", Some(50)),
822 Node::branch("cond"),
823 ];
824
825 for node in &nodes {
826 let json = serde_json::to_string(node).unwrap();
827 let parsed: Node = serde_json::from_str(&json).unwrap();
828 assert_eq!(parsed.id, node.id);
829 }
830 }
831}