1use std::collections::HashMap;
4
5use tensorlogic_ir::{EinsumGraph, EinsumNode, OpType};
6
7#[derive(Debug, Clone, PartialEq, Eq)]
9pub enum DimSize {
10 Static(usize),
12 Dynamic,
14 Symbolic(String),
16}
17
18impl DimSize {
19 pub fn is_static(&self) -> bool {
20 matches!(self, DimSize::Static(_))
21 }
22
23 pub fn as_static(&self) -> Option<usize> {
24 match self {
25 DimSize::Static(size) => Some(*size),
26 _ => None,
27 }
28 }
29}
30
31#[derive(Debug, Clone, PartialEq, Eq)]
33pub struct TensorShape {
34 pub dims: Vec<DimSize>,
35}
36
37impl TensorShape {
38 pub fn new(dims: Vec<DimSize>) -> Self {
39 TensorShape { dims }
40 }
41
42 pub fn static_shape(sizes: Vec<usize>) -> Self {
43 TensorShape {
44 dims: sizes.into_iter().map(DimSize::Static).collect(),
45 }
46 }
47
48 pub fn dynamic(rank: usize) -> Self {
49 TensorShape {
50 dims: vec![DimSize::Dynamic; rank],
51 }
52 }
53
54 pub fn rank(&self) -> usize {
55 self.dims.len()
56 }
57
58 pub fn is_static(&self) -> bool {
59 self.dims.iter().all(|d| d.is_static())
60 }
61
62 pub fn as_static(&self) -> Option<Vec<usize>> {
63 self.dims.iter().map(|d| d.as_static()).collect()
64 }
65
66 pub fn compatible_with(&self, other: &TensorShape) -> bool {
68 if self.rank() != other.rank() {
69 return false;
70 }
71
72 for (a, b) in self.dims.iter().zip(other.dims.iter()) {
73 match (a, b) {
74 (DimSize::Static(size_a), DimSize::Static(size_b)) => {
75 if size_a != size_b && *size_a != 1 && *size_b != 1 {
76 return false;
77 }
78 }
79 _ => {
80 }
82 }
83 }
84
85 true
86 }
87}
88
89pub struct ShapeInferenceContext {
91 tensor_shapes: HashMap<usize, TensorShape>,
92}
93
94impl ShapeInferenceContext {
95 pub fn new() -> Self {
96 ShapeInferenceContext {
97 tensor_shapes: HashMap::new(),
98 }
99 }
100
101 pub fn set_tensor_shape(&mut self, tensor_idx: usize, shape: TensorShape) {
102 self.tensor_shapes.insert(tensor_idx, shape);
103 }
104
105 pub fn get_tensor_shape(&self, tensor_idx: usize) -> Option<&TensorShape> {
106 self.tensor_shapes.get(&tensor_idx)
107 }
108
109 pub fn infer_graph_shapes(
111 &mut self,
112 graph: &EinsumGraph,
113 input_shapes: &HashMap<usize, TensorShape>,
114 ) -> Result<(), String> {
115 for (idx, shape) in input_shapes {
117 self.tensor_shapes.insert(*idx, shape.clone());
118 }
119
120 for (node_idx, node) in graph.nodes.iter().enumerate() {
122 let output_idx = node_idx + graph.tensors.len(); let output_shape = self.infer_node_shape(node)?;
124 self.tensor_shapes.insert(output_idx, output_shape);
125 }
126
127 Ok(())
128 }
129
130 fn infer_node_shape(&self, node: &EinsumNode) -> Result<TensorShape, String> {
131 match &node.op {
132 OpType::Einsum { spec } => {
133 self.infer_einsum_shape(spec, &node.inputs)
135 }
136 OpType::ElemUnary { op: _ } => {
137 if let Some(input_shape) = self.get_tensor_shape(node.inputs[0]) {
139 Ok(input_shape.clone())
140 } else {
141 Err("Input shape not available for unary op".to_string())
142 }
143 }
144 OpType::ElemBinary { op: _ } => {
145 if node.inputs.len() < 2 {
147 return Err("Binary op requires 2 inputs".to_string());
148 }
149
150 let shape_a = self
151 .get_tensor_shape(node.inputs[0])
152 .ok_or("Input 0 shape not available")?;
153 let shape_b = self
154 .get_tensor_shape(node.inputs[1])
155 .ok_or("Input 1 shape not available")?;
156
157 if !shape_a.compatible_with(shape_b) {
158 return Err(format!(
159 "Incompatible shapes for binary op: {:?} vs {:?}",
160 shape_a, shape_b
161 ));
162 }
163
164 Ok(shape_a.clone())
166 }
167 OpType::Reduce { op: _, axes } => {
168 if let Some(input_shape) = self.get_tensor_shape(node.inputs[0]) {
169 let mut output_dims = input_shape.dims.clone();
171 for &axis in axes.iter().rev() {
172 if axis < output_dims.len() {
173 output_dims.remove(axis);
174 }
175 }
176 Ok(TensorShape::new(output_dims))
177 } else {
178 Err("Input shape not available for reduce op".to_string())
179 }
180 }
181 }
182 }
183
184 fn infer_einsum_shape(&self, spec: &str, inputs: &[usize]) -> Result<TensorShape, String> {
185 let (input_specs, output_spec) = if let Some(arrow_pos) = spec.find("->") {
187 let input_part = &spec[..arrow_pos];
188 let output_part = &spec[arrow_pos + 2..];
189 (input_part, Some(output_part))
190 } else {
191 (spec, None)
192 };
193
194 let input_specs: Vec<&str> = input_specs.split(',').map(|s| s.trim()).collect();
196
197 if input_specs.len() != inputs.len() {
198 return Err(format!(
199 "Einsum spec has {} inputs but {} tensors provided",
200 input_specs.len(),
201 inputs.len()
202 ));
203 }
204
205 let mut dim_sizes: std::collections::HashMap<char, DimSize> =
207 std::collections::HashMap::new();
208
209 for (spec_idx, &input_idx) in inputs.iter().enumerate() {
210 let input_shape = self
211 .get_tensor_shape(input_idx)
212 .ok_or_else(|| format!("Input {} shape not available", input_idx))?;
213
214 let axes = input_specs[spec_idx].chars().collect::<Vec<_>>();
215
216 if axes.len() != input_shape.rank() {
217 return Err(format!(
218 "Input {} spec '{}' has {} axes but tensor has rank {}",
219 spec_idx,
220 input_specs[spec_idx],
221 axes.len(),
222 input_shape.rank()
223 ));
224 }
225
226 for (axis_idx, axis_char) in axes.iter().enumerate() {
228 let dim_size = input_shape.dims[axis_idx].clone();
229
230 if let Some(existing) = dim_sizes.get(axis_char) {
231 if let (DimSize::Static(size1), DimSize::Static(size2)) = (existing, &dim_size)
233 {
234 if size1 != size2 {
235 return Err(format!(
236 "Dimension '{}' has inconsistent sizes: {} vs {}",
237 axis_char, size1, size2
238 ));
239 }
240 }
241 } else {
242 dim_sizes.insert(*axis_char, dim_size);
243 }
244 }
245 }
246
247 let output_dims = if let Some(output_axes) = output_spec {
249 output_axes
251 .chars()
252 .map(|c| dim_sizes.get(&c).cloned().unwrap_or(DimSize::Dynamic))
253 .collect()
254 } else {
255 let mut all_axes: Vec<char> = dim_sizes.keys().copied().collect();
257 all_axes.sort();
258 all_axes
259 .into_iter()
260 .map(|c| dim_sizes[&c].clone())
261 .collect()
262 };
263
264 Ok(TensorShape::new(output_dims))
265 }
266}
267
268impl Default for ShapeInferenceContext {
269 fn default() -> Self {
270 Self::new()
271 }
272}
273
274#[cfg(test)]
275mod tests {
276 use super::*;
277
278 #[test]
279 fn test_tensor_shape_static() {
280 let shape = TensorShape::static_shape(vec![3, 4, 5]);
281 assert_eq!(shape.rank(), 3);
282 assert!(shape.is_static());
283 assert_eq!(shape.as_static(), Some(vec![3, 4, 5]));
284 }
285
286 #[test]
287 fn test_tensor_shape_dynamic() {
288 let shape = TensorShape::dynamic(3);
289 assert_eq!(shape.rank(), 3);
290 assert!(!shape.is_static());
291 assert_eq!(shape.as_static(), None);
292 }
293
294 #[test]
295 fn test_shape_compatibility() {
296 let shape1 = TensorShape::static_shape(vec![3, 4]);
297 let shape2 = TensorShape::static_shape(vec![3, 4]);
298 assert!(shape1.compatible_with(&shape2));
299
300 let shape3 = TensorShape::static_shape(vec![3, 1]);
301 assert!(shape1.compatible_with(&shape3)); let shape4 = TensorShape::static_shape(vec![3, 5]);
304 assert!(!shape1.compatible_with(&shape4));
305 }
306
307 #[test]
308 fn test_shape_inference_context() {
309 let mut ctx = ShapeInferenceContext::new();
310 let shape = TensorShape::static_shape(vec![2, 3]);
311
312 ctx.set_tensor_shape(0, shape.clone());
313 assert_eq!(ctx.get_tensor_shape(0), Some(&shape));
314 assert_eq!(ctx.get_tensor_shape(1), None);
315 }
316
317 #[test]
318 fn test_einsum_shape_inference() {
319 let mut ctx = ShapeInferenceContext::new();
320
321 ctx.set_tensor_shape(0, TensorShape::static_shape(vec![3, 4]));
323 ctx.set_tensor_shape(1, TensorShape::static_shape(vec![4, 5]));
324
325 let shape = ctx.infer_einsum_shape("ab,bc->ac", &[0, 1]).unwrap();
327 assert_eq!(shape.rank(), 2);
328 assert_eq!(shape.as_static(), Some(vec![3, 5]));
329 }
330
331 #[test]
332 fn test_einsum_shape_inference_explicit() {
333 let mut ctx = ShapeInferenceContext::new();
334 ctx.set_tensor_shape(0, TensorShape::static_shape(vec![2, 3, 4]));
335
336 let shape = ctx.infer_einsum_shape("abc->ab", &[0]).unwrap();
338 assert_eq!(shape.rank(), 2);
339 assert_eq!(shape.as_static(), Some(vec![2, 3]));
340 }
341
342 #[test]
343 fn test_einsum_shape_inference_diagonal() {
344 let mut ctx = ShapeInferenceContext::new();
345 ctx.set_tensor_shape(0, TensorShape::static_shape(vec![3, 3]));
346
347 let shape = ctx.infer_einsum_shape("aa->a", &[0]).unwrap();
349 assert_eq!(shape.rank(), 1);
350 assert_eq!(shape.as_static(), Some(vec![3]));
351 }
352
353 #[test]
354 fn test_einsum_shape_inference_batch_matmul() {
355 let mut ctx = ShapeInferenceContext::new();
356 ctx.set_tensor_shape(0, TensorShape::static_shape(vec![10, 3, 4]));
357 ctx.set_tensor_shape(1, TensorShape::static_shape(vec![10, 4, 5]));
358
359 let shape = ctx.infer_einsum_shape("bik,bkj->bij", &[0, 1]).unwrap();
361 assert_eq!(shape.rank(), 3);
362 assert_eq!(shape.as_static(), Some(vec![10, 3, 5]));
363 }
364
365 #[test]
366 fn test_einsum_shape_inference_inconsistent_dims() {
367 let mut ctx = ShapeInferenceContext::new();
368 ctx.set_tensor_shape(0, TensorShape::static_shape(vec![3, 4]));
369 ctx.set_tensor_shape(1, TensorShape::static_shape(vec![5, 6]));
370
371 let result = ctx.infer_einsum_shape("ab,bc->ac", &[0, 1]);
373 assert!(result.is_err());
374 assert!(result.unwrap_err().contains("inconsistent"));
375 }
376}