torsh_python/utils/
validation.rs

1//! Input validation utilities for Python bindings
2
3use crate::error::PyResult;
4use pyo3::prelude::*;
5
6/// Validate that a shape is valid (all dimensions > 0)
7pub fn validate_shape(shape: &[usize]) -> PyResult<()> {
8    for (i, &dim) in shape.iter().enumerate() {
9        if dim == 0 {
10            return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
11                "Invalid shape: dimension {} cannot be zero",
12                i
13            )));
14        }
15    }
16    Ok(())
17}
18
19/// Validate that an index is within bounds for a given dimension
20pub fn validate_index(index: i64, dim_size: usize) -> PyResult<usize> {
21    let positive_index = if index < 0 {
22        let abs_index = (-index) as usize;
23        if abs_index > dim_size {
24            return Err(PyErr::new::<pyo3::exceptions::PyIndexError, _>(format!(
25                "Index {} is out of bounds for dimension with size {}",
26                index, dim_size
27            )));
28        }
29        dim_size - abs_index
30    } else {
31        let pos_index = index as usize;
32        if pos_index >= dim_size {
33            return Err(PyErr::new::<pyo3::exceptions::PyIndexError, _>(format!(
34                "Index {} is out of bounds for dimension with size {}",
35                index, dim_size
36            )));
37        }
38        pos_index
39    };
40    Ok(positive_index)
41}
42
43/// Validate that dimensions are compatible for broadcasting
44pub fn validate_broadcast_shapes(shape1: &[usize], shape2: &[usize]) -> PyResult<Vec<usize>> {
45    let mut result_shape = Vec::new();
46    let max_dims = shape1.len().max(shape2.len());
47
48    for i in 0..max_dims {
49        let dim1 = if i < shape1.len() {
50            shape1[shape1.len() - 1 - i]
51        } else {
52            1
53        };
54        let dim2 = if i < shape2.len() {
55            shape2[shape2.len() - 1 - i]
56        } else {
57            1
58        };
59
60        if dim1 == dim2 || dim1 == 1 || dim2 == 1 {
61            result_shape.push(dim1.max(dim2));
62        } else {
63            return Err(PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!(
64                "Cannot broadcast shapes {:?} and {:?}",
65                shape1, shape2
66            )));
67        }
68    }
69
70    result_shape.reverse();
71    Ok(result_shape)
72}
73
74/// Validate that a learning rate is positive
75pub fn validate_learning_rate(lr: f32) -> PyResult<()> {
76    if lr <= 0.0 {
77        return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
78            "Learning rate must be positive",
79        ));
80    }
81    Ok(())
82}
83
84/// Validate that momentum is in valid range [0, 1]
85pub fn validate_momentum(momentum: f32) -> PyResult<()> {
86    if !(0.0..=1.0).contains(&momentum) {
87        return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
88            "Momentum must be in range [0, 1]",
89        ));
90    }
91    Ok(())
92}
93
94/// Validate that weight decay is non-negative
95pub fn validate_weight_decay(weight_decay: f32) -> PyResult<()> {
96    if weight_decay < 0.0 {
97        return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
98            "Weight decay must be non-negative",
99        ));
100    }
101    Ok(())
102}
103
104/// Validate that epsilon is positive
105pub fn validate_epsilon(eps: f32) -> PyResult<()> {
106    if eps <= 0.0 {
107        return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
108            "Epsilon must be positive",
109        ));
110    }
111    Ok(())
112}
113
114/// Validate beta parameters for Adam-like optimizers
115pub fn validate_betas(betas: (f32, f32)) -> PyResult<()> {
116    let (beta1, beta2) = betas;
117    if !(0.0..1.0).contains(&beta1) {
118        return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
119            "Beta1 must be in range [0, 1)",
120        ));
121    }
122    if !(0.0..1.0).contains(&beta2) {
123        return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
124            "Beta2 must be in range [0, 1)",
125        ));
126    }
127    Ok(())
128}
129
130/// Validate that tensor dimensions match for operations
131pub fn validate_tensor_shapes_match(shape1: &[usize], shape2: &[usize]) -> PyResult<()> {
132    if shape1 != shape2 {
133        return Err(PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!(
134            "Tensor shapes do not match: {:?} vs {:?}",
135            shape1, shape2
136        )));
137    }
138    Ok(())
139}
140
141/// Validate that a dimension index is valid for a tensor
142pub fn validate_dimension(dim: i32, ndim: usize) -> PyResult<usize> {
143    let positive_dim = if dim < 0 {
144        let abs_dim = (-dim) as usize;
145        if abs_dim > ndim {
146            return Err(PyErr::new::<pyo3::exceptions::PyIndexError, _>(format!(
147                "Dimension {} is out of bounds for tensor with {} dimensions",
148                dim, ndim
149            )));
150        }
151        ndim - abs_dim
152    } else {
153        let pos_dim = dim as usize;
154        if pos_dim >= ndim {
155            return Err(PyErr::new::<pyo3::exceptions::PyIndexError, _>(format!(
156                "Dimension {} is out of bounds for tensor with {} dimensions",
157                dim, ndim
158            )));
159        }
160        pos_dim
161    };
162    Ok(positive_dim)
163}
164
165/// Validate that parameters list is not empty
166pub fn validate_parameters_not_empty<T>(params: &[T]) -> PyResult<()> {
167    if params.is_empty() {
168        return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
169            "Parameters list cannot be empty",
170        ));
171    }
172    Ok(())
173}