1use crate::{Result, Shape, TensorError};
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10pub enum ShapeErrorCategory {
11 ElementwiseMismatch,
13 BroadcastIncompatible,
15 MatMulIncompatible,
17 ConvolutionInvalid,
19 ReductionAxisInvalid,
21 ReshapeInvalid,
23 ConcatenationInvalid,
25 TransposeInvalid,
27 PaddingInvalid,
29 DimensionConstraintViolated,
31}
32
33impl ShapeErrorCategory {
34 pub fn name(&self) -> &'static str {
36 match self {
37 Self::ElementwiseMismatch => "Elementwise Shape Mismatch",
38 Self::BroadcastIncompatible => "Broadcasting Incompatibility",
39 Self::MatMulIncompatible => "Matrix Multiplication Incompatibility",
40 Self::ConvolutionInvalid => "Convolution Parameter Invalid",
41 Self::ReductionAxisInvalid => "Reduction Axis Invalid",
42 Self::ReshapeInvalid => "Reshape Invalid",
43 Self::ConcatenationInvalid => "Concatenation Invalid",
44 Self::TransposeInvalid => "Transpose Invalid",
45 Self::PaddingInvalid => "Padding Invalid",
46 Self::DimensionConstraintViolated => "Dimension Constraint Violated",
47 }
48 }
49
50 pub fn fix_suggestion(&self) -> &'static str {
52 match self {
53 Self::ElementwiseMismatch => {
54 "Ensure input tensors have identical shapes for elementwise operations"
55 }
56 Self::BroadcastIncompatible => {
57 "Review NumPy broadcasting rules: dimensions must be equal or one of them must be 1"
58 }
59 Self::MatMulIncompatible => "For matmul(A, B), ensure A.shape[-1] == B.shape[-2]",
60 Self::ConvolutionInvalid => {
61 "Check kernel size, stride, padding, and dilation parameters"
62 }
63 Self::ReductionAxisInvalid => {
64 "Ensure reduction axis is within [0, ndim) or use -1 for last axis"
65 }
66 Self::ReshapeInvalid => "New shape must have same total number of elements as original",
67 Self::ConcatenationInvalid => {
68 "All tensors must have same shape except in the concatenation dimension"
69 }
70 Self::TransposeInvalid => "Permutation must be a valid reordering of axes [0..ndim)",
71 Self::PaddingInvalid => "Padding values must be non-negative",
72 Self::DimensionConstraintViolated => {
73 "Review operation documentation for dimension requirements"
74 }
75 }
76 }
77}
78
79pub struct ShapeErrorBuilder {
81 operation: String,
82 category: ShapeErrorCategory,
83 expected: String,
84 got: String,
85 details: Vec<String>,
86 suggestions: Vec<String>,
87}
88
89impl ShapeErrorBuilder {
90 pub fn new(operation: &str, category: ShapeErrorCategory) -> Self {
92 Self {
93 operation: operation.to_string(),
94 category,
95 expected: String::new(),
96 got: String::new(),
97 details: Vec::new(),
98 suggestions: vec![category.fix_suggestion().to_string()],
99 }
100 }
101
102 pub fn expected(mut self, expected: &str) -> Self {
104 self.expected = expected.to_string();
105 self
106 }
107
108 pub fn got(mut self, got: &str) -> Self {
110 self.got = got.to_string();
111 self
112 }
113
114 pub fn detail(mut self, detail: &str) -> Self {
116 self.details.push(detail.to_string());
117 self
118 }
119
120 pub fn suggestion(mut self, suggestion: &str) -> Self {
122 self.suggestions.push(suggestion.to_string());
123 self
124 }
125
126 pub fn build(self) -> TensorError {
128 let mut message = format!(
129 "[{}] in operation '{}'",
130 self.category.name(),
131 self.operation
132 );
133
134 if !self.expected.is_empty() {
135 message.push_str(&format!("\nExpected: {}", self.expected));
136 }
137
138 if !self.got.is_empty() {
139 message.push_str(&format!("\nGot: {}", self.got));
140 }
141
142 if !self.details.is_empty() {
143 message.push_str("\n\nDetails:");
144 for detail in &self.details {
145 message.push_str(&format!("\n • {}", detail));
146 }
147 }
148
149 if !self.suggestions.is_empty() {
150 message.push_str("\n\nSuggestions:");
151 for suggestion in &self.suggestions {
152 message.push_str(&format!("\n • {}", suggestion));
153 }
154 }
155
156 TensorError::invalid_shape_simple(message)
157 }
158}
159
160pub struct ShapeErrorUtils;
162
163impl ShapeErrorUtils {
164 pub fn elementwise_mismatch(operation: &str, shape1: &Shape, shape2: &Shape) -> TensorError {
166 ShapeErrorBuilder::new(operation, ShapeErrorCategory::ElementwiseMismatch)
167 .expected(&format!("identical shapes: {}", shape1))
168 .got(&format!("shapes {} and {}", shape1, shape2))
169 .detail("Elementwise operations require tensors with identical shapes")
170 .build()
171 }
172
173 pub fn broadcast_incompatible(operation: &str, shape1: &Shape, shape2: &Shape) -> TensorError {
175 ShapeErrorBuilder::new(operation, ShapeErrorCategory::BroadcastIncompatible)
176 .expected(&format!(
177 "broadcastable shapes (matching dims or dim=1): {} and {}",
178 shape1, shape2
179 ))
180 .got(&format!(
181 "non-broadcastable shapes {} and {}",
182 shape1, shape2
183 ))
184 .detail("Broadcasting rules: dimensions must match or one must be 1")
185 .build()
186 }
187
188 pub fn matmul_incompatible(
190 operation: &str,
191 shape_a: &Shape,
192 shape_b: &Shape,
193 transpose_a: bool,
194 transpose_b: bool,
195 ) -> TensorError {
196 let (m, k1) = if transpose_a {
197 (shape_a.dims()[1], shape_a.dims()[0])
198 } else {
199 (shape_a.dims()[0], shape_a.dims()[1])
200 };
201
202 let (k2, n) = if transpose_b {
203 (shape_b.dims()[1], shape_b.dims()[0])
204 } else {
205 (shape_b.dims()[0], shape_b.dims()[1])
206 };
207
208 ShapeErrorBuilder::new(operation, ShapeErrorCategory::MatMulIncompatible)
209 .expected(&format!(
210 "compatible matrix dimensions: inner dimensions must match (k1={} should equal k2={})",
211 k1, k2
212 ))
213 .got(&format!(
214 "A{}: {} ({}×{}), B{}: {} ({}×{})",
215 if transpose_a { ".T" } else { "" },
216 shape_a,
217 m,
218 k1,
219 if transpose_b { ".T" } else { "" },
220 shape_b,
221 k2,
222 n
223 ))
224 .detail(&format!("Result shape would be: ({}, {})", m, n))
225 .detail(&format!(
226 "Transpose flags: transpose_a={}, transpose_b={}",
227 transpose_a, transpose_b
228 ))
229 .build()
230 }
231
232 pub fn invalid_reduction_axis(operation: &str, axis: isize, shape: &Shape) -> TensorError {
234 let ndim = shape.rank();
235 ShapeErrorBuilder::new(operation, ShapeErrorCategory::ReductionAxisInvalid)
236 .expected(&format!("axis in range [0, {}) or [-{}, -1]", ndim, ndim))
237 .got(&format!("axis = {}", axis))
238 .detail(&format!("Tensor shape: {}", shape))
239 .detail(&format!("Number of dimensions: {}", ndim))
240 .suggestion("Use axis=-1 to reduce over the last dimension")
241 .build()
242 }
243
244 pub fn invalid_reshape(
246 operation: &str,
247 original_shape: &Shape,
248 new_shape: &[usize],
249 ) -> TensorError {
250 let original_size: usize = original_shape.dims().iter().product();
251 let new_size: usize = new_shape.iter().product();
252
253 ShapeErrorBuilder::new(operation, ShapeErrorCategory::ReshapeInvalid)
254 .expected(&format!(
255 "new shape with total elements = {} (same as original)",
256 original_size
257 ))
258 .got(&format!(
259 "shape {:?} with total elements = {}",
260 new_shape, new_size
261 ))
262 .detail(&format!("Original shape: {}", original_shape))
263 .detail(&format!("Original size: {}", original_size))
264 .detail(&format!("New shape: {:?}", new_shape))
265 .detail(&format!("New size: {}", new_size))
266 .suggestion("Use -1 in one dimension to infer its size automatically")
267 .build()
268 }
269
270 pub fn concatenation_mismatch(operation: &str, shapes: &[Shape], axis: usize) -> TensorError {
272 let mut builder =
273 ShapeErrorBuilder::new(operation, ShapeErrorCategory::ConcatenationInvalid);
274
275 if let Some(first_shape) = shapes.first() {
276 builder = builder.expected(&format!(
277 "all tensors to have same shape as first tensor {} (except in axis {})",
278 first_shape, axis
279 ));
280
281 for (i, shape) in shapes.iter().enumerate().skip(1) {
282 if shape != first_shape {
283 let mut diff_axes = Vec::new();
284 for (ax, (d1, d2)) in first_shape.dims().iter().zip(shape.dims()).enumerate() {
285 if d1 != d2 && ax != axis {
286 diff_axes.push(ax);
287 }
288 }
289 if !diff_axes.is_empty() {
290 builder = builder.detail(&format!(
291 "Tensor {} differs from first tensor in axes {:?} (non-concat axes must match)",
292 i, diff_axes
293 ));
294 }
295 }
296 }
297 }
298
299 builder = builder.detail(&format!("Concatenation axis: {}", axis));
300 for (i, shape) in shapes.iter().enumerate() {
301 builder = builder.detail(&format!("Tensor {}: {}", i, shape));
302 }
303
304 builder.build()
305 }
306
307 pub fn dimension_constraint(
309 operation: &str,
310 constraint_description: &str,
311 shape: &Shape,
312 ) -> TensorError {
313 ShapeErrorBuilder::new(operation, ShapeErrorCategory::DimensionConstraintViolated)
314 .expected(constraint_description)
315 .got(&format!("shape {}", shape))
316 .detail(&format!("Actual rank: {}", shape.rank()))
317 .build()
318 }
319
320 pub fn invalid_transpose(operation: &str, shape: &Shape, axes: &[usize]) -> TensorError {
322 let ndim = shape.rank();
323 let expected_axes: Vec<usize> = (0..ndim).collect();
324
325 ShapeErrorBuilder::new(operation, ShapeErrorCategory::TransposeInvalid)
326 .expected(&format!("permutation of {:?}", expected_axes))
327 .got(&format!("axes {:?}", axes))
328 .detail(&format!("Tensor shape: {}", shape))
329 .detail(&format!("Number of dimensions: {}", ndim))
330 .detail("Permutation must contain each axis index exactly once")
331 .build()
332 }
333
334 pub fn convolution_invalid(
336 operation: &str,
337 input_shape: &Shape,
338 kernel_shape: &Shape,
339 details: &str,
340 ) -> TensorError {
341 ShapeErrorBuilder::new(operation, ShapeErrorCategory::ConvolutionInvalid)
342 .detail(&format!("Input shape: {}", input_shape))
343 .detail(&format!("Kernel shape: {}", kernel_shape))
344 .detail(details)
345 .suggestion("Check that kernel size, stride, padding, and dilation are valid")
346 .suggestion("Ensure input channels match kernel input channels")
347 .build()
348 }
349
350 pub fn rank_mismatch(
352 operation: &str,
353 expected_rank: usize,
354 actual_shape: &Shape,
355 ) -> TensorError {
356 ShapeErrorBuilder::new(operation, ShapeErrorCategory::DimensionConstraintViolated)
357 .expected(&format!("{}-dimensional tensor", expected_rank))
358 .got(&format!(
359 "{}-dimensional tensor with shape {}",
360 actual_shape.rank(),
361 actual_shape
362 ))
363 .build()
364 }
365
366 pub fn rank_range_mismatch(
368 operation: &str,
369 min_rank: usize,
370 max_rank: Option<usize>,
371 actual_shape: &Shape,
372 ) -> TensorError {
373 let expected = if let Some(max) = max_rank {
374 format!("tensor with rank in range [{}, {}]", min_rank, max)
375 } else {
376 format!("tensor with rank >= {}", min_rank)
377 };
378
379 ShapeErrorBuilder::new(operation, ShapeErrorCategory::DimensionConstraintViolated)
380 .expected(&expected)
381 .got(&format!(
382 "rank {} tensor with shape {}",
383 actual_shape.rank(),
384 actual_shape
385 ))
386 .build()
387 }
388}
389
390pub fn validate_elementwise_shapes(operation: &str, shape1: &Shape, shape2: &Shape) -> Result<()> {
392 if shape1 != shape2 {
393 Err(ShapeErrorUtils::elementwise_mismatch(
394 operation, shape1, shape2,
395 ))
396 } else {
397 Ok(())
398 }
399}
400
401pub fn validate_broadcast_shapes(operation: &str, shape1: &Shape, shape2: &Shape) -> Result<Shape> {
403 shape1
404 .broadcast_shape(shape2)
405 .ok_or_else(|| ShapeErrorUtils::broadcast_incompatible(operation, shape1, shape2))
406}
407
408pub fn validate_matmul_shapes(
410 operation: &str,
411 shape_a: &Shape,
412 shape_b: &Shape,
413 transpose_a: bool,
414 transpose_b: bool,
415) -> Result<Shape> {
416 if shape_a.rank() != 2 || shape_b.rank() != 2 {
417 return Err(TensorError::invalid_shape_simple(format!(
418 "Matrix multiplication requires 2D tensors, got shapes {} and {}",
419 shape_a, shape_b
420 )));
421 }
422
423 let dims_a = shape_a.dims();
424 let dims_b = shape_b.dims();
425
426 let (m, k1) = if transpose_a {
427 (dims_a[1], dims_a[0])
428 } else {
429 (dims_a[0], dims_a[1])
430 };
431
432 let (k2, n) = if transpose_b {
433 (dims_b[1], dims_b[0])
434 } else {
435 (dims_b[0], dims_b[1])
436 };
437
438 if k1 != k2 {
439 Err(ShapeErrorUtils::matmul_incompatible(
440 operation,
441 shape_a,
442 shape_b,
443 transpose_a,
444 transpose_b,
445 ))
446 } else {
447 Ok(Shape::from_slice(&[m, n]))
448 }
449}
450
451pub fn validate_reduction_axis(operation: &str, axis: isize, shape: &Shape) -> Result<usize> {
453 let ndim = shape.rank() as isize;
454 let normalized_axis = if axis < 0 { ndim + axis } else { axis };
455
456 if normalized_axis < 0 || normalized_axis >= ndim {
457 Err(ShapeErrorUtils::invalid_reduction_axis(
458 operation, axis, shape,
459 ))
460 } else {
461 Ok(normalized_axis as usize)
462 }
463}
464
465pub fn validate_reshape(
467 operation: &str,
468 original_shape: &Shape,
469 new_shape: &[usize],
470) -> Result<()> {
471 let original_size: usize = original_shape.dims().iter().product();
472 let new_size: usize = new_shape.iter().product();
473
474 if original_size != new_size {
475 Err(ShapeErrorUtils::invalid_reshape(
476 operation,
477 original_shape,
478 new_shape,
479 ))
480 } else {
481 Ok(())
482 }
483}
484
485#[cfg(test)]
486mod tests {
487 use super::*;
488
489 #[test]
490 fn test_elementwise_mismatch_error() {
491 let shape1 = Shape::from_slice(&[3, 4]);
492 let shape2 = Shape::from_slice(&[3, 5]);
493 let err = ShapeErrorUtils::elementwise_mismatch("add", &shape1, &shape2);
494 let msg = format!("{}", err);
495 assert!(msg.contains("Elementwise Shape Mismatch"));
496 assert!(msg.contains("add"));
497 }
498
499 #[test]
500 fn test_matmul_incompatible_error() {
501 let shape_a = Shape::from_slice(&[3, 4]);
502 let shape_b = Shape::from_slice(&[5, 6]);
503 let err = ShapeErrorUtils::matmul_incompatible("matmul", &shape_a, &shape_b, false, false);
504 let msg = format!("{}", err);
505 assert!(msg.contains("Matrix Multiplication Incompatibility"));
506 assert!(msg.contains("matmul"));
507 }
508
509 #[test]
510 fn test_validate_matmul_shapes() {
511 let shape_a = Shape::from_slice(&[3, 4]);
512 let shape_b = Shape::from_slice(&[4, 5]);
513 let result = validate_matmul_shapes("matmul", &shape_a, &shape_b, false, false);
514 assert!(result.is_ok());
515 let output_shape = result.expect("test: operation should succeed");
516 assert_eq!(output_shape.dims(), &[3, 5]);
517 }
518
519 #[test]
520 fn test_validate_matmul_shapes_incompatible() {
521 let shape_a = Shape::from_slice(&[3, 4]);
522 let shape_b = Shape::from_slice(&[5, 6]);
523 let result = validate_matmul_shapes("matmul", &shape_a, &shape_b, false, false);
524 assert!(result.is_err());
525 }
526
527 #[test]
528 fn test_validate_reduction_axis() {
529 let shape = Shape::from_slice(&[3, 4, 5]);
530 assert!(validate_reduction_axis("sum", 0, &shape).is_ok());
531 assert!(validate_reduction_axis("sum", 1, &shape).is_ok());
532 assert!(validate_reduction_axis("sum", 2, &shape).is_ok());
533 assert!(validate_reduction_axis("sum", -1, &shape).is_ok());
534 assert!(validate_reduction_axis("sum", -2, &shape).is_ok());
535 assert!(validate_reduction_axis("sum", 3, &shape).is_err());
536 assert!(validate_reduction_axis("sum", -4, &shape).is_err());
537 }
538
539 #[test]
540 fn test_validate_reshape() {
541 let shape = Shape::from_slice(&[3, 4]);
542 assert!(validate_reshape("reshape", &shape, &[12]).is_ok());
543 assert!(validate_reshape("reshape", &shape, &[2, 6]).is_ok());
544 assert!(validate_reshape("reshape", &shape, &[2, 7]).is_err());
545 }
546}