Skip to main content

somatize_core/
graph.rs

1//! Computational graph — DAG of filter nodes connected by edges.
2//!
3//! The graph is the user-facing representation of a pipeline topology.
4//! It gets compiled into an [`ExecutionPlan`] by the compiler.
5
6use crate::error::{Result, SomaError};
7use crate::strategy::TrainingStrategy;
8use serde::{Deserialize, Serialize};
9use std::collections::{HashMap, HashSet};
10
11/// Unique identifier for a node in a graph.
12///
13/// Currently a type alias. Will be promoted to a newtype in a future version
14/// for stronger type safety (tracked in architecture-review.md).
15pub type NodeId = String;
16
17/// Unique identifier for an edge in a graph.
18pub type EdgeId = String;
19
20/// What kind of computation a node represents.
21#[derive(Debug, Clone, Serialize, Deserialize)]
22#[serde(tag = "type")]
23#[non_exhaustive]
24pub enum NodeKind {
25    /// A single filter (the common case).
26    Filter { filter_name: String },
27    /// A nested sub-graph (compiled recursively).
28    SubGraph { graph: Box<Graph> },
29    /// A loop node — body is the sub-graph of successors.
30    Loop { max_iterations: Option<usize> },
31    /// A branch/conditional node — arms determined by control edges.
32    Branch,
33}
34
35/// A node in the computational graph.
36#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct Node {
38    pub id: NodeId,
39    pub label: String,
40    pub kind: NodeKind,
41}
42
43impl Node {
44    /// Create a filter node (backward-compatible with old 3-arg constructor).
45    pub fn new(
46        id: impl Into<String>,
47        label: impl Into<String>,
48        filter_name: impl Into<String>,
49    ) -> Self {
50        Self {
51            id: id.into(),
52            label: label.into(),
53            kind: NodeKind::Filter {
54                filter_name: filter_name.into(),
55            },
56        }
57    }
58
59    /// Create a filter node with explicit id and filter_name.
60    pub fn filter_with_id(id: impl Into<String>, filter_name: impl Into<String>) -> Self {
61        let id = id.into();
62        Self {
63            label: id.clone(),
64            id,
65            kind: NodeKind::Filter {
66                filter_name: filter_name.into(),
67            },
68        }
69    }
70
71    /// Create a filter node where id defaults to filter_name.
72    pub fn filter(filter_name: impl Into<String>) -> Self {
73        let name = filter_name.into();
74        Self {
75            id: name.clone(),
76            label: name.clone(),
77            kind: NodeKind::Filter { filter_name: name },
78        }
79    }
80
81    /// Create a sub-graph node.
82    pub fn subgraph(id: impl Into<String>, graph: Graph) -> Self {
83        let id = id.into();
84        Self {
85            id: id.clone(),
86            label: id,
87            kind: NodeKind::SubGraph {
88                graph: Box::new(graph),
89            },
90        }
91    }
92
93    /// Create a loop node.
94    pub fn loop_node(id: impl Into<String>, max_iterations: Option<usize>) -> Self {
95        let id = id.into();
96        Self {
97            id: id.clone(),
98            label: id,
99            kind: NodeKind::Loop { max_iterations },
100        }
101    }
102
103    /// Create a branch/conditional node.
104    pub fn branch(id: impl Into<String>) -> Self {
105        let id = id.into();
106        Self {
107            id: id.clone(),
108            label: id,
109            kind: NodeKind::Branch,
110        }
111    }
112
113    /// Get the filter name if this is a Filter node.
114    pub fn filter_name(&self) -> Option<&str> {
115        match &self.kind {
116            NodeKind::Filter { filter_name } => Some(filter_name),
117            _ => None,
118        }
119    }
120}
121
122/// Type of connection between nodes.
123#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
124pub enum EdgeKind {
125    /// Normal data flow: output of source becomes input of target.
126    Data,
127    /// Control flow edge (for conditional/loop logic).
128    Control,
129}
130
131/// A directed edge connecting two nodes.
132#[derive(Debug, Clone, Serialize, Deserialize)]
133pub struct Edge {
134    pub id: EdgeId,
135    pub source: NodeId,
136    pub target: NodeId,
137    pub kind: EdgeKind,
138    pub label: Option<String>,
139}
140
141impl Edge {
142    pub fn data(
143        id: impl Into<String>,
144        source: impl Into<String>,
145        target: impl Into<String>,
146    ) -> Self {
147        Self {
148            id: id.into(),
149            source: source.into(),
150            target: target.into(),
151            kind: EdgeKind::Data,
152            label: None,
153        }
154    }
155
156    pub fn control(
157        id: impl Into<String>,
158        source: impl Into<String>,
159        target: impl Into<String>,
160    ) -> Self {
161        Self {
162            id: id.into(),
163            source: source.into(),
164            target: target.into(),
165            kind: EdgeKind::Control,
166            label: None,
167        }
168    }
169}
170
171/// A directed graph of computational nodes.
172#[derive(Debug, Clone, Serialize, Deserialize)]
173pub struct Graph {
174    pub nodes: Vec<Node>,
175    pub edges: Vec<Edge>,
176    /// Training strategy for distributed execution.
177    /// Inherited by subgraphs unless overridden.
178    #[serde(default, skip_serializing_if = "Option::is_none")]
179    pub training_strategy: Option<TrainingStrategy>,
180}
181
182impl Graph {
183    pub fn new() -> Self {
184        Self {
185            nodes: Vec::new(),
186            edges: Vec::new(),
187            training_strategy: None,
188        }
189    }
190
191    /// Set the training strategy for this graph.
192    pub fn with_strategy(mut self, strategy: TrainingStrategy) -> Self {
193        self.training_strategy = Some(strategy);
194        self
195    }
196
197    /// Set the training strategy (mutable).
198    pub fn set_strategy(&mut self, strategy: TrainingStrategy) {
199        self.training_strategy = Some(strategy);
200    }
201
202    /// Get the effective training strategy (defaults to Local).
203    pub fn effective_strategy(&self) -> &TrainingStrategy {
204        static LOCAL: TrainingStrategy = TrainingStrategy::Local;
205        self.training_strategy.as_ref().unwrap_or(&LOCAL)
206    }
207
208    pub fn add_node(&mut self, node: Node) {
209        self.nodes.push(node);
210    }
211
212    /// Add a filter node using the filter name as the node id.
213    /// If a node with that name already exists, appends a suffix.
214    pub fn add_filter(&mut self, filter_name: impl Into<String>) -> &str {
215        let name = filter_name.into();
216        let id = if self.nodes.iter().any(|n| n.id == name) {
217            let mut i = 2;
218            loop {
219                let candidate = format!("{name}_{i}");
220                if !self.nodes.iter().any(|n| n.id == candidate) {
221                    break candidate;
222                }
223                i += 1;
224            }
225        } else {
226            name.clone()
227        };
228        self.nodes.push(Node::filter_with_id(&id, &name));
229        &self.nodes.last().unwrap().id
230    }
231
232    pub fn add_edge(&mut self, edge: Edge) {
233        self.edges.push(edge);
234    }
235
236    /// Connect two nodes with a data edge (auto-generates edge id).
237    pub fn connect(&mut self, source: impl Into<String>, target: impl Into<String>) {
238        let id = format!("e_{}", self.edges.len());
239        self.edges.push(Edge::data(id, source, target));
240    }
241
242    /// Get a node by its ID.
243    pub fn node(&self, id: &str) -> Option<&Node> {
244        self.nodes.iter().find(|n| n.id == id)
245    }
246
247    /// Get all node IDs.
248    pub fn node_ids(&self) -> Vec<&str> {
249        self.nodes.iter().map(|n| n.id.as_str()).collect()
250    }
251
252    /// Get predecessors of a node (nodes with edges pointing to it).
253    pub fn predecessors(&self, node_id: &str) -> Vec<&str> {
254        self.edges
255            .iter()
256            .filter(|e| e.target == node_id)
257            .map(|e| e.source.as_str())
258            .collect()
259    }
260
261    /// Get successors of a node (nodes it points to).
262    pub fn successors(&self, node_id: &str) -> Vec<&str> {
263        self.edges
264            .iter()
265            .filter(|e| e.source == node_id)
266            .map(|e| e.target.as_str())
267            .collect()
268    }
269
270    /// Find root nodes (no incoming edges).
271    pub fn roots(&self) -> Vec<&str> {
272        let has_incoming: HashSet<&str> = self.edges.iter().map(|e| e.target.as_str()).collect();
273        self.nodes
274            .iter()
275            .filter(|n| !has_incoming.contains(n.id.as_str()))
276            .map(|n| n.id.as_str())
277            .collect()
278    }
279
280    /// Find leaf nodes (no outgoing edges).
281    pub fn leaves(&self) -> Vec<&str> {
282        let has_outgoing: HashSet<&str> = self.edges.iter().map(|e| e.source.as_str()).collect();
283        self.nodes
284            .iter()
285            .filter(|n| !has_outgoing.contains(n.id.as_str()))
286            .map(|n| n.id.as_str())
287            .collect()
288    }
289
290    /// Compute in-degree for each node.
291    fn in_degrees(&self) -> HashMap<&str, usize> {
292        let mut degrees: HashMap<&str, usize> =
293            self.nodes.iter().map(|n| (n.id.as_str(), 0)).collect();
294        for edge in &self.edges {
295            *degrees.entry(edge.target.as_str()).or_insert(0) += 1;
296        }
297        degrees
298    }
299
300    /// Topological sort using Kahn's algorithm.
301    /// Returns Err if the graph contains a cycle.
302    pub fn topological_sort(&self) -> Result<Vec<&str>> {
303        let mut in_deg = self.in_degrees();
304        let mut queue: Vec<&str> = in_deg
305            .iter()
306            .filter(|(_, deg)| **deg == 0)
307            .map(|(&id, _)| id)
308            .collect();
309        queue.sort(); // deterministic order
310
311        let mut sorted = Vec::with_capacity(self.nodes.len());
312
313        while let Some(node) = queue.pop() {
314            sorted.push(node);
315            let mut next = Vec::new();
316            for succ in self.successors(node) {
317                if let Some(deg) = in_deg.get_mut(succ) {
318                    *deg -= 1;
319                    if *deg == 0 {
320                        next.push(succ);
321                    }
322                }
323            }
324            next.sort();
325            // Insert at beginning so we process in deterministic order
326            for n in next.into_iter().rev() {
327                queue.push(n);
328            }
329        }
330
331        if sorted.len() != self.nodes.len() {
332            return Err(SomaError::CycleDetected);
333        }
334
335        Ok(sorted)
336    }
337
338    /// Validate the graph structure (recursively validates sub-graphs).
339    pub fn validate(&self) -> Result<()> {
340        // Check for duplicate node IDs
341        let mut seen = HashSet::new();
342        for node in &self.nodes {
343            if !seen.insert(&node.id) {
344                return Err(SomaError::Compilation(format!(
345                    "duplicate node id: `{}`",
346                    node.id
347                )));
348            }
349        }
350
351        // Check that all edge endpoints reference existing nodes
352        let node_ids: HashSet<&str> = self.nodes.iter().map(|n| n.id.as_str()).collect();
353        for edge in &self.edges {
354            if !node_ids.contains(edge.source.as_str()) {
355                return Err(SomaError::NodeNotFound(edge.source.clone()));
356            }
357            if !node_ids.contains(edge.target.as_str()) {
358                return Err(SomaError::NodeNotFound(edge.target.clone()));
359            }
360        }
361
362        // Check for cycles
363        self.topological_sort()?;
364
365        // Recursively validate sub-graphs
366        for node in &self.nodes {
367            if let NodeKind::SubGraph { graph } = &node.kind {
368                graph.validate()?;
369            }
370        }
371
372        Ok(())
373    }
374}
375
376// ── Visualization ──
377
378impl Graph {
379    /// Render as a Mermaid diagram.
380    ///
381    /// ```text
382    /// graph LR
383    ///     scaler[scaler]
384    ///     model[model]
385    ///     scaler --> model
386    /// ```
387    pub fn to_mermaid(&self) -> String {
388        use std::fmt::Write;
389        let mut out = String::from("graph LR\n");
390        for node in &self.nodes {
391            let shape = match &node.kind {
392                NodeKind::Filter { .. } => format!("    {}[{}]", node.id, node.label),
393                NodeKind::SubGraph { .. } => format!("    {}[[{}]]", node.id, node.label),
394                NodeKind::Loop { max_iterations } => {
395                    let label = match max_iterations {
396                        Some(n) => format!("{} (max {})", node.label, n),
397                        None => node.label.clone(),
398                    };
399                    format!("    {}(({}))", node.id, label)
400                }
401                NodeKind::Branch => format!("    {}{{{{{}}}}}", node.id, node.label),
402            };
403            let _ = writeln!(out, "{shape}");
404        }
405        for edge in &self.edges {
406            let arrow = match edge.kind {
407                EdgeKind::Data => "-->",
408                EdgeKind::Control => "-.->",
409            };
410            if let Some(label) = &edge.label {
411                let _ = writeln!(
412                    out,
413                    "    {} {}|{}| {}",
414                    edge.source, arrow, label, edge.target
415                );
416            } else {
417                let _ = writeln!(out, "    {} {} {}", edge.source, arrow, edge.target);
418            }
419        }
420        out
421    }
422
423    /// Render as Graphviz DOT format.
424    pub fn to_graphviz(&self) -> String {
425        use std::fmt::Write;
426        let mut out = String::from("digraph G {\n    rankdir=LR;\n");
427        for node in &self.nodes {
428            let shape = match &node.kind {
429                NodeKind::Filter { .. } => "box",
430                NodeKind::SubGraph { .. } => "doubleoctagon",
431                NodeKind::Loop { .. } => "ellipse",
432                NodeKind::Branch => "diamond",
433            };
434            let _ = writeln!(
435                out,
436                "    \"{}\" [label=\"{}\" shape={}];",
437                node.id, node.label, shape
438            );
439        }
440        for edge in &self.edges {
441            let style = match edge.kind {
442                EdgeKind::Data => "",
443                EdgeKind::Control => " [style=dashed]",
444            };
445            let label = edge
446                .label
447                .as_ref()
448                .map(|l| format!(" [label=\"{l}\"]"))
449                .unwrap_or_default();
450            let attrs = if style.is_empty() && label.is_empty() {
451                String::new()
452            } else if label.is_empty() {
453                style.to_string()
454            } else {
455                label
456            };
457            let _ = writeln!(
458                out,
459                "    \"{}\" -> \"{}\"{};",
460                edge.source, edge.target, attrs
461            );
462        }
463        out.push_str("}\n");
464        out
465    }
466
467    /// Render as an ASCII text tree for terminal display.
468    pub fn to_text(&self) -> String {
469        use std::fmt::Write;
470        let mut out = String::new();
471        let sorted = self.topological_sort().unwrap_or_default();
472        let total_nodes = self.nodes.len();
473        let total_edges = self.edges.len();
474        let _ = writeln!(out, "Graph ({total_nodes} nodes, {total_edges} edges)");
475
476        for (i, node_id) in sorted.iter().enumerate() {
477            let node = match self.node(node_id) {
478                Some(n) => n,
479                None => continue,
480            };
481            let is_last = i == sorted.len() - 1;
482            let prefix = if is_last { "└── " } else { "├── " };
483            let kind_tag = match &node.kind {
484                NodeKind::Filter { filter_name } => {
485                    if filter_name == &node.id {
486                        String::new()
487                    } else {
488                        format!(" ({})", filter_name)
489                    }
490                }
491                NodeKind::SubGraph { graph } => {
492                    format!(" [subgraph: {} nodes]", graph.nodes.len())
493                }
494                NodeKind::Loop { max_iterations } => match max_iterations {
495                    Some(n) => format!(" [loop max={n}]"),
496                    None => " [loop]".into(),
497                },
498                NodeKind::Branch => " [branch]".into(),
499            };
500            let preds = self.predecessors(node_id);
501            let pred_info = if preds.is_empty() {
502                String::new()
503            } else {
504                format!(" ← {}", preds.join(", "))
505            };
506            let _ = writeln!(out, "{prefix}{}{kind_tag}{pred_info}", node.id);
507        }
508        out
509    }
510}
511
512impl std::fmt::Display for Graph {
513    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
514        write!(f, "{}", self.to_text())
515    }
516}
517
518impl Default for Graph {
519    fn default() -> Self {
520        Self::new()
521    }
522}
523
524/// Builder for constructing linear pipelines easily.
525pub fn linear_pipeline(nodes: Vec<Node>) -> Graph {
526    let mut graph = Graph::new();
527    for (i, node) in nodes.iter().enumerate() {
528        graph.add_node(node.clone());
529        if i > 0 {
530            graph.add_edge(Edge::data(format!("e_{}", i), &nodes[i - 1].id, &node.id));
531        }
532    }
533    graph
534}
535
536#[cfg(test)]
537mod tests {
538    use super::*;
539
540    fn sample_linear_graph() -> Graph {
541        linear_pipeline(vec![
542            Node::new("a", "Scaler", "StandardScaler"),
543            Node::new("b", "PCA", "PCA"),
544            Node::new("c", "SVM", "SVM"),
545        ])
546    }
547
548    #[test]
549    fn linear_pipeline_structure() {
550        let g = sample_linear_graph();
551        assert_eq!(g.nodes.len(), 3);
552        assert_eq!(g.edges.len(), 2);
553    }
554
555    #[test]
556    fn roots_and_leaves() {
557        let g = sample_linear_graph();
558        assert_eq!(g.roots(), vec!["a"]);
559        assert_eq!(g.leaves(), vec!["c"]);
560    }
561
562    #[test]
563    fn predecessors_and_successors() {
564        let g = sample_linear_graph();
565        assert!(g.predecessors("a").is_empty());
566        assert_eq!(g.predecessors("b"), vec!["a"]);
567        assert_eq!(g.successors("a"), vec!["b"]);
568        assert_eq!(g.successors("b"), vec!["c"]);
569        assert!(g.successors("c").is_empty());
570    }
571
572    #[test]
573    fn topological_sort_linear() {
574        let g = sample_linear_graph();
575        let sorted = g.topological_sort().unwrap();
576        assert_eq!(sorted, vec!["a", "b", "c"]);
577    }
578
579    #[test]
580    fn topological_sort_parallel() {
581        let mut g = Graph::new();
582        g.add_node(Node::new("root", "Root", "Input"));
583        g.add_node(Node::new("b1", "Branch1", "F1"));
584        g.add_node(Node::new("b2", "Branch2", "F2"));
585        g.add_node(Node::new("merge", "Merge", "Merge"));
586        g.add_edge(Edge::data("e1", "root", "b1"));
587        g.add_edge(Edge::data("e2", "root", "b2"));
588        g.add_edge(Edge::data("e3", "b1", "merge"));
589        g.add_edge(Edge::data("e4", "b2", "merge"));
590
591        let sorted = g.topological_sort().unwrap();
592        // root must be first, merge must be last
593        assert_eq!(sorted[0], "root");
594        assert_eq!(sorted[3], "merge");
595        // b1 and b2 can be in any order between root and merge
596        let middle: HashSet<&str> = sorted[1..3].iter().copied().collect();
597        assert!(middle.contains("b1"));
598        assert!(middle.contains("b2"));
599    }
600
601    #[test]
602    fn topological_sort_detects_cycle() {
603        let mut g = Graph::new();
604        g.add_node(Node::new("a", "A", "F"));
605        g.add_node(Node::new("b", "B", "F"));
606        g.add_edge(Edge::data("e1", "a", "b"));
607        g.add_edge(Edge::data("e2", "b", "a")); // cycle!
608
609        let result = g.topological_sort();
610        assert!(matches!(result, Err(SomaError::CycleDetected)));
611    }
612
613    #[test]
614    fn validate_accepts_valid_graph() {
615        let g = sample_linear_graph();
616        assert!(g.validate().is_ok());
617    }
618
619    #[test]
620    fn validate_rejects_duplicate_ids() {
621        let mut g = Graph::new();
622        g.add_node(Node::new("a", "A", "F"));
623        g.add_node(Node::new("a", "A2", "F"));
624        assert!(matches!(g.validate(), Err(SomaError::Compilation(_))));
625    }
626
627    #[test]
628    fn validate_rejects_missing_edge_target() {
629        let mut g = Graph::new();
630        g.add_node(Node::new("a", "A", "F"));
631        g.add_edge(Edge::data("e1", "a", "nonexistent"));
632        assert!(matches!(g.validate(), Err(SomaError::NodeNotFound(_))));
633    }
634
635    #[test]
636    fn graph_serde_roundtrip() {
637        let g = sample_linear_graph();
638        let json = serde_json::to_string(&g).unwrap();
639        let deserialized: Graph = serde_json::from_str(&json).unwrap();
640        assert_eq!(deserialized.nodes.len(), 3);
641        assert_eq!(deserialized.edges.len(), 2);
642    }
643
644    #[test]
645    fn empty_graph_is_valid() {
646        let g = Graph::new();
647        assert!(g.validate().is_ok());
648        assert!(g.topological_sort().unwrap().is_empty());
649    }
650
651    #[test]
652    fn single_node_graph() {
653        let mut g = Graph::new();
654        g.add_node(Node::new("solo", "Solo", "F"));
655        assert_eq!(g.roots(), vec!["solo"]);
656        assert_eq!(g.leaves(), vec!["solo"]);
657        assert_eq!(g.topological_sort().unwrap(), vec!["solo"]);
658    }
659
660    // ── NodeKind tests ──
661
662    #[test]
663    fn node_filter_shorthand() {
664        let n = Node::filter("StandardScaler");
665        assert_eq!(n.id, "StandardScaler");
666        assert_eq!(n.filter_name(), Some("StandardScaler"));
667    }
668
669    #[test]
670    fn node_filter_with_id() {
671        let n = Node::filter_with_id("my_scaler", "StandardScaler");
672        assert_eq!(n.id, "my_scaler");
673        assert_eq!(n.filter_name(), Some("StandardScaler"));
674    }
675
676    #[test]
677    fn graph_add_filter_auto_names() {
678        let mut g = Graph::new();
679        g.add_filter("Scaler");
680        g.add_filter("PCA");
681        g.connect("Scaler", "PCA");
682
683        assert!(g.validate().is_ok());
684        assert_eq!(g.nodes.len(), 2);
685        assert_eq!(g.nodes[0].id, "Scaler");
686        assert_eq!(g.nodes[1].id, "PCA");
687    }
688
689    #[test]
690    fn graph_add_filter_deduplicates() {
691        let mut g = Graph::new();
692        g.add_filter("Scaler");
693        g.add_filter("Scaler"); // duplicate name → gets suffix
694
695        assert_eq!(g.nodes.len(), 2);
696        assert_eq!(g.nodes[0].id, "Scaler");
697        assert_eq!(g.nodes[1].id, "Scaler_2");
698    }
699
700    #[test]
701    fn subgraph_node() {
702        let inner = linear_pipeline(vec![Node::new("a", "A", "F"), Node::new("b", "B", "F")]);
703
704        let mut outer = Graph::new();
705        outer.add_node(Node::new("input", "Input", "Input"));
706        outer.add_node(Node::subgraph("pipeline", inner));
707        outer.add_node(Node::new("output", "Output", "Output"));
708        outer.add_edge(Edge::data("e1", "input", "pipeline"));
709        outer.add_edge(Edge::data("e2", "pipeline", "output"));
710
711        assert!(outer.validate().is_ok());
712        assert_eq!(outer.nodes.len(), 3);
713
714        // SubGraph node has no filter_name
715        assert!(outer.node("pipeline").unwrap().filter_name().is_none());
716    }
717
718    #[test]
719    fn loop_and_branch_nodes() {
720        let mut g = Graph::new();
721        g.add_node(Node::loop_node("train_loop", Some(100)));
722        g.add_node(Node::branch("check_convergence"));
723        g.add_edge(Edge::data("e1", "train_loop", "check_convergence"));
724
725        assert!(g.validate().is_ok());
726        assert!(matches!(
727            g.node("train_loop").unwrap().kind,
728            NodeKind::Loop {
729                max_iterations: Some(100)
730            }
731        ));
732        assert!(matches!(
733            g.node("check_convergence").unwrap().kind,
734            NodeKind::Branch
735        ));
736    }
737
738    // ── Visualization tests ──
739
740    #[test]
741    fn to_mermaid_linear() {
742        let g = sample_linear_graph();
743        let m = g.to_mermaid();
744        assert!(m.starts_with("graph LR"));
745        assert!(m.contains("a[Scaler]"));
746        assert!(m.contains("b[PCA]"));
747        assert!(m.contains("c[SVM]"));
748        assert!(m.contains("a --> b"));
749        assert!(m.contains("b --> c"));
750    }
751
752    #[test]
753    fn to_mermaid_branch_and_loop() {
754        let mut g = Graph::new();
755        g.add_node(Node::loop_node("train", Some(100)));
756        g.add_node(Node::branch("check"));
757        g.add_edge(Edge::data("e1", "train", "check"));
758
759        let m = g.to_mermaid();
760        assert!(m.contains("train((train (max 100)))"));
761        assert!(m.contains("check{"));
762        assert!(m.contains("train --> check"));
763    }
764
765    #[test]
766    fn to_graphviz_output() {
767        let g = sample_linear_graph();
768        let dot = g.to_graphviz();
769        assert!(dot.starts_with("digraph G {"));
770        assert!(dot.contains("rankdir=LR"));
771        assert!(dot.contains("\"a\" [label=\"Scaler\" shape=box]"));
772        assert!(dot.contains("\"a\" -> \"b\""));
773        assert!(dot.ends_with("}\n"));
774    }
775
776    #[test]
777    fn to_text_output() {
778        let g = sample_linear_graph();
779        let text = g.to_text();
780        assert!(text.contains("Graph (3 nodes, 2 edges)"));
781        assert!(text.contains("a"));
782        assert!(text.contains("b"));
783        assert!(text.contains("c"));
784        assert!(text.contains("← a"));
785    }
786
787    #[test]
788    fn display_trait() {
789        let g = sample_linear_graph();
790        let s = format!("{g}");
791        assert!(s.contains("Graph (3 nodes"));
792    }
793
794    #[test]
795    fn node_kind_serde_roundtrip() {
796        let inner = linear_pipeline(vec![Node::new("x", "X", "F")]);
797        let nodes = vec![
798            Node::filter("Scaler"),
799            Node::subgraph("sub", inner),
800            Node::loop_node("loop", Some(50)),
801            Node::branch("cond"),
802        ];
803
804        for node in &nodes {
805            let json = serde_json::to_string(node).unwrap();
806            let parsed: Node = serde_json::from_str(&json).unwrap();
807            assert_eq!(parsed.id, node.id);
808        }
809    }
810}