Skip to main content

torsh_functional/
utils.rs

1//! Utility functions for torsh-functional
2//!
3//! This module contains helper functions and common patterns used
4//! throughout the functional API.
5
6use torsh_core::{Result as TorshResult, TorshError};
7use torsh_tensor::Tensor;
8
9/// Validates that input tensors have compatible shapes for element-wise operations
10pub fn validate_elementwise_shapes(a: &Tensor, b: &Tensor) -> TorshResult<()> {
11    let binding_a = a.shape();
12    let shape_a = binding_a.dims();
13    let binding_b = b.shape();
14    let shape_b = binding_b.dims();
15
16    if shape_a != shape_b {
17        return Err(TorshError::invalid_argument_with_context(
18            &format!(
19                "Tensor shapes are not compatible for element-wise operation: {:?} vs {:?}",
20                shape_a, shape_b
21            ),
22            "elementwise_operation",
23        ));
24    }
25
26    Ok(())
27}
28
29/// Validates that a value is within a specified range
30pub fn validate_range<T: PartialOrd + std::fmt::Display>(
31    value: T,
32    min: T,
33    max: T,
34    param_name: &str,
35    context: &str,
36) -> TorshResult<()> {
37    if value < min || value > max {
38        return Err(TorshError::invalid_argument_with_context(
39            &format!(
40                "{} must be in range [{}, {}], got {}",
41                param_name, min, max, value
42            ),
43            context,
44        ));
45    }
46    Ok(())
47}
48
49/// Validates that a tensor is not empty
50pub fn validate_non_empty(tensor: &Tensor, context: &str) -> TorshResult<()> {
51    if tensor.numel() == 0 {
52        return Err(TorshError::invalid_argument_with_context(
53            "Tensor cannot be empty",
54            context,
55        ));
56    }
57    Ok(())
58}
59
60/// Validates dimension index for a tensor
61pub fn validate_dimension(tensor: &Tensor, dim: i32, context: &str) -> TorshResult<()> {
62    let ndim = tensor.shape().ndim() as i32;
63    let normalized_dim = if dim < 0 { dim + ndim } else { dim };
64
65    if normalized_dim < 0 || normalized_dim >= ndim {
66        return Err(TorshError::invalid_argument_with_context(
67            &format!(
68                "Dimension {} is out of range for tensor with {} dimensions",
69                dim, ndim
70            ),
71            context,
72        ));
73    }
74    Ok(())
75}
76
77/// Validates that a parameter is positive
78pub fn validate_positive<T: PartialOrd + std::fmt::Display + Copy>(
79    value: T,
80    param_name: &str,
81    context: &str,
82) -> TorshResult<()>
83where
84    T: From<f32>,
85{
86    let zero = T::from(0.0);
87    if value <= zero {
88        return Err(TorshError::invalid_argument_with_context(
89            &format!("{} must be positive, got {}", param_name, value),
90            context,
91        ));
92    }
93    Ok(())
94}
95
96/// Creates a standardized context string for function errors
97pub fn function_context(function_name: &str) -> String {
98    function_name.to_string()
99}
100
101/// Standard parameter validation for activation functions
102pub fn validate_activation_params<T: PartialOrd + std::fmt::Display + Copy>(
103    input: &Tensor,
104    alpha: Option<T>,
105    beta: Option<T>,
106    context: &str,
107) -> TorshResult<()>
108where
109    T: From<f32>,
110{
111    validate_non_empty(input, context)?;
112
113    if let Some(alpha_val) = alpha {
114        validate_positive(alpha_val, "alpha", context)?;
115    }
116
117    if let Some(beta_val) = beta {
118        validate_positive(beta_val, "beta", context)?;
119    }
120
121    Ok(())
122}
123
124/// Standard parameter validation for pooling operations
125pub fn validate_pooling_params(
126    input: &Tensor,
127    kernel_size: &[usize],
128    stride: &[usize],
129    _padding: &[usize],
130    context: &str,
131) -> TorshResult<()> {
132    validate_non_empty(input, context)?;
133
134    if kernel_size.is_empty() {
135        return Err(TorshError::invalid_argument_with_context(
136            "kernel_size cannot be empty",
137            context,
138        ));
139    }
140
141    if kernel_size.iter().any(|&k| k == 0) {
142        return Err(TorshError::invalid_argument_with_context(
143            "All kernel_size values must be positive",
144            context,
145        ));
146    }
147
148    if stride.iter().any(|&s| s == 0) {
149        return Err(TorshError::invalid_argument_with_context(
150            "All stride values must be positive",
151            context,
152        ));
153    }
154
155    Ok(())
156}
157
158/// Standard parameter validation for loss functions
159pub fn validate_loss_params(
160    input: &Tensor,
161    target: &Tensor,
162    reduction: &str,
163    context: &str,
164) -> TorshResult<()> {
165    validate_non_empty(input, context)?;
166    validate_non_empty(target, context)?;
167
168    match reduction {
169        "none" | "mean" | "sum" => Ok(()),
170        _ => Err(TorshError::invalid_argument_with_context(
171            &format!(
172                "Invalid reduction '{}'. Must be 'none', 'mean', or 'sum'",
173                reduction
174            ),
175            context,
176        )),
177    }
178}
179
180/// Validates tensor dimensions for specific operations
181pub fn validate_tensor_dims(
182    tensor: &Tensor,
183    expected_dims: usize,
184    context: &str,
185) -> TorshResult<()> {
186    let actual_dims = tensor.shape().ndim();
187    if actual_dims != expected_dims {
188        return Err(TorshError::invalid_argument_with_context(
189            &format!(
190                "Expected {}D tensor, got {}D tensor",
191                expected_dims, actual_dims
192            ),
193            context,
194        ));
195    }
196    Ok(())
197}
198
199/// Validates tensor shapes are broadcastable
200pub fn validate_broadcastable_shapes(a: &Tensor, b: &Tensor, context: &str) -> TorshResult<()> {
201    let binding_a = a.shape();
202    let shape_a = binding_a.dims();
203    let binding_b = b.shape();
204    let shape_b = binding_b.dims();
205
206    // Simple broadcastability check (can be expanded for more complex cases)
207    if shape_a.len() != shape_b.len() && shape_a != shape_b {
208        // Allow different lengths if one is scalar or can be broadcast
209        let a_numel = a.numel();
210        let b_numel = b.numel();
211
212        if a_numel != 1 && b_numel != 1 && shape_a != shape_b {
213            return Err(TorshError::invalid_argument_with_context(
214                &format!(
215                    "Tensor shapes {:?} and {:?} are not broadcastable",
216                    shape_a, shape_b
217                ),
218                context,
219            ));
220        }
221    }
222
223    Ok(())
224}
225
226/// Helper function to create invalid argument error with function context
227pub fn invalid_argument_error(message: &str, function_name: &str) -> TorshError {
228    TorshError::invalid_argument_with_context(message, function_name)
229}
230
231/// Standard function documentation format helper
232pub fn create_function_docs(
233    name: &str,
234    description: &str,
235    formula: Option<&str>,
236    parameters: &[(&str, &str)],
237    example: Option<&str>,
238) -> String {
239    let mut docs = String::new();
240    docs.push_str(&format!("/// {}\n", name));
241    docs.push_str("///\n");
242    docs.push_str(&format!("/// {}\n", description));
243
244    if let Some(formula) = formula {
245        docs.push_str("///\n");
246        docs.push_str(&format!("/// Formula: {}\n", formula));
247    }
248
249    if !parameters.is_empty() {
250        docs.push_str("///\n");
251        docs.push_str("/// # Parameters\n");
252        for (param, desc) in parameters {
253            docs.push_str(&format!("/// - `{}`: {}\n", param, desc));
254        }
255    }
256
257    if let Some(example) = example {
258        docs.push_str("///\n");
259        docs.push_str("/// # Example\n");
260        docs.push_str("/// ```rust\n");
261        docs.push_str(&format!("/// {}\n", example));
262        docs.push_str("/// ```\n");
263    }
264
265    docs
266}
267
268/// Computes the safe logarithm of a tensor, clamping values to avoid log(0) or log(negative)
269///
270/// This utility function prevents numerical instability when computing logarithms
271/// by clamping input values to be within a safe range.
272///
273/// # Arguments
274/// * `input` - Input tensor
275/// * `eps` - Minimum value for clamping (default: 1e-8)
276/// * `max_val` - Maximum value for clamping (default: f32::MAX)
277///
278/// # Returns
279/// Tensor with logarithm applied to clamped values
280///
281/// # Example
282/// ```ignore
283/// let tensor = Tensor::from_vec(vec![0.0, 1.0, 2.0], &[3]).unwrap();
284/// let log_tensor = safe_log(&tensor, None, None).unwrap();
285/// ```
286pub fn safe_log(input: &Tensor, eps: Option<f32>, max_val: Option<f32>) -> TorshResult<Tensor> {
287    let epsilon = eps.unwrap_or(1e-8_f32);
288    let maximum = max_val.unwrap_or(f32::MAX);
289
290    let clamped = input.clamp(epsilon, maximum)?;
291    clamped.log()
292}
293
294/// Computes the safe logarithm of probability values, clamping to valid probability range
295///
296/// This is specifically designed for probability tensors where values should be
297/// between 0 and 1. It clamps values to [eps, 1-eps] before taking the logarithm.
298///
299/// # Arguments
300/// * `input` - Input tensor containing probability values
301/// * `eps` - Small epsilon value for numerical stability (default: 1e-8)
302///
303/// # Returns
304/// Tensor with logarithm applied to clamped probability values
305///
306/// # Example
307/// ```ignore
308/// let probs = Tensor::from_vec(vec![0.1, 0.5, 0.9], &[3]).unwrap();
309/// let log_probs = safe_log_prob(&probs, None).unwrap();
310/// ```
311pub fn safe_log_prob(input: &Tensor, eps: Option<f32>) -> TorshResult<Tensor> {
312    let epsilon = eps.unwrap_or(1e-8_f32);
313    let clamped = input.clamp(epsilon, 1.0 - epsilon)?;
314    clamped.log()
315}
316
317/// Creates a safe version of tensor for logarithm operations by clamping
318///
319/// This is a lightweight helper that only performs clamping without the logarithm.
320/// Useful when you need to perform the clamping separately from the log operation.
321///
322/// # Arguments
323/// * `input` - Input tensor
324/// * `eps` - Minimum value for clamping (default: 1e-8)
325/// * `max_val` - Maximum value for clamping (default: f32::MAX)
326///
327/// # Returns
328/// Clamped tensor safe for logarithm operations
329pub fn safe_for_log(input: &Tensor, eps: Option<f32>, max_val: Option<f32>) -> TorshResult<Tensor> {
330    let epsilon = eps.unwrap_or(1e-8_f32);
331    let maximum = max_val.unwrap_or(f32::MAX);
332    input.clamp(epsilon, maximum)
333}
334
335/// Standardized inplace operation handling
336pub fn handle_inplace_operation<F>(
337    input: &Tensor,
338    inplace: bool,
339    operation: F,
340    _context: &str,
341) -> TorshResult<Tensor>
342where
343    F: Fn(&Tensor) -> TorshResult<Tensor>,
344{
345    if inplace {
346        // For true in-place operations, we would modify the tensor in place
347        // For now, we perform the operation and return a new tensor
348        // TODO: Implement proper in-place operations when tensor mutation is available
349        operation(input)
350    } else {
351        operation(input)
352    }
353}
354
355/// Utility for element-wise operations with inplace support
356pub fn apply_elementwise_operation<F>(
357    input: &Tensor,
358    _inplace: bool,
359    operation: F,
360    _context: &str,
361) -> TorshResult<Tensor>
362where
363    F: Fn(f32) -> f32,
364{
365    // For now, treat all operations as out-of-place since inplace is not fully implemented
366    let data = input.data()?;
367    let result_data: Vec<f32> = data.iter().map(|&x| operation(x)).collect();
368
369    Tensor::from_data(result_data, input.shape().dims().to_vec(), input.device())
370}
371
372/// Utility for conditional element-wise operations (like ReLU, LeakyReLU, etc.)
373pub fn apply_conditional_elementwise<F>(
374    input: &Tensor,
375    condition: F,
376    true_op: impl Fn(f32) -> f32,
377    false_op: impl Fn(f32) -> f32,
378    _inplace: bool,
379    _context: &str,
380) -> TorshResult<Tensor>
381where
382    F: Fn(f32) -> bool,
383{
384    let data = input.data()?;
385    let result_data: Vec<f32> = data
386        .iter()
387        .map(|&x| {
388            if condition(x) {
389                true_op(x)
390            } else {
391                false_op(x)
392            }
393        })
394        .collect();
395
396    Tensor::from_data(result_data, input.shape().dims().to_vec(), input.device())
397}
398
399/// Common pattern for pooling output size calculation
400pub fn calculate_pooling_output_size(
401    input_size: usize,
402    kernel_size: usize,
403    stride: usize,
404    padding: usize,
405    dilation: usize,
406) -> usize {
407    let effective_kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1);
408    (input_size + 2 * padding - effective_kernel_size) / stride + 1
409}
410
411/// Common pattern for 2D pooling output size calculation
412pub fn calculate_pooling_output_size_2d(
413    input_size: (usize, usize),
414    kernel_size: (usize, usize),
415    stride: (usize, usize),
416    padding: (usize, usize),
417    dilation: (usize, usize),
418) -> (usize, usize) {
419    let out_h =
420        calculate_pooling_output_size(input_size.0, kernel_size.0, stride.0, padding.0, dilation.0);
421    let out_w =
422        calculate_pooling_output_size(input_size.1, kernel_size.1, stride.1, padding.1, dilation.1);
423    (out_h, out_w)
424}
425
426/// Common pattern for 3D pooling output size calculation
427pub fn calculate_pooling_output_size_3d(
428    input_size: (usize, usize, usize),
429    kernel_size: (usize, usize, usize),
430    stride: (usize, usize, usize),
431    padding: (usize, usize, usize),
432    dilation: (usize, usize, usize),
433) -> (usize, usize, usize) {
434    let out_d =
435        calculate_pooling_output_size(input_size.0, kernel_size.0, stride.0, padding.0, dilation.0);
436    let out_h =
437        calculate_pooling_output_size(input_size.1, kernel_size.1, stride.1, padding.1, dilation.1);
438    let out_w =
439        calculate_pooling_output_size(input_size.2, kernel_size.2, stride.2, padding.2, dilation.2);
440    (out_d, out_h, out_w)
441}
442
443/// Utility for creating tensors with same shape and device as input
444pub fn create_tensor_like(
445    reference: &Tensor,
446    data: Vec<f32>,
447    shape: Option<Vec<usize>>,
448) -> TorshResult<Tensor> {
449    let tensor_shape = match shape {
450        Some(s) => s,
451        None => reference.shape().dims().to_vec(),
452    };
453
454    Tensor::from_data(data, tensor_shape, reference.device())
455}
456
457/// Common pattern for element-wise tensor operations with broadcasting
458pub fn apply_binary_elementwise<F>(
459    a: &Tensor,
460    b: &Tensor,
461    operation: F,
462    _context: &str,
463) -> TorshResult<Tensor>
464where
465    F: Fn(f32, f32) -> f32,
466{
467    validate_elementwise_shapes(a, b)?;
468
469    let data_a = a.data()?;
470    let data_b = b.data()?;
471
472    let result_data: Vec<f32> = data_a
473        .iter()
474        .zip(data_b.iter())
475        .map(|(&x, &y)| operation(x, y))
476        .collect();
477
478    create_tensor_like(a, result_data, None)
479}
480
481#[cfg(test)]
482mod tests {
483    use super::*;
484    use torsh_tensor::creation::zeros;
485
486    #[test]
487    fn test_validate_range() -> TorshResult<()> {
488        // Valid range
489        validate_range(5.0, 0.0, 10.0, "value", "test")?;
490
491        // Invalid range - too small
492        let result = validate_range(-1.0, 0.0, 10.0, "value", "test");
493        assert!(result.is_err());
494
495        // Invalid range - too large
496        let result = validate_range(15.0, 0.0, 10.0, "value", "test");
497        assert!(result.is_err());
498
499        Ok(())
500    }
501
502    #[test]
503    fn test_validate_non_empty() -> TorshResult<()> {
504        // Non-empty tensor
505        let tensor = zeros(&[2, 3])?;
506        validate_non_empty(&tensor, "test")?;
507
508        // Empty tensor
509        let empty_tensor = zeros(&[0])?;
510        let result = validate_non_empty(&empty_tensor, "test");
511        assert!(result.is_err());
512
513        Ok(())
514    }
515
516    #[test]
517    fn test_validate_dimension() -> TorshResult<()> {
518        let tensor = zeros(&[2, 3, 4])?;
519
520        // Valid dimensions
521        validate_dimension(&tensor, 0, "test")?;
522        validate_dimension(&tensor, 1, "test")?;
523        validate_dimension(&tensor, 2, "test")?;
524        validate_dimension(&tensor, -1, "test")?; // Last dimension
525        validate_dimension(&tensor, -2, "test")?; // Second to last
526
527        // Invalid dimensions
528        let result = validate_dimension(&tensor, 3, "test");
529        assert!(result.is_err());
530
531        let result = validate_dimension(&tensor, -4, "test");
532        assert!(result.is_err());
533
534        Ok(())
535    }
536
537    #[test]
538    fn test_validate_positive() -> TorshResult<()> {
539        // Valid positive value
540        validate_positive(1.5, "value", "test")?;
541
542        // Invalid zero value
543        let result = validate_positive(0.0, "value", "test");
544        assert!(result.is_err());
545
546        // Invalid negative value
547        let result = validate_positive(-1.0, "value", "test");
548        assert!(result.is_err());
549
550        Ok(())
551    }
552}