Skip to main content

tenflowers_core/
shape_error_taxonomy.rs

1/// Standardized shape error taxonomy for TenfloweRS
2///
3/// This module provides utilities for creating consistent, helpful shape error messages
4/// across all tensor operations. All operations should use these utilities to ensure
5/// a uniform error reporting experience.
6use crate::{Result, Shape, TensorError};
7
8/// Category of shape error for better diagnostics
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10pub enum ShapeErrorCategory {
11    /// Shapes don't match for elementwise operations
12    ElementwiseMismatch,
13    /// Broadcasting rules violated
14    BroadcastIncompatible,
15    /// Matrix multiplication dimension mismatch
16    MatMulIncompatible,
17    /// Convolution parameter mismatch
18    ConvolutionInvalid,
19    /// Reduction axis invalid
20    ReductionAxisInvalid,
21    /// Reshape/view parameters invalid
22    ReshapeInvalid,
23    /// Concatenation/stacking dimension mismatch
24    ConcatenationInvalid,
25    /// Transpose/permutation invalid
26    TransposeInvalid,
27    /// Padding parameters invalid
28    PaddingInvalid,
29    /// General dimension constraint violation
30    DimensionConstraintViolated,
31}
32
33impl ShapeErrorCategory {
34    /// Get a user-friendly name for this error category
35    pub fn name(&self) -> &'static str {
36        match self {
37            Self::ElementwiseMismatch => "Elementwise Shape Mismatch",
38            Self::BroadcastIncompatible => "Broadcasting Incompatibility",
39            Self::MatMulIncompatible => "Matrix Multiplication Incompatibility",
40            Self::ConvolutionInvalid => "Convolution Parameter Invalid",
41            Self::ReductionAxisInvalid => "Reduction Axis Invalid",
42            Self::ReshapeInvalid => "Reshape Invalid",
43            Self::ConcatenationInvalid => "Concatenation Invalid",
44            Self::TransposeInvalid => "Transpose Invalid",
45            Self::PaddingInvalid => "Padding Invalid",
46            Self::DimensionConstraintViolated => "Dimension Constraint Violated",
47        }
48    }
49
50    /// Get a description of how to fix this category of error
51    pub fn fix_suggestion(&self) -> &'static str {
52        match self {
53            Self::ElementwiseMismatch => {
54                "Ensure input tensors have identical shapes for elementwise operations"
55            }
56            Self::BroadcastIncompatible => {
57                "Review NumPy broadcasting rules: dimensions must be equal or one of them must be 1"
58            }
59            Self::MatMulIncompatible => "For matmul(A, B), ensure A.shape[-1] == B.shape[-2]",
60            Self::ConvolutionInvalid => {
61                "Check kernel size, stride, padding, and dilation parameters"
62            }
63            Self::ReductionAxisInvalid => {
64                "Ensure reduction axis is within [0, ndim) or use -1 for last axis"
65            }
66            Self::ReshapeInvalid => "New shape must have same total number of elements as original",
67            Self::ConcatenationInvalid => {
68                "All tensors must have same shape except in the concatenation dimension"
69            }
70            Self::TransposeInvalid => "Permutation must be a valid reordering of axes [0..ndim)",
71            Self::PaddingInvalid => "Padding values must be non-negative",
72            Self::DimensionConstraintViolated => {
73                "Review operation documentation for dimension requirements"
74            }
75        }
76    }
77}
78
79/// Detailed shape error builder for maximum clarity
80pub struct ShapeErrorBuilder {
81    operation: String,
82    category: ShapeErrorCategory,
83    expected: String,
84    got: String,
85    details: Vec<String>,
86    suggestions: Vec<String>,
87}
88
89impl ShapeErrorBuilder {
90    /// Create a new shape error builder
91    pub fn new(operation: &str, category: ShapeErrorCategory) -> Self {
92        Self {
93            operation: operation.to_string(),
94            category,
95            expected: String::new(),
96            got: String::new(),
97            details: Vec::new(),
98            suggestions: vec![category.fix_suggestion().to_string()],
99        }
100    }
101
102    /// Set expected shape description
103    pub fn expected(mut self, expected: &str) -> Self {
104        self.expected = expected.to_string();
105        self
106    }
107
108    /// Set actual shape description
109    pub fn got(mut self, got: &str) -> Self {
110        self.got = got.to_string();
111        self
112    }
113
114    /// Add a detail line
115    pub fn detail(mut self, detail: &str) -> Self {
116        self.details.push(detail.to_string());
117        self
118    }
119
120    /// Add a suggestion for fixing the error
121    pub fn suggestion(mut self, suggestion: &str) -> Self {
122        self.suggestions.push(suggestion.to_string());
123        self
124    }
125
126    /// Build the final error
127    pub fn build(self) -> TensorError {
128        let mut message = format!(
129            "[{}] in operation '{}'",
130            self.category.name(),
131            self.operation
132        );
133
134        if !self.expected.is_empty() {
135            message.push_str(&format!("\nExpected: {}", self.expected));
136        }
137
138        if !self.got.is_empty() {
139            message.push_str(&format!("\nGot:      {}", self.got));
140        }
141
142        if !self.details.is_empty() {
143            message.push_str("\n\nDetails:");
144            for detail in &self.details {
145                message.push_str(&format!("\n  • {}", detail));
146            }
147        }
148
149        if !self.suggestions.is_empty() {
150            message.push_str("\n\nSuggestions:");
151            for suggestion in &self.suggestions {
152                message.push_str(&format!("\n  • {}", suggestion));
153            }
154        }
155
156        TensorError::invalid_shape_simple(message)
157    }
158}
159
160/// Utilities for common shape error scenarios
161pub struct ShapeErrorUtils;
162
163impl ShapeErrorUtils {
164    /// Create error for elementwise operation shape mismatch
165    pub fn elementwise_mismatch(operation: &str, shape1: &Shape, shape2: &Shape) -> TensorError {
166        ShapeErrorBuilder::new(operation, ShapeErrorCategory::ElementwiseMismatch)
167            .expected(&format!("identical shapes: {}", shape1))
168            .got(&format!("shapes {} and {}", shape1, shape2))
169            .detail("Elementwise operations require tensors with identical shapes")
170            .build()
171    }
172
173    /// Create error for broadcasting incompatibility
174    pub fn broadcast_incompatible(operation: &str, shape1: &Shape, shape2: &Shape) -> TensorError {
175        ShapeErrorBuilder::new(operation, ShapeErrorCategory::BroadcastIncompatible)
176            .expected(&format!(
177                "broadcastable shapes (matching dims or dim=1): {} and {}",
178                shape1, shape2
179            ))
180            .got(&format!(
181                "non-broadcastable shapes {} and {}",
182                shape1, shape2
183            ))
184            .detail("Broadcasting rules: dimensions must match or one must be 1")
185            .build()
186    }
187
188    /// Create error for matrix multiplication dimension mismatch
189    pub fn matmul_incompatible(
190        operation: &str,
191        shape_a: &Shape,
192        shape_b: &Shape,
193        transpose_a: bool,
194        transpose_b: bool,
195    ) -> TensorError {
196        let (m, k1) = if transpose_a {
197            (shape_a.dims()[1], shape_a.dims()[0])
198        } else {
199            (shape_a.dims()[0], shape_a.dims()[1])
200        };
201
202        let (k2, n) = if transpose_b {
203            (shape_b.dims()[1], shape_b.dims()[0])
204        } else {
205            (shape_b.dims()[0], shape_b.dims()[1])
206        };
207
208        ShapeErrorBuilder::new(operation, ShapeErrorCategory::MatMulIncompatible)
209            .expected(&format!(
210                "compatible matrix dimensions: inner dimensions must match (k1={} should equal k2={})",
211                k1, k2
212            ))
213            .got(&format!(
214                "A{}: {} ({}×{}), B{}: {} ({}×{})",
215                if transpose_a { ".T" } else { "" },
216                shape_a,
217                m,
218                k1,
219                if transpose_b { ".T" } else { "" },
220                shape_b,
221                k2,
222                n
223            ))
224            .detail(&format!("Result shape would be: ({}, {})", m, n))
225            .detail(&format!(
226                "Transpose flags: transpose_a={}, transpose_b={}",
227                transpose_a, transpose_b
228            ))
229            .build()
230    }
231
232    /// Create error for invalid reduction axis
233    pub fn invalid_reduction_axis(operation: &str, axis: isize, shape: &Shape) -> TensorError {
234        let ndim = shape.rank();
235        ShapeErrorBuilder::new(operation, ShapeErrorCategory::ReductionAxisInvalid)
236            .expected(&format!("axis in range [0, {}) or [-{}, -1]", ndim, ndim))
237            .got(&format!("axis = {}", axis))
238            .detail(&format!("Tensor shape: {}", shape))
239            .detail(&format!("Number of dimensions: {}", ndim))
240            .suggestion("Use axis=-1 to reduce over the last dimension")
241            .build()
242    }
243
244    /// Create error for invalid reshape
245    pub fn invalid_reshape(
246        operation: &str,
247        original_shape: &Shape,
248        new_shape: &[usize],
249    ) -> TensorError {
250        let original_size: usize = original_shape.dims().iter().product();
251        let new_size: usize = new_shape.iter().product();
252
253        ShapeErrorBuilder::new(operation, ShapeErrorCategory::ReshapeInvalid)
254            .expected(&format!(
255                "new shape with total elements = {} (same as original)",
256                original_size
257            ))
258            .got(&format!(
259                "shape {:?} with total elements = {}",
260                new_shape, new_size
261            ))
262            .detail(&format!("Original shape: {}", original_shape))
263            .detail(&format!("Original size: {}", original_size))
264            .detail(&format!("New shape: {:?}", new_shape))
265            .detail(&format!("New size: {}", new_size))
266            .suggestion("Use -1 in one dimension to infer its size automatically")
267            .build()
268    }
269
270    /// Create error for concatenation shape mismatch
271    pub fn concatenation_mismatch(operation: &str, shapes: &[Shape], axis: usize) -> TensorError {
272        let mut builder =
273            ShapeErrorBuilder::new(operation, ShapeErrorCategory::ConcatenationInvalid);
274
275        if let Some(first_shape) = shapes.first() {
276            builder = builder.expected(&format!(
277                "all tensors to have same shape as first tensor {} (except in axis {})",
278                first_shape, axis
279            ));
280
281            for (i, shape) in shapes.iter().enumerate().skip(1) {
282                if shape != first_shape {
283                    let mut diff_axes = Vec::new();
284                    for (ax, (d1, d2)) in first_shape.dims().iter().zip(shape.dims()).enumerate() {
285                        if d1 != d2 && ax != axis {
286                            diff_axes.push(ax);
287                        }
288                    }
289                    if !diff_axes.is_empty() {
290                        builder = builder.detail(&format!(
291                            "Tensor {} differs from first tensor in axes {:?} (non-concat axes must match)",
292                            i, diff_axes
293                        ));
294                    }
295                }
296            }
297        }
298
299        builder = builder.detail(&format!("Concatenation axis: {}", axis));
300        for (i, shape) in shapes.iter().enumerate() {
301            builder = builder.detail(&format!("Tensor {}: {}", i, shape));
302        }
303
304        builder.build()
305    }
306
307    /// Create error for dimension constraint violation
308    pub fn dimension_constraint(
309        operation: &str,
310        constraint_description: &str,
311        shape: &Shape,
312    ) -> TensorError {
313        ShapeErrorBuilder::new(operation, ShapeErrorCategory::DimensionConstraintViolated)
314            .expected(constraint_description)
315            .got(&format!("shape {}", shape))
316            .detail(&format!("Actual rank: {}", shape.rank()))
317            .build()
318    }
319
320    /// Create error for invalid transpose/permutation
321    pub fn invalid_transpose(operation: &str, shape: &Shape, axes: &[usize]) -> TensorError {
322        let ndim = shape.rank();
323        let expected_axes: Vec<usize> = (0..ndim).collect();
324
325        ShapeErrorBuilder::new(operation, ShapeErrorCategory::TransposeInvalid)
326            .expected(&format!("permutation of {:?}", expected_axes))
327            .got(&format!("axes {:?}", axes))
328            .detail(&format!("Tensor shape: {}", shape))
329            .detail(&format!("Number of dimensions: {}", ndim))
330            .detail("Permutation must contain each axis index exactly once")
331            .build()
332    }
333
334    /// Create error for convolution parameter mismatch
335    pub fn convolution_invalid(
336        operation: &str,
337        input_shape: &Shape,
338        kernel_shape: &Shape,
339        details: &str,
340    ) -> TensorError {
341        ShapeErrorBuilder::new(operation, ShapeErrorCategory::ConvolutionInvalid)
342            .detail(&format!("Input shape: {}", input_shape))
343            .detail(&format!("Kernel shape: {}", kernel_shape))
344            .detail(details)
345            .suggestion("Check that kernel size, stride, padding, and dilation are valid")
346            .suggestion("Ensure input channels match kernel input channels")
347            .build()
348    }
349
350    /// Create error for rank mismatch
351    pub fn rank_mismatch(
352        operation: &str,
353        expected_rank: usize,
354        actual_shape: &Shape,
355    ) -> TensorError {
356        ShapeErrorBuilder::new(operation, ShapeErrorCategory::DimensionConstraintViolated)
357            .expected(&format!("{}-dimensional tensor", expected_rank))
358            .got(&format!(
359                "{}-dimensional tensor with shape {}",
360                actual_shape.rank(),
361                actual_shape
362            ))
363            .build()
364    }
365
366    /// Create error for rank range mismatch
367    pub fn rank_range_mismatch(
368        operation: &str,
369        min_rank: usize,
370        max_rank: Option<usize>,
371        actual_shape: &Shape,
372    ) -> TensorError {
373        let expected = if let Some(max) = max_rank {
374            format!("tensor with rank in range [{}, {}]", min_rank, max)
375        } else {
376            format!("tensor with rank >= {}", min_rank)
377        };
378
379        ShapeErrorBuilder::new(operation, ShapeErrorCategory::DimensionConstraintViolated)
380            .expected(&expected)
381            .got(&format!(
382                "rank {} tensor with shape {}",
383                actual_shape.rank(),
384                actual_shape
385            ))
386            .build()
387    }
388}
389
390/// Validate shape compatibility and return detailed error if invalid
391pub fn validate_elementwise_shapes(operation: &str, shape1: &Shape, shape2: &Shape) -> Result<()> {
392    if shape1 != shape2 {
393        Err(ShapeErrorUtils::elementwise_mismatch(
394            operation, shape1, shape2,
395        ))
396    } else {
397        Ok(())
398    }
399}
400
401/// Validate broadcast compatibility
402pub fn validate_broadcast_shapes(operation: &str, shape1: &Shape, shape2: &Shape) -> Result<Shape> {
403    shape1
404        .broadcast_shape(shape2)
405        .ok_or_else(|| ShapeErrorUtils::broadcast_incompatible(operation, shape1, shape2))
406}
407
408/// Validate matrix multiplication shapes
409pub fn validate_matmul_shapes(
410    operation: &str,
411    shape_a: &Shape,
412    shape_b: &Shape,
413    transpose_a: bool,
414    transpose_b: bool,
415) -> Result<Shape> {
416    if shape_a.rank() != 2 || shape_b.rank() != 2 {
417        return Err(TensorError::invalid_shape_simple(format!(
418            "Matrix multiplication requires 2D tensors, got shapes {} and {}",
419            shape_a, shape_b
420        )));
421    }
422
423    let dims_a = shape_a.dims();
424    let dims_b = shape_b.dims();
425
426    let (m, k1) = if transpose_a {
427        (dims_a[1], dims_a[0])
428    } else {
429        (dims_a[0], dims_a[1])
430    };
431
432    let (k2, n) = if transpose_b {
433        (dims_b[1], dims_b[0])
434    } else {
435        (dims_b[0], dims_b[1])
436    };
437
438    if k1 != k2 {
439        Err(ShapeErrorUtils::matmul_incompatible(
440            operation,
441            shape_a,
442            shape_b,
443            transpose_a,
444            transpose_b,
445        ))
446    } else {
447        Ok(Shape::from_slice(&[m, n]))
448    }
449}
450
451/// Validate reduction axis
452pub fn validate_reduction_axis(operation: &str, axis: isize, shape: &Shape) -> Result<usize> {
453    let ndim = shape.rank() as isize;
454    let normalized_axis = if axis < 0 { ndim + axis } else { axis };
455
456    if normalized_axis < 0 || normalized_axis >= ndim {
457        Err(ShapeErrorUtils::invalid_reduction_axis(
458            operation, axis, shape,
459        ))
460    } else {
461        Ok(normalized_axis as usize)
462    }
463}
464
465/// Validate reshape compatibility
466pub fn validate_reshape(
467    operation: &str,
468    original_shape: &Shape,
469    new_shape: &[usize],
470) -> Result<()> {
471    let original_size: usize = original_shape.dims().iter().product();
472    let new_size: usize = new_shape.iter().product();
473
474    if original_size != new_size {
475        Err(ShapeErrorUtils::invalid_reshape(
476            operation,
477            original_shape,
478            new_shape,
479        ))
480    } else {
481        Ok(())
482    }
483}
484
485#[cfg(test)]
486mod tests {
487    use super::*;
488
489    #[test]
490    fn test_elementwise_mismatch_error() {
491        let shape1 = Shape::from_slice(&[3, 4]);
492        let shape2 = Shape::from_slice(&[3, 5]);
493        let err = ShapeErrorUtils::elementwise_mismatch("add", &shape1, &shape2);
494        let msg = format!("{}", err);
495        assert!(msg.contains("Elementwise Shape Mismatch"));
496        assert!(msg.contains("add"));
497    }
498
499    #[test]
500    fn test_matmul_incompatible_error() {
501        let shape_a = Shape::from_slice(&[3, 4]);
502        let shape_b = Shape::from_slice(&[5, 6]);
503        let err = ShapeErrorUtils::matmul_incompatible("matmul", &shape_a, &shape_b, false, false);
504        let msg = format!("{}", err);
505        assert!(msg.contains("Matrix Multiplication Incompatibility"));
506        assert!(msg.contains("matmul"));
507    }
508
509    #[test]
510    fn test_validate_matmul_shapes() {
511        let shape_a = Shape::from_slice(&[3, 4]);
512        let shape_b = Shape::from_slice(&[4, 5]);
513        let result = validate_matmul_shapes("matmul", &shape_a, &shape_b, false, false);
514        assert!(result.is_ok());
515        let output_shape = result.expect("test: operation should succeed");
516        assert_eq!(output_shape.dims(), &[3, 5]);
517    }
518
519    #[test]
520    fn test_validate_matmul_shapes_incompatible() {
521        let shape_a = Shape::from_slice(&[3, 4]);
522        let shape_b = Shape::from_slice(&[5, 6]);
523        let result = validate_matmul_shapes("matmul", &shape_a, &shape_b, false, false);
524        assert!(result.is_err());
525    }
526
527    #[test]
528    fn test_validate_reduction_axis() {
529        let shape = Shape::from_slice(&[3, 4, 5]);
530        assert!(validate_reduction_axis("sum", 0, &shape).is_ok());
531        assert!(validate_reduction_axis("sum", 1, &shape).is_ok());
532        assert!(validate_reduction_axis("sum", 2, &shape).is_ok());
533        assert!(validate_reduction_axis("sum", -1, &shape).is_ok());
534        assert!(validate_reduction_axis("sum", -2, &shape).is_ok());
535        assert!(validate_reduction_axis("sum", 3, &shape).is_err());
536        assert!(validate_reduction_axis("sum", -4, &shape).is_err());
537    }
538
539    #[test]
540    fn test_validate_reshape() {
541        let shape = Shape::from_slice(&[3, 4]);
542        assert!(validate_reshape("reshape", &shape, &[12]).is_ok());
543        assert!(validate_reshape("reshape", &shape, &[2, 6]).is_ok());
544        assert!(validate_reshape("reshape", &shape, &[2, 7]).is_err());
545    }
546}