Skip to main content

tensorlogic_scirs_backend/
shape_inference.rs

1//! Shape inference and validation support.
2
3use crate::Scirs2Exec;
4use std::collections::HashMap;
5use tensorlogic_infer::{ExecutorError, ShapeInferenceContext, TensorShape};
6use tensorlogic_ir::{EinsumGraph, OpType};
7
8/// Shape inference engine for SciRS2 backend
9pub struct Scirs2ShapeInference {
10    /// Known tensor shapes
11    shapes: HashMap<String, Vec<usize>>,
12}
13
14impl Scirs2ShapeInference {
15    /// Create a new shape inference engine
16    pub fn new() -> Self {
17        Scirs2ShapeInference {
18            shapes: HashMap::new(),
19        }
20    }
21
22    /// Register a tensor shape
23    pub fn register_shape(&mut self, name: String, shape: Vec<usize>) {
24        self.shapes.insert(name, shape);
25    }
26
27    /// Infer shapes for all tensors in a graph
28    pub fn infer_graph_shapes(
29        &mut self,
30        graph: &EinsumGraph,
31        executor: &Scirs2Exec,
32    ) -> Result<ShapeInferenceContext, ExecutorError> {
33        let mut context = ShapeInferenceContext::new();
34
35        // Register input tensor shapes from executor
36        for (idx, tensor_name) in graph.tensors.iter().enumerate() {
37            if let Some(tensor) = executor.get_tensor(tensor_name) {
38                let shape = tensor.shape().to_vec();
39                context.set_tensor_shape(idx, TensorShape::static_shape(shape.clone()));
40                self.shapes.insert(tensor_name.clone(), shape);
41            }
42        }
43
44        // Infer shapes for each operation in the graph
45        for node in &graph.nodes {
46            self.infer_node_shape(node, &mut context)?;
47        }
48
49        Ok(context)
50    }
51
52    /// Infer output shape for a single node
53    fn infer_node_shape(
54        &self,
55        node: &tensorlogic_ir::EinsumNode,
56        context: &mut ShapeInferenceContext,
57    ) -> Result<(), ExecutorError> {
58        let input_shapes: Vec<TensorShape> = node
59            .inputs
60            .iter()
61            .filter_map(|&idx| context.get_tensor_shape(idx).cloned())
62            .collect();
63
64        if input_shapes.len() != node.inputs.len() {
65            return Err(ExecutorError::ShapeMismatch(
66                "Not all input shapes are known".to_string(),
67            ));
68        }
69
70        // Infer output shape based on operation type
71        let output_shape = match &node.op {
72            OpType::Einsum { spec } => self.infer_einsum_shape(spec, &input_shapes)?,
73            OpType::ElemUnary { .. } => {
74                // Unary operations preserve shape
75                input_shapes[0].clone()
76            }
77            OpType::ElemBinary { .. } => {
78                // Binary operations require compatible shapes
79                self.infer_binary_shape(&input_shapes[0], &input_shapes[1])?
80            }
81            OpType::Reduce { axes, .. } => {
82                // Reduction removes specified axes
83                self.infer_reduce_shape(&input_shapes[0], axes)?
84            }
85        };
86
87        // Register output shape
88        if let Some(&output_idx) = node.outputs.first() {
89            context.set_tensor_shape(output_idx, output_shape);
90        }
91
92        Ok(())
93    }
94
95    /// Infer shape for einsum operation
96    fn infer_einsum_shape(
97        &self,
98        spec: &str,
99        _input_shapes: &[TensorShape],
100    ) -> Result<TensorShape, ExecutorError> {
101        // Parse einsum specification
102        let parts: Vec<&str> = spec.split("->").collect();
103        if parts.len() != 2 {
104            return Err(ExecutorError::InvalidEinsumSpec(format!(
105                "Invalid einsum spec: {}",
106                spec
107            )));
108        }
109
110        let output_spec = parts[1].trim();
111
112        // For now, return a dynamic shape for einsum
113        // Full shape inference would require parsing the spec and input shapes
114        Ok(TensorShape::dynamic(output_spec.len()))
115    }
116
117    /// Infer shape for binary element-wise operation
118    fn infer_binary_shape(
119        &self,
120        shape1: &TensorShape,
121        shape2: &TensorShape,
122    ) -> Result<TensorShape, ExecutorError> {
123        // Check if both shapes are static
124        if let (Some(s1), Some(s2)) = (shape1.as_static(), shape2.as_static()) {
125            if s1 == s2 {
126                return Ok(TensorShape::static_shape(s1));
127            } else if s1.is_empty() {
128                // Scalar broadcast
129                return Ok(TensorShape::static_shape(s2));
130            } else if s2.is_empty() {
131                // Scalar broadcast
132                return Ok(TensorShape::static_shape(s1));
133            } else {
134                return Err(ExecutorError::ShapeMismatch(format!(
135                    "Incompatible shapes: {:?} and {:?}",
136                    s1, s2
137                )));
138            }
139        }
140
141        // If either shape is dynamic, return dynamic
142        Ok(TensorShape::dynamic(shape1.rank().max(shape2.rank())))
143    }
144
145    /// Infer shape for reduction operation
146    fn infer_reduce_shape(
147        &self,
148        shape: &TensorShape,
149        axes: &[usize],
150    ) -> Result<TensorShape, ExecutorError> {
151        if let Some(dims) = shape.as_static() {
152            let mut result_dims = dims.clone();
153            // Remove reduced axes (in reverse order to maintain indices)
154            for &axis in axes.iter().rev() {
155                if axis < result_dims.len() {
156                    result_dims.remove(axis);
157                }
158            }
159            return Ok(TensorShape::static_shape(result_dims));
160        }
161
162        // Dynamic or symbolic shape
163        let new_rank = shape.rank().saturating_sub(axes.len());
164        Ok(TensorShape::dynamic(new_rank))
165    }
166}
167
168impl Default for Scirs2ShapeInference {
169    fn default() -> Self {
170        Self::new()
171    }
172}
173
174/// Validate tensor shapes match expected shapes
175pub fn validate_tensor_shapes(
176    executor: &Scirs2Exec,
177    expected_shapes: &HashMap<String, Vec<usize>>,
178) -> Result<(), ExecutorError> {
179    for (name, expected_shape) in expected_shapes {
180        if let Some(tensor) = executor.get_tensor(name) {
181            let actual_shape = tensor.shape();
182            if actual_shape != expected_shape.as_slice() {
183                return Err(ExecutorError::ShapeMismatch(format!(
184                    "Tensor '{}': expected shape {:?}, got {:?}",
185                    name, expected_shape, actual_shape
186                )));
187            }
188        }
189    }
190    Ok(())
191}
192
193#[cfg(all(test, feature = "integration-tests"))]
194mod tests {
195    use super::*;
196    use scirs2_core::ndarray::ArrayD;
197    use tensorlogic_compiler::compile_to_einsum;
198    use tensorlogic_ir::{TLExpr, Term};
199
200    fn create_test_tensor(shape: &[usize]) -> ArrayD<f64> {
201        ArrayD::zeros(shape.to_vec())
202    }
203
204    #[test]
205    fn test_shape_inference_basic() {
206        let x = TLExpr::pred("x", vec![Term::var("i"), Term::var("j")]);
207        let y = TLExpr::pred("y", vec![Term::var("i"), Term::var("j")]);
208        let expr = TLExpr::add(x, y);
209        let graph = compile_to_einsum(&expr).unwrap();
210
211        let mut executor = Scirs2Exec::new();
212        executor.add_tensor(graph.tensors[0].clone(), create_test_tensor(&[3, 4]));
213        executor.add_tensor(graph.tensors[1].clone(), create_test_tensor(&[3, 4]));
214
215        let mut inference = Scirs2ShapeInference::new();
216        let context = inference.infer_graph_shapes(&graph, &executor).unwrap();
217
218        // Check that shapes were inferred
219        assert!(context.get_tensor_shape(0).is_some());
220        assert!(context.get_tensor_shape(1).is_some());
221    }
222
223    #[test]
224    fn test_validate_shapes_success() {
225        let mut executor = Scirs2Exec::new();
226        executor.add_tensor("x".to_string(), create_test_tensor(&[2, 3]));
227        executor.add_tensor("y".to_string(), create_test_tensor(&[4, 5]));
228
229        let mut expected = HashMap::new();
230        expected.insert("x".to_string(), vec![2, 3]);
231        expected.insert("y".to_string(), vec![4, 5]);
232
233        let result = validate_tensor_shapes(&executor, &expected);
234        assert!(result.is_ok());
235    }
236
237    #[test]
238    fn test_validate_shapes_mismatch() {
239        let mut executor = Scirs2Exec::new();
240        executor.add_tensor("x".to_string(), create_test_tensor(&[2, 3]));
241
242        let mut expected = HashMap::new();
243        expected.insert("x".to_string(), vec![3, 4]); // Wrong shape
244
245        let result = validate_tensor_shapes(&executor, &expected);
246        assert!(result.is_err());
247    }
248
249    #[test]
250    fn test_infer_unary_shape() {
251        let inference = Scirs2ShapeInference::new();
252        let input_shape = TensorShape::static_shape(vec![2, 3, 4]);
253
254        // Unary operations preserve shape
255        let node = tensorlogic_ir::EinsumNode {
256            inputs: vec![0],
257            outputs: vec![1],
258            op: OpType::ElemUnary {
259                op: "relu".to_string(),
260            },
261            metadata: None,
262        };
263
264        let mut context = ShapeInferenceContext::new();
265        context.set_tensor_shape(0, input_shape.clone());
266
267        inference.infer_node_shape(&node, &mut context).unwrap();
268
269        let output_shape = context.get_tensor_shape(1).unwrap();
270        assert_eq!(output_shape, &input_shape);
271    }
272
273    #[test]
274    fn test_infer_reduce_shape() {
275        let inference = Scirs2ShapeInference::new();
276
277        // Reduce along axis 1: [2, 3, 4] -> [2, 4]
278        let result = inference
279            .infer_reduce_shape(&TensorShape::static_shape(vec![2, 3, 4]), &[1])
280            .unwrap();
281
282        let result_dims = result.as_static().unwrap();
283        assert_eq!(result_dims, vec![2, 4]);
284    }
285
286    #[test]
287    fn test_infer_binary_shape_matching() {
288        let inference = Scirs2ShapeInference::new();
289
290        let shape1 = TensorShape::static_shape(vec![2, 3]);
291        let shape2 = TensorShape::static_shape(vec![2, 3]);
292
293        let result = inference.infer_binary_shape(&shape1, &shape2).unwrap();
294
295        let result_dims = result.as_static().unwrap();
296        assert_eq!(result_dims, vec![2, 3]);
297    }
298
299    #[test]
300    fn test_infer_binary_shape_scalar_broadcast() {
301        let inference = Scirs2ShapeInference::new();
302
303        let shape1 = TensorShape::static_shape(vec![]); // Scalar
304        let shape2 = TensorShape::static_shape(vec![2, 3]);
305
306        let result = inference.infer_binary_shape(&shape1, &shape2).unwrap();
307
308        let result_dims = result.as_static().unwrap();
309        assert_eq!(result_dims, vec![2, 3]);
310    }
311
312    #[test]
313    fn test_infer_binary_shape_mismatch() {
314        let inference = Scirs2ShapeInference::new();
315
316        let shape1 = TensorShape::static_shape(vec![2, 3]);
317        let shape2 = TensorShape::static_shape(vec![4, 5]);
318
319        let result = inference.infer_binary_shape(&shape1, &shape2);
320        assert!(result.is_err());
321    }
322}