tensorlogic_ir/graph/
node.rs1use serde::{Deserialize, Serialize};
4
5use crate::error::IrError;
6use crate::metadata::Metadata;
7
8use super::{EinsumSpec, OpType};
9
10#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
11pub struct EinsumNode {
12 pub op: OpType,
13 pub inputs: Vec<usize>,
14 pub outputs: Vec<usize>,
17 #[serde(skip_serializing_if = "Option::is_none")]
19 pub metadata: Option<Metadata>,
20}
21
22impl EinsumNode {
23 pub fn new(spec: impl Into<String>, inputs: Vec<usize>, outputs: Vec<usize>) -> Self {
24 EinsumNode {
25 op: OpType::Einsum { spec: spec.into() },
26 inputs,
27 outputs,
28 metadata: None,
29 }
30 }
31
32 pub fn einsum(spec: impl Into<String>, inputs: Vec<usize>, outputs: Vec<usize>) -> Self {
33 Self::new(spec, inputs, outputs)
34 }
35
36 pub fn elem_unary(op: impl Into<String>, input: usize, output: usize) -> Self {
37 EinsumNode {
38 op: OpType::ElemUnary { op: op.into() },
39 inputs: vec![input],
40 outputs: vec![output],
41 metadata: None,
42 }
43 }
44
45 pub fn elem_binary(op: impl Into<String>, left: usize, right: usize, output: usize) -> Self {
46 EinsumNode {
47 op: OpType::ElemBinary { op: op.into() },
48 inputs: vec![left, right],
49 outputs: vec![output],
50 metadata: None,
51 }
52 }
53
54 pub fn reduce(op: impl Into<String>, axes: Vec<usize>, input: usize, output: usize) -> Self {
55 EinsumNode {
56 op: OpType::Reduce {
57 op: op.into(),
58 axes,
59 },
60 inputs: vec![input],
61 outputs: vec![output],
62 metadata: None,
63 }
64 }
65
66 pub fn with_single_output(
70 spec: impl Into<String>,
71 inputs: Vec<usize>,
72 output_idx: usize,
73 ) -> Self {
74 Self::new(spec, inputs, vec![output_idx])
75 }
76
77 pub fn validate(&self, num_tensors: usize) -> Result<(), IrError> {
78 if let OpType::Einsum { spec } = &self.op {
79 if spec.is_empty() {
80 return Err(IrError::EmptyEinsumSpec);
81 }
82 }
83
84 for &idx in &self.inputs {
85 if idx >= num_tensors {
86 return Err(IrError::TensorIndexOutOfBounds {
87 index: idx,
88 max: num_tensors - 1,
89 });
90 }
91 }
92
93 for &idx in &self.outputs {
94 if idx >= num_tensors {
95 return Err(IrError::TensorIndexOutOfBounds {
96 index: idx,
97 max: num_tensors - 1,
98 });
99 }
100 }
101
102 Ok(())
103 }
104
105 pub fn primary_output(&self) -> Option<usize> {
108 self.outputs.first().copied()
109 }
110
111 pub fn produces(&self, tensor_idx: usize) -> bool {
113 self.outputs.contains(&tensor_idx)
114 }
115
116 pub fn parse_einsum_spec(&self) -> Result<Option<EinsumSpec>, IrError> {
118 match &self.op {
119 OpType::Einsum { spec } => {
120 let parsed = EinsumSpec::parse(spec)?;
121 parsed.validate_input_count(self.inputs.len())?;
122 Ok(Some(parsed))
123 }
124 _ => Ok(None),
125 }
126 }
127
128 pub fn operation_description(&self) -> String {
130 match &self.op {
131 OpType::Einsum { spec } => format!("Einsum({})", spec),
132 OpType::ElemUnary { op } => format!("ElemUnary({})", op),
133 OpType::ElemBinary { op } => format!("ElemBinary({})", op),
134 OpType::Reduce { op, axes } => format!("Reduce({}, axes={:?})", op, axes),
135 }
136 }
137
138 pub fn with_metadata(mut self, metadata: Metadata) -> Self {
140 self.metadata = Some(metadata);
141 self
142 }
143
144 pub fn get_metadata(&self) -> Option<&Metadata> {
146 self.metadata.as_ref()
147 }
148
149 pub fn set_metadata(&mut self, metadata: Metadata) {
151 self.metadata = Some(metadata);
152 }
153}