Skip to main content

rstorch_python/utils/
validation.rs

1//! Input validation utilities for Python bindings
2
3use crate::error::PyResult;
4use pyo3::prelude::*;
5
6/// Validate that a shape is valid (all dimensions > 0)
7pub fn validate_shape(shape: &[usize]) -> PyResult<()> {
8    for (i, &dim) in shape.iter().enumerate() {
9        if dim == 0 {
10            return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
11                "Invalid shape: dimension {} cannot be zero",
12                i
13            )));
14        }
15    }
16    Ok(())
17}
18
19/// Validate that an index is within bounds for a given dimension
20pub fn validate_index(index: i64, dim_size: usize) -> PyResult<usize> {
21    let positive_index = if index < 0 {
22        let abs_index = (-index) as usize;
23        if abs_index > dim_size {
24            return Err(PyErr::new::<pyo3::exceptions::PyIndexError, _>(format!(
25                "Index {} is out of bounds for dimension with size {}",
26                index, dim_size
27            )));
28        }
29        dim_size - abs_index
30    } else {
31        let pos_index = index as usize;
32        if pos_index >= dim_size {
33            return Err(PyErr::new::<pyo3::exceptions::PyIndexError, _>(format!(
34                "Index {} is out of bounds for dimension with size {}",
35                index, dim_size
36            )));
37        }
38        pos_index
39    };
40    Ok(positive_index)
41}
42
43/// Validate that dimensions are compatible for broadcasting
44pub fn validate_broadcast_shapes(shape1: &[usize], shape2: &[usize]) -> PyResult<Vec<usize>> {
45    let mut result_shape = Vec::new();
46    let max_dims = shape1.len().max(shape2.len());
47
48    for i in 0..max_dims {
49        let dim1 = if i < shape1.len() {
50            shape1[shape1.len() - 1 - i]
51        } else {
52            1
53        };
54        let dim2 = if i < shape2.len() {
55            shape2[shape2.len() - 1 - i]
56        } else {
57            1
58        };
59
60        if dim1 == dim2 || dim1 == 1 || dim2 == 1 {
61            result_shape.push(dim1.max(dim2));
62        } else {
63            return Err(PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!(
64                "Cannot broadcast shapes {:?} and {:?}",
65                shape1, shape2
66            )));
67        }
68    }
69
70    result_shape.reverse();
71    Ok(result_shape)
72}
73
74/// Validate that a learning rate is positive
75pub fn validate_learning_rate(lr: f32) -> PyResult<()> {
76    if lr <= 0.0 {
77        return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
78            "Learning rate must be positive",
79        ));
80    }
81    Ok(())
82}
83
84/// Validate that momentum is in valid range [0, 1]
85pub fn validate_momentum(momentum: f32) -> PyResult<()> {
86    if !(0.0..=1.0).contains(&momentum) {
87        return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
88            "Momentum must be in range [0, 1]",
89        ));
90    }
91    Ok(())
92}
93
94/// Validate that weight decay is non-negative
95pub fn validate_weight_decay(weight_decay: f32) -> PyResult<()> {
96    if weight_decay < 0.0 {
97        return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
98            "Weight decay must be non-negative",
99        ));
100    }
101    Ok(())
102}
103
104/// Validate that epsilon is positive
105pub fn validate_epsilon(eps: f32) -> PyResult<()> {
106    if eps <= 0.0 {
107        return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
108            "Epsilon must be positive",
109        ));
110    }
111    Ok(())
112}
113
114/// Validate beta parameters for Adam-like optimizers
115pub fn validate_betas(betas: (f32, f32)) -> PyResult<()> {
116    let (beta1, beta2) = betas;
117    if !(0.0..1.0).contains(&beta1) {
118        return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
119            "Beta1 must be in range [0, 1)",
120        ));
121    }
122    if !(0.0..1.0).contains(&beta2) {
123        return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
124            "Beta2 must be in range [0, 1)",
125        ));
126    }
127    Ok(())
128}
129
130/// Validate that tensor dimensions match for operations
131pub fn validate_tensor_shapes_match(shape1: &[usize], shape2: &[usize]) -> PyResult<()> {
132    if shape1 != shape2 {
133        return Err(PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!(
134            "Tensor shapes do not match: {:?} vs {:?}",
135            shape1, shape2
136        )));
137    }
138    Ok(())
139}
140
141/// Validate that a dimension index is valid for a tensor
142pub fn validate_dimension(dim: i32, ndim: usize) -> PyResult<usize> {
143    let positive_dim = if dim < 0 {
144        let abs_dim = (-dim) as usize;
145        if abs_dim > ndim {
146            return Err(PyErr::new::<pyo3::exceptions::PyIndexError, _>(format!(
147                "Dimension {} is out of bounds for tensor with {} dimensions",
148                dim, ndim
149            )));
150        }
151        ndim - abs_dim
152    } else {
153        let pos_dim = dim as usize;
154        if pos_dim >= ndim {
155            return Err(PyErr::new::<pyo3::exceptions::PyIndexError, _>(format!(
156                "Dimension {} is out of bounds for tensor with {} dimensions",
157                dim, ndim
158            )));
159        }
160        pos_dim
161    };
162    Ok(positive_dim)
163}
164
165/// Validate that parameters list is not empty
166pub fn validate_parameters_not_empty<T>(params: &[T]) -> PyResult<()> {
167    if params.is_empty() {
168        return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
169            "Parameters list cannot be empty",
170        ));
171    }
172    Ok(())
173}
174
175/// Validate dropout probability is in valid range [0, 1]
176pub fn validate_dropout_probability(p: f32) -> PyResult<()> {
177    if !(0.0..=1.0).contains(&p) {
178        return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
179            "Dropout probability must be in range [0, 1], got {}",
180            p
181        )));
182    }
183    Ok(())
184}
185
186/// Validate kernel size is positive
187pub fn validate_kernel_size(kernel_size: usize, name: &str) -> PyResult<()> {
188    if kernel_size == 0 {
189        return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
190            "{} must be positive, got 0",
191            name
192        )));
193    }
194    Ok(())
195}
196
197/// Validate stride is positive
198pub fn validate_stride(stride: usize, name: &str) -> PyResult<()> {
199    if stride == 0 {
200        return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
201            "{} must be positive, got 0",
202            name
203        )));
204    }
205    Ok(())
206}
207
208/// Validate that input tensor has expected number of dimensions
209pub fn validate_tensor_ndim(
210    actual_ndim: usize,
211    expected_ndim: usize,
212    op_name: &str,
213) -> PyResult<()> {
214    if actual_ndim != expected_ndim {
215        return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
216            "{} expects {}D input, got {}D",
217            op_name, expected_ndim, actual_ndim
218        )));
219    }
220    Ok(())
221}
222
223/// Validate that input tensor has at least minimum number of dimensions
224pub fn validate_tensor_min_ndim(
225    actual_ndim: usize,
226    min_ndim: usize,
227    op_name: &str,
228) -> PyResult<()> {
229    if actual_ndim < min_ndim {
230        return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
231            "{} expects at least {}D input, got {}D",
232            op_name, min_ndim, actual_ndim
233        )));
234    }
235    Ok(())
236}
237
238/// Validate that number of features matches expected value
239pub fn validate_num_features(
240    actual_features: usize,
241    expected_features: usize,
242    layer_name: &str,
243) -> PyResult<()> {
244    if actual_features != expected_features {
245        return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
246            "{} expected {} features, got {}",
247            layer_name, expected_features, actual_features
248        )));
249    }
250    Ok(())
251}
252
253/// Validate that a value is finite (not NaN or infinity)
254pub fn validate_finite(value: f32, name: &str) -> PyResult<()> {
255    if !value.is_finite() {
256        return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
257            "{} must be finite, got {}",
258            name, value
259        )));
260    }
261    Ok(())
262}
263
264/// Validate that a range is valid (start < end)
265pub fn validate_range(start: usize, end: usize, name: &str) -> PyResult<()> {
266    if start >= end {
267        return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
268            "Invalid range for {}: start ({}) must be less than end ({})",
269            name, start, end
270        )));
271    }
272    Ok(())
273}
274
275/// Validate pooling output size calculation
276pub fn validate_pooling_output_size(
277    input_size: usize,
278    kernel_size: usize,
279    stride: usize,
280    padding: usize,
281    dilation: usize,
282) -> PyResult<usize> {
283    if kernel_size == 0 || stride == 0 {
284        return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
285            "Kernel size and stride must be positive",
286        ));
287    }
288
289    let effective_kernel = dilation * (kernel_size - 1) + 1;
290    if input_size + 2 * padding < effective_kernel {
291        return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
292            "Input size {} (with padding {}) is too small for kernel size {} (with dilation {})",
293            input_size, padding, kernel_size, dilation
294        )));
295    }
296
297    let output_size = (input_size + 2 * padding - effective_kernel) / stride + 1;
298    Ok(output_size)
299}
300
301/// Validate convolution parameters
302pub fn validate_conv_params(
303    in_channels: usize,
304    out_channels: usize,
305    kernel_size: usize,
306) -> PyResult<()> {
307    if in_channels == 0 {
308        return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
309            "in_channels must be positive",
310        ));
311    }
312    if out_channels == 0 {
313        return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
314            "out_channels must be positive",
315        ));
316    }
317    if kernel_size == 0 {
318        return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
319            "kernel_size must be positive",
320        ));
321    }
322    Ok(())
323}
324
325#[cfg(test)]
326mod tests {
327    use super::*;
328
329    // =============================================================================
330    // Shape Validation Tests
331    // =============================================================================
332
333    #[test]
334    fn test_validate_shape_valid() {
335        assert!(validate_shape(&[1, 2, 3]).is_ok());
336        assert!(validate_shape(&[10, 20, 30, 40]).is_ok());
337    }
338
339    #[test]
340    fn test_validate_shape_with_zero() {
341        assert!(validate_shape(&[1, 0, 3]).is_err());
342        assert!(validate_shape(&[0]).is_err());
343    }
344
345    #[test]
346    fn test_validate_shape_empty() {
347        // Empty shape (scalar) should be valid
348        assert!(validate_shape(&[]).is_ok());
349    }
350
351    // =============================================================================
352    // Index Validation Tests
353    // =============================================================================
354
355    #[test]
356    fn test_validate_index_positive() {
357        assert_eq!(
358            validate_index(0, 10).expect("validate index should succeed"),
359            0
360        );
361        assert_eq!(
362            validate_index(5, 10).expect("validate index should succeed"),
363            5
364        );
365        assert_eq!(
366            validate_index(9, 10).expect("validate index should succeed"),
367            9
368        );
369    }
370
371    #[test]
372    fn test_validate_index_negative() {
373        assert_eq!(
374            validate_index(-1, 10).expect("validate index should succeed"),
375            9
376        );
377        assert_eq!(
378            validate_index(-5, 10).expect("validate index should succeed"),
379            5
380        );
381        assert_eq!(
382            validate_index(-10, 10).expect("validate index should succeed"),
383            0
384        );
385    }
386
387    #[test]
388    fn test_validate_index_out_of_bounds_positive() {
389        assert!(validate_index(10, 10).is_err());
390        assert!(validate_index(100, 10).is_err());
391    }
392
393    #[test]
394    fn test_validate_index_out_of_bounds_negative() {
395        assert!(validate_index(-11, 10).is_err());
396        assert!(validate_index(-100, 10).is_err());
397    }
398
399    // =============================================================================
400    // Broadcasting Tests
401    // =============================================================================
402
403    #[test]
404    fn test_validate_broadcast_shapes_compatible() {
405        assert_eq!(
406            validate_broadcast_shapes(&[3, 4], &[3, 4])
407                .expect("validate broadcast shapes should succeed"),
408            vec![3, 4]
409        );
410        assert_eq!(
411            validate_broadcast_shapes(&[3, 1], &[3, 4])
412                .expect("validate broadcast shapes should succeed"),
413            vec![3, 4]
414        );
415        assert_eq!(
416            validate_broadcast_shapes(&[1, 4], &[3, 4])
417                .expect("validate broadcast shapes should succeed"),
418            vec![3, 4]
419        );
420        assert_eq!(
421            validate_broadcast_shapes(&[3, 4], &[4])
422                .expect("validate broadcast shapes should succeed"),
423            vec![3, 4]
424        );
425    }
426
427    #[test]
428    fn test_validate_broadcast_shapes_incompatible() {
429        assert!(validate_broadcast_shapes(&[3, 4], &[3, 5]).is_err());
430        assert!(validate_broadcast_shapes(&[2, 3], &[3, 4]).is_err());
431    }
432
433    // =============================================================================
434    // Learning Rate Validation Tests
435    // =============================================================================
436
437    #[test]
438    fn test_validate_learning_rate_valid() {
439        assert!(validate_learning_rate(0.001).is_ok());
440        assert!(validate_learning_rate(0.1).is_ok());
441        assert!(validate_learning_rate(1.0).is_ok());
442        assert!(validate_learning_rate(10.0).is_ok());
443    }
444
445    #[test]
446    fn test_validate_learning_rate_invalid() {
447        assert!(validate_learning_rate(0.0).is_err());
448        assert!(validate_learning_rate(-0.1).is_err());
449    }
450
451    // =============================================================================
452    // Momentum Validation Tests
453    // =============================================================================
454
455    #[test]
456    fn test_validate_momentum_valid() {
457        assert!(validate_momentum(0.0).is_ok());
458        assert!(validate_momentum(0.5).is_ok());
459        assert!(validate_momentum(0.9).is_ok());
460        assert!(validate_momentum(1.0).is_ok());
461    }
462
463    #[test]
464    fn test_validate_momentum_invalid() {
465        assert!(validate_momentum(-0.1).is_err());
466        assert!(validate_momentum(1.1).is_err());
467    }
468
469    // =============================================================================
470    // Weight Decay Validation Tests
471    // =============================================================================
472
473    #[test]
474    fn test_validate_weight_decay_valid() {
475        assert!(validate_weight_decay(0.0).is_ok());
476        assert!(validate_weight_decay(0.01).is_ok());
477        assert!(validate_weight_decay(1.0).is_ok());
478    }
479
480    #[test]
481    fn test_validate_weight_decay_invalid() {
482        assert!(validate_weight_decay(-0.1).is_err());
483    }
484
485    // =============================================================================
486    // Epsilon Validation Tests
487    // =============================================================================
488
489    #[test]
490    fn test_validate_epsilon_valid() {
491        assert!(validate_epsilon(1e-8).is_ok());
492        assert!(validate_epsilon(1e-5).is_ok());
493        assert!(validate_epsilon(0.1).is_ok());
494    }
495
496    #[test]
497    fn test_validate_epsilon_invalid() {
498        assert!(validate_epsilon(0.0).is_err());
499        assert!(validate_epsilon(-1e-8).is_err());
500    }
501
502    // =============================================================================
503    // Beta Parameters Validation Tests
504    // =============================================================================
505
506    #[test]
507    fn test_validate_betas_valid() {
508        assert!(validate_betas((0.0, 0.0)).is_ok());
509        assert!(validate_betas((0.9, 0.999)).is_ok());
510        assert!(validate_betas((0.5, 0.5)).is_ok());
511    }
512
513    #[test]
514    fn test_validate_betas_invalid() {
515        assert!(validate_betas((-0.1, 0.5)).is_err());
516        assert!(validate_betas((0.5, 1.0)).is_err());
517        assert!(validate_betas((1.0, 0.5)).is_err());
518        assert!(validate_betas((1.1, 0.5)).is_err());
519    }
520
521    // =============================================================================
522    // Tensor Shape Matching Tests
523    // =============================================================================
524
525    #[test]
526    fn test_validate_tensor_shapes_match_valid() {
527        assert!(validate_tensor_shapes_match(&[3, 4], &[3, 4]).is_ok());
528        assert!(validate_tensor_shapes_match(&[], &[]).is_ok());
529    }
530
531    #[test]
532    fn test_validate_tensor_shapes_match_invalid() {
533        assert!(validate_tensor_shapes_match(&[3, 4], &[3, 5]).is_err());
534        assert!(validate_tensor_shapes_match(&[3, 4], &[4, 3]).is_err());
535    }
536
537    // =============================================================================
538    // Dimension Validation Tests
539    // =============================================================================
540
541    #[test]
542    fn test_validate_dimension_positive() {
543        assert_eq!(
544            validate_dimension(0, 4).expect("validate dimension should succeed"),
545            0
546        );
547        assert_eq!(
548            validate_dimension(2, 4).expect("validate dimension should succeed"),
549            2
550        );
551        assert_eq!(
552            validate_dimension(3, 4).expect("validate dimension should succeed"),
553            3
554        );
555    }
556
557    #[test]
558    fn test_validate_dimension_negative() {
559        assert_eq!(
560            validate_dimension(-1, 4).expect("validate dimension should succeed"),
561            3
562        );
563        assert_eq!(
564            validate_dimension(-2, 4).expect("validate dimension should succeed"),
565            2
566        );
567        assert_eq!(
568            validate_dimension(-4, 4).expect("validate dimension should succeed"),
569            0
570        );
571    }
572
573    #[test]
574    fn test_validate_dimension_out_of_bounds() {
575        assert!(validate_dimension(4, 4).is_err());
576        assert!(validate_dimension(-5, 4).is_err());
577    }
578
579    // =============================================================================
580    // Parameters Validation Tests
581    // =============================================================================
582
583    #[test]
584    fn test_validate_parameters_not_empty_valid() {
585        assert!(validate_parameters_not_empty(&[1, 2, 3]).is_ok());
586    }
587
588    #[test]
589    fn test_validate_parameters_not_empty_invalid() {
590        let empty: &[i32] = &[];
591        assert!(validate_parameters_not_empty(empty).is_err());
592    }
593
594    // =============================================================================
595    // Dropout Probability Validation Tests
596    // =============================================================================
597
598    #[test]
599    fn test_validate_dropout_probability_valid() {
600        assert!(validate_dropout_probability(0.0).is_ok());
601        assert!(validate_dropout_probability(0.5).is_ok());
602        assert!(validate_dropout_probability(1.0).is_ok());
603    }
604
605    #[test]
606    fn test_validate_dropout_probability_invalid() {
607        assert!(validate_dropout_probability(-0.1).is_err());
608        assert!(validate_dropout_probability(1.1).is_err());
609    }
610
611    // =============================================================================
612    // Kernel Size Validation Tests
613    // =============================================================================
614
615    #[test]
616    fn test_validate_kernel_size_valid() {
617        assert!(validate_kernel_size(1, "kernel").is_ok());
618        assert!(validate_kernel_size(3, "kernel").is_ok());
619        assert!(validate_kernel_size(5, "kernel").is_ok());
620    }
621
622    #[test]
623    fn test_validate_kernel_size_invalid() {
624        assert!(validate_kernel_size(0, "kernel").is_err());
625    }
626
627    // =============================================================================
628    // Stride Validation Tests
629    // =============================================================================
630
631    #[test]
632    fn test_validate_stride_valid() {
633        assert!(validate_stride(1, "stride").is_ok());
634        assert!(validate_stride(2, "stride").is_ok());
635    }
636
637    #[test]
638    fn test_validate_stride_invalid() {
639        assert!(validate_stride(0, "stride").is_err());
640    }
641
642    // =============================================================================
643    // Tensor NDim Validation Tests
644    // =============================================================================
645
646    #[test]
647    fn test_validate_tensor_ndim_valid() {
648        assert!(validate_tensor_ndim(4, 4, "conv2d").is_ok());
649        assert!(validate_tensor_ndim(2, 2, "linear").is_ok());
650    }
651
652    #[test]
653    fn test_validate_tensor_ndim_invalid() {
654        assert!(validate_tensor_ndim(3, 4, "conv2d").is_err());
655        assert!(validate_tensor_ndim(5, 4, "conv2d").is_err());
656    }
657
658    // =============================================================================
659    // Tensor Min NDim Validation Tests
660    // =============================================================================
661
662    #[test]
663    fn test_validate_tensor_min_ndim_valid() {
664        assert!(validate_tensor_min_ndim(4, 2, "operation").is_ok());
665        assert!(validate_tensor_min_ndim(2, 2, "operation").is_ok());
666    }
667
668    #[test]
669    fn test_validate_tensor_min_ndim_invalid() {
670        assert!(validate_tensor_min_ndim(1, 2, "operation").is_err());
671    }
672
673    // =============================================================================
674    // Number of Features Validation Tests
675    // =============================================================================
676
677    #[test]
678    fn test_validate_num_features_valid() {
679        assert!(validate_num_features(64, 64, "BatchNorm").is_ok());
680    }
681
682    #[test]
683    fn test_validate_num_features_invalid() {
684        assert!(validate_num_features(32, 64, "BatchNorm").is_err());
685    }
686
687    // =============================================================================
688    // Finite Value Validation Tests
689    // =============================================================================
690
691    #[test]
692    fn test_validate_finite_valid() {
693        assert!(validate_finite(0.0, "value").is_ok());
694        assert!(validate_finite(1.0, "value").is_ok());
695        assert!(validate_finite(-1.0, "value").is_ok());
696    }
697
698    #[test]
699    fn test_validate_finite_invalid() {
700        assert!(validate_finite(f32::NAN, "value").is_err());
701        assert!(validate_finite(f32::INFINITY, "value").is_err());
702        assert!(validate_finite(f32::NEG_INFINITY, "value").is_err());
703    }
704
705    // =============================================================================
706    // Range Validation Tests
707    // =============================================================================
708
709    #[test]
710    fn test_validate_range_valid() {
711        assert!(validate_range(0, 10, "range").is_ok());
712        assert!(validate_range(5, 10, "range").is_ok());
713    }
714
715    #[test]
716    fn test_validate_range_invalid() {
717        assert!(validate_range(10, 10, "range").is_err());
718        assert!(validate_range(10, 5, "range").is_err());
719    }
720
721    // =============================================================================
722    // Pooling Output Size Validation Tests
723    // =============================================================================
724
725    #[test]
726    fn test_validate_pooling_output_size_valid() {
727        // Input: 28, Kernel: 2, Stride: 2, Padding: 0, Dilation: 1
728        // Output: (28 + 0 - 2) / 2 + 1 = 14
729        assert_eq!(
730            validate_pooling_output_size(28, 2, 2, 0, 1)
731                .expect("validate pooling output size should succeed"),
732            14
733        );
734
735        // Input: 32, Kernel: 3, Stride: 1, Padding: 1, Dilation: 1
736        // Output: (32 + 2 - 3) / 1 + 1 = 32
737        assert_eq!(
738            validate_pooling_output_size(32, 3, 1, 1, 1)
739                .expect("validate pooling output size should succeed"),
740            32
741        );
742    }
743
744    #[test]
745    fn test_validate_pooling_output_size_invalid_zero_kernel() {
746        assert!(validate_pooling_output_size(28, 0, 2, 0, 1).is_err());
747    }
748
749    #[test]
750    fn test_validate_pooling_output_size_invalid_zero_stride() {
751        assert!(validate_pooling_output_size(28, 2, 0, 0, 1).is_err());
752    }
753
754    #[test]
755    fn test_validate_pooling_output_size_invalid_too_small() {
756        // Input: 2, Kernel: 5, Stride: 1, Padding: 0, Dilation: 1
757        // Input too small for kernel
758        assert!(validate_pooling_output_size(2, 5, 1, 0, 1).is_err());
759    }
760
761    // =============================================================================
762    // Convolution Parameters Validation Tests
763    // =============================================================================
764
765    #[test]
766    fn test_validate_conv_params_valid() {
767        assert!(validate_conv_params(3, 64, 3).is_ok());
768        assert!(validate_conv_params(64, 128, 5).is_ok());
769    }
770
771    #[test]
772    fn test_validate_conv_params_invalid_in_channels() {
773        assert!(validate_conv_params(0, 64, 3).is_err());
774    }
775
776    #[test]
777    fn test_validate_conv_params_invalid_out_channels() {
778        assert!(validate_conv_params(3, 0, 3).is_err());
779    }
780
781    #[test]
782    fn test_validate_conv_params_invalid_kernel_size() {
783        assert!(validate_conv_params(3, 64, 0).is_err());
784    }
785}