Skip to main content

tensorlogic_ir/graph/
node.rs

1//! Computation nodes in the tensor graph.
2
3use serde::{Deserialize, Serialize};
4
5use crate::error::IrError;
6use crate::metadata::Metadata;
7
8use super::{EinsumSpec, OpType};
9
10#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
11pub struct EinsumNode {
12    pub op: OpType,
13    pub inputs: Vec<usize>,
14    /// Tensor indices that this node produces/writes to.
15    /// Most operations produce a single tensor, but some may produce multiple.
16    pub outputs: Vec<usize>,
17    /// Optional metadata for debugging and provenance tracking
18    #[serde(skip_serializing_if = "Option::is_none")]
19    pub metadata: Option<Metadata>,
20}
21
22impl EinsumNode {
23    pub fn new(spec: impl Into<String>, inputs: Vec<usize>, outputs: Vec<usize>) -> Self {
24        EinsumNode {
25            op: OpType::Einsum { spec: spec.into() },
26            inputs,
27            outputs,
28            metadata: None,
29        }
30    }
31
32    pub fn einsum(spec: impl Into<String>, inputs: Vec<usize>, outputs: Vec<usize>) -> Self {
33        Self::new(spec, inputs, outputs)
34    }
35
36    pub fn elem_unary(op: impl Into<String>, input: usize, output: usize) -> Self {
37        EinsumNode {
38            op: OpType::ElemUnary { op: op.into() },
39            inputs: vec![input],
40            outputs: vec![output],
41            metadata: None,
42        }
43    }
44
45    pub fn elem_binary(op: impl Into<String>, left: usize, right: usize, output: usize) -> Self {
46        EinsumNode {
47            op: OpType::ElemBinary { op: op.into() },
48            inputs: vec![left, right],
49            outputs: vec![output],
50            metadata: None,
51        }
52    }
53
54    pub fn reduce(op: impl Into<String>, axes: Vec<usize>, input: usize, output: usize) -> Self {
55        EinsumNode {
56            op: OpType::Reduce {
57                op: op.into(),
58                axes,
59            },
60            inputs: vec![input],
61            outputs: vec![output],
62            metadata: None,
63        }
64    }
65
66    /// Creates a node with automatic output tracking.
67    /// The output tensor index should be provided by the caller after calling add_tensor().
68    /// This is a convenience method for the common case of single-output operations.
69    pub fn with_single_output(
70        spec: impl Into<String>,
71        inputs: Vec<usize>,
72        output_idx: usize,
73    ) -> Self {
74        Self::new(spec, inputs, vec![output_idx])
75    }
76
77    pub fn validate(&self, num_tensors: usize) -> Result<(), IrError> {
78        if let OpType::Einsum { spec } = &self.op {
79            if spec.is_empty() {
80                return Err(IrError::EmptyEinsumSpec);
81            }
82        }
83
84        for &idx in &self.inputs {
85            if idx >= num_tensors {
86                return Err(IrError::TensorIndexOutOfBounds {
87                    index: idx,
88                    max: num_tensors - 1,
89                });
90            }
91        }
92
93        for &idx in &self.outputs {
94            if idx >= num_tensors {
95                return Err(IrError::TensorIndexOutOfBounds {
96                    index: idx,
97                    max: num_tensors - 1,
98                });
99            }
100        }
101
102        Ok(())
103    }
104
105    /// Get the primary output tensor index (first output).
106    /// Most operations produce a single tensor.
107    pub fn primary_output(&self) -> Option<usize> {
108        self.outputs.first().copied()
109    }
110
111    /// Check if this node produces a specific tensor.
112    pub fn produces(&self, tensor_idx: usize) -> bool {
113        self.outputs.contains(&tensor_idx)
114    }
115
116    /// Parse and validate the einsum spec if this is an Einsum operation.
117    pub fn parse_einsum_spec(&self) -> Result<Option<EinsumSpec>, IrError> {
118        match &self.op {
119            OpType::Einsum { spec } => {
120                let parsed = EinsumSpec::parse(spec)?;
121                parsed.validate_input_count(self.inputs.len())?;
122                Ok(Some(parsed))
123            }
124            _ => Ok(None),
125        }
126    }
127
128    /// Get a human-readable description of this node's operation.
129    pub fn operation_description(&self) -> String {
130        match &self.op {
131            OpType::Einsum { spec } => format!("Einsum({})", spec),
132            OpType::ElemUnary { op } => format!("ElemUnary({})", op),
133            OpType::ElemBinary { op } => format!("ElemBinary({})", op),
134            OpType::Reduce { op, axes } => format!("Reduce({}, axes={:?})", op, axes),
135        }
136    }
137
138    /// Attach metadata to this node.
139    pub fn with_metadata(mut self, metadata: Metadata) -> Self {
140        self.metadata = Some(metadata);
141        self
142    }
143
144    /// Get the metadata if present.
145    pub fn get_metadata(&self) -> Option<&Metadata> {
146        self.metadata.as_ref()
147    }
148
149    /// Set the metadata for this node.
150    pub fn set_metadata(&mut self, metadata: Metadata) {
151        self.metadata = Some(metadata);
152    }
153}