runnx/
graph.rs

1//! Computational graph representation
2//!
3//! This module defines the graph structure for ONNX models, including
4//! nodes, edges, and the overall graph representation.
5
6use crate::{
7    error::{OnnxError, Result},
8    operators::OperatorType,
9    tensor::Tensor,
10};
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13
14/// A node in the computational graph
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct Node {
17    /// Unique identifier for the node
18    pub name: String,
19    /// Type of operation this node performs
20    pub op_type: String,
21    /// Input tensor names
22    pub inputs: Vec<String>,
23    /// Output tensor names
24    pub outputs: Vec<String>,
25    /// Node attributes (parameters)
26    pub attributes: HashMap<String, String>,
27}
28
29impl Node {
30    /// Create a new node
31    pub fn new(name: String, op_type: String, inputs: Vec<String>, outputs: Vec<String>) -> Self {
32        Self {
33            name,
34            op_type,
35            inputs,
36            outputs,
37            attributes: HashMap::new(),
38        }
39    }
40
41    /// Add an attribute to the node
42    pub fn add_attribute<K: Into<String>, V: Into<String>>(&mut self, key: K, value: V) {
43        self.attributes.insert(key.into(), value.into());
44    }
45
46    /// Get the operator type as enum
47    pub fn get_operator_type(&self) -> Result<OperatorType> {
48        self.op_type.parse()
49    }
50}
51
52/// Represents the computational graph of an ONNX model
53#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct Graph {
55    /// Graph name
56    pub name: String,
57    /// List of nodes in execution order
58    pub nodes: Vec<Node>,
59    /// Input tensor specifications
60    pub inputs: Vec<TensorSpec>,
61    /// Output tensor specifications
62    pub outputs: Vec<TensorSpec>,
63    /// Initial values for parameters/constants
64    pub initializers: HashMap<String, Tensor>,
65}
66
67/// Tensor specification with name and shape information
68#[derive(Debug, Clone, Serialize, Deserialize)]
69pub struct TensorSpec {
70    /// Name of the tensor
71    pub name: String,
72    /// Shape of the tensor (None for dynamic dimensions)
73    pub dimensions: Vec<Option<usize>>,
74    /// Data type (simplified to f32 for this implementation)
75    pub dtype: String,
76}
77
78impl TensorSpec {
79    /// Create a new tensor specification
80    pub fn new(name: String, dimensions: Vec<Option<usize>>) -> Self {
81        Self {
82            name,
83            dimensions,
84            dtype: "float32".to_string(),
85        }
86    }
87
88    /// Check if the tensor spec matches a given tensor
89    pub fn matches_tensor(&self, tensor: &Tensor) -> bool {
90        let tensor_shape = tensor.shape();
91
92        if self.dimensions.len() != tensor_shape.len() {
93            return false;
94        }
95
96        for (spec_dim, &tensor_dim) in self.dimensions.iter().zip(tensor_shape.iter()) {
97            match spec_dim {
98                Some(expected) => {
99                    if *expected != tensor_dim {
100                        return false;
101                    }
102                }
103                None => {
104                    // Dynamic dimension, any size is acceptable
105                    continue;
106                }
107            }
108        }
109
110        true
111    }
112}
113
114impl Graph {
115    /// Create a new empty graph
116    pub fn new(name: String) -> Self {
117        Self {
118            name,
119            nodes: Vec::new(),
120            inputs: Vec::new(),
121            outputs: Vec::new(),
122            initializers: HashMap::new(),
123        }
124    }
125
126    /// Add a node to the graph
127    pub fn add_node(&mut self, node: Node) {
128        self.nodes.push(node);
129    }
130
131    /// Add an input specification
132    pub fn add_input(&mut self, input_spec: TensorSpec) {
133        self.inputs.push(input_spec);
134    }
135
136    /// Add an output specification
137    pub fn add_output(&mut self, output_spec: TensorSpec) {
138        self.outputs.push(output_spec);
139    }
140
141    /// Add an initializer (constant tensor)
142    pub fn add_initializer(&mut self, name: String, tensor: Tensor) {
143        self.initializers.insert(name, tensor);
144    }
145
146    /// Get input tensor names
147    pub fn input_names(&self) -> Vec<&str> {
148        self.inputs.iter().map(|spec| spec.name.as_str()).collect()
149    }
150
151    /// Get output tensor names
152    pub fn output_names(&self) -> Vec<&str> {
153        self.outputs.iter().map(|spec| spec.name.as_str()).collect()
154    }
155
156    /// Validate the graph structure
157    pub fn validate(&self) -> Result<()> {
158        // Check for duplicate node names
159        let mut node_names = std::collections::HashSet::new();
160        for node in &self.nodes {
161            if !node_names.insert(&node.name) {
162                return Err(OnnxError::graph_validation_error(format!(
163                    "Duplicate node name: {}",
164                    node.name
165                )));
166            }
167        }
168
169        // Check that all node inputs/outputs are valid tensor names
170        let mut available_tensors = std::collections::HashSet::new();
171
172        // Add input tensors
173        for input in &self.inputs {
174            available_tensors.insert(&input.name);
175        }
176
177        // Add initializer tensors
178        for name in self.initializers.keys() {
179            available_tensors.insert(name);
180        }
181
182        // Process nodes in order
183        for node in &self.nodes {
184            // Check that all inputs are available
185            for input_name in &node.inputs {
186                if !available_tensors.contains(input_name) {
187                    return Err(OnnxError::graph_validation_error(format!(
188                        "Node '{}' references unknown input tensor '{}'",
189                        node.name, input_name
190                    )));
191                }
192            }
193
194            // Add outputs to available tensors
195            for output_name in &node.outputs {
196                available_tensors.insert(output_name);
197            }
198
199            // Validate operator type
200            node.get_operator_type().map_err(|e| {
201                OnnxError::graph_validation_error(format!(
202                    "Node '{}' has invalid operator type '{}': {}",
203                    node.name, node.op_type, e
204                ))
205            })?;
206        }
207
208        // Check that all outputs are available
209        for output in &self.outputs {
210            if !available_tensors.contains(&output.name) {
211                return Err(OnnxError::graph_validation_error(format!(
212                    "Graph output '{}' is not produced by any node",
213                    output.name
214                )));
215            }
216        }
217
218        Ok(())
219    }
220
221    /// Perform topological sort to get execution order
222    pub fn topological_sort(&self) -> Result<Vec<usize>> {
223        let n = self.nodes.len();
224        let mut in_degree = vec![0; n];
225        let mut adjacency_list: Vec<Vec<usize>> = vec![vec![]; n];
226
227        // Build adjacency list and in-degree count
228        for (i, node) in self.nodes.iter().enumerate() {
229            for output in &node.outputs {
230                for (j, other_node) in self.nodes.iter().enumerate() {
231                    if i != j && other_node.inputs.contains(output) {
232                        adjacency_list[i].push(j);
233                        in_degree[j] += 1;
234                    }
235                }
236            }
237        }
238
239        // Kahn's algorithm
240        let mut queue: Vec<usize> = (0..n).filter(|&i| in_degree[i] == 0).collect();
241        let mut result = Vec::new();
242
243        while let Some(current) = queue.pop() {
244            result.push(current);
245
246            for &neighbor in &adjacency_list[current] {
247                in_degree[neighbor] -= 1;
248                if in_degree[neighbor] == 0 {
249                    queue.push(neighbor);
250                }
251            }
252        }
253
254        if result.len() != n {
255            return Err(OnnxError::graph_validation_error(
256                "Graph contains cycles".to_string(),
257            ));
258        }
259
260        Ok(result)
261    }
262
263    /// Print the graph structure in a visual ASCII format
264    pub fn print_graph(&self) {
265        // Calculate the width needed for the graph name
266        let title = format!("GRAPH: {}", self.name);
267        let min_width = title.len() + 4; // 2 spaces on each side
268        let box_width = std::cmp::max(min_width, 40); // Minimum width of 40 characters
269
270        // Create the top border
271        let top_border = format!("┌{}┐", "─".repeat(box_width));
272
273        // Create the title line with proper centering
274        let padding = (box_width - title.len()) / 2;
275        let left_padding = " ".repeat(padding);
276        let right_padding = " ".repeat(box_width - title.len() - padding);
277        let title_line = format!("│{left_padding}{title}{right_padding}│");
278
279        // Create the bottom border
280        let bottom_border = format!("└{}┘", "─".repeat(box_width));
281
282        println!("\n{top_border}");
283        println!("{title_line}");
284        println!("{bottom_border}");
285
286        // Print inputs
287        if !self.inputs.is_empty() {
288            println!("\n📥 INPUTS:");
289            for input in &self.inputs {
290                let shape_str = input
291                    .dimensions
292                    .iter()
293                    .map(|d| d.map_or("?".to_string(), |v| v.to_string()))
294                    .collect::<Vec<_>>()
295                    .join(" × ");
296                println!("   ┌─ {} [{}] ({})", input.name, shape_str, input.dtype);
297            }
298        }
299
300        // Print initializers
301        if !self.initializers.is_empty() {
302            println!("\n⚙️  INITIALIZERS:");
303            for (name, tensor) in &self.initializers {
304                let shape_str = tensor
305                    .shape()
306                    .iter()
307                    .map(|&d| d.to_string())
308                    .collect::<Vec<_>>()
309                    .join(" × ");
310                println!("   ┌─ {name} [{shape_str}]");
311            }
312        }
313
314        // Print computation flow
315        if !self.nodes.is_empty() {
316            println!("\n🔄 COMPUTATION FLOW:");
317
318            // Try to get execution order, fall back to original order if there are cycles
319            let execution_order = self.topological_sort().unwrap_or_else(|_| {
320                println!("   ⚠️  Warning: Graph contains cycles, showing original order");
321                (0..self.nodes.len()).collect()
322            });
323
324            for (step, &node_idx) in execution_order.iter().enumerate() {
325                let node = &self.nodes[node_idx];
326
327                // Print step number
328                println!("   │");
329                println!("   ├─ Step {}: {}", step + 1, node.name);
330
331                // Print operation type
332                println!("   │  ┌─ Operation: {}", node.op_type);
333
334                // Print inputs
335                if !node.inputs.is_empty() {
336                    println!("   │  ├─ Inputs:");
337                    for input in &node.inputs {
338                        println!("   │  │  └─ {input}");
339                    }
340                }
341
342                // Print outputs
343                if !node.outputs.is_empty() {
344                    println!("   │  ├─ Outputs:");
345                    for output in &node.outputs {
346                        println!("   │  │  └─ {output}");
347                    }
348                }
349
350                // Print attributes if any
351                if !node.attributes.is_empty() {
352                    println!("   │  └─ Attributes:");
353                    for (key, value) in &node.attributes {
354                        println!("   │     └─ {key}: {value}");
355                    }
356                } else {
357                    println!("   │  └─ (no attributes)");
358                }
359            }
360        }
361
362        // Print outputs
363        if !self.outputs.is_empty() {
364            println!("   │");
365            println!("📤 OUTPUTS:");
366            for output in &self.outputs {
367                let shape_str = output
368                    .dimensions
369                    .iter()
370                    .map(|d| d.map_or("?".to_string(), |v| v.to_string()))
371                    .collect::<Vec<_>>()
372                    .join(" × ");
373                println!("   └─ {} [{}] ({})", output.name, shape_str, output.dtype);
374            }
375        }
376
377        println!("\n📊 STATISTICS:");
378        println!("   ├─ Total nodes: {}", self.nodes.len());
379        println!("   ├─ Input tensors: {}", self.inputs.len());
380        println!("   ├─ Output tensors: {}", self.outputs.len());
381        println!("   └─ Initializers: {}", self.initializers.len());
382
383        // Print operation summary
384        if !self.nodes.is_empty() {
385            let mut op_counts: std::collections::BTreeMap<String, usize> =
386                std::collections::BTreeMap::new();
387            for node in &self.nodes {
388                *op_counts.entry(node.op_type.clone()).or_insert(0) += 1;
389            }
390
391            println!("\n🎯 OPERATION SUMMARY:");
392            for (op_type, count) in op_counts {
393                println!("   ├─ {op_type}: {count}");
394            }
395        }
396
397        println!();
398    }
399
400    /// Generate a simplified DOT format for graph visualization tools
401    pub fn to_dot(&self) -> String {
402        let mut dot = String::new();
403
404        dot.push_str("digraph G {\n");
405        dot.push_str("  rankdir=TB;\n");
406        dot.push_str("  node [shape=box, style=rounded];\n\n");
407
408        // Add input nodes
409        for input in &self.inputs {
410            dot.push_str(&format!(
411                "  \"{}\" [shape=ellipse, color=green, label=\"{}\"];\n",
412                input.name, input.name
413            ));
414        }
415
416        // Add initializer nodes
417        for name in self.initializers.keys() {
418            dot.push_str(&format!(
419                "  \"{name}\" [shape=diamond, color=blue, label=\"{name}\"];\n"
420            ));
421        }
422
423        // Add operation nodes
424        for node in &self.nodes {
425            dot.push_str(&format!(
426                "  \"{}\" [label=\"{}\\n({})\"];\n",
427                node.name, node.name, node.op_type
428            ));
429        }
430
431        // Add output nodes
432        for output in &self.outputs {
433            dot.push_str(&format!(
434                "  \"{}\" [shape=ellipse, color=red, label=\"{}\"];\n",
435                output.name, output.name
436            ));
437        }
438
439        dot.push('\n');
440
441        // Add edges
442        for node in &self.nodes {
443            for input in &node.inputs {
444                dot.push_str(&format!("  \"{}\" -> \"{}\";\n", input, node.name));
445            }
446            for output in &node.outputs {
447                dot.push_str(&format!("  \"{}\" -> \"{}\";\n", node.name, output));
448            }
449        }
450
451        dot.push_str("}\n");
452        dot
453    }
454
455    /// Create a simple linear graph for testing
456    pub fn create_simple_linear() -> Self {
457        let mut graph = Graph::new("simple_linear".to_string());
458
459        // Add inputs
460        graph.add_input(TensorSpec::new("input".to_string(), vec![Some(1), Some(3)]));
461
462        // Add outputs
463        graph.add_output(TensorSpec::new(
464            "output".to_string(),
465            vec![Some(1), Some(2)],
466        ));
467
468        // Add weight initializer
469        let weights = Tensor::from_shape_vec(&[3, 2], vec![0.5, 0.3, 0.2, 0.4, 0.1, 0.6]).unwrap();
470        let bias = Tensor::from_shape_vec(&[1, 2], vec![0.1, 0.2]).unwrap();
471
472        graph.add_initializer("weights".to_string(), weights);
473        graph.add_initializer("bias".to_string(), bias);
474
475        // Add MatMul node
476        let matmul_node = Node::new(
477            "matmul".to_string(),
478            "MatMul".to_string(),
479            vec!["input".to_string(), "weights".to_string()],
480            vec!["matmul_output".to_string()],
481        );
482        graph.add_node(matmul_node);
483
484        // Add Add node (bias)
485        let add_node = Node::new(
486            "add_bias".to_string(),
487            "Add".to_string(),
488            vec!["matmul_output".to_string(), "bias".to_string()],
489            vec!["output".to_string()],
490        );
491        graph.add_node(add_node);
492
493        graph
494    }
495}
496
497#[cfg(test)]
498mod tests {
499    use super::*;
500
501    #[test]
502    fn test_node_creation() {
503        let mut node = Node::new(
504            "test_node".to_string(),
505            "Add".to_string(),
506            vec!["input1".to_string(), "input2".to_string()],
507            vec!["output".to_string()],
508        );
509
510        assert_eq!(node.name, "test_node");
511        assert_eq!(node.op_type, "Add");
512        assert_eq!(node.inputs.len(), 2);
513        assert_eq!(node.outputs.len(), 1);
514
515        node.add_attribute("axis", "1");
516        assert_eq!(node.attributes.get("axis"), Some(&"1".to_string()));
517    }
518
519    #[test]
520    fn test_tensor_spec() {
521        let spec = TensorSpec::new("test_tensor".to_string(), vec![Some(2), Some(3), None]);
522
523        let matching_tensor = Tensor::zeros(&[2, 3, 5]); // 5 is dynamic
524        let non_matching_tensor = Tensor::zeros(&[2, 4, 5]); // Wrong second dimension
525
526        assert!(spec.matches_tensor(&matching_tensor));
527        assert!(!spec.matches_tensor(&non_matching_tensor));
528    }
529
530    #[test]
531    fn test_graph_creation() {
532        let mut graph = Graph::new("test_graph".to_string());
533
534        graph.add_input(TensorSpec::new("input".to_string(), vec![Some(1), Some(3)]));
535        graph.add_output(TensorSpec::new(
536            "output".to_string(),
537            vec![Some(1), Some(1)],
538        ));
539
540        let node = Node::new(
541            "relu".to_string(),
542            "Relu".to_string(),
543            vec!["input".to_string()],
544            vec!["output".to_string()],
545        );
546        graph.add_node(node);
547
548        assert_eq!(graph.nodes.len(), 1);
549        assert_eq!(graph.inputs.len(), 1);
550        assert_eq!(graph.outputs.len(), 1);
551        assert_eq!(graph.input_names(), vec!["input"]);
552        assert_eq!(graph.output_names(), vec!["output"]);
553    }
554
555    #[test]
556    fn test_graph_validation_success() {
557        let mut graph = Graph::new("valid_graph".to_string());
558
559        graph.add_input(TensorSpec::new("input".to_string(), vec![Some(1), Some(3)]));
560        graph.add_output(TensorSpec::new(
561            "output".to_string(),
562            vec![Some(1), Some(3)],
563        ));
564
565        let node = Node::new(
566            "relu".to_string(),
567            "Relu".to_string(),
568            vec!["input".to_string()],
569            vec!["output".to_string()],
570        );
571        graph.add_node(node);
572
573        assert!(graph.validate().is_ok());
574    }
575
576    #[test]
577    fn test_graph_validation_failure() {
578        let mut graph = Graph::new("invalid_graph".to_string());
579
580        // Missing input declaration
581        graph.add_output(TensorSpec::new(
582            "output".to_string(),
583            vec![Some(1), Some(3)],
584        ));
585
586        let node = Node::new(
587            "relu".to_string(),
588            "Relu".to_string(),
589            vec!["missing_input".to_string()], // References unknown input
590            vec!["output".to_string()],
591        );
592        graph.add_node(node);
593
594        assert!(graph.validate().is_err());
595    }
596
597    #[test]
598    fn test_simple_linear_graph() {
599        let graph = Graph::create_simple_linear();
600
601        assert!(graph.validate().is_ok());
602        assert_eq!(graph.nodes.len(), 2);
603        assert_eq!(graph.inputs.len(), 1);
604        assert_eq!(graph.outputs.len(), 1);
605        assert_eq!(graph.initializers.len(), 2);
606
607        // Test topological sort
608        let order = graph.topological_sort().unwrap();
609        assert_eq!(order.len(), 2);
610        // MatMul should come before Add
611        let matmul_pos = order
612            .iter()
613            .position(|&i| graph.nodes[i].op_type == "MatMul")
614            .unwrap();
615        let add_pos = order
616            .iter()
617            .position(|&i| graph.nodes[i].op_type == "Add")
618            .unwrap();
619        assert!(matmul_pos < add_pos);
620    }
621
622    #[test]
623    fn test_graph_print_functions() {
624        let graph = Graph::create_simple_linear();
625
626        // Test that print_graph doesn't panic
627        graph.print_graph();
628
629        // Test DOT format generation
630        let dot_content = graph.to_dot();
631        assert!(dot_content.contains("digraph G {"));
632        assert!(dot_content.contains("input"));
633        assert!(dot_content.contains("output"));
634        assert!(dot_content.contains("MatMul"));
635        assert!(dot_content.contains("Add"));
636        assert!(dot_content.contains("->"));
637        assert!(dot_content.ends_with("}\n"));
638    }
639
640    #[test]
641    fn test_topological_sort() {
642        let mut graph = Graph::new("test_topo".to_string());
643
644        // Create a simple chain: input -> relu -> sigmoid -> output
645        graph.add_input(TensorSpec::new("input".to_string(), vec![Some(1), Some(3)]));
646        graph.add_output(TensorSpec::new(
647            "output".to_string(),
648            vec![Some(1), Some(3)],
649        ));
650
651        let relu_node = Node::new(
652            "relu".to_string(),
653            "Relu".to_string(),
654            vec!["input".to_string()],
655            vec!["relu_out".to_string()],
656        );
657        graph.add_node(relu_node);
658
659        let sigmoid_node = Node::new(
660            "sigmoid".to_string(),
661            "Sigmoid".to_string(),
662            vec!["relu_out".to_string()],
663            vec!["output".to_string()],
664        );
665        graph.add_node(sigmoid_node);
666
667        let order = graph.topological_sort().unwrap();
668        assert_eq!(order, vec![0, 1]); // relu first, then sigmoid
669    }
670}