Skip to main content

torsh_functional/
reduction.rs

1//! # Reduction Operations for Tensors
2//!
3//! This module provides comprehensive reduction operations that collapse tensor dimensions
4//! by applying aggregation functions across specified axes.
5//!
6//! ## Mathematical Foundation
7//!
8//! ### Basic Reductions
9//!
10//! #### Sum
11//! Computes the sum of tensor elements:
12//! ```text
13//! sum(X) = Σ(i=1 to n) x_i
14//! sum_dim(X, d) = Σ(i in dimension d) x_i
15//! ```
16//!
17//! #### Mean (Arithmetic Average)
18//! ```text
19//! mean(X) = (1/n) Σ(i=1 to n) x_i
20//! mean_dim(X, d) = (1/n_d) Σ(i in dimension d) x_i
21//! ```
22//! where `n_d` is the size of dimension `d`.
23//!
24//! #### Product
25//! ```text
26//! prod(X) = Π(i=1 to n) x_i
27//! prod_dim(X, d) = Π(i in dimension d) x_i
28//! ```
29//!
30//! #### Max/Min
31//! ```text
32//! max(X) = max{x_1, x_2, ..., x_n}
33//! min(X) = min{x_1, x_2, ..., x_n}
34//! ```
35//!
36//! ### Statistical Reductions
37//!
38//! #### Variance
39//! ```text
40//! var(X) = (1/n) Σ(i=1 to n) (x_i - μ)²
41//! where μ = mean(X)
42//! ```
43//! - **Sample variance**: Use `n-1` denominator (Bessel's correction)
44//! - **Population variance**: Use `n` denominator
45//!
46//! #### Standard Deviation
47//! ```text
48//! std(X) = √var(X)
49//! ```
50//!
51//! ### Norm Operations
52//!
53//! #### L-p Norm
54//! ```text
55//! ||X||_p = (Σ|x_i|^p)^(1/p)
56//! ```
57//! Special cases:
58//! - **L1 norm** (p=1): Manhattan distance = Σ|x_i|
59//! - **L2 norm** (p=2): Euclidean distance = √(Σx_i²)
60//! - **L∞ norm** (p=∞): Maximum absolute value = max|x_i|
61//!
62//! ### Advanced Reductions
63//!
64//! #### LogSumExp (Numerically Stable)
65//! ```text
66//! logsumexp(X) = log(Σ exp(x_i))
67//!              = max(X) + log(Σ exp(x_i - max(X)))  [stable version]
68//! ```
69//! Used in softmax computation to prevent overflow.
70//!
71//! #### Cumulative Sum
72//! ```text
73//! cumsum(X)[i] = Σ(j=0 to i) x_j
74//! ```
75//!
76//! ## Performance Characteristics
77//!
78//! ### Computational Complexity
79//! For tensor with `n` elements:
80//! - **Basic reductions** (sum, mean, max, min): O(n)
81//! - **Statistical** (var, std): O(n) - requires two passes (mean, then variance)
82//! - **Norm operations**: O(n)
83//! - **Cumulative operations**: O(n)
84//!
85//! ### Memory Usage
86//! - **Full reduction**: O(1) - single scalar output
87//! - **Dimension reduction**: O(n / d_size) where `d_size` is the reduced dimension size
88//! - **Cumulative operations**: O(n) - same size as input
89//!
90//! ### Optimization Strategies
91//! - **Parallel reduction**: Tree-based reduction for parallel backends
92//! - **SIMD**: Vectorized operations for element-wise computation
93//! - **Fusion**: Combine multiple reductions in single pass when possible
94//! - **Numerical stability**: Use Kahan summation or compensated summation for float precision
95//!
96//! ## Common Use Cases
97//!
98//! ### Loss Function Computation
99//! ```rust
100//! use torsh_functional::reduction::{mean, sum};
101//! use torsh_functional::random_ops::randn;
102//!
103//! fn example() -> Result<(), Box<dyn std::error::Error>> {
104//!     // Compute mean squared error
105//!     let predictions = randn(&[32, 10], None, None, Some(42))?;  // batch_size=32, num_classes=10
106//!     let targets = randn(&[32, 10], None, None, Some(43))?;
107//!
108//!     let diff = predictions.sub(&targets)?;
109//!     let squared = diff.pow_scalar(2.0)?;
110//!     let mse_loss = mean(&squared)?;
111//!
112//!     // Or compute sum for total loss
113//!     let total_loss = sum(&squared)?;
114//!     Ok(())
115//! }
116//! ```
117//!
118//! ### Batch Statistics for Normalization
119//! ```rust
120//! use torsh_functional::reduction::{mean_dim, std_dim};
121//! use torsh_functional::random_ops::randn;
122//!
123//! fn example() -> Result<(), Box<dyn std::error::Error>> {
124//!     // Compute mean and std across batch dimension for BatchNorm
125//!     let features = randn(&[64, 128, 28, 28], None, None, Some(42))?;  // [N, C, H, W]
126//!
127//!     // Compute statistics over batch and spatial dimensions (0, 2, 3)
128//!     // keeping channel dimension (1)
129//!     let batch_mean = mean_dim(&features, &[0, 2, 3], true)?;
130//!     let batch_std = std_dim(&features, &[0, 2, 3], true, false)?;
131//!
132//!     // Normalize: (x - mean) / std
133//!     let normalized = features.sub(&batch_mean)?
134//!                              .div(&batch_std)?;
135//!     Ok(())
136//! }
137//! ```
138//!
139//! ### Attention Score Normalization
140//! ```rust
141//! use torsh_functional::reduction::{max_dim, sum_dim};
142//! use torsh_functional::random_ops::randn;
143//!
144//! fn example() -> Result<(), Box<dyn std::error::Error>> {
145//!     // Softmax using max for numerical stability
146//!     let scores = randn(&[8, 64, 64], None, None, Some(42))?;  // [batch, seq_len, seq_len]
147//!
148//!     // Stable softmax: exp(x - max(x)) / sum(exp(x - max(x)))
149//!     let (max_scores, _) = max_dim(&scores, -1, true)?;
150//!     let shifted = scores.sub(&max_scores)?;
151//!     let exp_scores = shifted.exp()?;
152//!     let sum_exp = sum_dim(&exp_scores, &[-1], true)?;
153//!     let softmax = exp_scores.div(&sum_exp)?;
154//!     Ok(())
155//! }
156//! ```
157//!
158//! ## Numerical Stability Considerations
159//!
160//! ### Sum Accumulation
161//! - Use Kahan summation for large arrays to minimize floating-point error
162//! - For very large tensors, consider hierarchical reduction
163//!
164//! ### Variance Computation
165//! - Use Welford's online algorithm for single-pass stable variance
166//! - Avoid naive formula `E[X²] - E[X]²` which suffers from catastrophic cancellation
167//!
168//! ### LogSumExp
169//! - Always use max-subtraction trick: `max(x) + log(sum(exp(x - max(x))))`
170//! - Essential for preventing overflow in exp() computation
171
172use std::collections::HashMap;
173use torsh_core::{Result as TorshResult, TorshError};
174use torsh_tensor::{creation::zeros, stats::StatMode, Tensor};
175
176// ============================================================================
177// Basic Reduction Operations
178// ============================================================================
179
180/// Sum all elements in a tensor.
181///
182/// # Mathematical Definition
183/// ```text
184/// sum(X) = Σ(i=1 to n) x_i
185/// ```
186///
187/// # Arguments
188/// * `tensor` - Input tensor of any shape
189///
190/// # Returns
191/// Scalar tensor containing the sum of all elements
192///
193/// # Complexity
194/// - Time: O(n) where n is the number of elements
195/// - Space: O(1)
196///
197/// # Examples
198/// ```rust,no_run
199/// # use torsh_tensor::Tensor;
200/// # use torsh_functional::reduction::sum;
201/// # use torsh_functional::random_ops::randn;
202/// # fn example() -> torsh_core::Result<()> {
203/// let x = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], &[5])?;
204/// let result = sum(&x)?;  // Returns tensor with value 15.0
205/// # Ok(())
206/// # }
207/// ```
208pub fn sum(tensor: &Tensor) -> TorshResult<Tensor> {
209    tensor.sum()
210}
211
212/// Sum along specified dimensions.
213///
214/// # Mathematical Definition
215/// ```text
216/// sum_dim(X, d) = Σ(i in dimension d) x_i
217/// ```
218///
219/// # Arguments
220/// * `tensor` - Input tensor
221/// * `dim` - Dimensions to reduce (negative indexing supported)
222/// * `keepdim` - If true, reduced dimensions are retained with size 1
223///
224/// # Returns
225/// Tensor with specified dimensions reduced
226///
227/// # Shape
228/// - Input: `[..., d_i, ...]` where `d_i` is a dimension to reduce
229/// - Output (keepdim=false): `[..., ...]` with `d_i` removed
230/// - Output (keepdim=true): `[..., 1, ...]` with `d_i` replaced by 1
231///
232/// # Examples
233/// ```rust,no_run
234/// # use torsh_tensor::Tensor;
235/// # use torsh_functional::reduction::sum_dim;
236/// # use torsh_functional::random_ops::randn;
237/// # fn example() -> torsh_core::Result<()> {
238/// let x = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3])?;
239/// // [[1, 2, 3],
240/// //  [4, 5, 6]]
241///
242/// let row_sums = sum_dim(&x, &[1], false)?;  // [6, 15]
243/// let col_sums = sum_dim(&x, &[0], false)?;  // [5, 7, 9]
244/// let total = sum_dim(&x, &[0, 1], false)?;  // 21
245/// # Ok(())
246/// # }
247/// ```
248pub fn sum_dim(tensor: &Tensor, dim: &[isize], keepdim: bool) -> TorshResult<Tensor> {
249    let i32_dims: Vec<i32> = dim.iter().map(|&d| d as i32).collect();
250    tensor.sum_dim(&i32_dims, keepdim)
251}
252
253/// Mean (arithmetic average) of all elements in a tensor.
254///
255/// # Mathematical Definition
256/// ```text
257/// mean(X) = (1/n) Σ(i=1 to n) x_i
258/// ```
259///
260/// # Arguments
261/// * `tensor` - Input tensor of any shape
262///
263/// # Returns
264/// Scalar tensor containing the mean of all elements
265///
266/// # Complexity
267/// - Time: O(n)
268/// - Space: O(1)
269///
270/// # Examples
271/// ```rust,no_run
272/// # use torsh_tensor::Tensor;
273/// # use torsh_functional::reduction::mean;
274/// # use torsh_functional::random_ops::randn;
275/// # fn example() -> torsh_core::Result<()> {
276/// let x = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], &[5])?;
277/// let result = mean(&x)?;  // Returns tensor with value 3.0
278/// # Ok(())
279/// # }
280/// ```
281pub fn mean(tensor: &Tensor) -> TorshResult<Tensor> {
282    tensor.mean(None, false)
283}
284
285/// Mean along specified dimensions.
286///
287/// # Mathematical Definition
288/// ```text
289/// mean_dim(X, d) = (1/n_d) Σ(i in dimension d) x_i
290/// ```
291/// where `n_d` is the size of dimension `d`.
292///
293/// # Arguments
294/// * `tensor` - Input tensor
295/// * `dim` - Dimensions to reduce
296/// * `keepdim` - If true, reduced dimensions are retained with size 1
297///
298/// # Returns
299/// Tensor with specified dimensions reduced by averaging
300///
301/// # Examples
302/// ```rust
303/// use torsh_functional::reduction::mean_dim;
304/// use torsh_functional::random_ops::randn;
305///
306/// fn example() -> Result<(), Box<dyn std::error::Error>> {
307///     // Compute mean across batch dimension for normalization
308///     let batch_data = randn(&[32, 128], None, None, None)?;  // [batch, features]
309///     let feature_means = mean_dim(&batch_data, &[0], false)?;  // [128]
310///     Ok(())
311/// }
312/// ```
313pub fn mean_dim(tensor: &Tensor, dim: &[isize], keepdim: bool) -> TorshResult<Tensor> {
314    let usize_dims: Vec<usize> = dim.iter().map(|&d| d as usize).collect();
315    tensor.mean(Some(&usize_dims), keepdim)
316}
317
318/// Maximum value in a tensor.
319///
320/// # Mathematical Definition
321/// ```text
322/// max(X) = max{x_1, x_2, ..., x_n}
323/// ```
324///
325/// # Arguments
326/// * `tensor` - Input tensor of any shape
327///
328/// # Returns
329/// Scalar tensor containing the maximum element value
330///
331/// # Complexity
332/// - Time: O(n)
333/// - Space: O(1)
334///
335/// # Examples
336/// ```rust,no_run
337/// # use torsh_tensor::Tensor;
338/// # use torsh_functional::reduction::max;
339/// # use torsh_functional::random_ops::randn;
340/// # fn example() -> torsh_core::Result<()> {
341/// let x = Tensor::from_vec(vec![1.0, 5.0, 3.0, 9.0, 2.0], &[5])?;
342/// let result = max(&x)?;  // Returns tensor with value 9.0
343/// # Ok(())
344/// # }
345/// ```
346pub fn max(tensor: &Tensor) -> TorshResult<Tensor> {
347    tensor.max(None, false)
348}
349
350/// Maximum along specified dimension with indices.
351///
352/// # Arguments
353/// * `tensor` - Input tensor
354/// * `dim` - Dimension to reduce
355/// * `keepdim` - If true, reduced dimension is retained with size 1
356///
357/// # Returns
358/// Tuple of (max_values, indices) where:
359/// - `max_values`: Maximum values along the dimension
360/// - `indices`: Indices of maximum values (useful for pooling operations)
361///
362/// # Examples
363/// ```rust,no_run
364/// # use torsh_tensor::Tensor;
365/// # use torsh_functional::reduction::max_dim;
366/// # use torsh_functional::random_ops::randn;
367/// # fn example() -> torsh_core::Result<()> {
368/// let x = Tensor::from_vec(vec![1.0, 5.0, 3.0, 9.0, 2.0, 7.0], &[2, 3])?;
369/// let (max_vals, indices) = max_dim(&x, 1, false)?;
370/// // max_vals: [5.0, 9.0] - max of each row
371/// // indices: [1, 1] - position of max in each row
372/// # Ok(())
373/// # }
374/// ```
375pub fn max_dim(tensor: &Tensor, dim: isize, keepdim: bool) -> TorshResult<(Tensor, Tensor<i64>)> {
376    let values = tensor.max_dim(dim as i32, keepdim)?;
377    let indices = tensor.argmax(Some(dim as i32))?;
378    Ok((values, indices))
379}
380
381/// Minimum value in a tensor.
382///
383/// # Mathematical Definition
384/// ```text
385/// min(X) = min{x_1, x_2, ..., x_n}
386/// ```
387///
388/// # Arguments
389/// * `tensor` - Input tensor of any shape
390///
391/// # Returns
392/// Scalar tensor containing the minimum element value
393///
394/// # Complexity
395/// - Time: O(n)
396/// - Space: O(1)
397///
398/// # Examples
399/// ```rust,no_run
400/// # use torsh_tensor::Tensor;
401/// # use torsh_functional::reduction::min;
402/// # use torsh_functional::random_ops::randn;
403/// # fn example() -> torsh_core::Result<()> {
404/// let x = Tensor::from_vec(vec![1.0, 5.0, 3.0, 9.0, 2.0], &[5])?;
405/// let result = min(&x)?;  // Returns tensor with value 1.0
406/// # Ok(())
407/// # }
408/// ```
409pub fn min(tensor: &Tensor) -> TorshResult<Tensor> {
410    tensor.min()
411}
412
413/// Minimum along specified dimension with indices.
414///
415/// # Arguments
416/// * `tensor` - Input tensor
417/// * `dim` - Dimension to reduce
418/// * `keepdim` - If true, reduced dimension is retained with size 1
419///
420/// # Returns
421/// Tuple of (min_values, indices) where:
422/// - `min_values`: Minimum values along the dimension
423/// - `indices`: Indices of minimum values
424///
425/// # Examples
426/// ```rust,no_run
427/// # use torsh_tensor::Tensor;
428/// # use torsh_functional::reduction::min_dim;
429/// # use torsh_functional::random_ops::randn;
430/// # fn example() -> torsh_core::Result<()> {
431/// let x = Tensor::from_vec(vec![1.0, 5.0, 3.0, 9.0, 2.0, 7.0], &[2, 3])?;
432/// let (min_vals, indices) = min_dim(&x, 1, false)?;
433/// // min_vals: [1.0, 2.0] - min of each row
434/// // indices: [0, 1] - position of min in each row
435/// # Ok(())
436/// # }
437/// ```
438pub fn min_dim(tensor: &Tensor, dim: isize, keepdim: bool) -> TorshResult<(Tensor, Tensor<i64>)> {
439    let values = tensor.min_dim(dim as i32, keepdim)?;
440    let indices = tensor.argmin(Some(dim as i32))?;
441    Ok((values, indices))
442}
443
444/// Index of maximum value in a tensor
445pub fn argmax(tensor: &Tensor) -> TorshResult<Tensor<i64>> {
446    tensor.argmax(None)
447}
448
449/// Index of maximum value along specified dimension
450pub fn argmax_dim(tensor: &Tensor, dim: isize, keepdim: bool) -> TorshResult<Tensor<i64>> {
451    let indices = tensor.argmax(Some(dim as i32))?;
452    if keepdim {
453        // Add dimension back for keepdim=true
454        let mut new_shape = indices.shape().dims().to_vec();
455        new_shape.insert(dim as usize, 1);
456        let new_shape_i32: Vec<i32> = new_shape.iter().map(|&x| x as i32).collect();
457        indices.view(&new_shape_i32)
458    } else {
459        Ok(indices)
460    }
461}
462
463/// Index of minimum value in a tensor
464pub fn argmin(tensor: &Tensor) -> TorshResult<Tensor<i64>> {
465    tensor.argmin(None)
466}
467
468/// Index of minimum value along specified dimension
469pub fn argmin_dim(tensor: &Tensor, dim: isize, keepdim: bool) -> TorshResult<Tensor<i64>> {
470    let indices = tensor.argmin(Some(dim as i32))?;
471    if keepdim {
472        // Add dimension back for keepdim=true
473        let mut new_shape = indices.shape().dims().to_vec();
474        new_shape.insert(dim as usize, 1);
475        let new_shape_i32: Vec<i32> = new_shape.iter().map(|&x| x as i32).collect();
476        indices.view(&new_shape_i32)
477    } else {
478        Ok(indices)
479    }
480}
481
482/// Product of all elements in a tensor
483pub fn prod(tensor: &Tensor) -> TorshResult<Tensor> {
484    // Flatten tensor and compute product manually since prod() might not be implemented
485    let flat = tensor.flatten()?;
486    let size = flat.shape().dims()[0];
487    let mut result = 1.0f32;
488    for i in 0..size {
489        result *= flat.get(&[i])?;
490    }
491    Tensor::from_vec(vec![result], &[])
492}
493
494/// Product along specified dimensions
495pub fn prod_dim(_tensor: &Tensor, _dim: &[isize], _keepdim: bool) -> TorshResult<Tensor> {
496    // Simple implementation: return error for now as this requires complex dimension reduction
497    let _ = (_dim, _keepdim); // silence unused warnings
498    Err(TorshError::Other(
499        "prod_dim not yet fully implemented".to_string(),
500    ))
501}
502
503/// Standard deviation of all elements
504pub fn std(tensor: &Tensor, unbiased: bool) -> TorshResult<Tensor> {
505    let mode = if unbiased {
506        StatMode::Sample
507    } else {
508        StatMode::Population
509    };
510    tensor.std(None, false, mode)
511}
512
513/// Standard deviation along specified dimensions
514pub fn std_dim(
515    tensor: &Tensor,
516    dim: &[isize],
517    unbiased: bool,
518    keepdim: bool,
519) -> TorshResult<Tensor> {
520    let usize_dims: Vec<usize> = dim.iter().map(|&d| d as usize).collect();
521    let mode = if unbiased {
522        StatMode::Sample
523    } else {
524        StatMode::Population
525    };
526    tensor.std(Some(&usize_dims), keepdim, mode)
527}
528
529/// Variance of all elements
530pub fn var(tensor: &Tensor, unbiased: bool) -> TorshResult<Tensor> {
531    let mode = if unbiased {
532        StatMode::Sample
533    } else {
534        StatMode::Population
535    };
536    tensor.var(None, false, mode)
537}
538
539/// Variance along specified dimensions
540pub fn var_dim(
541    tensor: &Tensor,
542    dim: &[isize],
543    unbiased: bool,
544    keepdim: bool,
545) -> TorshResult<Tensor> {
546    let usize_dims: Vec<usize> = dim.iter().map(|&d| d as usize).collect();
547    let mode = if unbiased {
548        StatMode::Sample
549    } else {
550        StatMode::Population
551    };
552    tensor.var(Some(&usize_dims), keepdim, mode)
553}
554
555/// L1 norm (sum of absolute values)
556pub fn norm_l1(tensor: &Tensor) -> TorshResult<Tensor> {
557    let abs_tensor = tensor.abs()?;
558    abs_tensor.sum()
559}
560
561/// L2 norm (Euclidean norm)
562pub fn norm_l2(tensor: &Tensor) -> TorshResult<Tensor> {
563    let squared = tensor.square()?;
564    let sum = squared.sum()?;
565    sum.sqrt()
566}
567
568/// P-norm
569pub fn norm_p(tensor: &Tensor, p: f32) -> TorshResult<Tensor> {
570    if p == 1.0 {
571        norm_l1(tensor)
572    } else if p == 2.0 {
573        norm_l2(tensor)
574    } else {
575        let abs_tensor = tensor.abs()?;
576        let powered = abs_tensor.pow_scalar(p)?;
577        let sum = powered.sum()?;
578        sum.pow_scalar(1.0 / p)
579    }
580}
581
582/// Frobenius norm (for matrices)
583pub fn norm_frobenius(tensor: &Tensor) -> TorshResult<Tensor> {
584    norm_l2(tensor)
585}
586
587/// Nuclear norm (sum of singular values)
588///
589/// The nuclear norm is defined as the sum of singular values of a matrix:
590/// ```text
591/// ||A||_* = Σ σ_i(A)
592/// ```
593/// where σ_i are the singular values obtained from SVD: A = UΣV^T
594///
595/// This is also known as the trace norm or Schatten 1-norm.
596/// It's commonly used in:
597/// - Low-rank matrix approximation
598/// - Matrix completion problems
599/// - Compressed sensing
600/// - Regularization for machine learning models
601pub fn norm_nuclear(tensor: &Tensor) -> TorshResult<Tensor> {
602    // Ensure input is 2D
603    if tensor.shape().ndim() != 2 {
604        return Err(TorshError::InvalidArgument(
605            "Nuclear norm requires 2D tensor (matrix)".to_string(),
606        ));
607    }
608
609    // Compute SVD: A = U * S * V^T
610    // The singular values S contain the information we need
611    let (_u, s, _vt) = torsh_linalg::decomposition::svd(tensor, false)?;
612
613    // Nuclear norm is the sum of singular values
614    // The singular values are already sorted in descending order
615    s.sum()
616}
617
618/// Count non-zero elements
619pub fn count_nonzero(tensor: &Tensor) -> TorshResult<Tensor> {
620    let zero_tensor = zeros(tensor.shape().dims())?;
621    let nonzero_mask = tensor.ne(&zero_tensor)?;
622    // Create a tensor of ones with the same shape and sum where mask is true
623    let ones = tensor.ones_like()?;
624    let zeros = tensor.zeros_like()?;
625    let count_tensor = ones.where_tensor(&nonzero_mask, &zeros)?;
626    count_tensor.sum()
627}
628
629/// Count non-zero elements along dimension
630pub fn count_nonzero_dim(tensor: &Tensor, dim: isize) -> TorshResult<Tensor> {
631    let zero_tensor = zeros(tensor.shape().dims())?;
632    let nonzero_mask = tensor.ne(&zero_tensor)?;
633    // Create a tensor of ones with the same shape and sum where mask is true
634    let ones = tensor.ones_like()?;
635    let zeros = tensor.zeros_like()?;
636    let count_tensor = ones.where_tensor(&nonzero_mask, &zeros)?;
637    count_tensor.sum_dim(&[dim as i32], false)
638}
639
640/// Cumulative sum
641pub fn cumsum(tensor: &Tensor, dim: isize) -> TorshResult<Tensor> {
642    tensor.cumsum(dim.try_into().expect("dimension conversion should succeed"))
643}
644
645/// Cumulative product
646pub fn cumprod(tensor: &Tensor, dim: isize) -> TorshResult<Tensor> {
647    tensor.cumprod(dim.try_into().expect("dimension conversion should succeed"))
648}
649
650/// All elements are true (non-zero)
651pub fn all(tensor: &Tensor) -> TorshResult<torsh_tensor::Tensor<bool>> {
652    tensor.all()
653}
654
655/// All elements are true along dimension
656pub fn all_dim(
657    tensor: &Tensor,
658    dim: isize,
659    keepdim: bool,
660) -> TorshResult<torsh_tensor::Tensor<bool>> {
661    tensor.all_dim(
662        dim.try_into().expect("dimension conversion should succeed"),
663        keepdim,
664    )
665}
666
667/// Any element is true (non-zero)
668pub fn any(tensor: &Tensor) -> TorshResult<torsh_tensor::Tensor<bool>> {
669    tensor.any()
670}
671
672/// Any element is true along dimension
673pub fn any_dim(
674    tensor: &Tensor,
675    dim: isize,
676    keepdim: bool,
677) -> TorshResult<torsh_tensor::Tensor<bool>> {
678    tensor.any_dim(
679        dim.try_into().expect("dimension conversion should succeed"),
680        keepdim,
681    )
682}
683
684// ============================================================================
685// Unique Operations
686// ============================================================================
687
688/// Find unique elements in a tensor
689pub fn unique(
690    tensor: &Tensor,
691    sorted: bool,
692    return_inverse: bool,
693    return_counts: bool,
694    dim: Option<isize>,
695) -> TorshResult<UniqueResult> {
696    if let Some(_d) = dim {
697        // Unique along dimension
698        return Err(TorshError::Other(
699            "unique along dimension not yet implemented".to_string(),
700        ));
701    }
702
703    // Flatten tensor and get unique values
704    let flat = tensor.flatten()?;
705    let size = flat.shape().dims()[0];
706
707    let mut unique_map: HashMap<OrderedFloat, usize> = HashMap::new();
708    let mut unique_values = Vec::new();
709    let mut inverse_indices = vec![0; size];
710
711    // Find unique values
712    for i in 0..size {
713        let value = flat.get(&[i])?;
714        let key = OrderedFloat(value);
715
716        match unique_map.get(&key) {
717            Some(&idx) => {
718                if return_inverse {
719                    inverse_indices[i] = idx;
720                }
721            }
722            None => {
723                let idx = unique_values.len();
724                unique_values.push(value);
725                unique_map.insert(key, idx);
726                if return_inverse {
727                    inverse_indices[i] = idx;
728                }
729            }
730        }
731    }
732
733    // Sort if requested
734    if sorted {
735        let mut indices: Vec<_> = (0..unique_values.len()).collect();
736        indices.sort_by(|&a, &b| {
737            unique_values[a]
738                .partial_cmp(&unique_values[b])
739                .expect("numeric comparison should succeed")
740        });
741
742        // Reorder unique values
743        let sorted_values: Vec<_> = indices.iter().map(|&i| unique_values[i]).collect();
744
745        // Update inverse indices if needed
746        if return_inverse {
747            let mut index_map = vec![0; indices.len()];
748            for (new_idx, &old_idx) in indices.iter().enumerate() {
749                index_map[old_idx] = new_idx;
750            }
751
752            for inv_idx in inverse_indices.iter_mut() {
753                *inv_idx = index_map[*inv_idx];
754            }
755        }
756
757        unique_values = sorted_values;
758    }
759
760    // Create output tensor
761    let output = Tensor::from_vec(unique_values.clone(), &[unique_values.len()])?;
762
763    // Compute counts if requested
764    let counts = if return_counts {
765        let mut count_vec = vec![0; unique_values.len()];
766        if return_inverse {
767            for &idx in inverse_indices.iter() {
768                count_vec[idx] += 1;
769            }
770        } else {
771            // Recompute if we didn't track inverse indices
772            for i in 0..size {
773                let value = flat.get(&[i])?;
774                for (j, &unique_val) in unique_values.iter().enumerate() {
775                    if (value - unique_val).abs() < f32::EPSILON {
776                        count_vec[j] += 1;
777                        break;
778                    }
779                }
780            }
781        }
782        let count_data: Vec<f32> = count_vec.into_iter().map(|c| c as f32).collect();
783        Some(Tensor::from_vec(count_data.clone(), &[count_data.len()])?)
784    } else {
785        None
786    };
787
788    // Create inverse tensor if requested
789    let inverse = if return_inverse {
790        let inverse_data: Vec<f32> = inverse_indices.into_iter().map(|i| i as f32).collect();
791        Some(Tensor::from_vec(
792            inverse_data.clone(),
793            &[inverse_data.len()],
794        )?)
795    } else {
796        None
797    };
798
799    Ok(UniqueResult {
800        values: output,
801        inverse,
802        counts,
803    })
804}
805
806/// Find unique consecutive elements
807pub fn unique_consecutive(
808    tensor: &Tensor,
809    return_inverse: bool,
810    return_counts: bool,
811    dim: Option<isize>,
812) -> TorshResult<UniqueResult> {
813    if let Some(_d) = dim {
814        // Unique consecutive along dimension
815        return Err(TorshError::Other(
816            "unique_consecutive along dimension not yet implemented".to_string(),
817        ));
818    }
819
820    // Flatten tensor
821    let flat = tensor.flatten()?;
822    let size = flat.shape().dims()[0];
823
824    if size == 0 {
825        return Ok(UniqueResult {
826            values: zeros(&[0])?,
827            inverse: if return_inverse {
828                Some(zeros(&[0])?)
829            } else {
830                None
831            },
832            counts: if return_counts {
833                Some(zeros(&[0])?)
834            } else {
835                None
836            },
837        });
838    }
839
840    let mut unique_values = Vec::new();
841    let mut inverse_indices = vec![0; size];
842    let mut counts = Vec::new();
843
844    // Process first element
845    let mut current_value = flat.get(&[0])?;
846    unique_values.push(current_value);
847    let mut current_count = 1;
848
849    // Process remaining elements
850    for i in 1..size {
851        let value = flat.get(&[i])?;
852
853        if (value - current_value).abs() < f32::EPSILON {
854            // Same as previous
855            current_count += 1;
856            if return_inverse {
857                inverse_indices[i] = unique_values.len() - 1;
858            }
859        } else {
860            // Different from previous
861            if return_counts {
862                counts.push(current_count);
863            }
864
865            current_value = value;
866            unique_values.push(value);
867            current_count = 1;
868
869            if return_inverse {
870                inverse_indices[i] = unique_values.len() - 1;
871            }
872        }
873    }
874
875    // Don't forget the last group
876    if return_counts {
877        counts.push(current_count);
878    }
879
880    // Create output tensors
881    let output = Tensor::from_vec(unique_values.clone(), &[unique_values.len()])?;
882
883    let counts_tensor = if return_counts {
884        let count_data: Vec<f32> = counts.into_iter().map(|c| c as f32).collect();
885        Some(Tensor::from_vec(count_data.clone(), &[count_data.len()])?)
886    } else {
887        None
888    };
889
890    let inverse_tensor = if return_inverse {
891        let inverse_data: Vec<f32> = inverse_indices.into_iter().map(|i| i as f32).collect();
892        Some(Tensor::from_vec(
893            inverse_data.clone(),
894            &[inverse_data.len()],
895        )?)
896    } else {
897        None
898    };
899
900    Ok(UniqueResult {
901        values: output,
902        inverse: inverse_tensor,
903        counts: counts_tensor,
904    })
905}
906
907/// Result of unique operations
908pub struct UniqueResult {
909    pub values: Tensor,
910    pub inverse: Option<Tensor>,
911    pub counts: Option<Tensor>,
912}
913
914/// Wrapper for f32 that implements Eq and Hash
915#[derive(Debug, Clone, Copy)]
916struct OrderedFloat(f32);
917
918impl PartialEq for OrderedFloat {
919    fn eq(&self, other: &Self) -> bool {
920        (self.0 - other.0).abs() < f32::EPSILON
921    }
922}
923
924impl Eq for OrderedFloat {}
925
926impl std::hash::Hash for OrderedFloat {
927    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
928        // Hash the bits of the float
929        self.0.to_bits().hash(state);
930    }
931}
932
933#[cfg(test)]
934mod tests {
935    use super::*;
936    use torsh_tensor::tensor;
937
938    #[test]
939    fn test_unique() {
940        let tensor = tensor![3.0f32, 1.0, 4.0, 1.0, 5.0, 9.0, 2.0, 6.0, 5.0].unwrap();
941
942        // Test basic unique
943        let result = unique(&tensor, true, false, false, None).unwrap();
944        assert_eq!(result.values.shape().dims()[0], 7); // Should have 7 unique values
945
946        // Test with counts
947        let result = unique(&tensor, true, false, true, None).unwrap();
948        assert!(result.counts.is_some());
949    }
950
951    #[test]
952    fn test_unique_consecutive() {
953        let tensor = tensor![1.0f32, 1.0, 2.0, 2.0, 2.0, 3.0, 1.0, 1.0].unwrap();
954
955        let result = unique_consecutive(&tensor, true, true, None).unwrap();
956
957        // Should have 4 groups: [1,1], [2,2,2], [3], [1,1]
958        assert_eq!(result.values.shape().dims()[0], 4);
959
960        let expected_values = vec![1.0, 2.0, 3.0, 1.0];
961        let expected_counts = vec![2.0, 3.0, 1.0, 2.0];
962
963        // Verify values and counts
964        for i in 0..4 {
965            assert!((result.values.get(&[i]).unwrap() - expected_values[i]).abs() < f32::EPSILON);
966            assert!(
967                (result.counts.as_ref().unwrap().get(&[i]).unwrap() - expected_counts[i]).abs()
968                    < f32::EPSILON
969            );
970        }
971    }
972}