tensorlogic_scirs_backend/
shape_inference.rs1use crate::Scirs2Exec;
4use std::collections::HashMap;
5use tensorlogic_infer::{ExecutorError, ShapeInferenceContext, TensorShape};
6use tensorlogic_ir::{EinsumGraph, OpType};
7
8pub struct Scirs2ShapeInference {
10 shapes: HashMap<String, Vec<usize>>,
12}
13
14impl Scirs2ShapeInference {
15 pub fn new() -> Self {
17 Scirs2ShapeInference {
18 shapes: HashMap::new(),
19 }
20 }
21
22 pub fn register_shape(&mut self, name: String, shape: Vec<usize>) {
24 self.shapes.insert(name, shape);
25 }
26
27 pub fn infer_graph_shapes(
29 &mut self,
30 graph: &EinsumGraph,
31 executor: &Scirs2Exec,
32 ) -> Result<ShapeInferenceContext, ExecutorError> {
33 let mut context = ShapeInferenceContext::new();
34
35 for (idx, tensor_name) in graph.tensors.iter().enumerate() {
37 if let Some(tensor) = executor.get_tensor(tensor_name) {
38 let shape = tensor.shape().to_vec();
39 context.set_tensor_shape(idx, TensorShape::static_shape(shape.clone()));
40 self.shapes.insert(tensor_name.clone(), shape);
41 }
42 }
43
44 for node in &graph.nodes {
46 self.infer_node_shape(node, &mut context)?;
47 }
48
49 Ok(context)
50 }
51
52 fn infer_node_shape(
54 &self,
55 node: &tensorlogic_ir::EinsumNode,
56 context: &mut ShapeInferenceContext,
57 ) -> Result<(), ExecutorError> {
58 let input_shapes: Vec<TensorShape> = node
59 .inputs
60 .iter()
61 .filter_map(|&idx| context.get_tensor_shape(idx).cloned())
62 .collect();
63
64 if input_shapes.len() != node.inputs.len() {
65 return Err(ExecutorError::ShapeMismatch(
66 "Not all input shapes are known".to_string(),
67 ));
68 }
69
70 let output_shape = match &node.op {
72 OpType::Einsum { spec } => self.infer_einsum_shape(spec, &input_shapes)?,
73 OpType::ElemUnary { .. } => {
74 input_shapes[0].clone()
76 }
77 OpType::ElemBinary { .. } => {
78 self.infer_binary_shape(&input_shapes[0], &input_shapes[1])?
80 }
81 OpType::Reduce { axes, .. } => {
82 self.infer_reduce_shape(&input_shapes[0], axes)?
84 }
85 };
86
87 if let Some(&output_idx) = node.outputs.first() {
89 context.set_tensor_shape(output_idx, output_shape);
90 }
91
92 Ok(())
93 }
94
95 fn infer_einsum_shape(
97 &self,
98 spec: &str,
99 _input_shapes: &[TensorShape],
100 ) -> Result<TensorShape, ExecutorError> {
101 let parts: Vec<&str> = spec.split("->").collect();
103 if parts.len() != 2 {
104 return Err(ExecutorError::InvalidEinsumSpec(format!(
105 "Invalid einsum spec: {}",
106 spec
107 )));
108 }
109
110 let output_spec = parts[1].trim();
111
112 Ok(TensorShape::dynamic(output_spec.len()))
115 }
116
117 fn infer_binary_shape(
119 &self,
120 shape1: &TensorShape,
121 shape2: &TensorShape,
122 ) -> Result<TensorShape, ExecutorError> {
123 if let (Some(s1), Some(s2)) = (shape1.as_static(), shape2.as_static()) {
125 if s1 == s2 {
126 return Ok(TensorShape::static_shape(s1));
127 } else if s1.is_empty() {
128 return Ok(TensorShape::static_shape(s2));
130 } else if s2.is_empty() {
131 return Ok(TensorShape::static_shape(s1));
133 } else {
134 return Err(ExecutorError::ShapeMismatch(format!(
135 "Incompatible shapes: {:?} and {:?}",
136 s1, s2
137 )));
138 }
139 }
140
141 Ok(TensorShape::dynamic(shape1.rank().max(shape2.rank())))
143 }
144
145 fn infer_reduce_shape(
147 &self,
148 shape: &TensorShape,
149 axes: &[usize],
150 ) -> Result<TensorShape, ExecutorError> {
151 if let Some(dims) = shape.as_static() {
152 let mut result_dims = dims.clone();
153 for &axis in axes.iter().rev() {
155 if axis < result_dims.len() {
156 result_dims.remove(axis);
157 }
158 }
159 return Ok(TensorShape::static_shape(result_dims));
160 }
161
162 let new_rank = shape.rank().saturating_sub(axes.len());
164 Ok(TensorShape::dynamic(new_rank))
165 }
166}
167
168impl Default for Scirs2ShapeInference {
169 fn default() -> Self {
170 Self::new()
171 }
172}
173
174pub fn validate_tensor_shapes(
176 executor: &Scirs2Exec,
177 expected_shapes: &HashMap<String, Vec<usize>>,
178) -> Result<(), ExecutorError> {
179 for (name, expected_shape) in expected_shapes {
180 if let Some(tensor) = executor.get_tensor(name) {
181 let actual_shape = tensor.shape();
182 if actual_shape != expected_shape.as_slice() {
183 return Err(ExecutorError::ShapeMismatch(format!(
184 "Tensor '{}': expected shape {:?}, got {:?}",
185 name, expected_shape, actual_shape
186 )));
187 }
188 }
189 }
190 Ok(())
191}
192
193#[cfg(all(test, feature = "integration-tests"))]
194mod tests {
195 use super::*;
196 use scirs2_core::ndarray::ArrayD;
197 use tensorlogic_compiler::compile_to_einsum;
198 use tensorlogic_ir::{TLExpr, Term};
199
200 fn create_test_tensor(shape: &[usize]) -> ArrayD<f64> {
201 ArrayD::zeros(shape.to_vec())
202 }
203
204 #[test]
205 fn test_shape_inference_basic() {
206 let x = TLExpr::pred("x", vec![Term::var("i"), Term::var("j")]);
207 let y = TLExpr::pred("y", vec![Term::var("i"), Term::var("j")]);
208 let expr = TLExpr::add(x, y);
209 let graph = compile_to_einsum(&expr).unwrap();
210
211 let mut executor = Scirs2Exec::new();
212 executor.add_tensor(graph.tensors[0].clone(), create_test_tensor(&[3, 4]));
213 executor.add_tensor(graph.tensors[1].clone(), create_test_tensor(&[3, 4]));
214
215 let mut inference = Scirs2ShapeInference::new();
216 let context = inference.infer_graph_shapes(&graph, &executor).unwrap();
217
218 assert!(context.get_tensor_shape(0).is_some());
220 assert!(context.get_tensor_shape(1).is_some());
221 }
222
223 #[test]
224 fn test_validate_shapes_success() {
225 let mut executor = Scirs2Exec::new();
226 executor.add_tensor("x".to_string(), create_test_tensor(&[2, 3]));
227 executor.add_tensor("y".to_string(), create_test_tensor(&[4, 5]));
228
229 let mut expected = HashMap::new();
230 expected.insert("x".to_string(), vec![2, 3]);
231 expected.insert("y".to_string(), vec![4, 5]);
232
233 let result = validate_tensor_shapes(&executor, &expected);
234 assert!(result.is_ok());
235 }
236
237 #[test]
238 fn test_validate_shapes_mismatch() {
239 let mut executor = Scirs2Exec::new();
240 executor.add_tensor("x".to_string(), create_test_tensor(&[2, 3]));
241
242 let mut expected = HashMap::new();
243 expected.insert("x".to_string(), vec![3, 4]); let result = validate_tensor_shapes(&executor, &expected);
246 assert!(result.is_err());
247 }
248
249 #[test]
250 fn test_infer_unary_shape() {
251 let inference = Scirs2ShapeInference::new();
252 let input_shape = TensorShape::static_shape(vec![2, 3, 4]);
253
254 let node = tensorlogic_ir::EinsumNode {
256 inputs: vec![0],
257 outputs: vec![1],
258 op: OpType::ElemUnary {
259 op: "relu".to_string(),
260 },
261 metadata: None,
262 };
263
264 let mut context = ShapeInferenceContext::new();
265 context.set_tensor_shape(0, input_shape.clone());
266
267 inference.infer_node_shape(&node, &mut context).unwrap();
268
269 let output_shape = context.get_tensor_shape(1).unwrap();
270 assert_eq!(output_shape, &input_shape);
271 }
272
273 #[test]
274 fn test_infer_reduce_shape() {
275 let inference = Scirs2ShapeInference::new();
276
277 let result = inference
279 .infer_reduce_shape(&TensorShape::static_shape(vec![2, 3, 4]), &[1])
280 .unwrap();
281
282 let result_dims = result.as_static().unwrap();
283 assert_eq!(result_dims, vec![2, 4]);
284 }
285
286 #[test]
287 fn test_infer_binary_shape_matching() {
288 let inference = Scirs2ShapeInference::new();
289
290 let shape1 = TensorShape::static_shape(vec![2, 3]);
291 let shape2 = TensorShape::static_shape(vec![2, 3]);
292
293 let result = inference.infer_binary_shape(&shape1, &shape2).unwrap();
294
295 let result_dims = result.as_static().unwrap();
296 assert_eq!(result_dims, vec![2, 3]);
297 }
298
299 #[test]
300 fn test_infer_binary_shape_scalar_broadcast() {
301 let inference = Scirs2ShapeInference::new();
302
303 let shape1 = TensorShape::static_shape(vec![]); let shape2 = TensorShape::static_shape(vec![2, 3]);
305
306 let result = inference.infer_binary_shape(&shape1, &shape2).unwrap();
307
308 let result_dims = result.as_static().unwrap();
309 assert_eq!(result_dims, vec![2, 3]);
310 }
311
312 #[test]
313 fn test_infer_binary_shape_mismatch() {
314 let inference = Scirs2ShapeInference::new();
315
316 let shape1 = TensorShape::static_shape(vec![2, 3]);
317 let shape2 = TensorShape::static_shape(vec![4, 5]);
318
319 let result = inference.infer_binary_shape(&shape1, &shape2);
320 assert!(result.is_err());
321 }
322}