torsh_python/utils/
validation.rs1use crate::error::PyResult;
4use pyo3::prelude::*;
5
6pub fn validate_shape(shape: &[usize]) -> PyResult<()> {
8 for (i, &dim) in shape.iter().enumerate() {
9 if dim == 0 {
10 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
11 "Invalid shape: dimension {} cannot be zero",
12 i
13 )));
14 }
15 }
16 Ok(())
17}
18
19pub fn validate_index(index: i64, dim_size: usize) -> PyResult<usize> {
21 let positive_index = if index < 0 {
22 let abs_index = (-index) as usize;
23 if abs_index > dim_size {
24 return Err(PyErr::new::<pyo3::exceptions::PyIndexError, _>(format!(
25 "Index {} is out of bounds for dimension with size {}",
26 index, dim_size
27 )));
28 }
29 dim_size - abs_index
30 } else {
31 let pos_index = index as usize;
32 if pos_index >= dim_size {
33 return Err(PyErr::new::<pyo3::exceptions::PyIndexError, _>(format!(
34 "Index {} is out of bounds for dimension with size {}",
35 index, dim_size
36 )));
37 }
38 pos_index
39 };
40 Ok(positive_index)
41}
42
43pub fn validate_broadcast_shapes(shape1: &[usize], shape2: &[usize]) -> PyResult<Vec<usize>> {
45 let mut result_shape = Vec::new();
46 let max_dims = shape1.len().max(shape2.len());
47
48 for i in 0..max_dims {
49 let dim1 = if i < shape1.len() {
50 shape1[shape1.len() - 1 - i]
51 } else {
52 1
53 };
54 let dim2 = if i < shape2.len() {
55 shape2[shape2.len() - 1 - i]
56 } else {
57 1
58 };
59
60 if dim1 == dim2 || dim1 == 1 || dim2 == 1 {
61 result_shape.push(dim1.max(dim2));
62 } else {
63 return Err(PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!(
64 "Cannot broadcast shapes {:?} and {:?}",
65 shape1, shape2
66 )));
67 }
68 }
69
70 result_shape.reverse();
71 Ok(result_shape)
72}
73
74pub fn validate_learning_rate(lr: f32) -> PyResult<()> {
76 if lr <= 0.0 {
77 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
78 "Learning rate must be positive",
79 ));
80 }
81 Ok(())
82}
83
84pub fn validate_momentum(momentum: f32) -> PyResult<()> {
86 if !(0.0..=1.0).contains(&momentum) {
87 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
88 "Momentum must be in range [0, 1]",
89 ));
90 }
91 Ok(())
92}
93
94pub fn validate_weight_decay(weight_decay: f32) -> PyResult<()> {
96 if weight_decay < 0.0 {
97 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
98 "Weight decay must be non-negative",
99 ));
100 }
101 Ok(())
102}
103
104pub fn validate_epsilon(eps: f32) -> PyResult<()> {
106 if eps <= 0.0 {
107 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
108 "Epsilon must be positive",
109 ));
110 }
111 Ok(())
112}
113
114pub fn validate_betas(betas: (f32, f32)) -> PyResult<()> {
116 let (beta1, beta2) = betas;
117 if !(0.0..1.0).contains(&beta1) {
118 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
119 "Beta1 must be in range [0, 1)",
120 ));
121 }
122 if !(0.0..1.0).contains(&beta2) {
123 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
124 "Beta2 must be in range [0, 1)",
125 ));
126 }
127 Ok(())
128}
129
130pub fn validate_tensor_shapes_match(shape1: &[usize], shape2: &[usize]) -> PyResult<()> {
132 if shape1 != shape2 {
133 return Err(PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!(
134 "Tensor shapes do not match: {:?} vs {:?}",
135 shape1, shape2
136 )));
137 }
138 Ok(())
139}
140
141pub fn validate_dimension(dim: i32, ndim: usize) -> PyResult<usize> {
143 let positive_dim = if dim < 0 {
144 let abs_dim = (-dim) as usize;
145 if abs_dim > ndim {
146 return Err(PyErr::new::<pyo3::exceptions::PyIndexError, _>(format!(
147 "Dimension {} is out of bounds for tensor with {} dimensions",
148 dim, ndim
149 )));
150 }
151 ndim - abs_dim
152 } else {
153 let pos_dim = dim as usize;
154 if pos_dim >= ndim {
155 return Err(PyErr::new::<pyo3::exceptions::PyIndexError, _>(format!(
156 "Dimension {} is out of bounds for tensor with {} dimensions",
157 dim, ndim
158 )));
159 }
160 pos_dim
161 };
162 Ok(positive_dim)
163}
164
165pub fn validate_parameters_not_empty<T>(params: &[T]) -> PyResult<()> {
167 if params.is_empty() {
168 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
169 "Parameters list cannot be empty",
170 ));
171 }
172 Ok(())
173}