1use crate::{
7 error::{OnnxError, Result},
8 operators::OperatorType,
9 tensor::Tensor,
10};
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct Node {
17 pub name: String,
19 pub op_type: String,
21 pub inputs: Vec<String>,
23 pub outputs: Vec<String>,
25 pub attributes: HashMap<String, String>,
27}
28
29impl Node {
30 pub fn new(name: String, op_type: String, inputs: Vec<String>, outputs: Vec<String>) -> Self {
32 Self {
33 name,
34 op_type,
35 inputs,
36 outputs,
37 attributes: HashMap::new(),
38 }
39 }
40
41 pub fn add_attribute<K: Into<String>, V: Into<String>>(&mut self, key: K, value: V) {
43 self.attributes.insert(key.into(), value.into());
44 }
45
46 pub fn get_operator_type(&self) -> Result<OperatorType> {
48 self.op_type.parse()
49 }
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct Graph {
55 pub name: String,
57 pub nodes: Vec<Node>,
59 pub inputs: Vec<TensorSpec>,
61 pub outputs: Vec<TensorSpec>,
63 pub initializers: HashMap<String, Tensor>,
65}
66
67#[derive(Debug, Clone, Serialize, Deserialize)]
69pub struct TensorSpec {
70 pub name: String,
72 pub dimensions: Vec<Option<usize>>,
74 pub dtype: String,
76}
77
78impl TensorSpec {
79 pub fn new(name: String, dimensions: Vec<Option<usize>>) -> Self {
81 Self {
82 name,
83 dimensions,
84 dtype: "float32".to_string(),
85 }
86 }
87
88 pub fn matches_tensor(&self, tensor: &Tensor) -> bool {
90 let tensor_shape = tensor.shape();
91
92 if self.dimensions.len() != tensor_shape.len() {
93 return false;
94 }
95
96 for (spec_dim, &tensor_dim) in self.dimensions.iter().zip(tensor_shape.iter()) {
97 match spec_dim {
98 Some(expected) => {
99 if *expected != tensor_dim {
100 return false;
101 }
102 }
103 None => {
104 continue;
106 }
107 }
108 }
109
110 true
111 }
112}
113
114impl Graph {
115 pub fn new(name: String) -> Self {
117 Self {
118 name,
119 nodes: Vec::new(),
120 inputs: Vec::new(),
121 outputs: Vec::new(),
122 initializers: HashMap::new(),
123 }
124 }
125
126 pub fn add_node(&mut self, node: Node) {
128 self.nodes.push(node);
129 }
130
131 pub fn add_input(&mut self, input_spec: TensorSpec) {
133 self.inputs.push(input_spec);
134 }
135
136 pub fn add_output(&mut self, output_spec: TensorSpec) {
138 self.outputs.push(output_spec);
139 }
140
141 pub fn add_initializer(&mut self, name: String, tensor: Tensor) {
143 self.initializers.insert(name, tensor);
144 }
145
146 pub fn input_names(&self) -> Vec<&str> {
148 self.inputs.iter().map(|spec| spec.name.as_str()).collect()
149 }
150
151 pub fn output_names(&self) -> Vec<&str> {
153 self.outputs.iter().map(|spec| spec.name.as_str()).collect()
154 }
155
156 pub fn validate(&self) -> Result<()> {
158 let mut node_names = std::collections::HashSet::new();
160 for node in &self.nodes {
161 if !node_names.insert(&node.name) {
162 return Err(OnnxError::graph_validation_error(format!(
163 "Duplicate node name: {}",
164 node.name
165 )));
166 }
167 }
168
169 let mut available_tensors = std::collections::HashSet::new();
171
172 for input in &self.inputs {
174 available_tensors.insert(&input.name);
175 }
176
177 for name in self.initializers.keys() {
179 available_tensors.insert(name);
180 }
181
182 for node in &self.nodes {
184 for input_name in &node.inputs {
186 if !available_tensors.contains(input_name) {
187 return Err(OnnxError::graph_validation_error(format!(
188 "Node '{}' references unknown input tensor '{}'",
189 node.name, input_name
190 )));
191 }
192 }
193
194 for output_name in &node.outputs {
196 available_tensors.insert(output_name);
197 }
198
199 node.get_operator_type().map_err(|e| {
201 OnnxError::graph_validation_error(format!(
202 "Node '{}' has invalid operator type '{}': {}",
203 node.name, node.op_type, e
204 ))
205 })?;
206 }
207
208 for output in &self.outputs {
210 if !available_tensors.contains(&output.name) {
211 return Err(OnnxError::graph_validation_error(format!(
212 "Graph output '{}' is not produced by any node",
213 output.name
214 )));
215 }
216 }
217
218 Ok(())
219 }
220
221 pub fn topological_sort(&self) -> Result<Vec<usize>> {
223 let n = self.nodes.len();
224 let mut in_degree = vec![0; n];
225 let mut adjacency_list: Vec<Vec<usize>> = vec![vec![]; n];
226
227 for (i, node) in self.nodes.iter().enumerate() {
229 for output in &node.outputs {
230 for (j, other_node) in self.nodes.iter().enumerate() {
231 if i != j && other_node.inputs.contains(output) {
232 adjacency_list[i].push(j);
233 in_degree[j] += 1;
234 }
235 }
236 }
237 }
238
239 let mut queue: Vec<usize> = (0..n).filter(|&i| in_degree[i] == 0).collect();
241 let mut result = Vec::new();
242
243 while let Some(current) = queue.pop() {
244 result.push(current);
245
246 for &neighbor in &adjacency_list[current] {
247 in_degree[neighbor] -= 1;
248 if in_degree[neighbor] == 0 {
249 queue.push(neighbor);
250 }
251 }
252 }
253
254 if result.len() != n {
255 return Err(OnnxError::graph_validation_error(
256 "Graph contains cycles".to_string(),
257 ));
258 }
259
260 Ok(result)
261 }
262
263 pub fn print_graph(&self) {
265 let title = format!("GRAPH: {}", self.name);
267 let min_width = title.len() + 4; let box_width = std::cmp::max(min_width, 40); let top_border = format!("┌{}┐", "─".repeat(box_width));
272
273 let padding = (box_width - title.len()) / 2;
275 let left_padding = " ".repeat(padding);
276 let right_padding = " ".repeat(box_width - title.len() - padding);
277 let title_line = format!("│{left_padding}{title}{right_padding}│");
278
279 let bottom_border = format!("└{}┘", "─".repeat(box_width));
281
282 println!("\n{top_border}");
283 println!("{title_line}");
284 println!("{bottom_border}");
285
286 if !self.inputs.is_empty() {
288 println!("\n📥 INPUTS:");
289 for input in &self.inputs {
290 let shape_str = input
291 .dimensions
292 .iter()
293 .map(|d| d.map_or("?".to_string(), |v| v.to_string()))
294 .collect::<Vec<_>>()
295 .join(" × ");
296 println!(" ┌─ {} [{}] ({})", input.name, shape_str, input.dtype);
297 }
298 }
299
300 if !self.initializers.is_empty() {
302 println!("\n⚙️ INITIALIZERS:");
303 for (name, tensor) in &self.initializers {
304 let shape_str = tensor
305 .shape()
306 .iter()
307 .map(|&d| d.to_string())
308 .collect::<Vec<_>>()
309 .join(" × ");
310 println!(" ┌─ {name} [{shape_str}]");
311 }
312 }
313
314 if !self.nodes.is_empty() {
316 println!("\n🔄 COMPUTATION FLOW:");
317
318 let execution_order = self.topological_sort().unwrap_or_else(|_| {
320 println!(" ⚠️ Warning: Graph contains cycles, showing original order");
321 (0..self.nodes.len()).collect()
322 });
323
324 for (step, &node_idx) in execution_order.iter().enumerate() {
325 let node = &self.nodes[node_idx];
326
327 println!(" │");
329 println!(" ├─ Step {}: {}", step + 1, node.name);
330
331 println!(" │ ┌─ Operation: {}", node.op_type);
333
334 if !node.inputs.is_empty() {
336 println!(" │ ├─ Inputs:");
337 for input in &node.inputs {
338 println!(" │ │ └─ {input}");
339 }
340 }
341
342 if !node.outputs.is_empty() {
344 println!(" │ ├─ Outputs:");
345 for output in &node.outputs {
346 println!(" │ │ └─ {output}");
347 }
348 }
349
350 if !node.attributes.is_empty() {
352 println!(" │ └─ Attributes:");
353 for (key, value) in &node.attributes {
354 println!(" │ └─ {key}: {value}");
355 }
356 } else {
357 println!(" │ └─ (no attributes)");
358 }
359 }
360 }
361
362 if !self.outputs.is_empty() {
364 println!(" │");
365 println!("📤 OUTPUTS:");
366 for output in &self.outputs {
367 let shape_str = output
368 .dimensions
369 .iter()
370 .map(|d| d.map_or("?".to_string(), |v| v.to_string()))
371 .collect::<Vec<_>>()
372 .join(" × ");
373 println!(" └─ {} [{}] ({})", output.name, shape_str, output.dtype);
374 }
375 }
376
377 println!("\n📊 STATISTICS:");
378 println!(" ├─ Total nodes: {}", self.nodes.len());
379 println!(" ├─ Input tensors: {}", self.inputs.len());
380 println!(" ├─ Output tensors: {}", self.outputs.len());
381 println!(" └─ Initializers: {}", self.initializers.len());
382
383 if !self.nodes.is_empty() {
385 let mut op_counts: std::collections::BTreeMap<String, usize> =
386 std::collections::BTreeMap::new();
387 for node in &self.nodes {
388 *op_counts.entry(node.op_type.clone()).or_insert(0) += 1;
389 }
390
391 println!("\n🎯 OPERATION SUMMARY:");
392 for (op_type, count) in op_counts {
393 println!(" ├─ {op_type}: {count}");
394 }
395 }
396
397 println!();
398 }
399
400 pub fn to_dot(&self) -> String {
402 let mut dot = String::new();
403
404 dot.push_str("digraph G {\n");
405 dot.push_str(" rankdir=TB;\n");
406 dot.push_str(" node [shape=box, style=rounded];\n\n");
407
408 for input in &self.inputs {
410 dot.push_str(&format!(
411 " \"{}\" [shape=ellipse, color=green, label=\"{}\"];\n",
412 input.name, input.name
413 ));
414 }
415
416 for name in self.initializers.keys() {
418 dot.push_str(&format!(
419 " \"{name}\" [shape=diamond, color=blue, label=\"{name}\"];\n"
420 ));
421 }
422
423 for node in &self.nodes {
425 dot.push_str(&format!(
426 " \"{}\" [label=\"{}\\n({})\"];\n",
427 node.name, node.name, node.op_type
428 ));
429 }
430
431 for output in &self.outputs {
433 dot.push_str(&format!(
434 " \"{}\" [shape=ellipse, color=red, label=\"{}\"];\n",
435 output.name, output.name
436 ));
437 }
438
439 dot.push('\n');
440
441 for node in &self.nodes {
443 for input in &node.inputs {
444 dot.push_str(&format!(" \"{}\" -> \"{}\";\n", input, node.name));
445 }
446 for output in &node.outputs {
447 dot.push_str(&format!(" \"{}\" -> \"{}\";\n", node.name, output));
448 }
449 }
450
451 dot.push_str("}\n");
452 dot
453 }
454
455 pub fn create_simple_linear() -> Self {
457 let mut graph = Graph::new("simple_linear".to_string());
458
459 graph.add_input(TensorSpec::new("input".to_string(), vec![Some(1), Some(3)]));
461
462 graph.add_output(TensorSpec::new(
464 "output".to_string(),
465 vec![Some(1), Some(2)],
466 ));
467
468 let weights = Tensor::from_shape_vec(&[3, 2], vec![0.5, 0.3, 0.2, 0.4, 0.1, 0.6]).unwrap();
470 let bias = Tensor::from_shape_vec(&[1, 2], vec![0.1, 0.2]).unwrap();
471
472 graph.add_initializer("weights".to_string(), weights);
473 graph.add_initializer("bias".to_string(), bias);
474
475 let matmul_node = Node::new(
477 "matmul".to_string(),
478 "MatMul".to_string(),
479 vec!["input".to_string(), "weights".to_string()],
480 vec!["matmul_output".to_string()],
481 );
482 graph.add_node(matmul_node);
483
484 let add_node = Node::new(
486 "add_bias".to_string(),
487 "Add".to_string(),
488 vec!["matmul_output".to_string(), "bias".to_string()],
489 vec!["output".to_string()],
490 );
491 graph.add_node(add_node);
492
493 graph
494 }
495}
496
497#[cfg(test)]
498mod tests {
499 use super::*;
500
501 #[test]
502 fn test_node_creation() {
503 let mut node = Node::new(
504 "test_node".to_string(),
505 "Add".to_string(),
506 vec!["input1".to_string(), "input2".to_string()],
507 vec!["output".to_string()],
508 );
509
510 assert_eq!(node.name, "test_node");
511 assert_eq!(node.op_type, "Add");
512 assert_eq!(node.inputs.len(), 2);
513 assert_eq!(node.outputs.len(), 1);
514
515 node.add_attribute("axis", "1");
516 assert_eq!(node.attributes.get("axis"), Some(&"1".to_string()));
517 }
518
519 #[test]
520 fn test_tensor_spec() {
521 let spec = TensorSpec::new("test_tensor".to_string(), vec![Some(2), Some(3), None]);
522
523 let matching_tensor = Tensor::zeros(&[2, 3, 5]); let non_matching_tensor = Tensor::zeros(&[2, 4, 5]); assert!(spec.matches_tensor(&matching_tensor));
527 assert!(!spec.matches_tensor(&non_matching_tensor));
528 }
529
530 #[test]
531 fn test_graph_creation() {
532 let mut graph = Graph::new("test_graph".to_string());
533
534 graph.add_input(TensorSpec::new("input".to_string(), vec![Some(1), Some(3)]));
535 graph.add_output(TensorSpec::new(
536 "output".to_string(),
537 vec![Some(1), Some(1)],
538 ));
539
540 let node = Node::new(
541 "relu".to_string(),
542 "Relu".to_string(),
543 vec!["input".to_string()],
544 vec!["output".to_string()],
545 );
546 graph.add_node(node);
547
548 assert_eq!(graph.nodes.len(), 1);
549 assert_eq!(graph.inputs.len(), 1);
550 assert_eq!(graph.outputs.len(), 1);
551 assert_eq!(graph.input_names(), vec!["input"]);
552 assert_eq!(graph.output_names(), vec!["output"]);
553 }
554
555 #[test]
556 fn test_graph_validation_success() {
557 let mut graph = Graph::new("valid_graph".to_string());
558
559 graph.add_input(TensorSpec::new("input".to_string(), vec![Some(1), Some(3)]));
560 graph.add_output(TensorSpec::new(
561 "output".to_string(),
562 vec![Some(1), Some(3)],
563 ));
564
565 let node = Node::new(
566 "relu".to_string(),
567 "Relu".to_string(),
568 vec!["input".to_string()],
569 vec!["output".to_string()],
570 );
571 graph.add_node(node);
572
573 assert!(graph.validate().is_ok());
574 }
575
576 #[test]
577 fn test_graph_validation_failure() {
578 let mut graph = Graph::new("invalid_graph".to_string());
579
580 graph.add_output(TensorSpec::new(
582 "output".to_string(),
583 vec![Some(1), Some(3)],
584 ));
585
586 let node = Node::new(
587 "relu".to_string(),
588 "Relu".to_string(),
589 vec!["missing_input".to_string()], vec!["output".to_string()],
591 );
592 graph.add_node(node);
593
594 assert!(graph.validate().is_err());
595 }
596
597 #[test]
598 fn test_simple_linear_graph() {
599 let graph = Graph::create_simple_linear();
600
601 assert!(graph.validate().is_ok());
602 assert_eq!(graph.nodes.len(), 2);
603 assert_eq!(graph.inputs.len(), 1);
604 assert_eq!(graph.outputs.len(), 1);
605 assert_eq!(graph.initializers.len(), 2);
606
607 let order = graph.topological_sort().unwrap();
609 assert_eq!(order.len(), 2);
610 let matmul_pos = order
612 .iter()
613 .position(|&i| graph.nodes[i].op_type == "MatMul")
614 .unwrap();
615 let add_pos = order
616 .iter()
617 .position(|&i| graph.nodes[i].op_type == "Add")
618 .unwrap();
619 assert!(matmul_pos < add_pos);
620 }
621
622 #[test]
623 fn test_graph_print_functions() {
624 let graph = Graph::create_simple_linear();
625
626 graph.print_graph();
628
629 let dot_content = graph.to_dot();
631 assert!(dot_content.contains("digraph G {"));
632 assert!(dot_content.contains("input"));
633 assert!(dot_content.contains("output"));
634 assert!(dot_content.contains("MatMul"));
635 assert!(dot_content.contains("Add"));
636 assert!(dot_content.contains("->"));
637 assert!(dot_content.ends_with("}\n"));
638 }
639
640 #[test]
641 fn test_topological_sort() {
642 let mut graph = Graph::new("test_topo".to_string());
643
644 graph.add_input(TensorSpec::new("input".to_string(), vec![Some(1), Some(3)]));
646 graph.add_output(TensorSpec::new(
647 "output".to_string(),
648 vec![Some(1), Some(3)],
649 ));
650
651 let relu_node = Node::new(
652 "relu".to_string(),
653 "Relu".to_string(),
654 vec!["input".to_string()],
655 vec!["relu_out".to_string()],
656 );
657 graph.add_node(relu_node);
658
659 let sigmoid_node = Node::new(
660 "sigmoid".to_string(),
661 "Sigmoid".to_string(),
662 vec!["relu_out".to_string()],
663 vec!["output".to_string()],
664 );
665 graph.add_node(sigmoid_node);
666
667 let order = graph.topological_sort().unwrap();
668 assert_eq!(order, vec![0, 1]); }
670}