Skip to main content

tensorlogic_infer/
shape.rs

1//! Tensor shape inference and validation.
2
3use std::collections::HashMap;
4
5use tensorlogic_ir::{EinsumGraph, EinsumNode, OpType};
6
7/// Shape information for a tensor dimension
8#[derive(Debug, Clone, PartialEq, Eq)]
9pub enum DimSize {
10    /// Static known size
11    Static(usize),
12    /// Dynamic size (known at runtime)
13    Dynamic,
14    /// Symbolic dimension (e.g., batch size)
15    Symbolic(String),
16}
17
18impl DimSize {
19    pub fn is_static(&self) -> bool {
20        matches!(self, DimSize::Static(_))
21    }
22
23    pub fn as_static(&self) -> Option<usize> {
24        match self {
25            DimSize::Static(size) => Some(*size),
26            _ => None,
27        }
28    }
29}
30
31/// Tensor shape representation
32#[derive(Debug, Clone, PartialEq, Eq)]
33pub struct TensorShape {
34    pub dims: Vec<DimSize>,
35}
36
37impl TensorShape {
38    pub fn new(dims: Vec<DimSize>) -> Self {
39        TensorShape { dims }
40    }
41
42    pub fn static_shape(sizes: Vec<usize>) -> Self {
43        TensorShape {
44            dims: sizes.into_iter().map(DimSize::Static).collect(),
45        }
46    }
47
48    pub fn dynamic(rank: usize) -> Self {
49        TensorShape {
50            dims: vec![DimSize::Dynamic; rank],
51        }
52    }
53
54    pub fn rank(&self) -> usize {
55        self.dims.len()
56    }
57
58    pub fn is_static(&self) -> bool {
59        self.dims.iter().all(|d| d.is_static())
60    }
61
62    pub fn as_static(&self) -> Option<Vec<usize>> {
63        self.dims.iter().map(|d| d.as_static()).collect()
64    }
65
66    /// Check if two shapes are compatible (can broadcast or are equal)
67    pub fn compatible_with(&self, other: &TensorShape) -> bool {
68        if self.rank() != other.rank() {
69            return false;
70        }
71
72        for (a, b) in self.dims.iter().zip(other.dims.iter()) {
73            match (a, b) {
74                (DimSize::Static(size_a), DimSize::Static(size_b)) => {
75                    if size_a != size_b && *size_a != 1 && *size_b != 1 {
76                        return false;
77                    }
78                }
79                _ => {
80                    // Dynamic or symbolic dims are always compatible
81                }
82            }
83        }
84
85        true
86    }
87}
88
89/// Shape inference context
90pub struct ShapeInferenceContext {
91    tensor_shapes: HashMap<usize, TensorShape>,
92}
93
94impl ShapeInferenceContext {
95    pub fn new() -> Self {
96        ShapeInferenceContext {
97            tensor_shapes: HashMap::new(),
98        }
99    }
100
101    pub fn set_tensor_shape(&mut self, tensor_idx: usize, shape: TensorShape) {
102        self.tensor_shapes.insert(tensor_idx, shape);
103    }
104
105    pub fn get_tensor_shape(&self, tensor_idx: usize) -> Option<&TensorShape> {
106        self.tensor_shapes.get(&tensor_idx)
107    }
108
109    /// Infer shapes for all tensors in a graph
110    pub fn infer_graph_shapes(
111        &mut self,
112        graph: &EinsumGraph,
113        input_shapes: &HashMap<usize, TensorShape>,
114    ) -> Result<(), String> {
115        // Copy input shapes
116        for (idx, shape) in input_shapes {
117            self.tensor_shapes.insert(*idx, shape.clone());
118        }
119
120        // Infer shapes for each node
121        for (node_idx, node) in graph.nodes.iter().enumerate() {
122            let output_idx = node_idx + graph.tensors.len(); // Simplified
123            let output_shape = self.infer_node_shape(node)?;
124            self.tensor_shapes.insert(output_idx, output_shape);
125        }
126
127        Ok(())
128    }
129
130    fn infer_node_shape(&self, node: &EinsumNode) -> Result<TensorShape, String> {
131        match &node.op {
132            OpType::Einsum { spec } => {
133                // Parse einsum spec to infer output shape
134                self.infer_einsum_shape(spec, &node.inputs)
135            }
136            OpType::ElemUnary { op: _ } => {
137                // Unary ops preserve shape
138                if let Some(input_shape) = self.get_tensor_shape(node.inputs[0]) {
139                    Ok(input_shape.clone())
140                } else {
141                    Err("Input shape not available for unary op".to_string())
142                }
143            }
144            OpType::ElemBinary { op: _ } => {
145                // Binary ops require compatible shapes
146                if node.inputs.len() < 2 {
147                    return Err("Binary op requires 2 inputs".to_string());
148                }
149
150                let shape_a = self
151                    .get_tensor_shape(node.inputs[0])
152                    .ok_or("Input 0 shape not available")?;
153                let shape_b = self
154                    .get_tensor_shape(node.inputs[1])
155                    .ok_or("Input 1 shape not available")?;
156
157                if !shape_a.compatible_with(shape_b) {
158                    return Err(format!(
159                        "Incompatible shapes for binary op: {:?} vs {:?}",
160                        shape_a, shape_b
161                    ));
162                }
163
164                // Return the broadcasted shape
165                Ok(shape_a.clone())
166            }
167            OpType::Reduce { op: _, axes } => {
168                if let Some(input_shape) = self.get_tensor_shape(node.inputs[0]) {
169                    // Remove reduced dimensions
170                    let mut output_dims = input_shape.dims.clone();
171                    for &axis in axes.iter().rev() {
172                        if axis < output_dims.len() {
173                            output_dims.remove(axis);
174                        }
175                    }
176                    Ok(TensorShape::new(output_dims))
177                } else {
178                    Err("Input shape not available for reduce op".to_string())
179                }
180            }
181        }
182    }
183
184    fn infer_einsum_shape(&self, spec: &str, inputs: &[usize]) -> Result<TensorShape, String> {
185        // Parse einsum specification
186        let (input_specs, output_spec) = if let Some(arrow_pos) = spec.find("->") {
187            let input_part = &spec[..arrow_pos];
188            let output_part = &spec[arrow_pos + 2..];
189            (input_part, Some(output_part))
190        } else {
191            (spec, None)
192        };
193
194        // Parse input specs (e.g., "ab,bc" -> ["ab", "bc"])
195        let input_specs: Vec<&str> = input_specs.split(',').map(|s| s.trim()).collect();
196
197        if input_specs.len() != inputs.len() {
198            return Err(format!(
199                "Einsum spec has {} inputs but {} tensors provided",
200                input_specs.len(),
201                inputs.len()
202            ));
203        }
204
205        // Build dimension size map from input tensors
206        let mut dim_sizes: std::collections::HashMap<char, DimSize> =
207            std::collections::HashMap::new();
208
209        for (spec_idx, &input_idx) in inputs.iter().enumerate() {
210            let input_shape = self
211                .get_tensor_shape(input_idx)
212                .ok_or_else(|| format!("Input {} shape not available", input_idx))?;
213
214            let axes = input_specs[spec_idx].chars().collect::<Vec<_>>();
215
216            if axes.len() != input_shape.rank() {
217                return Err(format!(
218                    "Input {} spec '{}' has {} axes but tensor has rank {}",
219                    spec_idx,
220                    input_specs[spec_idx],
221                    axes.len(),
222                    input_shape.rank()
223                ));
224            }
225
226            // Map each axis character to its dimension size
227            for (axis_idx, axis_char) in axes.iter().enumerate() {
228                let dim_size = input_shape.dims[axis_idx].clone();
229
230                if let Some(existing) = dim_sizes.get(axis_char) {
231                    // Check consistency
232                    if let (DimSize::Static(size1), DimSize::Static(size2)) = (existing, &dim_size)
233                    {
234                        if size1 != size2 {
235                            return Err(format!(
236                                "Dimension '{}' has inconsistent sizes: {} vs {}",
237                                axis_char, size1, size2
238                            ));
239                        }
240                    }
241                } else {
242                    dim_sizes.insert(*axis_char, dim_size);
243                }
244            }
245        }
246
247        // Determine output shape
248        let output_dims = if let Some(output_axes) = output_spec {
249            // Explicit output specification
250            output_axes
251                .chars()
252                .map(|c| dim_sizes.get(&c).cloned().unwrap_or(DimSize::Dynamic))
253                .collect()
254        } else {
255            // Implicit output: all non-repeated indices in alphabetical order
256            let mut all_axes: Vec<char> = dim_sizes.keys().copied().collect();
257            all_axes.sort();
258            all_axes
259                .into_iter()
260                .map(|c| dim_sizes[&c].clone())
261                .collect()
262        };
263
264        Ok(TensorShape::new(output_dims))
265    }
266}
267
268impl Default for ShapeInferenceContext {
269    fn default() -> Self {
270        Self::new()
271    }
272}
273
274#[cfg(test)]
275mod tests {
276    use super::*;
277
278    #[test]
279    fn test_tensor_shape_static() {
280        let shape = TensorShape::static_shape(vec![3, 4, 5]);
281        assert_eq!(shape.rank(), 3);
282        assert!(shape.is_static());
283        assert_eq!(shape.as_static(), Some(vec![3, 4, 5]));
284    }
285
286    #[test]
287    fn test_tensor_shape_dynamic() {
288        let shape = TensorShape::dynamic(3);
289        assert_eq!(shape.rank(), 3);
290        assert!(!shape.is_static());
291        assert_eq!(shape.as_static(), None);
292    }
293
294    #[test]
295    fn test_shape_compatibility() {
296        let shape1 = TensorShape::static_shape(vec![3, 4]);
297        let shape2 = TensorShape::static_shape(vec![3, 4]);
298        assert!(shape1.compatible_with(&shape2));
299
300        let shape3 = TensorShape::static_shape(vec![3, 1]);
301        assert!(shape1.compatible_with(&shape3)); // Broadcasting
302
303        let shape4 = TensorShape::static_shape(vec![3, 5]);
304        assert!(!shape1.compatible_with(&shape4));
305    }
306
307    #[test]
308    fn test_shape_inference_context() {
309        let mut ctx = ShapeInferenceContext::new();
310        let shape = TensorShape::static_shape(vec![2, 3]);
311
312        ctx.set_tensor_shape(0, shape.clone());
313        assert_eq!(ctx.get_tensor_shape(0), Some(&shape));
314        assert_eq!(ctx.get_tensor_shape(1), None);
315    }
316
317    #[test]
318    fn test_einsum_shape_inference() {
319        let mut ctx = ShapeInferenceContext::new();
320
321        // Set up input shapes
322        ctx.set_tensor_shape(0, TensorShape::static_shape(vec![3, 4]));
323        ctx.set_tensor_shape(1, TensorShape::static_shape(vec![4, 5]));
324
325        // "ab,bc->ac" should produce shape [3, 5]
326        let shape = ctx.infer_einsum_shape("ab,bc->ac", &[0, 1]).unwrap();
327        assert_eq!(shape.rank(), 2);
328        assert_eq!(shape.as_static(), Some(vec![3, 5]));
329    }
330
331    #[test]
332    fn test_einsum_shape_inference_explicit() {
333        let mut ctx = ShapeInferenceContext::new();
334        ctx.set_tensor_shape(0, TensorShape::static_shape(vec![2, 3, 4]));
335
336        // "abc->ab" should produce shape [2, 3]
337        let shape = ctx.infer_einsum_shape("abc->ab", &[0]).unwrap();
338        assert_eq!(shape.rank(), 2);
339        assert_eq!(shape.as_static(), Some(vec![2, 3]));
340    }
341
342    #[test]
343    fn test_einsum_shape_inference_diagonal() {
344        let mut ctx = ShapeInferenceContext::new();
345        ctx.set_tensor_shape(0, TensorShape::static_shape(vec![3, 3]));
346
347        // "aa->a" should produce shape [3]
348        let shape = ctx.infer_einsum_shape("aa->a", &[0]).unwrap();
349        assert_eq!(shape.rank(), 1);
350        assert_eq!(shape.as_static(), Some(vec![3]));
351    }
352
353    #[test]
354    fn test_einsum_shape_inference_batch_matmul() {
355        let mut ctx = ShapeInferenceContext::new();
356        ctx.set_tensor_shape(0, TensorShape::static_shape(vec![10, 3, 4]));
357        ctx.set_tensor_shape(1, TensorShape::static_shape(vec![10, 4, 5]));
358
359        // "bik,bkj->bij" should produce shape [10, 3, 5]
360        let shape = ctx.infer_einsum_shape("bik,bkj->bij", &[0, 1]).unwrap();
361        assert_eq!(shape.rank(), 3);
362        assert_eq!(shape.as_static(), Some(vec![10, 3, 5]));
363    }
364
365    #[test]
366    fn test_einsum_shape_inference_inconsistent_dims() {
367        let mut ctx = ShapeInferenceContext::new();
368        ctx.set_tensor_shape(0, TensorShape::static_shape(vec![3, 4]));
369        ctx.set_tensor_shape(1, TensorShape::static_shape(vec![5, 6]));
370
371        // "ab,bc->ac" should fail because 'b' has different sizes (4 vs 5)
372        let result = ctx.infer_einsum_shape("ab,bc->ac", &[0, 1]);
373        assert!(result.is_err());
374        assert!(result.unwrap_err().contains("inconsistent"));
375    }
376}