1use crate::error::PyResult;
4use pyo3::prelude::*;
5
6pub fn validate_shape(shape: &[usize]) -> PyResult<()> {
8 for (i, &dim) in shape.iter().enumerate() {
9 if dim == 0 {
10 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
11 "Invalid shape: dimension {} cannot be zero",
12 i
13 )));
14 }
15 }
16 Ok(())
17}
18
19pub fn validate_index(index: i64, dim_size: usize) -> PyResult<usize> {
21 let positive_index = if index < 0 {
22 let abs_index = (-index) as usize;
23 if abs_index > dim_size {
24 return Err(PyErr::new::<pyo3::exceptions::PyIndexError, _>(format!(
25 "Index {} is out of bounds for dimension with size {}",
26 index, dim_size
27 )));
28 }
29 dim_size - abs_index
30 } else {
31 let pos_index = index as usize;
32 if pos_index >= dim_size {
33 return Err(PyErr::new::<pyo3::exceptions::PyIndexError, _>(format!(
34 "Index {} is out of bounds for dimension with size {}",
35 index, dim_size
36 )));
37 }
38 pos_index
39 };
40 Ok(positive_index)
41}
42
43pub fn validate_broadcast_shapes(shape1: &[usize], shape2: &[usize]) -> PyResult<Vec<usize>> {
45 let mut result_shape = Vec::new();
46 let max_dims = shape1.len().max(shape2.len());
47
48 for i in 0..max_dims {
49 let dim1 = if i < shape1.len() {
50 shape1[shape1.len() - 1 - i]
51 } else {
52 1
53 };
54 let dim2 = if i < shape2.len() {
55 shape2[shape2.len() - 1 - i]
56 } else {
57 1
58 };
59
60 if dim1 == dim2 || dim1 == 1 || dim2 == 1 {
61 result_shape.push(dim1.max(dim2));
62 } else {
63 return Err(PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!(
64 "Cannot broadcast shapes {:?} and {:?}",
65 shape1, shape2
66 )));
67 }
68 }
69
70 result_shape.reverse();
71 Ok(result_shape)
72}
73
74pub fn validate_learning_rate(lr: f32) -> PyResult<()> {
76 if lr <= 0.0 {
77 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
78 "Learning rate must be positive",
79 ));
80 }
81 Ok(())
82}
83
84pub fn validate_momentum(momentum: f32) -> PyResult<()> {
86 if !(0.0..=1.0).contains(&momentum) {
87 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
88 "Momentum must be in range [0, 1]",
89 ));
90 }
91 Ok(())
92}
93
94pub fn validate_weight_decay(weight_decay: f32) -> PyResult<()> {
96 if weight_decay < 0.0 {
97 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
98 "Weight decay must be non-negative",
99 ));
100 }
101 Ok(())
102}
103
104pub fn validate_epsilon(eps: f32) -> PyResult<()> {
106 if eps <= 0.0 {
107 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
108 "Epsilon must be positive",
109 ));
110 }
111 Ok(())
112}
113
114pub fn validate_betas(betas: (f32, f32)) -> PyResult<()> {
116 let (beta1, beta2) = betas;
117 if !(0.0..1.0).contains(&beta1) {
118 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
119 "Beta1 must be in range [0, 1)",
120 ));
121 }
122 if !(0.0..1.0).contains(&beta2) {
123 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
124 "Beta2 must be in range [0, 1)",
125 ));
126 }
127 Ok(())
128}
129
130pub fn validate_tensor_shapes_match(shape1: &[usize], shape2: &[usize]) -> PyResult<()> {
132 if shape1 != shape2 {
133 return Err(PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!(
134 "Tensor shapes do not match: {:?} vs {:?}",
135 shape1, shape2
136 )));
137 }
138 Ok(())
139}
140
141pub fn validate_dimension(dim: i32, ndim: usize) -> PyResult<usize> {
143 let positive_dim = if dim < 0 {
144 let abs_dim = (-dim) as usize;
145 if abs_dim > ndim {
146 return Err(PyErr::new::<pyo3::exceptions::PyIndexError, _>(format!(
147 "Dimension {} is out of bounds for tensor with {} dimensions",
148 dim, ndim
149 )));
150 }
151 ndim - abs_dim
152 } else {
153 let pos_dim = dim as usize;
154 if pos_dim >= ndim {
155 return Err(PyErr::new::<pyo3::exceptions::PyIndexError, _>(format!(
156 "Dimension {} is out of bounds for tensor with {} dimensions",
157 dim, ndim
158 )));
159 }
160 pos_dim
161 };
162 Ok(positive_dim)
163}
164
165pub fn validate_parameters_not_empty<T>(params: &[T]) -> PyResult<()> {
167 if params.is_empty() {
168 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
169 "Parameters list cannot be empty",
170 ));
171 }
172 Ok(())
173}
174
175pub fn validate_dropout_probability(p: f32) -> PyResult<()> {
177 if !(0.0..=1.0).contains(&p) {
178 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
179 "Dropout probability must be in range [0, 1], got {}",
180 p
181 )));
182 }
183 Ok(())
184}
185
186pub fn validate_kernel_size(kernel_size: usize, name: &str) -> PyResult<()> {
188 if kernel_size == 0 {
189 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
190 "{} must be positive, got 0",
191 name
192 )));
193 }
194 Ok(())
195}
196
197pub fn validate_stride(stride: usize, name: &str) -> PyResult<()> {
199 if stride == 0 {
200 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
201 "{} must be positive, got 0",
202 name
203 )));
204 }
205 Ok(())
206}
207
208pub fn validate_tensor_ndim(
210 actual_ndim: usize,
211 expected_ndim: usize,
212 op_name: &str,
213) -> PyResult<()> {
214 if actual_ndim != expected_ndim {
215 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
216 "{} expects {}D input, got {}D",
217 op_name, expected_ndim, actual_ndim
218 )));
219 }
220 Ok(())
221}
222
223pub fn validate_tensor_min_ndim(
225 actual_ndim: usize,
226 min_ndim: usize,
227 op_name: &str,
228) -> PyResult<()> {
229 if actual_ndim < min_ndim {
230 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
231 "{} expects at least {}D input, got {}D",
232 op_name, min_ndim, actual_ndim
233 )));
234 }
235 Ok(())
236}
237
238pub fn validate_num_features(
240 actual_features: usize,
241 expected_features: usize,
242 layer_name: &str,
243) -> PyResult<()> {
244 if actual_features != expected_features {
245 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
246 "{} expected {} features, got {}",
247 layer_name, expected_features, actual_features
248 )));
249 }
250 Ok(())
251}
252
253pub fn validate_finite(value: f32, name: &str) -> PyResult<()> {
255 if !value.is_finite() {
256 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
257 "{} must be finite, got {}",
258 name, value
259 )));
260 }
261 Ok(())
262}
263
264pub fn validate_range(start: usize, end: usize, name: &str) -> PyResult<()> {
266 if start >= end {
267 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
268 "Invalid range for {}: start ({}) must be less than end ({})",
269 name, start, end
270 )));
271 }
272 Ok(())
273}
274
275pub fn validate_pooling_output_size(
277 input_size: usize,
278 kernel_size: usize,
279 stride: usize,
280 padding: usize,
281 dilation: usize,
282) -> PyResult<usize> {
283 if kernel_size == 0 || stride == 0 {
284 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
285 "Kernel size and stride must be positive",
286 ));
287 }
288
289 let effective_kernel = dilation * (kernel_size - 1) + 1;
290 if input_size + 2 * padding < effective_kernel {
291 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
292 "Input size {} (with padding {}) is too small for kernel size {} (with dilation {})",
293 input_size, padding, kernel_size, dilation
294 )));
295 }
296
297 let output_size = (input_size + 2 * padding - effective_kernel) / stride + 1;
298 Ok(output_size)
299}
300
301pub fn validate_conv_params(
303 in_channels: usize,
304 out_channels: usize,
305 kernel_size: usize,
306) -> PyResult<()> {
307 if in_channels == 0 {
308 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
309 "in_channels must be positive",
310 ));
311 }
312 if out_channels == 0 {
313 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
314 "out_channels must be positive",
315 ));
316 }
317 if kernel_size == 0 {
318 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
319 "kernel_size must be positive",
320 ));
321 }
322 Ok(())
323}
324
325#[cfg(test)]
326mod tests {
327 use super::*;
328
329 #[test]
334 fn test_validate_shape_valid() {
335 assert!(validate_shape(&[1, 2, 3]).is_ok());
336 assert!(validate_shape(&[10, 20, 30, 40]).is_ok());
337 }
338
339 #[test]
340 fn test_validate_shape_with_zero() {
341 assert!(validate_shape(&[1, 0, 3]).is_err());
342 assert!(validate_shape(&[0]).is_err());
343 }
344
345 #[test]
346 fn test_validate_shape_empty() {
347 assert!(validate_shape(&[]).is_ok());
349 }
350
351 #[test]
356 fn test_validate_index_positive() {
357 assert_eq!(validate_index(0, 10).unwrap(), 0);
358 assert_eq!(validate_index(5, 10).unwrap(), 5);
359 assert_eq!(validate_index(9, 10).unwrap(), 9);
360 }
361
362 #[test]
363 fn test_validate_index_negative() {
364 assert_eq!(validate_index(-1, 10).unwrap(), 9);
365 assert_eq!(validate_index(-5, 10).unwrap(), 5);
366 assert_eq!(validate_index(-10, 10).unwrap(), 0);
367 }
368
369 #[test]
370 fn test_validate_index_out_of_bounds_positive() {
371 assert!(validate_index(10, 10).is_err());
372 assert!(validate_index(100, 10).is_err());
373 }
374
375 #[test]
376 fn test_validate_index_out_of_bounds_negative() {
377 assert!(validate_index(-11, 10).is_err());
378 assert!(validate_index(-100, 10).is_err());
379 }
380
381 #[test]
386 fn test_validate_broadcast_shapes_compatible() {
387 assert_eq!(
388 validate_broadcast_shapes(&[3, 4], &[3, 4]).unwrap(),
389 vec![3, 4]
390 );
391 assert_eq!(
392 validate_broadcast_shapes(&[3, 1], &[3, 4]).unwrap(),
393 vec![3, 4]
394 );
395 assert_eq!(
396 validate_broadcast_shapes(&[1, 4], &[3, 4]).unwrap(),
397 vec![3, 4]
398 );
399 assert_eq!(
400 validate_broadcast_shapes(&[3, 4], &[4]).unwrap(),
401 vec![3, 4]
402 );
403 }
404
405 #[test]
406 fn test_validate_broadcast_shapes_incompatible() {
407 assert!(validate_broadcast_shapes(&[3, 4], &[3, 5]).is_err());
408 assert!(validate_broadcast_shapes(&[2, 3], &[3, 4]).is_err());
409 }
410
411 #[test]
416 fn test_validate_learning_rate_valid() {
417 assert!(validate_learning_rate(0.001).is_ok());
418 assert!(validate_learning_rate(0.1).is_ok());
419 assert!(validate_learning_rate(1.0).is_ok());
420 assert!(validate_learning_rate(10.0).is_ok());
421 }
422
423 #[test]
424 fn test_validate_learning_rate_invalid() {
425 assert!(validate_learning_rate(0.0).is_err());
426 assert!(validate_learning_rate(-0.1).is_err());
427 }
428
429 #[test]
434 fn test_validate_momentum_valid() {
435 assert!(validate_momentum(0.0).is_ok());
436 assert!(validate_momentum(0.5).is_ok());
437 assert!(validate_momentum(0.9).is_ok());
438 assert!(validate_momentum(1.0).is_ok());
439 }
440
441 #[test]
442 fn test_validate_momentum_invalid() {
443 assert!(validate_momentum(-0.1).is_err());
444 assert!(validate_momentum(1.1).is_err());
445 }
446
447 #[test]
452 fn test_validate_weight_decay_valid() {
453 assert!(validate_weight_decay(0.0).is_ok());
454 assert!(validate_weight_decay(0.01).is_ok());
455 assert!(validate_weight_decay(1.0).is_ok());
456 }
457
458 #[test]
459 fn test_validate_weight_decay_invalid() {
460 assert!(validate_weight_decay(-0.1).is_err());
461 }
462
463 #[test]
468 fn test_validate_epsilon_valid() {
469 assert!(validate_epsilon(1e-8).is_ok());
470 assert!(validate_epsilon(1e-5).is_ok());
471 assert!(validate_epsilon(0.1).is_ok());
472 }
473
474 #[test]
475 fn test_validate_epsilon_invalid() {
476 assert!(validate_epsilon(0.0).is_err());
477 assert!(validate_epsilon(-1e-8).is_err());
478 }
479
480 #[test]
485 fn test_validate_betas_valid() {
486 assert!(validate_betas((0.0, 0.0)).is_ok());
487 assert!(validate_betas((0.9, 0.999)).is_ok());
488 assert!(validate_betas((0.5, 0.5)).is_ok());
489 }
490
491 #[test]
492 fn test_validate_betas_invalid() {
493 assert!(validate_betas((-0.1, 0.5)).is_err());
494 assert!(validate_betas((0.5, 1.0)).is_err());
495 assert!(validate_betas((1.0, 0.5)).is_err());
496 assert!(validate_betas((1.1, 0.5)).is_err());
497 }
498
499 #[test]
504 fn test_validate_tensor_shapes_match_valid() {
505 assert!(validate_tensor_shapes_match(&[3, 4], &[3, 4]).is_ok());
506 assert!(validate_tensor_shapes_match(&[], &[]).is_ok());
507 }
508
509 #[test]
510 fn test_validate_tensor_shapes_match_invalid() {
511 assert!(validate_tensor_shapes_match(&[3, 4], &[3, 5]).is_err());
512 assert!(validate_tensor_shapes_match(&[3, 4], &[4, 3]).is_err());
513 }
514
515 #[test]
520 fn test_validate_dimension_positive() {
521 assert_eq!(validate_dimension(0, 4).unwrap(), 0);
522 assert_eq!(validate_dimension(2, 4).unwrap(), 2);
523 assert_eq!(validate_dimension(3, 4).unwrap(), 3);
524 }
525
526 #[test]
527 fn test_validate_dimension_negative() {
528 assert_eq!(validate_dimension(-1, 4).unwrap(), 3);
529 assert_eq!(validate_dimension(-2, 4).unwrap(), 2);
530 assert_eq!(validate_dimension(-4, 4).unwrap(), 0);
531 }
532
533 #[test]
534 fn test_validate_dimension_out_of_bounds() {
535 assert!(validate_dimension(4, 4).is_err());
536 assert!(validate_dimension(-5, 4).is_err());
537 }
538
539 #[test]
544 fn test_validate_parameters_not_empty_valid() {
545 assert!(validate_parameters_not_empty(&[1, 2, 3]).is_ok());
546 }
547
548 #[test]
549 fn test_validate_parameters_not_empty_invalid() {
550 let empty: &[i32] = &[];
551 assert!(validate_parameters_not_empty(empty).is_err());
552 }
553
554 #[test]
559 fn test_validate_dropout_probability_valid() {
560 assert!(validate_dropout_probability(0.0).is_ok());
561 assert!(validate_dropout_probability(0.5).is_ok());
562 assert!(validate_dropout_probability(1.0).is_ok());
563 }
564
565 #[test]
566 fn test_validate_dropout_probability_invalid() {
567 assert!(validate_dropout_probability(-0.1).is_err());
568 assert!(validate_dropout_probability(1.1).is_err());
569 }
570
571 #[test]
576 fn test_validate_kernel_size_valid() {
577 assert!(validate_kernel_size(1, "kernel").is_ok());
578 assert!(validate_kernel_size(3, "kernel").is_ok());
579 assert!(validate_kernel_size(5, "kernel").is_ok());
580 }
581
582 #[test]
583 fn test_validate_kernel_size_invalid() {
584 assert!(validate_kernel_size(0, "kernel").is_err());
585 }
586
587 #[test]
592 fn test_validate_stride_valid() {
593 assert!(validate_stride(1, "stride").is_ok());
594 assert!(validate_stride(2, "stride").is_ok());
595 }
596
597 #[test]
598 fn test_validate_stride_invalid() {
599 assert!(validate_stride(0, "stride").is_err());
600 }
601
602 #[test]
607 fn test_validate_tensor_ndim_valid() {
608 assert!(validate_tensor_ndim(4, 4, "conv2d").is_ok());
609 assert!(validate_tensor_ndim(2, 2, "linear").is_ok());
610 }
611
612 #[test]
613 fn test_validate_tensor_ndim_invalid() {
614 assert!(validate_tensor_ndim(3, 4, "conv2d").is_err());
615 assert!(validate_tensor_ndim(5, 4, "conv2d").is_err());
616 }
617
618 #[test]
623 fn test_validate_tensor_min_ndim_valid() {
624 assert!(validate_tensor_min_ndim(4, 2, "operation").is_ok());
625 assert!(validate_tensor_min_ndim(2, 2, "operation").is_ok());
626 }
627
628 #[test]
629 fn test_validate_tensor_min_ndim_invalid() {
630 assert!(validate_tensor_min_ndim(1, 2, "operation").is_err());
631 }
632
633 #[test]
638 fn test_validate_num_features_valid() {
639 assert!(validate_num_features(64, 64, "BatchNorm").is_ok());
640 }
641
642 #[test]
643 fn test_validate_num_features_invalid() {
644 assert!(validate_num_features(32, 64, "BatchNorm").is_err());
645 }
646
647 #[test]
652 fn test_validate_finite_valid() {
653 assert!(validate_finite(0.0, "value").is_ok());
654 assert!(validate_finite(1.0, "value").is_ok());
655 assert!(validate_finite(-1.0, "value").is_ok());
656 }
657
658 #[test]
659 fn test_validate_finite_invalid() {
660 assert!(validate_finite(f32::NAN, "value").is_err());
661 assert!(validate_finite(f32::INFINITY, "value").is_err());
662 assert!(validate_finite(f32::NEG_INFINITY, "value").is_err());
663 }
664
665 #[test]
670 fn test_validate_range_valid() {
671 assert!(validate_range(0, 10, "range").is_ok());
672 assert!(validate_range(5, 10, "range").is_ok());
673 }
674
675 #[test]
676 fn test_validate_range_invalid() {
677 assert!(validate_range(10, 10, "range").is_err());
678 assert!(validate_range(10, 5, "range").is_err());
679 }
680
681 #[test]
686 fn test_validate_pooling_output_size_valid() {
687 assert_eq!(validate_pooling_output_size(28, 2, 2, 0, 1).unwrap(), 14);
690
691 assert_eq!(validate_pooling_output_size(32, 3, 1, 1, 1).unwrap(), 32);
694 }
695
696 #[test]
697 fn test_validate_pooling_output_size_invalid_zero_kernel() {
698 assert!(validate_pooling_output_size(28, 0, 2, 0, 1).is_err());
699 }
700
701 #[test]
702 fn test_validate_pooling_output_size_invalid_zero_stride() {
703 assert!(validate_pooling_output_size(28, 2, 0, 0, 1).is_err());
704 }
705
706 #[test]
707 fn test_validate_pooling_output_size_invalid_too_small() {
708 assert!(validate_pooling_output_size(2, 5, 1, 0, 1).is_err());
711 }
712
713 #[test]
718 fn test_validate_conv_params_valid() {
719 assert!(validate_conv_params(3, 64, 3).is_ok());
720 assert!(validate_conv_params(64, 128, 5).is_ok());
721 }
722
723 #[test]
724 fn test_validate_conv_params_invalid_in_channels() {
725 assert!(validate_conv_params(0, 64, 3).is_err());
726 }
727
728 #[test]
729 fn test_validate_conv_params_invalid_out_channels() {
730 assert!(validate_conv_params(3, 0, 3).is_err());
731 }
732
733 #[test]
734 fn test_validate_conv_params_invalid_kernel_size() {
735 assert!(validate_conv_params(3, 64, 0).is_err());
736 }
737}