1use torsh_core::{Result as TorshResult, TorshError};
7use torsh_tensor::Tensor;
8
9pub fn validate_elementwise_shapes(a: &Tensor, b: &Tensor) -> TorshResult<()> {
11 let binding_a = a.shape();
12 let shape_a = binding_a.dims();
13 let binding_b = b.shape();
14 let shape_b = binding_b.dims();
15
16 if shape_a != shape_b {
17 return Err(TorshError::invalid_argument_with_context(
18 &format!(
19 "Tensor shapes are not compatible for element-wise operation: {:?} vs {:?}",
20 shape_a, shape_b
21 ),
22 "elementwise_operation",
23 ));
24 }
25
26 Ok(())
27}
28
29pub fn validate_range<T: PartialOrd + std::fmt::Display>(
31 value: T,
32 min: T,
33 max: T,
34 param_name: &str,
35 context: &str,
36) -> TorshResult<()> {
37 if value < min || value > max {
38 return Err(TorshError::invalid_argument_with_context(
39 &format!(
40 "{} must be in range [{}, {}], got {}",
41 param_name, min, max, value
42 ),
43 context,
44 ));
45 }
46 Ok(())
47}
48
49pub fn validate_non_empty(tensor: &Tensor, context: &str) -> TorshResult<()> {
51 if tensor.numel() == 0 {
52 return Err(TorshError::invalid_argument_with_context(
53 "Tensor cannot be empty",
54 context,
55 ));
56 }
57 Ok(())
58}
59
60pub fn validate_dimension(tensor: &Tensor, dim: i32, context: &str) -> TorshResult<()> {
62 let ndim = tensor.shape().ndim() as i32;
63 let normalized_dim = if dim < 0 { dim + ndim } else { dim };
64
65 if normalized_dim < 0 || normalized_dim >= ndim {
66 return Err(TorshError::invalid_argument_with_context(
67 &format!(
68 "Dimension {} is out of range for tensor with {} dimensions",
69 dim, ndim
70 ),
71 context,
72 ));
73 }
74 Ok(())
75}
76
77pub fn validate_positive<T: PartialOrd + std::fmt::Display + Copy>(
79 value: T,
80 param_name: &str,
81 context: &str,
82) -> TorshResult<()>
83where
84 T: From<f32>,
85{
86 let zero = T::from(0.0);
87 if value <= zero {
88 return Err(TorshError::invalid_argument_with_context(
89 &format!("{} must be positive, got {}", param_name, value),
90 context,
91 ));
92 }
93 Ok(())
94}
95
96pub fn function_context(function_name: &str) -> String {
98 function_name.to_string()
99}
100
101pub fn validate_activation_params<T: PartialOrd + std::fmt::Display + Copy>(
103 input: &Tensor,
104 alpha: Option<T>,
105 beta: Option<T>,
106 context: &str,
107) -> TorshResult<()>
108where
109 T: From<f32>,
110{
111 validate_non_empty(input, context)?;
112
113 if let Some(alpha_val) = alpha {
114 validate_positive(alpha_val, "alpha", context)?;
115 }
116
117 if let Some(beta_val) = beta {
118 validate_positive(beta_val, "beta", context)?;
119 }
120
121 Ok(())
122}
123
124pub fn validate_pooling_params(
126 input: &Tensor,
127 kernel_size: &[usize],
128 stride: &[usize],
129 _padding: &[usize],
130 context: &str,
131) -> TorshResult<()> {
132 validate_non_empty(input, context)?;
133
134 if kernel_size.is_empty() {
135 return Err(TorshError::invalid_argument_with_context(
136 "kernel_size cannot be empty",
137 context,
138 ));
139 }
140
141 if kernel_size.iter().any(|&k| k == 0) {
142 return Err(TorshError::invalid_argument_with_context(
143 "All kernel_size values must be positive",
144 context,
145 ));
146 }
147
148 if stride.iter().any(|&s| s == 0) {
149 return Err(TorshError::invalid_argument_with_context(
150 "All stride values must be positive",
151 context,
152 ));
153 }
154
155 Ok(())
156}
157
158pub fn validate_loss_params(
160 input: &Tensor,
161 target: &Tensor,
162 reduction: &str,
163 context: &str,
164) -> TorshResult<()> {
165 validate_non_empty(input, context)?;
166 validate_non_empty(target, context)?;
167
168 match reduction {
169 "none" | "mean" | "sum" => Ok(()),
170 _ => Err(TorshError::invalid_argument_with_context(
171 &format!(
172 "Invalid reduction '{}'. Must be 'none', 'mean', or 'sum'",
173 reduction
174 ),
175 context,
176 )),
177 }
178}
179
180pub fn validate_tensor_dims(
182 tensor: &Tensor,
183 expected_dims: usize,
184 context: &str,
185) -> TorshResult<()> {
186 let actual_dims = tensor.shape().ndim();
187 if actual_dims != expected_dims {
188 return Err(TorshError::invalid_argument_with_context(
189 &format!(
190 "Expected {}D tensor, got {}D tensor",
191 expected_dims, actual_dims
192 ),
193 context,
194 ));
195 }
196 Ok(())
197}
198
199pub fn validate_broadcastable_shapes(a: &Tensor, b: &Tensor, context: &str) -> TorshResult<()> {
201 let binding_a = a.shape();
202 let shape_a = binding_a.dims();
203 let binding_b = b.shape();
204 let shape_b = binding_b.dims();
205
206 if shape_a.len() != shape_b.len() && shape_a != shape_b {
208 let a_numel = a.numel();
210 let b_numel = b.numel();
211
212 if a_numel != 1 && b_numel != 1 && shape_a != shape_b {
213 return Err(TorshError::invalid_argument_with_context(
214 &format!(
215 "Tensor shapes {:?} and {:?} are not broadcastable",
216 shape_a, shape_b
217 ),
218 context,
219 ));
220 }
221 }
222
223 Ok(())
224}
225
226pub fn invalid_argument_error(message: &str, function_name: &str) -> TorshError {
228 TorshError::invalid_argument_with_context(message, function_name)
229}
230
231pub fn create_function_docs(
233 name: &str,
234 description: &str,
235 formula: Option<&str>,
236 parameters: &[(&str, &str)],
237 example: Option<&str>,
238) -> String {
239 let mut docs = String::new();
240 docs.push_str(&format!("/// {}\n", name));
241 docs.push_str("///\n");
242 docs.push_str(&format!("/// {}\n", description));
243
244 if let Some(formula) = formula {
245 docs.push_str("///\n");
246 docs.push_str(&format!("/// Formula: {}\n", formula));
247 }
248
249 if !parameters.is_empty() {
250 docs.push_str("///\n");
251 docs.push_str("/// # Parameters\n");
252 for (param, desc) in parameters {
253 docs.push_str(&format!("/// - `{}`: {}\n", param, desc));
254 }
255 }
256
257 if let Some(example) = example {
258 docs.push_str("///\n");
259 docs.push_str("/// # Example\n");
260 docs.push_str("/// ```rust\n");
261 docs.push_str(&format!("/// {}\n", example));
262 docs.push_str("/// ```\n");
263 }
264
265 docs
266}
267
268pub fn safe_log(input: &Tensor, eps: Option<f32>, max_val: Option<f32>) -> TorshResult<Tensor> {
287 let epsilon = eps.unwrap_or(1e-8_f32);
288 let maximum = max_val.unwrap_or(f32::MAX);
289
290 let clamped = input.clamp(epsilon, maximum)?;
291 clamped.log()
292}
293
294pub fn safe_log_prob(input: &Tensor, eps: Option<f32>) -> TorshResult<Tensor> {
312 let epsilon = eps.unwrap_or(1e-8_f32);
313 let clamped = input.clamp(epsilon, 1.0 - epsilon)?;
314 clamped.log()
315}
316
317pub fn safe_for_log(input: &Tensor, eps: Option<f32>, max_val: Option<f32>) -> TorshResult<Tensor> {
330 let epsilon = eps.unwrap_or(1e-8_f32);
331 let maximum = max_val.unwrap_or(f32::MAX);
332 input.clamp(epsilon, maximum)
333}
334
335pub fn handle_inplace_operation<F>(
337 input: &Tensor,
338 inplace: bool,
339 operation: F,
340 _context: &str,
341) -> TorshResult<Tensor>
342where
343 F: Fn(&Tensor) -> TorshResult<Tensor>,
344{
345 if inplace {
346 operation(input)
350 } else {
351 operation(input)
352 }
353}
354
355pub fn apply_elementwise_operation<F>(
357 input: &Tensor,
358 _inplace: bool,
359 operation: F,
360 _context: &str,
361) -> TorshResult<Tensor>
362where
363 F: Fn(f32) -> f32,
364{
365 let data = input.data()?;
367 let result_data: Vec<f32> = data.iter().map(|&x| operation(x)).collect();
368
369 Tensor::from_data(result_data, input.shape().dims().to_vec(), input.device())
370}
371
372pub fn apply_conditional_elementwise<F>(
374 input: &Tensor,
375 condition: F,
376 true_op: impl Fn(f32) -> f32,
377 false_op: impl Fn(f32) -> f32,
378 _inplace: bool,
379 _context: &str,
380) -> TorshResult<Tensor>
381where
382 F: Fn(f32) -> bool,
383{
384 let data = input.data()?;
385 let result_data: Vec<f32> = data
386 .iter()
387 .map(|&x| {
388 if condition(x) {
389 true_op(x)
390 } else {
391 false_op(x)
392 }
393 })
394 .collect();
395
396 Tensor::from_data(result_data, input.shape().dims().to_vec(), input.device())
397}
398
399pub fn calculate_pooling_output_size(
401 input_size: usize,
402 kernel_size: usize,
403 stride: usize,
404 padding: usize,
405 dilation: usize,
406) -> usize {
407 let effective_kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1);
408 (input_size + 2 * padding - effective_kernel_size) / stride + 1
409}
410
411pub fn calculate_pooling_output_size_2d(
413 input_size: (usize, usize),
414 kernel_size: (usize, usize),
415 stride: (usize, usize),
416 padding: (usize, usize),
417 dilation: (usize, usize),
418) -> (usize, usize) {
419 let out_h =
420 calculate_pooling_output_size(input_size.0, kernel_size.0, stride.0, padding.0, dilation.0);
421 let out_w =
422 calculate_pooling_output_size(input_size.1, kernel_size.1, stride.1, padding.1, dilation.1);
423 (out_h, out_w)
424}
425
426pub fn calculate_pooling_output_size_3d(
428 input_size: (usize, usize, usize),
429 kernel_size: (usize, usize, usize),
430 stride: (usize, usize, usize),
431 padding: (usize, usize, usize),
432 dilation: (usize, usize, usize),
433) -> (usize, usize, usize) {
434 let out_d =
435 calculate_pooling_output_size(input_size.0, kernel_size.0, stride.0, padding.0, dilation.0);
436 let out_h =
437 calculate_pooling_output_size(input_size.1, kernel_size.1, stride.1, padding.1, dilation.1);
438 let out_w =
439 calculate_pooling_output_size(input_size.2, kernel_size.2, stride.2, padding.2, dilation.2);
440 (out_d, out_h, out_w)
441}
442
443pub fn create_tensor_like(
445 reference: &Tensor,
446 data: Vec<f32>,
447 shape: Option<Vec<usize>>,
448) -> TorshResult<Tensor> {
449 let tensor_shape = match shape {
450 Some(s) => s,
451 None => reference.shape().dims().to_vec(),
452 };
453
454 Tensor::from_data(data, tensor_shape, reference.device())
455}
456
457pub fn apply_binary_elementwise<F>(
459 a: &Tensor,
460 b: &Tensor,
461 operation: F,
462 _context: &str,
463) -> TorshResult<Tensor>
464where
465 F: Fn(f32, f32) -> f32,
466{
467 validate_elementwise_shapes(a, b)?;
468
469 let data_a = a.data()?;
470 let data_b = b.data()?;
471
472 let result_data: Vec<f32> = data_a
473 .iter()
474 .zip(data_b.iter())
475 .map(|(&x, &y)| operation(x, y))
476 .collect();
477
478 create_tensor_like(a, result_data, None)
479}
480
481#[cfg(test)]
482mod tests {
483 use super::*;
484 use torsh_tensor::creation::zeros;
485
486 #[test]
487 fn test_validate_range() -> TorshResult<()> {
488 validate_range(5.0, 0.0, 10.0, "value", "test")?;
490
491 let result = validate_range(-1.0, 0.0, 10.0, "value", "test");
493 assert!(result.is_err());
494
495 let result = validate_range(15.0, 0.0, 10.0, "value", "test");
497 assert!(result.is_err());
498
499 Ok(())
500 }
501
502 #[test]
503 fn test_validate_non_empty() -> TorshResult<()> {
504 let tensor = zeros(&[2, 3])?;
506 validate_non_empty(&tensor, "test")?;
507
508 let empty_tensor = zeros(&[0])?;
510 let result = validate_non_empty(&empty_tensor, "test");
511 assert!(result.is_err());
512
513 Ok(())
514 }
515
516 #[test]
517 fn test_validate_dimension() -> TorshResult<()> {
518 let tensor = zeros(&[2, 3, 4])?;
519
520 validate_dimension(&tensor, 0, "test")?;
522 validate_dimension(&tensor, 1, "test")?;
523 validate_dimension(&tensor, 2, "test")?;
524 validate_dimension(&tensor, -1, "test")?; validate_dimension(&tensor, -2, "test")?; let result = validate_dimension(&tensor, 3, "test");
529 assert!(result.is_err());
530
531 let result = validate_dimension(&tensor, -4, "test");
532 assert!(result.is_err());
533
534 Ok(())
535 }
536
537 #[test]
538 fn test_validate_positive() -> TorshResult<()> {
539 validate_positive(1.5, "value", "test")?;
541
542 let result = validate_positive(0.0, "value", "test");
544 assert!(result.is_err());
545
546 let result = validate_positive(-1.0, "value", "test");
548 assert!(result.is_err());
549
550 Ok(())
551 }
552}