1use crate::error::{Result, SomaError};
7use crate::strategy::TrainingStrategy;
8use serde::{Deserialize, Serialize};
9use std::collections::{HashMap, HashSet};
10
11pub type NodeId = String;
16
17pub type EdgeId = String;
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
22#[serde(tag = "type")]
23#[non_exhaustive]
24pub enum NodeKind {
25 Filter { filter_name: String },
27 SubGraph { graph: Box<Graph> },
29 Loop { max_iterations: Option<usize> },
31 Branch,
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct Node {
38 pub id: NodeId,
39 pub label: String,
40 pub kind: NodeKind,
41}
42
43impl Node {
44 pub fn new(
46 id: impl Into<String>,
47 label: impl Into<String>,
48 filter_name: impl Into<String>,
49 ) -> Self {
50 Self {
51 id: id.into(),
52 label: label.into(),
53 kind: NodeKind::Filter {
54 filter_name: filter_name.into(),
55 },
56 }
57 }
58
59 pub fn filter_with_id(id: impl Into<String>, filter_name: impl Into<String>) -> Self {
61 let id = id.into();
62 Self {
63 label: id.clone(),
64 id,
65 kind: NodeKind::Filter {
66 filter_name: filter_name.into(),
67 },
68 }
69 }
70
71 pub fn filter(filter_name: impl Into<String>) -> Self {
73 let name = filter_name.into();
74 Self {
75 id: name.clone(),
76 label: name.clone(),
77 kind: NodeKind::Filter { filter_name: name },
78 }
79 }
80
81 pub fn subgraph(id: impl Into<String>, graph: Graph) -> Self {
83 let id = id.into();
84 Self {
85 id: id.clone(),
86 label: id,
87 kind: NodeKind::SubGraph {
88 graph: Box::new(graph),
89 },
90 }
91 }
92
93 pub fn loop_node(id: impl Into<String>, max_iterations: Option<usize>) -> Self {
95 let id = id.into();
96 Self {
97 id: id.clone(),
98 label: id,
99 kind: NodeKind::Loop { max_iterations },
100 }
101 }
102
103 pub fn branch(id: impl Into<String>) -> Self {
105 let id = id.into();
106 Self {
107 id: id.clone(),
108 label: id,
109 kind: NodeKind::Branch,
110 }
111 }
112
113 pub fn filter_name(&self) -> Option<&str> {
115 match &self.kind {
116 NodeKind::Filter { filter_name } => Some(filter_name),
117 _ => None,
118 }
119 }
120}
121
122#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
124pub enum EdgeKind {
125 Data,
127 Control,
129}
130
131#[derive(Debug, Clone, Serialize, Deserialize)]
133pub struct Edge {
134 pub id: EdgeId,
135 pub source: NodeId,
136 pub target: NodeId,
137 pub kind: EdgeKind,
138 pub label: Option<String>,
139}
140
141impl Edge {
142 pub fn data(
143 id: impl Into<String>,
144 source: impl Into<String>,
145 target: impl Into<String>,
146 ) -> Self {
147 Self {
148 id: id.into(),
149 source: source.into(),
150 target: target.into(),
151 kind: EdgeKind::Data,
152 label: None,
153 }
154 }
155
156 pub fn control(
157 id: impl Into<String>,
158 source: impl Into<String>,
159 target: impl Into<String>,
160 ) -> Self {
161 Self {
162 id: id.into(),
163 source: source.into(),
164 target: target.into(),
165 kind: EdgeKind::Control,
166 label: None,
167 }
168 }
169}
170
171#[derive(Debug, Clone, Serialize, Deserialize)]
173pub struct Graph {
174 pub nodes: Vec<Node>,
175 pub edges: Vec<Edge>,
176 #[serde(default, skip_serializing_if = "Option::is_none")]
179 pub training_strategy: Option<TrainingStrategy>,
180}
181
182impl Graph {
183 pub fn new() -> Self {
184 Self {
185 nodes: Vec::new(),
186 edges: Vec::new(),
187 training_strategy: None,
188 }
189 }
190
191 pub fn with_strategy(mut self, strategy: TrainingStrategy) -> Self {
193 self.training_strategy = Some(strategy);
194 self
195 }
196
197 pub fn set_strategy(&mut self, strategy: TrainingStrategy) {
199 self.training_strategy = Some(strategy);
200 }
201
202 pub fn effective_strategy(&self) -> &TrainingStrategy {
204 static LOCAL: TrainingStrategy = TrainingStrategy::Local;
205 self.training_strategy.as_ref().unwrap_or(&LOCAL)
206 }
207
208 pub fn add_node(&mut self, node: Node) {
209 self.nodes.push(node);
210 }
211
212 pub fn add_filter(&mut self, filter_name: impl Into<String>) -> &str {
215 let name = filter_name.into();
216 let id = if self.nodes.iter().any(|n| n.id == name) {
217 let mut i = 2;
218 loop {
219 let candidate = format!("{name}_{i}");
220 if !self.nodes.iter().any(|n| n.id == candidate) {
221 break candidate;
222 }
223 i += 1;
224 }
225 } else {
226 name.clone()
227 };
228 self.nodes.push(Node::filter_with_id(&id, &name));
229 &self.nodes.last().unwrap().id
230 }
231
232 pub fn add_edge(&mut self, edge: Edge) {
233 self.edges.push(edge);
234 }
235
236 pub fn connect(&mut self, source: impl Into<String>, target: impl Into<String>) {
238 let id = format!("e_{}", self.edges.len());
239 self.edges.push(Edge::data(id, source, target));
240 }
241
242 pub fn node(&self, id: &str) -> Option<&Node> {
244 self.nodes.iter().find(|n| n.id == id)
245 }
246
247 pub fn node_ids(&self) -> Vec<&str> {
249 self.nodes.iter().map(|n| n.id.as_str()).collect()
250 }
251
252 pub fn predecessors(&self, node_id: &str) -> Vec<&str> {
254 self.edges
255 .iter()
256 .filter(|e| e.target == node_id)
257 .map(|e| e.source.as_str())
258 .collect()
259 }
260
261 pub fn successors(&self, node_id: &str) -> Vec<&str> {
263 self.edges
264 .iter()
265 .filter(|e| e.source == node_id)
266 .map(|e| e.target.as_str())
267 .collect()
268 }
269
270 pub fn roots(&self) -> Vec<&str> {
272 let has_incoming: HashSet<&str> = self.edges.iter().map(|e| e.target.as_str()).collect();
273 self.nodes
274 .iter()
275 .filter(|n| !has_incoming.contains(n.id.as_str()))
276 .map(|n| n.id.as_str())
277 .collect()
278 }
279
280 pub fn leaves(&self) -> Vec<&str> {
282 let has_outgoing: HashSet<&str> = self.edges.iter().map(|e| e.source.as_str()).collect();
283 self.nodes
284 .iter()
285 .filter(|n| !has_outgoing.contains(n.id.as_str()))
286 .map(|n| n.id.as_str())
287 .collect()
288 }
289
290 fn in_degrees(&self) -> HashMap<&str, usize> {
292 let mut degrees: HashMap<&str, usize> =
293 self.nodes.iter().map(|n| (n.id.as_str(), 0)).collect();
294 for edge in &self.edges {
295 *degrees.entry(edge.target.as_str()).or_insert(0) += 1;
296 }
297 degrees
298 }
299
300 pub fn topological_sort(&self) -> Result<Vec<&str>> {
303 let mut in_deg = self.in_degrees();
304 let mut queue: Vec<&str> = in_deg
305 .iter()
306 .filter(|(_, deg)| **deg == 0)
307 .map(|(&id, _)| id)
308 .collect();
309 queue.sort(); let mut sorted = Vec::with_capacity(self.nodes.len());
312
313 while let Some(node) = queue.pop() {
314 sorted.push(node);
315 let mut next = Vec::new();
316 for succ in self.successors(node) {
317 if let Some(deg) = in_deg.get_mut(succ) {
318 *deg -= 1;
319 if *deg == 0 {
320 next.push(succ);
321 }
322 }
323 }
324 next.sort();
325 for n in next.into_iter().rev() {
327 queue.push(n);
328 }
329 }
330
331 if sorted.len() != self.nodes.len() {
332 return Err(SomaError::CycleDetected);
333 }
334
335 Ok(sorted)
336 }
337
338 pub fn validate(&self) -> Result<()> {
340 let mut seen = HashSet::new();
342 for node in &self.nodes {
343 if !seen.insert(&node.id) {
344 return Err(SomaError::Compilation(format!(
345 "duplicate node id: `{}`",
346 node.id
347 )));
348 }
349 }
350
351 let node_ids: HashSet<&str> = self.nodes.iter().map(|n| n.id.as_str()).collect();
353 for edge in &self.edges {
354 if !node_ids.contains(edge.source.as_str()) {
355 return Err(SomaError::NodeNotFound(edge.source.clone()));
356 }
357 if !node_ids.contains(edge.target.as_str()) {
358 return Err(SomaError::NodeNotFound(edge.target.clone()));
359 }
360 }
361
362 self.topological_sort()?;
364
365 for node in &self.nodes {
367 if let NodeKind::SubGraph { graph } = &node.kind {
368 graph.validate()?;
369 }
370 }
371
372 Ok(())
373 }
374}
375
376impl Graph {
379 pub fn to_mermaid(&self) -> String {
388 use std::fmt::Write;
389 let mut out = String::from("graph LR\n");
390 for node in &self.nodes {
391 let shape = match &node.kind {
392 NodeKind::Filter { .. } => format!(" {}[{}]", node.id, node.label),
393 NodeKind::SubGraph { .. } => format!(" {}[[{}]]", node.id, node.label),
394 NodeKind::Loop { max_iterations } => {
395 let label = match max_iterations {
396 Some(n) => format!("{} (max {})", node.label, n),
397 None => node.label.clone(),
398 };
399 format!(" {}(({}))", node.id, label)
400 }
401 NodeKind::Branch => format!(" {}{{{{{}}}}}", node.id, node.label),
402 };
403 let _ = writeln!(out, "{shape}");
404 }
405 for edge in &self.edges {
406 let arrow = match edge.kind {
407 EdgeKind::Data => "-->",
408 EdgeKind::Control => "-.->",
409 };
410 if let Some(label) = &edge.label {
411 let _ = writeln!(
412 out,
413 " {} {}|{}| {}",
414 edge.source, arrow, label, edge.target
415 );
416 } else {
417 let _ = writeln!(out, " {} {} {}", edge.source, arrow, edge.target);
418 }
419 }
420 out
421 }
422
423 pub fn to_graphviz(&self) -> String {
425 use std::fmt::Write;
426 let mut out = String::from("digraph G {\n rankdir=LR;\n");
427 for node in &self.nodes {
428 let shape = match &node.kind {
429 NodeKind::Filter { .. } => "box",
430 NodeKind::SubGraph { .. } => "doubleoctagon",
431 NodeKind::Loop { .. } => "ellipse",
432 NodeKind::Branch => "diamond",
433 };
434 let _ = writeln!(
435 out,
436 " \"{}\" [label=\"{}\" shape={}];",
437 node.id, node.label, shape
438 );
439 }
440 for edge in &self.edges {
441 let style = match edge.kind {
442 EdgeKind::Data => "",
443 EdgeKind::Control => " [style=dashed]",
444 };
445 let label = edge
446 .label
447 .as_ref()
448 .map(|l| format!(" [label=\"{l}\"]"))
449 .unwrap_or_default();
450 let attrs = if style.is_empty() && label.is_empty() {
451 String::new()
452 } else if label.is_empty() {
453 style.to_string()
454 } else {
455 label
456 };
457 let _ = writeln!(
458 out,
459 " \"{}\" -> \"{}\"{};",
460 edge.source, edge.target, attrs
461 );
462 }
463 out.push_str("}\n");
464 out
465 }
466
467 pub fn to_text(&self) -> String {
469 use std::fmt::Write;
470 let mut out = String::new();
471 let sorted = self.topological_sort().unwrap_or_default();
472 let total_nodes = self.nodes.len();
473 let total_edges = self.edges.len();
474 let _ = writeln!(out, "Graph ({total_nodes} nodes, {total_edges} edges)");
475
476 for (i, node_id) in sorted.iter().enumerate() {
477 let node = match self.node(node_id) {
478 Some(n) => n,
479 None => continue,
480 };
481 let is_last = i == sorted.len() - 1;
482 let prefix = if is_last { "└── " } else { "├── " };
483 let kind_tag = match &node.kind {
484 NodeKind::Filter { filter_name } => {
485 if filter_name == &node.id {
486 String::new()
487 } else {
488 format!(" ({})", filter_name)
489 }
490 }
491 NodeKind::SubGraph { graph } => {
492 format!(" [subgraph: {} nodes]", graph.nodes.len())
493 }
494 NodeKind::Loop { max_iterations } => match max_iterations {
495 Some(n) => format!(" [loop max={n}]"),
496 None => " [loop]".into(),
497 },
498 NodeKind::Branch => " [branch]".into(),
499 };
500 let preds = self.predecessors(node_id);
501 let pred_info = if preds.is_empty() {
502 String::new()
503 } else {
504 format!(" ← {}", preds.join(", "))
505 };
506 let _ = writeln!(out, "{prefix}{}{kind_tag}{pred_info}", node.id);
507 }
508 out
509 }
510}
511
512impl std::fmt::Display for Graph {
513 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
514 write!(f, "{}", self.to_text())
515 }
516}
517
518impl Default for Graph {
519 fn default() -> Self {
520 Self::new()
521 }
522}
523
524pub fn linear_pipeline(nodes: Vec<Node>) -> Graph {
526 let mut graph = Graph::new();
527 for (i, node) in nodes.iter().enumerate() {
528 graph.add_node(node.clone());
529 if i > 0 {
530 graph.add_edge(Edge::data(format!("e_{}", i), &nodes[i - 1].id, &node.id));
531 }
532 }
533 graph
534}
535
536#[cfg(test)]
537mod tests {
538 use super::*;
539
540 fn sample_linear_graph() -> Graph {
541 linear_pipeline(vec![
542 Node::new("a", "Scaler", "StandardScaler"),
543 Node::new("b", "PCA", "PCA"),
544 Node::new("c", "SVM", "SVM"),
545 ])
546 }
547
548 #[test]
549 fn linear_pipeline_structure() {
550 let g = sample_linear_graph();
551 assert_eq!(g.nodes.len(), 3);
552 assert_eq!(g.edges.len(), 2);
553 }
554
555 #[test]
556 fn roots_and_leaves() {
557 let g = sample_linear_graph();
558 assert_eq!(g.roots(), vec!["a"]);
559 assert_eq!(g.leaves(), vec!["c"]);
560 }
561
562 #[test]
563 fn predecessors_and_successors() {
564 let g = sample_linear_graph();
565 assert!(g.predecessors("a").is_empty());
566 assert_eq!(g.predecessors("b"), vec!["a"]);
567 assert_eq!(g.successors("a"), vec!["b"]);
568 assert_eq!(g.successors("b"), vec!["c"]);
569 assert!(g.successors("c").is_empty());
570 }
571
572 #[test]
573 fn topological_sort_linear() {
574 let g = sample_linear_graph();
575 let sorted = g.topological_sort().unwrap();
576 assert_eq!(sorted, vec!["a", "b", "c"]);
577 }
578
579 #[test]
580 fn topological_sort_parallel() {
581 let mut g = Graph::new();
582 g.add_node(Node::new("root", "Root", "Input"));
583 g.add_node(Node::new("b1", "Branch1", "F1"));
584 g.add_node(Node::new("b2", "Branch2", "F2"));
585 g.add_node(Node::new("merge", "Merge", "Merge"));
586 g.add_edge(Edge::data("e1", "root", "b1"));
587 g.add_edge(Edge::data("e2", "root", "b2"));
588 g.add_edge(Edge::data("e3", "b1", "merge"));
589 g.add_edge(Edge::data("e4", "b2", "merge"));
590
591 let sorted = g.topological_sort().unwrap();
592 assert_eq!(sorted[0], "root");
594 assert_eq!(sorted[3], "merge");
595 let middle: HashSet<&str> = sorted[1..3].iter().copied().collect();
597 assert!(middle.contains("b1"));
598 assert!(middle.contains("b2"));
599 }
600
601 #[test]
602 fn topological_sort_detects_cycle() {
603 let mut g = Graph::new();
604 g.add_node(Node::new("a", "A", "F"));
605 g.add_node(Node::new("b", "B", "F"));
606 g.add_edge(Edge::data("e1", "a", "b"));
607 g.add_edge(Edge::data("e2", "b", "a")); let result = g.topological_sort();
610 assert!(matches!(result, Err(SomaError::CycleDetected)));
611 }
612
613 #[test]
614 fn validate_accepts_valid_graph() {
615 let g = sample_linear_graph();
616 assert!(g.validate().is_ok());
617 }
618
619 #[test]
620 fn validate_rejects_duplicate_ids() {
621 let mut g = Graph::new();
622 g.add_node(Node::new("a", "A", "F"));
623 g.add_node(Node::new("a", "A2", "F"));
624 assert!(matches!(g.validate(), Err(SomaError::Compilation(_))));
625 }
626
627 #[test]
628 fn validate_rejects_missing_edge_target() {
629 let mut g = Graph::new();
630 g.add_node(Node::new("a", "A", "F"));
631 g.add_edge(Edge::data("e1", "a", "nonexistent"));
632 assert!(matches!(g.validate(), Err(SomaError::NodeNotFound(_))));
633 }
634
635 #[test]
636 fn graph_serde_roundtrip() {
637 let g = sample_linear_graph();
638 let json = serde_json::to_string(&g).unwrap();
639 let deserialized: Graph = serde_json::from_str(&json).unwrap();
640 assert_eq!(deserialized.nodes.len(), 3);
641 assert_eq!(deserialized.edges.len(), 2);
642 }
643
644 #[test]
645 fn empty_graph_is_valid() {
646 let g = Graph::new();
647 assert!(g.validate().is_ok());
648 assert!(g.topological_sort().unwrap().is_empty());
649 }
650
651 #[test]
652 fn single_node_graph() {
653 let mut g = Graph::new();
654 g.add_node(Node::new("solo", "Solo", "F"));
655 assert_eq!(g.roots(), vec!["solo"]);
656 assert_eq!(g.leaves(), vec!["solo"]);
657 assert_eq!(g.topological_sort().unwrap(), vec!["solo"]);
658 }
659
660 #[test]
663 fn node_filter_shorthand() {
664 let n = Node::filter("StandardScaler");
665 assert_eq!(n.id, "StandardScaler");
666 assert_eq!(n.filter_name(), Some("StandardScaler"));
667 }
668
669 #[test]
670 fn node_filter_with_id() {
671 let n = Node::filter_with_id("my_scaler", "StandardScaler");
672 assert_eq!(n.id, "my_scaler");
673 assert_eq!(n.filter_name(), Some("StandardScaler"));
674 }
675
676 #[test]
677 fn graph_add_filter_auto_names() {
678 let mut g = Graph::new();
679 g.add_filter("Scaler");
680 g.add_filter("PCA");
681 g.connect("Scaler", "PCA");
682
683 assert!(g.validate().is_ok());
684 assert_eq!(g.nodes.len(), 2);
685 assert_eq!(g.nodes[0].id, "Scaler");
686 assert_eq!(g.nodes[1].id, "PCA");
687 }
688
689 #[test]
690 fn graph_add_filter_deduplicates() {
691 let mut g = Graph::new();
692 g.add_filter("Scaler");
693 g.add_filter("Scaler"); assert_eq!(g.nodes.len(), 2);
696 assert_eq!(g.nodes[0].id, "Scaler");
697 assert_eq!(g.nodes[1].id, "Scaler_2");
698 }
699
700 #[test]
701 fn subgraph_node() {
702 let inner = linear_pipeline(vec![Node::new("a", "A", "F"), Node::new("b", "B", "F")]);
703
704 let mut outer = Graph::new();
705 outer.add_node(Node::new("input", "Input", "Input"));
706 outer.add_node(Node::subgraph("pipeline", inner));
707 outer.add_node(Node::new("output", "Output", "Output"));
708 outer.add_edge(Edge::data("e1", "input", "pipeline"));
709 outer.add_edge(Edge::data("e2", "pipeline", "output"));
710
711 assert!(outer.validate().is_ok());
712 assert_eq!(outer.nodes.len(), 3);
713
714 assert!(outer.node("pipeline").unwrap().filter_name().is_none());
716 }
717
718 #[test]
719 fn loop_and_branch_nodes() {
720 let mut g = Graph::new();
721 g.add_node(Node::loop_node("train_loop", Some(100)));
722 g.add_node(Node::branch("check_convergence"));
723 g.add_edge(Edge::data("e1", "train_loop", "check_convergence"));
724
725 assert!(g.validate().is_ok());
726 assert!(matches!(
727 g.node("train_loop").unwrap().kind,
728 NodeKind::Loop {
729 max_iterations: Some(100)
730 }
731 ));
732 assert!(matches!(
733 g.node("check_convergence").unwrap().kind,
734 NodeKind::Branch
735 ));
736 }
737
738 #[test]
741 fn to_mermaid_linear() {
742 let g = sample_linear_graph();
743 let m = g.to_mermaid();
744 assert!(m.starts_with("graph LR"));
745 assert!(m.contains("a[Scaler]"));
746 assert!(m.contains("b[PCA]"));
747 assert!(m.contains("c[SVM]"));
748 assert!(m.contains("a --> b"));
749 assert!(m.contains("b --> c"));
750 }
751
752 #[test]
753 fn to_mermaid_branch_and_loop() {
754 let mut g = Graph::new();
755 g.add_node(Node::loop_node("train", Some(100)));
756 g.add_node(Node::branch("check"));
757 g.add_edge(Edge::data("e1", "train", "check"));
758
759 let m = g.to_mermaid();
760 assert!(m.contains("train((train (max 100)))"));
761 assert!(m.contains("check{"));
762 assert!(m.contains("train --> check"));
763 }
764
765 #[test]
766 fn to_graphviz_output() {
767 let g = sample_linear_graph();
768 let dot = g.to_graphviz();
769 assert!(dot.starts_with("digraph G {"));
770 assert!(dot.contains("rankdir=LR"));
771 assert!(dot.contains("\"a\" [label=\"Scaler\" shape=box]"));
772 assert!(dot.contains("\"a\" -> \"b\""));
773 assert!(dot.ends_with("}\n"));
774 }
775
776 #[test]
777 fn to_text_output() {
778 let g = sample_linear_graph();
779 let text = g.to_text();
780 assert!(text.contains("Graph (3 nodes, 2 edges)"));
781 assert!(text.contains("a"));
782 assert!(text.contains("b"));
783 assert!(text.contains("c"));
784 assert!(text.contains("← a"));
785 }
786
787 #[test]
788 fn display_trait() {
789 let g = sample_linear_graph();
790 let s = format!("{g}");
791 assert!(s.contains("Graph (3 nodes"));
792 }
793
794 #[test]
795 fn node_kind_serde_roundtrip() {
796 let inner = linear_pipeline(vec![Node::new("x", "X", "F")]);
797 let nodes = vec![
798 Node::filter("Scaler"),
799 Node::subgraph("sub", inner),
800 Node::loop_node("loop", Some(50)),
801 Node::branch("cond"),
802 ];
803
804 for node in &nodes {
805 let json = serde_json::to_string(node).unwrap();
806 let parsed: Node = serde_json::from_str(&json).unwrap();
807 assert_eq!(parsed.id, node.id);
808 }
809 }
810}