Skip to main content

somatize_core/
graph.rs

1//! Computational graph — DAG of filter nodes connected by edges.
2//!
3//! The graph is the user-facing representation of a pipeline topology.
4//! It gets compiled into an [`ExecutionPlan`] by the compiler.
5
6use crate::error::{Result, SomaError};
7use crate::strategy::TrainingStrategy;
8use serde::{Deserialize, Serialize};
9use std::collections::{HashMap, HashSet};
10
11/// Unique identifier for a node in a graph.
12///
13/// Currently a type alias. Will be promoted to a newtype in a future version
14/// for stronger type safety (tracked in architecture-review.md).
15pub type NodeId = String;
16
17/// Unique identifier for an edge in a graph.
18pub type EdgeId = String;
19
20/// What kind of computation a node represents.
21#[derive(Debug, Clone, Serialize, Deserialize)]
22#[serde(tag = "type")]
23#[non_exhaustive]
24pub enum NodeKind {
25    /// A single filter (the common case).
26    Filter { filter_name: String },
27    /// A nested sub-graph (compiled recursively).
28    SubGraph { graph: Box<Graph> },
29    /// A loop node — body is the sub-graph of successors.
30    Loop { max_iterations: Option<usize> },
31    /// A branch/conditional node — arms determined by control edges.
32    Branch,
33}
34
35/// A node in the computational graph.
36#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct Node {
38    pub id: NodeId,
39    pub label: String,
40    pub kind: NodeKind,
41    /// Execution target: "local" (reserved, always local), or a worker tag.
42    /// None means: use default (remote if workers available, else local).
43    #[serde(default, skip_serializing_if = "Option::is_none")]
44    pub target: Option<String>,
45}
46
47impl Node {
48    /// Create a filter node (backward-compatible with old 3-arg constructor).
49    pub fn new(
50        id: impl Into<String>,
51        label: impl Into<String>,
52        filter_name: impl Into<String>,
53    ) -> Self {
54        Self {
55            id: id.into(),
56            label: label.into(),
57            kind: NodeKind::Filter {
58                filter_name: filter_name.into(),
59            },
60            target: None,
61        }
62    }
63
64    /// Create a filter node with explicit id and filter_name.
65    pub fn filter_with_id(id: impl Into<String>, filter_name: impl Into<String>) -> Self {
66        let id = id.into();
67        Self {
68            label: id.clone(),
69            id,
70            kind: NodeKind::Filter {
71                filter_name: filter_name.into(),
72            },
73            target: None,
74        }
75    }
76
77    /// Create a filter node where id defaults to filter_name.
78    pub fn filter(filter_name: impl Into<String>) -> Self {
79        let name = filter_name.into();
80        Self {
81            id: name.clone(),
82            label: name.clone(),
83            kind: NodeKind::Filter { filter_name: name },
84            target: None,
85        }
86    }
87
88    /// Create a sub-graph node.
89    pub fn subgraph(id: impl Into<String>, graph: Graph) -> Self {
90        let id = id.into();
91        Self {
92            id: id.clone(),
93            label: id,
94            kind: NodeKind::SubGraph {
95                graph: Box::new(graph),
96            },
97            target: None,
98        }
99    }
100
101    /// Create a loop node.
102    pub fn loop_node(id: impl Into<String>, max_iterations: Option<usize>) -> Self {
103        let id = id.into();
104        Self {
105            id: id.clone(),
106            label: id,
107            kind: NodeKind::Loop { max_iterations },
108            target: None,
109        }
110    }
111
112    /// Create a branch/conditional node.
113    pub fn branch(id: impl Into<String>) -> Self {
114        let id = id.into();
115        Self {
116            id: id.clone(),
117            label: id,
118            kind: NodeKind::Branch,
119            target: None,
120        }
121    }
122
123    /// Set the execution target for this node.
124    pub fn with_target(mut self, target: impl Into<String>) -> Self {
125        self.target = Some(target.into());
126        self
127    }
128
129    /// Whether this node is forced local.
130    pub fn is_local(&self) -> bool {
131        self.target.as_deref() == Some("local")
132    }
133
134    /// Get the filter name if this is a Filter node.
135    pub fn filter_name(&self) -> Option<&str> {
136        match &self.kind {
137            NodeKind::Filter { filter_name } => Some(filter_name),
138            _ => None,
139        }
140    }
141}
142
143/// Type of connection between nodes.
144#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
145pub enum EdgeKind {
146    /// Normal data flow: output of source becomes input of target.
147    Data,
148    /// Control flow edge (for conditional/loop logic).
149    Control,
150}
151
152/// A directed edge connecting two nodes.
153#[derive(Debug, Clone, Serialize, Deserialize)]
154pub struct Edge {
155    pub id: EdgeId,
156    pub source: NodeId,
157    pub target: NodeId,
158    pub kind: EdgeKind,
159    pub label: Option<String>,
160}
161
162impl Edge {
163    pub fn data(
164        id: impl Into<String>,
165        source: impl Into<String>,
166        target: impl Into<String>,
167    ) -> Self {
168        Self {
169            id: id.into(),
170            source: source.into(),
171            target: target.into(),
172            kind: EdgeKind::Data,
173            label: None,
174        }
175    }
176
177    pub fn control(
178        id: impl Into<String>,
179        source: impl Into<String>,
180        target: impl Into<String>,
181    ) -> Self {
182        Self {
183            id: id.into(),
184            source: source.into(),
185            target: target.into(),
186            kind: EdgeKind::Control,
187            label: None,
188        }
189    }
190}
191
192/// A directed graph of computational nodes.
193#[derive(Debug, Clone, Serialize, Deserialize)]
194pub struct Graph {
195    pub nodes: Vec<Node>,
196    pub edges: Vec<Edge>,
197    /// Training strategy for distributed execution.
198    /// Inherited by subgraphs unless overridden.
199    #[serde(default, skip_serializing_if = "Option::is_none")]
200    pub training_strategy: Option<TrainingStrategy>,
201}
202
203impl Graph {
204    pub fn new() -> Self {
205        Self {
206            nodes: Vec::new(),
207            edges: Vec::new(),
208            training_strategy: None,
209        }
210    }
211
212    /// Set the training strategy for this graph.
213    pub fn with_strategy(mut self, strategy: TrainingStrategy) -> Self {
214        self.training_strategy = Some(strategy);
215        self
216    }
217
218    /// Set the training strategy (mutable).
219    pub fn set_strategy(&mut self, strategy: TrainingStrategy) {
220        self.training_strategy = Some(strategy);
221    }
222
223    /// Get the effective training strategy (defaults to Local).
224    pub fn effective_strategy(&self) -> &TrainingStrategy {
225        static LOCAL: TrainingStrategy = TrainingStrategy::Local;
226        self.training_strategy.as_ref().unwrap_or(&LOCAL)
227    }
228
229    pub fn add_node(&mut self, node: Node) {
230        self.nodes.push(node);
231    }
232
233    /// Add a filter node using the filter name as the node id.
234    /// If a node with that name already exists, appends a suffix.
235    pub fn add_filter(&mut self, filter_name: impl Into<String>) -> &str {
236        let name = filter_name.into();
237        let id = if self.nodes.iter().any(|n| n.id == name) {
238            let mut i = 2;
239            loop {
240                let candidate = format!("{name}_{i}");
241                if !self.nodes.iter().any(|n| n.id == candidate) {
242                    break candidate;
243                }
244                i += 1;
245            }
246        } else {
247            name.clone()
248        };
249        self.nodes.push(Node::filter_with_id(&id, &name));
250        &self.nodes.last().unwrap().id
251    }
252
253    pub fn add_edge(&mut self, edge: Edge) {
254        self.edges.push(edge);
255    }
256
257    /// Connect two nodes with a data edge (auto-generates edge id).
258    pub fn connect(&mut self, source: impl Into<String>, target: impl Into<String>) {
259        let id = format!("e_{}", self.edges.len());
260        self.edges.push(Edge::data(id, source, target));
261    }
262
263    /// Get a node by its ID.
264    pub fn node(&self, id: &str) -> Option<&Node> {
265        self.nodes.iter().find(|n| n.id == id)
266    }
267
268    /// Get all node IDs.
269    pub fn node_ids(&self) -> Vec<&str> {
270        self.nodes.iter().map(|n| n.id.as_str()).collect()
271    }
272
273    /// Get predecessors of a node (nodes with edges pointing to it).
274    pub fn predecessors(&self, node_id: &str) -> Vec<&str> {
275        self.edges
276            .iter()
277            .filter(|e| e.target == node_id)
278            .map(|e| e.source.as_str())
279            .collect()
280    }
281
282    /// Get successors of a node (nodes it points to).
283    pub fn successors(&self, node_id: &str) -> Vec<&str> {
284        self.edges
285            .iter()
286            .filter(|e| e.source == node_id)
287            .map(|e| e.target.as_str())
288            .collect()
289    }
290
291    /// Find root nodes (no incoming edges).
292    pub fn roots(&self) -> Vec<&str> {
293        let has_incoming: HashSet<&str> = self.edges.iter().map(|e| e.target.as_str()).collect();
294        self.nodes
295            .iter()
296            .filter(|n| !has_incoming.contains(n.id.as_str()))
297            .map(|n| n.id.as_str())
298            .collect()
299    }
300
301    /// Find leaf nodes (no outgoing edges).
302    pub fn leaves(&self) -> Vec<&str> {
303        let has_outgoing: HashSet<&str> = self.edges.iter().map(|e| e.source.as_str()).collect();
304        self.nodes
305            .iter()
306            .filter(|n| !has_outgoing.contains(n.id.as_str()))
307            .map(|n| n.id.as_str())
308            .collect()
309    }
310
311    /// Compute in-degree for each node.
312    fn in_degrees(&self) -> HashMap<&str, usize> {
313        let mut degrees: HashMap<&str, usize> =
314            self.nodes.iter().map(|n| (n.id.as_str(), 0)).collect();
315        for edge in &self.edges {
316            *degrees.entry(edge.target.as_str()).or_insert(0) += 1;
317        }
318        degrees
319    }
320
321    /// Topological sort using Kahn's algorithm.
322    /// Returns Err if the graph contains a cycle.
323    pub fn topological_sort(&self) -> Result<Vec<&str>> {
324        let mut in_deg = self.in_degrees();
325        let mut queue: Vec<&str> = in_deg
326            .iter()
327            .filter(|(_, deg)| **deg == 0)
328            .map(|(&id, _)| id)
329            .collect();
330        queue.sort(); // deterministic order
331
332        let mut sorted = Vec::with_capacity(self.nodes.len());
333
334        while let Some(node) = queue.pop() {
335            sorted.push(node);
336            let mut next = Vec::new();
337            for succ in self.successors(node) {
338                if let Some(deg) = in_deg.get_mut(succ) {
339                    *deg -= 1;
340                    if *deg == 0 {
341                        next.push(succ);
342                    }
343                }
344            }
345            next.sort();
346            // Insert at beginning so we process in deterministic order
347            for n in next.into_iter().rev() {
348                queue.push(n);
349            }
350        }
351
352        if sorted.len() != self.nodes.len() {
353            return Err(SomaError::CycleDetected);
354        }
355
356        Ok(sorted)
357    }
358
359    /// Validate the graph structure (recursively validates sub-graphs).
360    pub fn validate(&self) -> Result<()> {
361        // Check for duplicate node IDs
362        let mut seen = HashSet::new();
363        for node in &self.nodes {
364            if !seen.insert(&node.id) {
365                return Err(SomaError::Compilation(format!(
366                    "duplicate node id: `{}`",
367                    node.id
368                )));
369            }
370        }
371
372        // Check that all edge endpoints reference existing nodes
373        let node_ids: HashSet<&str> = self.nodes.iter().map(|n| n.id.as_str()).collect();
374        for edge in &self.edges {
375            if !node_ids.contains(edge.source.as_str()) {
376                return Err(SomaError::NodeNotFound(edge.source.clone()));
377            }
378            if !node_ids.contains(edge.target.as_str()) {
379                return Err(SomaError::NodeNotFound(edge.target.clone()));
380            }
381        }
382
383        // Check for cycles
384        self.topological_sort()?;
385
386        // Recursively validate sub-graphs
387        for node in &self.nodes {
388            if let NodeKind::SubGraph { graph } = &node.kind {
389                graph.validate()?;
390            }
391        }
392
393        Ok(())
394    }
395}
396
397// ── Visualization ──
398
399impl Graph {
400    /// Render as a Mermaid diagram.
401    ///
402    /// ```text
403    /// graph LR
404    ///     scaler[scaler]
405    ///     model[model]
406    ///     scaler --> model
407    /// ```
408    pub fn to_mermaid(&self) -> String {
409        use std::fmt::Write;
410        let mut out = String::from("graph LR\n");
411        for node in &self.nodes {
412            let shape = match &node.kind {
413                NodeKind::Filter { .. } => format!("    {}[{}]", node.id, node.label),
414                NodeKind::SubGraph { .. } => format!("    {}[[{}]]", node.id, node.label),
415                NodeKind::Loop { max_iterations } => {
416                    let label = match max_iterations {
417                        Some(n) => format!("{} (max {})", node.label, n),
418                        None => node.label.clone(),
419                    };
420                    format!("    {}(({}))", node.id, label)
421                }
422                NodeKind::Branch => format!("    {}{{{{{}}}}}", node.id, node.label),
423            };
424            let _ = writeln!(out, "{shape}");
425        }
426        for edge in &self.edges {
427            let arrow = match edge.kind {
428                EdgeKind::Data => "-->",
429                EdgeKind::Control => "-.->",
430            };
431            if let Some(label) = &edge.label {
432                let _ = writeln!(
433                    out,
434                    "    {} {}|{}| {}",
435                    edge.source, arrow, label, edge.target
436                );
437            } else {
438                let _ = writeln!(out, "    {} {} {}", edge.source, arrow, edge.target);
439            }
440        }
441        out
442    }
443
444    /// Render as Graphviz DOT format.
445    pub fn to_graphviz(&self) -> String {
446        use std::fmt::Write;
447        let mut out = String::from("digraph G {\n    rankdir=LR;\n");
448        for node in &self.nodes {
449            let shape = match &node.kind {
450                NodeKind::Filter { .. } => "box",
451                NodeKind::SubGraph { .. } => "doubleoctagon",
452                NodeKind::Loop { .. } => "ellipse",
453                NodeKind::Branch => "diamond",
454            };
455            let _ = writeln!(
456                out,
457                "    \"{}\" [label=\"{}\" shape={}];",
458                node.id, node.label, shape
459            );
460        }
461        for edge in &self.edges {
462            let style = match edge.kind {
463                EdgeKind::Data => "",
464                EdgeKind::Control => " [style=dashed]",
465            };
466            let label = edge
467                .label
468                .as_ref()
469                .map(|l| format!(" [label=\"{l}\"]"))
470                .unwrap_or_default();
471            let attrs = if style.is_empty() && label.is_empty() {
472                String::new()
473            } else if label.is_empty() {
474                style.to_string()
475            } else {
476                label
477            };
478            let _ = writeln!(
479                out,
480                "    \"{}\" -> \"{}\"{};",
481                edge.source, edge.target, attrs
482            );
483        }
484        out.push_str("}\n");
485        out
486    }
487
488    /// Render as an ASCII text tree for terminal display.
489    pub fn to_text(&self) -> String {
490        use std::fmt::Write;
491        let mut out = String::new();
492        let sorted = self.topological_sort().unwrap_or_default();
493        let total_nodes = self.nodes.len();
494        let total_edges = self.edges.len();
495        let _ = writeln!(out, "Graph ({total_nodes} nodes, {total_edges} edges)");
496
497        for (i, node_id) in sorted.iter().enumerate() {
498            let node = match self.node(node_id) {
499                Some(n) => n,
500                None => continue,
501            };
502            let is_last = i == sorted.len() - 1;
503            let prefix = if is_last { "└── " } else { "├── " };
504            let kind_tag = match &node.kind {
505                NodeKind::Filter { filter_name } => {
506                    if filter_name == &node.id {
507                        String::new()
508                    } else {
509                        format!(" ({})", filter_name)
510                    }
511                }
512                NodeKind::SubGraph { graph } => {
513                    format!(" [subgraph: {} nodes]", graph.nodes.len())
514                }
515                NodeKind::Loop { max_iterations } => match max_iterations {
516                    Some(n) => format!(" [loop max={n}]"),
517                    None => " [loop]".into(),
518                },
519                NodeKind::Branch => " [branch]".into(),
520            };
521            let preds = self.predecessors(node_id);
522            let pred_info = if preds.is_empty() {
523                String::new()
524            } else {
525                format!(" ← {}", preds.join(", "))
526            };
527            let _ = writeln!(out, "{prefix}{}{kind_tag}{pred_info}", node.id);
528        }
529        out
530    }
531}
532
533impl std::fmt::Display for Graph {
534    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
535        write!(f, "{}", self.to_text())
536    }
537}
538
539impl Default for Graph {
540    fn default() -> Self {
541        Self::new()
542    }
543}
544
545/// Builder for constructing linear pipelines easily.
546pub fn linear_pipeline(nodes: Vec<Node>) -> Graph {
547    let mut graph = Graph::new();
548    for (i, node) in nodes.iter().enumerate() {
549        graph.add_node(node.clone());
550        if i > 0 {
551            graph.add_edge(Edge::data(format!("e_{}", i), &nodes[i - 1].id, &node.id));
552        }
553    }
554    graph
555}
556
557#[cfg(test)]
558mod tests {
559    use super::*;
560
561    fn sample_linear_graph() -> Graph {
562        linear_pipeline(vec![
563            Node::new("a", "Scaler", "StandardScaler"),
564            Node::new("b", "PCA", "PCA"),
565            Node::new("c", "SVM", "SVM"),
566        ])
567    }
568
569    #[test]
570    fn linear_pipeline_structure() {
571        let g = sample_linear_graph();
572        assert_eq!(g.nodes.len(), 3);
573        assert_eq!(g.edges.len(), 2);
574    }
575
576    #[test]
577    fn roots_and_leaves() {
578        let g = sample_linear_graph();
579        assert_eq!(g.roots(), vec!["a"]);
580        assert_eq!(g.leaves(), vec!["c"]);
581    }
582
583    #[test]
584    fn predecessors_and_successors() {
585        let g = sample_linear_graph();
586        assert!(g.predecessors("a").is_empty());
587        assert_eq!(g.predecessors("b"), vec!["a"]);
588        assert_eq!(g.successors("a"), vec!["b"]);
589        assert_eq!(g.successors("b"), vec!["c"]);
590        assert!(g.successors("c").is_empty());
591    }
592
593    #[test]
594    fn topological_sort_linear() {
595        let g = sample_linear_graph();
596        let sorted = g.topological_sort().unwrap();
597        assert_eq!(sorted, vec!["a", "b", "c"]);
598    }
599
600    #[test]
601    fn topological_sort_parallel() {
602        let mut g = Graph::new();
603        g.add_node(Node::new("root", "Root", "Input"));
604        g.add_node(Node::new("b1", "Branch1", "F1"));
605        g.add_node(Node::new("b2", "Branch2", "F2"));
606        g.add_node(Node::new("merge", "Merge", "Merge"));
607        g.add_edge(Edge::data("e1", "root", "b1"));
608        g.add_edge(Edge::data("e2", "root", "b2"));
609        g.add_edge(Edge::data("e3", "b1", "merge"));
610        g.add_edge(Edge::data("e4", "b2", "merge"));
611
612        let sorted = g.topological_sort().unwrap();
613        // root must be first, merge must be last
614        assert_eq!(sorted[0], "root");
615        assert_eq!(sorted[3], "merge");
616        // b1 and b2 can be in any order between root and merge
617        let middle: HashSet<&str> = sorted[1..3].iter().copied().collect();
618        assert!(middle.contains("b1"));
619        assert!(middle.contains("b2"));
620    }
621
622    #[test]
623    fn topological_sort_detects_cycle() {
624        let mut g = Graph::new();
625        g.add_node(Node::new("a", "A", "F"));
626        g.add_node(Node::new("b", "B", "F"));
627        g.add_edge(Edge::data("e1", "a", "b"));
628        g.add_edge(Edge::data("e2", "b", "a")); // cycle!
629
630        let result = g.topological_sort();
631        assert!(matches!(result, Err(SomaError::CycleDetected)));
632    }
633
634    #[test]
635    fn validate_accepts_valid_graph() {
636        let g = sample_linear_graph();
637        assert!(g.validate().is_ok());
638    }
639
640    #[test]
641    fn validate_rejects_duplicate_ids() {
642        let mut g = Graph::new();
643        g.add_node(Node::new("a", "A", "F"));
644        g.add_node(Node::new("a", "A2", "F"));
645        assert!(matches!(g.validate(), Err(SomaError::Compilation(_))));
646    }
647
648    #[test]
649    fn validate_rejects_missing_edge_target() {
650        let mut g = Graph::new();
651        g.add_node(Node::new("a", "A", "F"));
652        g.add_edge(Edge::data("e1", "a", "nonexistent"));
653        assert!(matches!(g.validate(), Err(SomaError::NodeNotFound(_))));
654    }
655
656    #[test]
657    fn graph_serde_roundtrip() {
658        let g = sample_linear_graph();
659        let json = serde_json::to_string(&g).unwrap();
660        let deserialized: Graph = serde_json::from_str(&json).unwrap();
661        assert_eq!(deserialized.nodes.len(), 3);
662        assert_eq!(deserialized.edges.len(), 2);
663    }
664
665    #[test]
666    fn empty_graph_is_valid() {
667        let g = Graph::new();
668        assert!(g.validate().is_ok());
669        assert!(g.topological_sort().unwrap().is_empty());
670    }
671
672    #[test]
673    fn single_node_graph() {
674        let mut g = Graph::new();
675        g.add_node(Node::new("solo", "Solo", "F"));
676        assert_eq!(g.roots(), vec!["solo"]);
677        assert_eq!(g.leaves(), vec!["solo"]);
678        assert_eq!(g.topological_sort().unwrap(), vec!["solo"]);
679    }
680
681    // ── NodeKind tests ──
682
683    #[test]
684    fn node_filter_shorthand() {
685        let n = Node::filter("StandardScaler");
686        assert_eq!(n.id, "StandardScaler");
687        assert_eq!(n.filter_name(), Some("StandardScaler"));
688    }
689
690    #[test]
691    fn node_filter_with_id() {
692        let n = Node::filter_with_id("my_scaler", "StandardScaler");
693        assert_eq!(n.id, "my_scaler");
694        assert_eq!(n.filter_name(), Some("StandardScaler"));
695    }
696
697    #[test]
698    fn graph_add_filter_auto_names() {
699        let mut g = Graph::new();
700        g.add_filter("Scaler");
701        g.add_filter("PCA");
702        g.connect("Scaler", "PCA");
703
704        assert!(g.validate().is_ok());
705        assert_eq!(g.nodes.len(), 2);
706        assert_eq!(g.nodes[0].id, "Scaler");
707        assert_eq!(g.nodes[1].id, "PCA");
708    }
709
710    #[test]
711    fn graph_add_filter_deduplicates() {
712        let mut g = Graph::new();
713        g.add_filter("Scaler");
714        g.add_filter("Scaler"); // duplicate name → gets suffix
715
716        assert_eq!(g.nodes.len(), 2);
717        assert_eq!(g.nodes[0].id, "Scaler");
718        assert_eq!(g.nodes[1].id, "Scaler_2");
719    }
720
721    #[test]
722    fn subgraph_node() {
723        let inner = linear_pipeline(vec![Node::new("a", "A", "F"), Node::new("b", "B", "F")]);
724
725        let mut outer = Graph::new();
726        outer.add_node(Node::new("input", "Input", "Input"));
727        outer.add_node(Node::subgraph("pipeline", inner));
728        outer.add_node(Node::new("output", "Output", "Output"));
729        outer.add_edge(Edge::data("e1", "input", "pipeline"));
730        outer.add_edge(Edge::data("e2", "pipeline", "output"));
731
732        assert!(outer.validate().is_ok());
733        assert_eq!(outer.nodes.len(), 3);
734
735        // SubGraph node has no filter_name
736        assert!(outer.node("pipeline").unwrap().filter_name().is_none());
737    }
738
739    #[test]
740    fn loop_and_branch_nodes() {
741        let mut g = Graph::new();
742        g.add_node(Node::loop_node("train_loop", Some(100)));
743        g.add_node(Node::branch("check_convergence"));
744        g.add_edge(Edge::data("e1", "train_loop", "check_convergence"));
745
746        assert!(g.validate().is_ok());
747        assert!(matches!(
748            g.node("train_loop").unwrap().kind,
749            NodeKind::Loop {
750                max_iterations: Some(100)
751            }
752        ));
753        assert!(matches!(
754            g.node("check_convergence").unwrap().kind,
755            NodeKind::Branch
756        ));
757    }
758
759    // ── Visualization tests ──
760
761    #[test]
762    fn to_mermaid_linear() {
763        let g = sample_linear_graph();
764        let m = g.to_mermaid();
765        assert!(m.starts_with("graph LR"));
766        assert!(m.contains("a[Scaler]"));
767        assert!(m.contains("b[PCA]"));
768        assert!(m.contains("c[SVM]"));
769        assert!(m.contains("a --> b"));
770        assert!(m.contains("b --> c"));
771    }
772
773    #[test]
774    fn to_mermaid_branch_and_loop() {
775        let mut g = Graph::new();
776        g.add_node(Node::loop_node("train", Some(100)));
777        g.add_node(Node::branch("check"));
778        g.add_edge(Edge::data("e1", "train", "check"));
779
780        let m = g.to_mermaid();
781        assert!(m.contains("train((train (max 100)))"));
782        assert!(m.contains("check{"));
783        assert!(m.contains("train --> check"));
784    }
785
786    #[test]
787    fn to_graphviz_output() {
788        let g = sample_linear_graph();
789        let dot = g.to_graphviz();
790        assert!(dot.starts_with("digraph G {"));
791        assert!(dot.contains("rankdir=LR"));
792        assert!(dot.contains("\"a\" [label=\"Scaler\" shape=box]"));
793        assert!(dot.contains("\"a\" -> \"b\""));
794        assert!(dot.ends_with("}\n"));
795    }
796
797    #[test]
798    fn to_text_output() {
799        let g = sample_linear_graph();
800        let text = g.to_text();
801        assert!(text.contains("Graph (3 nodes, 2 edges)"));
802        assert!(text.contains("a"));
803        assert!(text.contains("b"));
804        assert!(text.contains("c"));
805        assert!(text.contains("← a"));
806    }
807
808    #[test]
809    fn display_trait() {
810        let g = sample_linear_graph();
811        let s = format!("{g}");
812        assert!(s.contains("Graph (3 nodes"));
813    }
814
815    #[test]
816    fn node_kind_serde_roundtrip() {
817        let inner = linear_pipeline(vec![Node::new("x", "X", "F")]);
818        let nodes = vec![
819            Node::filter("Scaler"),
820            Node::subgraph("sub", inner),
821            Node::loop_node("loop", Some(50)),
822            Node::branch("cond"),
823        ];
824
825        for node in &nodes {
826            let json = serde_json::to_string(node).unwrap();
827            let parsed: Node = serde_json::from_str(&json).unwrap();
828            assert_eq!(parsed.id, node.id);
829        }
830    }
831}