Skip to main content

torsh_python/utils/
validation.rs

1//! Input validation utilities for Python bindings
2
3use crate::error::PyResult;
4use pyo3::prelude::*;
5
6/// Validate that a shape is valid (all dimensions > 0)
7pub fn validate_shape(shape: &[usize]) -> PyResult<()> {
8    for (i, &dim) in shape.iter().enumerate() {
9        if dim == 0 {
10            return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
11                "Invalid shape: dimension {} cannot be zero",
12                i
13            )));
14        }
15    }
16    Ok(())
17}
18
19/// Validate that an index is within bounds for a given dimension
20pub fn validate_index(index: i64, dim_size: usize) -> PyResult<usize> {
21    let positive_index = if index < 0 {
22        let abs_index = (-index) as usize;
23        if abs_index > dim_size {
24            return Err(PyErr::new::<pyo3::exceptions::PyIndexError, _>(format!(
25                "Index {} is out of bounds for dimension with size {}",
26                index, dim_size
27            )));
28        }
29        dim_size - abs_index
30    } else {
31        let pos_index = index as usize;
32        if pos_index >= dim_size {
33            return Err(PyErr::new::<pyo3::exceptions::PyIndexError, _>(format!(
34                "Index {} is out of bounds for dimension with size {}",
35                index, dim_size
36            )));
37        }
38        pos_index
39    };
40    Ok(positive_index)
41}
42
43/// Validate that dimensions are compatible for broadcasting
44pub fn validate_broadcast_shapes(shape1: &[usize], shape2: &[usize]) -> PyResult<Vec<usize>> {
45    let mut result_shape = Vec::new();
46    let max_dims = shape1.len().max(shape2.len());
47
48    for i in 0..max_dims {
49        let dim1 = if i < shape1.len() {
50            shape1[shape1.len() - 1 - i]
51        } else {
52            1
53        };
54        let dim2 = if i < shape2.len() {
55            shape2[shape2.len() - 1 - i]
56        } else {
57            1
58        };
59
60        if dim1 == dim2 || dim1 == 1 || dim2 == 1 {
61            result_shape.push(dim1.max(dim2));
62        } else {
63            return Err(PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!(
64                "Cannot broadcast shapes {:?} and {:?}",
65                shape1, shape2
66            )));
67        }
68    }
69
70    result_shape.reverse();
71    Ok(result_shape)
72}
73
74/// Validate that a learning rate is positive
75pub fn validate_learning_rate(lr: f32) -> PyResult<()> {
76    if lr <= 0.0 {
77        return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
78            "Learning rate must be positive",
79        ));
80    }
81    Ok(())
82}
83
84/// Validate that momentum is in valid range [0, 1]
85pub fn validate_momentum(momentum: f32) -> PyResult<()> {
86    if !(0.0..=1.0).contains(&momentum) {
87        return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
88            "Momentum must be in range [0, 1]",
89        ));
90    }
91    Ok(())
92}
93
94/// Validate that weight decay is non-negative
95pub fn validate_weight_decay(weight_decay: f32) -> PyResult<()> {
96    if weight_decay < 0.0 {
97        return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
98            "Weight decay must be non-negative",
99        ));
100    }
101    Ok(())
102}
103
104/// Validate that epsilon is positive
105pub fn validate_epsilon(eps: f32) -> PyResult<()> {
106    if eps <= 0.0 {
107        return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
108            "Epsilon must be positive",
109        ));
110    }
111    Ok(())
112}
113
114/// Validate beta parameters for Adam-like optimizers
115pub fn validate_betas(betas: (f32, f32)) -> PyResult<()> {
116    let (beta1, beta2) = betas;
117    if !(0.0..1.0).contains(&beta1) {
118        return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
119            "Beta1 must be in range [0, 1)",
120        ));
121    }
122    if !(0.0..1.0).contains(&beta2) {
123        return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
124            "Beta2 must be in range [0, 1)",
125        ));
126    }
127    Ok(())
128}
129
130/// Validate that tensor dimensions match for operations
131pub fn validate_tensor_shapes_match(shape1: &[usize], shape2: &[usize]) -> PyResult<()> {
132    if shape1 != shape2 {
133        return Err(PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!(
134            "Tensor shapes do not match: {:?} vs {:?}",
135            shape1, shape2
136        )));
137    }
138    Ok(())
139}
140
141/// Validate that a dimension index is valid for a tensor
142pub fn validate_dimension(dim: i32, ndim: usize) -> PyResult<usize> {
143    let positive_dim = if dim < 0 {
144        let abs_dim = (-dim) as usize;
145        if abs_dim > ndim {
146            return Err(PyErr::new::<pyo3::exceptions::PyIndexError, _>(format!(
147                "Dimension {} is out of bounds for tensor with {} dimensions",
148                dim, ndim
149            )));
150        }
151        ndim - abs_dim
152    } else {
153        let pos_dim = dim as usize;
154        if pos_dim >= ndim {
155            return Err(PyErr::new::<pyo3::exceptions::PyIndexError, _>(format!(
156                "Dimension {} is out of bounds for tensor with {} dimensions",
157                dim, ndim
158            )));
159        }
160        pos_dim
161    };
162    Ok(positive_dim)
163}
164
165/// Validate that parameters list is not empty
166pub fn validate_parameters_not_empty<T>(params: &[T]) -> PyResult<()> {
167    if params.is_empty() {
168        return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
169            "Parameters list cannot be empty",
170        ));
171    }
172    Ok(())
173}
174
175/// Validate dropout probability is in valid range [0, 1]
176pub fn validate_dropout_probability(p: f32) -> PyResult<()> {
177    if !(0.0..=1.0).contains(&p) {
178        return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
179            "Dropout probability must be in range [0, 1], got {}",
180            p
181        )));
182    }
183    Ok(())
184}
185
186/// Validate kernel size is positive
187pub fn validate_kernel_size(kernel_size: usize, name: &str) -> PyResult<()> {
188    if kernel_size == 0 {
189        return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
190            "{} must be positive, got 0",
191            name
192        )));
193    }
194    Ok(())
195}
196
197/// Validate stride is positive
198pub fn validate_stride(stride: usize, name: &str) -> PyResult<()> {
199    if stride == 0 {
200        return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
201            "{} must be positive, got 0",
202            name
203        )));
204    }
205    Ok(())
206}
207
208/// Validate that input tensor has expected number of dimensions
209pub fn validate_tensor_ndim(
210    actual_ndim: usize,
211    expected_ndim: usize,
212    op_name: &str,
213) -> PyResult<()> {
214    if actual_ndim != expected_ndim {
215        return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
216            "{} expects {}D input, got {}D",
217            op_name, expected_ndim, actual_ndim
218        )));
219    }
220    Ok(())
221}
222
223/// Validate that input tensor has at least minimum number of dimensions
224pub fn validate_tensor_min_ndim(
225    actual_ndim: usize,
226    min_ndim: usize,
227    op_name: &str,
228) -> PyResult<()> {
229    if actual_ndim < min_ndim {
230        return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
231            "{} expects at least {}D input, got {}D",
232            op_name, min_ndim, actual_ndim
233        )));
234    }
235    Ok(())
236}
237
238/// Validate that number of features matches expected value
239pub fn validate_num_features(
240    actual_features: usize,
241    expected_features: usize,
242    layer_name: &str,
243) -> PyResult<()> {
244    if actual_features != expected_features {
245        return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
246            "{} expected {} features, got {}",
247            layer_name, expected_features, actual_features
248        )));
249    }
250    Ok(())
251}
252
253/// Validate that a value is finite (not NaN or infinity)
254pub fn validate_finite(value: f32, name: &str) -> PyResult<()> {
255    if !value.is_finite() {
256        return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
257            "{} must be finite, got {}",
258            name, value
259        )));
260    }
261    Ok(())
262}
263
264/// Validate that a range is valid (start < end)
265pub fn validate_range(start: usize, end: usize, name: &str) -> PyResult<()> {
266    if start >= end {
267        return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
268            "Invalid range for {}: start ({}) must be less than end ({})",
269            name, start, end
270        )));
271    }
272    Ok(())
273}
274
275/// Validate pooling output size calculation
276pub fn validate_pooling_output_size(
277    input_size: usize,
278    kernel_size: usize,
279    stride: usize,
280    padding: usize,
281    dilation: usize,
282) -> PyResult<usize> {
283    if kernel_size == 0 || stride == 0 {
284        return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
285            "Kernel size and stride must be positive",
286        ));
287    }
288
289    let effective_kernel = dilation * (kernel_size - 1) + 1;
290    if input_size + 2 * padding < effective_kernel {
291        return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
292            "Input size {} (with padding {}) is too small for kernel size {} (with dilation {})",
293            input_size, padding, kernel_size, dilation
294        )));
295    }
296
297    let output_size = (input_size + 2 * padding - effective_kernel) / stride + 1;
298    Ok(output_size)
299}
300
301/// Validate convolution parameters
302pub fn validate_conv_params(
303    in_channels: usize,
304    out_channels: usize,
305    kernel_size: usize,
306) -> PyResult<()> {
307    if in_channels == 0 {
308        return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
309            "in_channels must be positive",
310        ));
311    }
312    if out_channels == 0 {
313        return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
314            "out_channels must be positive",
315        ));
316    }
317    if kernel_size == 0 {
318        return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
319            "kernel_size must be positive",
320        ));
321    }
322    Ok(())
323}
324
325#[cfg(test)]
326mod tests {
327    use super::*;
328
329    // =============================================================================
330    // Shape Validation Tests
331    // =============================================================================
332
333    #[test]
334    fn test_validate_shape_valid() {
335        assert!(validate_shape(&[1, 2, 3]).is_ok());
336        assert!(validate_shape(&[10, 20, 30, 40]).is_ok());
337    }
338
339    #[test]
340    fn test_validate_shape_with_zero() {
341        assert!(validate_shape(&[1, 0, 3]).is_err());
342        assert!(validate_shape(&[0]).is_err());
343    }
344
345    #[test]
346    fn test_validate_shape_empty() {
347        // Empty shape (scalar) should be valid
348        assert!(validate_shape(&[]).is_ok());
349    }
350
351    // =============================================================================
352    // Index Validation Tests
353    // =============================================================================
354
355    #[test]
356    fn test_validate_index_positive() {
357        assert_eq!(validate_index(0, 10).unwrap(), 0);
358        assert_eq!(validate_index(5, 10).unwrap(), 5);
359        assert_eq!(validate_index(9, 10).unwrap(), 9);
360    }
361
362    #[test]
363    fn test_validate_index_negative() {
364        assert_eq!(validate_index(-1, 10).unwrap(), 9);
365        assert_eq!(validate_index(-5, 10).unwrap(), 5);
366        assert_eq!(validate_index(-10, 10).unwrap(), 0);
367    }
368
369    #[test]
370    fn test_validate_index_out_of_bounds_positive() {
371        assert!(validate_index(10, 10).is_err());
372        assert!(validate_index(100, 10).is_err());
373    }
374
375    #[test]
376    fn test_validate_index_out_of_bounds_negative() {
377        assert!(validate_index(-11, 10).is_err());
378        assert!(validate_index(-100, 10).is_err());
379    }
380
381    // =============================================================================
382    // Broadcasting Tests
383    // =============================================================================
384
385    #[test]
386    fn test_validate_broadcast_shapes_compatible() {
387        assert_eq!(
388            validate_broadcast_shapes(&[3, 4], &[3, 4]).unwrap(),
389            vec![3, 4]
390        );
391        assert_eq!(
392            validate_broadcast_shapes(&[3, 1], &[3, 4]).unwrap(),
393            vec![3, 4]
394        );
395        assert_eq!(
396            validate_broadcast_shapes(&[1, 4], &[3, 4]).unwrap(),
397            vec![3, 4]
398        );
399        assert_eq!(
400            validate_broadcast_shapes(&[3, 4], &[4]).unwrap(),
401            vec![3, 4]
402        );
403    }
404
405    #[test]
406    fn test_validate_broadcast_shapes_incompatible() {
407        assert!(validate_broadcast_shapes(&[3, 4], &[3, 5]).is_err());
408        assert!(validate_broadcast_shapes(&[2, 3], &[3, 4]).is_err());
409    }
410
411    // =============================================================================
412    // Learning Rate Validation Tests
413    // =============================================================================
414
415    #[test]
416    fn test_validate_learning_rate_valid() {
417        assert!(validate_learning_rate(0.001).is_ok());
418        assert!(validate_learning_rate(0.1).is_ok());
419        assert!(validate_learning_rate(1.0).is_ok());
420        assert!(validate_learning_rate(10.0).is_ok());
421    }
422
423    #[test]
424    fn test_validate_learning_rate_invalid() {
425        assert!(validate_learning_rate(0.0).is_err());
426        assert!(validate_learning_rate(-0.1).is_err());
427    }
428
429    // =============================================================================
430    // Momentum Validation Tests
431    // =============================================================================
432
433    #[test]
434    fn test_validate_momentum_valid() {
435        assert!(validate_momentum(0.0).is_ok());
436        assert!(validate_momentum(0.5).is_ok());
437        assert!(validate_momentum(0.9).is_ok());
438        assert!(validate_momentum(1.0).is_ok());
439    }
440
441    #[test]
442    fn test_validate_momentum_invalid() {
443        assert!(validate_momentum(-0.1).is_err());
444        assert!(validate_momentum(1.1).is_err());
445    }
446
447    // =============================================================================
448    // Weight Decay Validation Tests
449    // =============================================================================
450
451    #[test]
452    fn test_validate_weight_decay_valid() {
453        assert!(validate_weight_decay(0.0).is_ok());
454        assert!(validate_weight_decay(0.01).is_ok());
455        assert!(validate_weight_decay(1.0).is_ok());
456    }
457
458    #[test]
459    fn test_validate_weight_decay_invalid() {
460        assert!(validate_weight_decay(-0.1).is_err());
461    }
462
463    // =============================================================================
464    // Epsilon Validation Tests
465    // =============================================================================
466
467    #[test]
468    fn test_validate_epsilon_valid() {
469        assert!(validate_epsilon(1e-8).is_ok());
470        assert!(validate_epsilon(1e-5).is_ok());
471        assert!(validate_epsilon(0.1).is_ok());
472    }
473
474    #[test]
475    fn test_validate_epsilon_invalid() {
476        assert!(validate_epsilon(0.0).is_err());
477        assert!(validate_epsilon(-1e-8).is_err());
478    }
479
480    // =============================================================================
481    // Beta Parameters Validation Tests
482    // =============================================================================
483
484    #[test]
485    fn test_validate_betas_valid() {
486        assert!(validate_betas((0.0, 0.0)).is_ok());
487        assert!(validate_betas((0.9, 0.999)).is_ok());
488        assert!(validate_betas((0.5, 0.5)).is_ok());
489    }
490
491    #[test]
492    fn test_validate_betas_invalid() {
493        assert!(validate_betas((-0.1, 0.5)).is_err());
494        assert!(validate_betas((0.5, 1.0)).is_err());
495        assert!(validate_betas((1.0, 0.5)).is_err());
496        assert!(validate_betas((1.1, 0.5)).is_err());
497    }
498
499    // =============================================================================
500    // Tensor Shape Matching Tests
501    // =============================================================================
502
503    #[test]
504    fn test_validate_tensor_shapes_match_valid() {
505        assert!(validate_tensor_shapes_match(&[3, 4], &[3, 4]).is_ok());
506        assert!(validate_tensor_shapes_match(&[], &[]).is_ok());
507    }
508
509    #[test]
510    fn test_validate_tensor_shapes_match_invalid() {
511        assert!(validate_tensor_shapes_match(&[3, 4], &[3, 5]).is_err());
512        assert!(validate_tensor_shapes_match(&[3, 4], &[4, 3]).is_err());
513    }
514
515    // =============================================================================
516    // Dimension Validation Tests
517    // =============================================================================
518
519    #[test]
520    fn test_validate_dimension_positive() {
521        assert_eq!(validate_dimension(0, 4).unwrap(), 0);
522        assert_eq!(validate_dimension(2, 4).unwrap(), 2);
523        assert_eq!(validate_dimension(3, 4).unwrap(), 3);
524    }
525
526    #[test]
527    fn test_validate_dimension_negative() {
528        assert_eq!(validate_dimension(-1, 4).unwrap(), 3);
529        assert_eq!(validate_dimension(-2, 4).unwrap(), 2);
530        assert_eq!(validate_dimension(-4, 4).unwrap(), 0);
531    }
532
533    #[test]
534    fn test_validate_dimension_out_of_bounds() {
535        assert!(validate_dimension(4, 4).is_err());
536        assert!(validate_dimension(-5, 4).is_err());
537    }
538
539    // =============================================================================
540    // Parameters Validation Tests
541    // =============================================================================
542
543    #[test]
544    fn test_validate_parameters_not_empty_valid() {
545        assert!(validate_parameters_not_empty(&[1, 2, 3]).is_ok());
546    }
547
548    #[test]
549    fn test_validate_parameters_not_empty_invalid() {
550        let empty: &[i32] = &[];
551        assert!(validate_parameters_not_empty(empty).is_err());
552    }
553
554    // =============================================================================
555    // Dropout Probability Validation Tests
556    // =============================================================================
557
558    #[test]
559    fn test_validate_dropout_probability_valid() {
560        assert!(validate_dropout_probability(0.0).is_ok());
561        assert!(validate_dropout_probability(0.5).is_ok());
562        assert!(validate_dropout_probability(1.0).is_ok());
563    }
564
565    #[test]
566    fn test_validate_dropout_probability_invalid() {
567        assert!(validate_dropout_probability(-0.1).is_err());
568        assert!(validate_dropout_probability(1.1).is_err());
569    }
570
571    // =============================================================================
572    // Kernel Size Validation Tests
573    // =============================================================================
574
575    #[test]
576    fn test_validate_kernel_size_valid() {
577        assert!(validate_kernel_size(1, "kernel").is_ok());
578        assert!(validate_kernel_size(3, "kernel").is_ok());
579        assert!(validate_kernel_size(5, "kernel").is_ok());
580    }
581
582    #[test]
583    fn test_validate_kernel_size_invalid() {
584        assert!(validate_kernel_size(0, "kernel").is_err());
585    }
586
587    // =============================================================================
588    // Stride Validation Tests
589    // =============================================================================
590
591    #[test]
592    fn test_validate_stride_valid() {
593        assert!(validate_stride(1, "stride").is_ok());
594        assert!(validate_stride(2, "stride").is_ok());
595    }
596
597    #[test]
598    fn test_validate_stride_invalid() {
599        assert!(validate_stride(0, "stride").is_err());
600    }
601
602    // =============================================================================
603    // Tensor NDim Validation Tests
604    // =============================================================================
605
606    #[test]
607    fn test_validate_tensor_ndim_valid() {
608        assert!(validate_tensor_ndim(4, 4, "conv2d").is_ok());
609        assert!(validate_tensor_ndim(2, 2, "linear").is_ok());
610    }
611
612    #[test]
613    fn test_validate_tensor_ndim_invalid() {
614        assert!(validate_tensor_ndim(3, 4, "conv2d").is_err());
615        assert!(validate_tensor_ndim(5, 4, "conv2d").is_err());
616    }
617
618    // =============================================================================
619    // Tensor Min NDim Validation Tests
620    // =============================================================================
621
622    #[test]
623    fn test_validate_tensor_min_ndim_valid() {
624        assert!(validate_tensor_min_ndim(4, 2, "operation").is_ok());
625        assert!(validate_tensor_min_ndim(2, 2, "operation").is_ok());
626    }
627
628    #[test]
629    fn test_validate_tensor_min_ndim_invalid() {
630        assert!(validate_tensor_min_ndim(1, 2, "operation").is_err());
631    }
632
633    // =============================================================================
634    // Number of Features Validation Tests
635    // =============================================================================
636
637    #[test]
638    fn test_validate_num_features_valid() {
639        assert!(validate_num_features(64, 64, "BatchNorm").is_ok());
640    }
641
642    #[test]
643    fn test_validate_num_features_invalid() {
644        assert!(validate_num_features(32, 64, "BatchNorm").is_err());
645    }
646
647    // =============================================================================
648    // Finite Value Validation Tests
649    // =============================================================================
650
651    #[test]
652    fn test_validate_finite_valid() {
653        assert!(validate_finite(0.0, "value").is_ok());
654        assert!(validate_finite(1.0, "value").is_ok());
655        assert!(validate_finite(-1.0, "value").is_ok());
656    }
657
658    #[test]
659    fn test_validate_finite_invalid() {
660        assert!(validate_finite(f32::NAN, "value").is_err());
661        assert!(validate_finite(f32::INFINITY, "value").is_err());
662        assert!(validate_finite(f32::NEG_INFINITY, "value").is_err());
663    }
664
665    // =============================================================================
666    // Range Validation Tests
667    // =============================================================================
668
669    #[test]
670    fn test_validate_range_valid() {
671        assert!(validate_range(0, 10, "range").is_ok());
672        assert!(validate_range(5, 10, "range").is_ok());
673    }
674
675    #[test]
676    fn test_validate_range_invalid() {
677        assert!(validate_range(10, 10, "range").is_err());
678        assert!(validate_range(10, 5, "range").is_err());
679    }
680
681    // =============================================================================
682    // Pooling Output Size Validation Tests
683    // =============================================================================
684
685    #[test]
686    fn test_validate_pooling_output_size_valid() {
687        // Input: 28, Kernel: 2, Stride: 2, Padding: 0, Dilation: 1
688        // Output: (28 + 0 - 2) / 2 + 1 = 14
689        assert_eq!(validate_pooling_output_size(28, 2, 2, 0, 1).unwrap(), 14);
690
691        // Input: 32, Kernel: 3, Stride: 1, Padding: 1, Dilation: 1
692        // Output: (32 + 2 - 3) / 1 + 1 = 32
693        assert_eq!(validate_pooling_output_size(32, 3, 1, 1, 1).unwrap(), 32);
694    }
695
696    #[test]
697    fn test_validate_pooling_output_size_invalid_zero_kernel() {
698        assert!(validate_pooling_output_size(28, 0, 2, 0, 1).is_err());
699    }
700
701    #[test]
702    fn test_validate_pooling_output_size_invalid_zero_stride() {
703        assert!(validate_pooling_output_size(28, 2, 0, 0, 1).is_err());
704    }
705
706    #[test]
707    fn test_validate_pooling_output_size_invalid_too_small() {
708        // Input: 2, Kernel: 5, Stride: 1, Padding: 0, Dilation: 1
709        // Input too small for kernel
710        assert!(validate_pooling_output_size(2, 5, 1, 0, 1).is_err());
711    }
712
713    // =============================================================================
714    // Convolution Parameters Validation Tests
715    // =============================================================================
716
717    #[test]
718    fn test_validate_conv_params_valid() {
719        assert!(validate_conv_params(3, 64, 3).is_ok());
720        assert!(validate_conv_params(64, 128, 5).is_ok());
721    }
722
723    #[test]
724    fn test_validate_conv_params_invalid_in_channels() {
725        assert!(validate_conv_params(0, 64, 3).is_err());
726    }
727
728    #[test]
729    fn test_validate_conv_params_invalid_out_channels() {
730        assert!(validate_conv_params(3, 0, 3).is_err());
731    }
732
733    #[test]
734    fn test_validate_conv_params_invalid_kernel_size() {
735        assert!(validate_conv_params(3, 64, 0).is_err());
736    }
737}