torsh_functional/reduction.rs
1//! # Reduction Operations for Tensors
2//!
3//! This module provides comprehensive reduction operations that collapse tensor dimensions
4//! by applying aggregation functions across specified axes.
5//!
6//! ## Mathematical Foundation
7//!
8//! ### Basic Reductions
9//!
10//! #### Sum
11//! Computes the sum of tensor elements:
12//! ```text
13//! sum(X) = Σ(i=1 to n) x_i
14//! sum_dim(X, d) = Σ(i in dimension d) x_i
15//! ```
16//!
17//! #### Mean (Arithmetic Average)
18//! ```text
19//! mean(X) = (1/n) Σ(i=1 to n) x_i
20//! mean_dim(X, d) = (1/n_d) Σ(i in dimension d) x_i
21//! ```
22//! where `n_d` is the size of dimension `d`.
23//!
24//! #### Product
25//! ```text
26//! prod(X) = Π(i=1 to n) x_i
27//! prod_dim(X, d) = Π(i in dimension d) x_i
28//! ```
29//!
30//! #### Max/Min
31//! ```text
32//! max(X) = max{x_1, x_2, ..., x_n}
33//! min(X) = min{x_1, x_2, ..., x_n}
34//! ```
35//!
36//! ### Statistical Reductions
37//!
38//! #### Variance
39//! ```text
40//! var(X) = (1/n) Σ(i=1 to n) (x_i - μ)²
41//! where μ = mean(X)
42//! ```
43//! - **Sample variance**: Use `n-1` denominator (Bessel's correction)
44//! - **Population variance**: Use `n` denominator
45//!
46//! #### Standard Deviation
47//! ```text
48//! std(X) = √var(X)
49//! ```
50//!
51//! ### Norm Operations
52//!
53//! #### L-p Norm
54//! ```text
55//! ||X||_p = (Σ|x_i|^p)^(1/p)
56//! ```
57//! Special cases:
58//! - **L1 norm** (p=1): Manhattan distance = Σ|x_i|
59//! - **L2 norm** (p=2): Euclidean distance = √(Σx_i²)
60//! - **L∞ norm** (p=∞): Maximum absolute value = max|x_i|
61//!
62//! ### Advanced Reductions
63//!
64//! #### LogSumExp (Numerically Stable)
65//! ```text
66//! logsumexp(X) = log(Σ exp(x_i))
67//! = max(X) + log(Σ exp(x_i - max(X))) [stable version]
68//! ```
69//! Used in softmax computation to prevent overflow.
70//!
71//! #### Cumulative Sum
72//! ```text
73//! cumsum(X)[i] = Σ(j=0 to i) x_j
74//! ```
75//!
76//! ## Performance Characteristics
77//!
78//! ### Computational Complexity
79//! For tensor with `n` elements:
80//! - **Basic reductions** (sum, mean, max, min): O(n)
81//! - **Statistical** (var, std): O(n) - requires two passes (mean, then variance)
82//! - **Norm operations**: O(n)
83//! - **Cumulative operations**: O(n)
84//!
85//! ### Memory Usage
86//! - **Full reduction**: O(1) - single scalar output
87//! - **Dimension reduction**: O(n / d_size) where `d_size` is the reduced dimension size
88//! - **Cumulative operations**: O(n) - same size as input
89//!
90//! ### Optimization Strategies
91//! - **Parallel reduction**: Tree-based reduction for parallel backends
92//! - **SIMD**: Vectorized operations for element-wise computation
93//! - **Fusion**: Combine multiple reductions in single pass when possible
94//! - **Numerical stability**: Use Kahan summation or compensated summation for float precision
95//!
96//! ## Common Use Cases
97//!
98//! ### Loss Function Computation
99//! ```rust
100//! use torsh_functional::reduction::{mean, sum};
101//! use torsh_functional::random_ops::randn;
102//!
103//! fn example() -> Result<(), Box<dyn std::error::Error>> {
104//! // Compute mean squared error
105//! let predictions = randn(&[32, 10], None, None, Some(42))?; // batch_size=32, num_classes=10
106//! let targets = randn(&[32, 10], None, None, Some(43))?;
107//!
108//! let diff = predictions.sub(&targets)?;
109//! let squared = diff.pow_scalar(2.0)?;
110//! let mse_loss = mean(&squared)?;
111//!
112//! // Or compute sum for total loss
113//! let total_loss = sum(&squared)?;
114//! Ok(())
115//! }
116//! ```
117//!
118//! ### Batch Statistics for Normalization
119//! ```rust
120//! use torsh_functional::reduction::{mean_dim, std_dim};
121//! use torsh_functional::random_ops::randn;
122//!
123//! fn example() -> Result<(), Box<dyn std::error::Error>> {
124//! // Compute mean and std across batch dimension for BatchNorm
125//! let features = randn(&[64, 128, 28, 28], None, None, Some(42))?; // [N, C, H, W]
126//!
127//! // Compute statistics over batch and spatial dimensions (0, 2, 3)
128//! // keeping channel dimension (1)
129//! let batch_mean = mean_dim(&features, &[0, 2, 3], true)?;
130//! let batch_std = std_dim(&features, &[0, 2, 3], true, false)?;
131//!
132//! // Normalize: (x - mean) / std
133//! let normalized = features.sub(&batch_mean)?
134//! .div(&batch_std)?;
135//! Ok(())
136//! }
137//! ```
138//!
139//! ### Attention Score Normalization
140//! ```rust
141//! use torsh_functional::reduction::{max_dim, sum_dim};
142//! use torsh_functional::random_ops::randn;
143//!
144//! fn example() -> Result<(), Box<dyn std::error::Error>> {
145//! // Softmax using max for numerical stability
146//! let scores = randn(&[8, 64, 64], None, None, Some(42))?; // [batch, seq_len, seq_len]
147//!
148//! // Stable softmax: exp(x - max(x)) / sum(exp(x - max(x)))
149//! let (max_scores, _) = max_dim(&scores, -1, true)?;
150//! let shifted = scores.sub(&max_scores)?;
151//! let exp_scores = shifted.exp()?;
152//! let sum_exp = sum_dim(&exp_scores, &[-1], true)?;
153//! let softmax = exp_scores.div(&sum_exp)?;
154//! Ok(())
155//! }
156//! ```
157//!
158//! ## Numerical Stability Considerations
159//!
160//! ### Sum Accumulation
161//! - Use Kahan summation for large arrays to minimize floating-point error
162//! - For very large tensors, consider hierarchical reduction
163//!
164//! ### Variance Computation
165//! - Use Welford's online algorithm for single-pass stable variance
166//! - Avoid naive formula `E[X²] - E[X]²` which suffers from catastrophic cancellation
167//!
168//! ### LogSumExp
169//! - Always use max-subtraction trick: `max(x) + log(sum(exp(x - max(x))))`
170//! - Essential for preventing overflow in exp() computation
171
172use std::collections::HashMap;
173use torsh_core::{Result as TorshResult, TorshError};
174use torsh_tensor::{creation::zeros, stats::StatMode, Tensor};
175
176// ============================================================================
177// Basic Reduction Operations
178// ============================================================================
179
180/// Sum all elements in a tensor.
181///
182/// # Mathematical Definition
183/// ```text
184/// sum(X) = Σ(i=1 to n) x_i
185/// ```
186///
187/// # Arguments
188/// * `tensor` - Input tensor of any shape
189///
190/// # Returns
191/// Scalar tensor containing the sum of all elements
192///
193/// # Complexity
194/// - Time: O(n) where n is the number of elements
195/// - Space: O(1)
196///
197/// # Examples
198/// ```rust,no_run
199/// # use torsh_tensor::Tensor;
200/// # use torsh_functional::reduction::sum;
201/// # use torsh_functional::random_ops::randn;
202/// # fn example() -> torsh_core::Result<()> {
203/// let x = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], &[5])?;
204/// let result = sum(&x)?; // Returns tensor with value 15.0
205/// # Ok(())
206/// # }
207/// ```
208pub fn sum(tensor: &Tensor) -> TorshResult<Tensor> {
209 tensor.sum()
210}
211
212/// Sum along specified dimensions.
213///
214/// # Mathematical Definition
215/// ```text
216/// sum_dim(X, d) = Σ(i in dimension d) x_i
217/// ```
218///
219/// # Arguments
220/// * `tensor` - Input tensor
221/// * `dim` - Dimensions to reduce (negative indexing supported)
222/// * `keepdim` - If true, reduced dimensions are retained with size 1
223///
224/// # Returns
225/// Tensor with specified dimensions reduced
226///
227/// # Shape
228/// - Input: `[..., d_i, ...]` where `d_i` is a dimension to reduce
229/// - Output (keepdim=false): `[..., ...]` with `d_i` removed
230/// - Output (keepdim=true): `[..., 1, ...]` with `d_i` replaced by 1
231///
232/// # Examples
233/// ```rust,no_run
234/// # use torsh_tensor::Tensor;
235/// # use torsh_functional::reduction::sum_dim;
236/// # use torsh_functional::random_ops::randn;
237/// # fn example() -> torsh_core::Result<()> {
238/// let x = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3])?;
239/// // [[1, 2, 3],
240/// // [4, 5, 6]]
241///
242/// let row_sums = sum_dim(&x, &[1], false)?; // [6, 15]
243/// let col_sums = sum_dim(&x, &[0], false)?; // [5, 7, 9]
244/// let total = sum_dim(&x, &[0, 1], false)?; // 21
245/// # Ok(())
246/// # }
247/// ```
248pub fn sum_dim(tensor: &Tensor, dim: &[isize], keepdim: bool) -> TorshResult<Tensor> {
249 let i32_dims: Vec<i32> = dim.iter().map(|&d| d as i32).collect();
250 tensor.sum_dim(&i32_dims, keepdim)
251}
252
253/// Mean (arithmetic average) of all elements in a tensor.
254///
255/// # Mathematical Definition
256/// ```text
257/// mean(X) = (1/n) Σ(i=1 to n) x_i
258/// ```
259///
260/// # Arguments
261/// * `tensor` - Input tensor of any shape
262///
263/// # Returns
264/// Scalar tensor containing the mean of all elements
265///
266/// # Complexity
267/// - Time: O(n)
268/// - Space: O(1)
269///
270/// # Examples
271/// ```rust,no_run
272/// # use torsh_tensor::Tensor;
273/// # use torsh_functional::reduction::mean;
274/// # use torsh_functional::random_ops::randn;
275/// # fn example() -> torsh_core::Result<()> {
276/// let x = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], &[5])?;
277/// let result = mean(&x)?; // Returns tensor with value 3.0
278/// # Ok(())
279/// # }
280/// ```
281pub fn mean(tensor: &Tensor) -> TorshResult<Tensor> {
282 tensor.mean(None, false)
283}
284
285/// Mean along specified dimensions.
286///
287/// # Mathematical Definition
288/// ```text
289/// mean_dim(X, d) = (1/n_d) Σ(i in dimension d) x_i
290/// ```
291/// where `n_d` is the size of dimension `d`.
292///
293/// # Arguments
294/// * `tensor` - Input tensor
295/// * `dim` - Dimensions to reduce
296/// * `keepdim` - If true, reduced dimensions are retained with size 1
297///
298/// # Returns
299/// Tensor with specified dimensions reduced by averaging
300///
301/// # Examples
302/// ```rust
303/// use torsh_functional::reduction::mean_dim;
304/// use torsh_functional::random_ops::randn;
305///
306/// fn example() -> Result<(), Box<dyn std::error::Error>> {
307/// // Compute mean across batch dimension for normalization
308/// let batch_data = randn(&[32, 128], None, None, None)?; // [batch, features]
309/// let feature_means = mean_dim(&batch_data, &[0], false)?; // [128]
310/// Ok(())
311/// }
312/// ```
313pub fn mean_dim(tensor: &Tensor, dim: &[isize], keepdim: bool) -> TorshResult<Tensor> {
314 let usize_dims: Vec<usize> = dim.iter().map(|&d| d as usize).collect();
315 tensor.mean(Some(&usize_dims), keepdim)
316}
317
318/// Maximum value in a tensor.
319///
320/// # Mathematical Definition
321/// ```text
322/// max(X) = max{x_1, x_2, ..., x_n}
323/// ```
324///
325/// # Arguments
326/// * `tensor` - Input tensor of any shape
327///
328/// # Returns
329/// Scalar tensor containing the maximum element value
330///
331/// # Complexity
332/// - Time: O(n)
333/// - Space: O(1)
334///
335/// # Examples
336/// ```rust,no_run
337/// # use torsh_tensor::Tensor;
338/// # use torsh_functional::reduction::max;
339/// # use torsh_functional::random_ops::randn;
340/// # fn example() -> torsh_core::Result<()> {
341/// let x = Tensor::from_vec(vec![1.0, 5.0, 3.0, 9.0, 2.0], &[5])?;
342/// let result = max(&x)?; // Returns tensor with value 9.0
343/// # Ok(())
344/// # }
345/// ```
346pub fn max(tensor: &Tensor) -> TorshResult<Tensor> {
347 tensor.max(None, false)
348}
349
350/// Maximum along specified dimension with indices.
351///
352/// # Arguments
353/// * `tensor` - Input tensor
354/// * `dim` - Dimension to reduce
355/// * `keepdim` - If true, reduced dimension is retained with size 1
356///
357/// # Returns
358/// Tuple of (max_values, indices) where:
359/// - `max_values`: Maximum values along the dimension
360/// - `indices`: Indices of maximum values (useful for pooling operations)
361///
362/// # Examples
363/// ```rust,no_run
364/// # use torsh_tensor::Tensor;
365/// # use torsh_functional::reduction::max_dim;
366/// # use torsh_functional::random_ops::randn;
367/// # fn example() -> torsh_core::Result<()> {
368/// let x = Tensor::from_vec(vec![1.0, 5.0, 3.0, 9.0, 2.0, 7.0], &[2, 3])?;
369/// let (max_vals, indices) = max_dim(&x, 1, false)?;
370/// // max_vals: [5.0, 9.0] - max of each row
371/// // indices: [1, 1] - position of max in each row
372/// # Ok(())
373/// # }
374/// ```
375pub fn max_dim(tensor: &Tensor, dim: isize, keepdim: bool) -> TorshResult<(Tensor, Tensor<i64>)> {
376 let values = tensor.max_dim(dim as i32, keepdim)?;
377 let indices = tensor.argmax(Some(dim as i32))?;
378 Ok((values, indices))
379}
380
381/// Minimum value in a tensor.
382///
383/// # Mathematical Definition
384/// ```text
385/// min(X) = min{x_1, x_2, ..., x_n}
386/// ```
387///
388/// # Arguments
389/// * `tensor` - Input tensor of any shape
390///
391/// # Returns
392/// Scalar tensor containing the minimum element value
393///
394/// # Complexity
395/// - Time: O(n)
396/// - Space: O(1)
397///
398/// # Examples
399/// ```rust,no_run
400/// # use torsh_tensor::Tensor;
401/// # use torsh_functional::reduction::min;
402/// # use torsh_functional::random_ops::randn;
403/// # fn example() -> torsh_core::Result<()> {
404/// let x = Tensor::from_vec(vec![1.0, 5.0, 3.0, 9.0, 2.0], &[5])?;
405/// let result = min(&x)?; // Returns tensor with value 1.0
406/// # Ok(())
407/// # }
408/// ```
409pub fn min(tensor: &Tensor) -> TorshResult<Tensor> {
410 tensor.min()
411}
412
413/// Minimum along specified dimension with indices.
414///
415/// # Arguments
416/// * `tensor` - Input tensor
417/// * `dim` - Dimension to reduce
418/// * `keepdim` - If true, reduced dimension is retained with size 1
419///
420/// # Returns
421/// Tuple of (min_values, indices) where:
422/// - `min_values`: Minimum values along the dimension
423/// - `indices`: Indices of minimum values
424///
425/// # Examples
426/// ```rust,no_run
427/// # use torsh_tensor::Tensor;
428/// # use torsh_functional::reduction::min_dim;
429/// # use torsh_functional::random_ops::randn;
430/// # fn example() -> torsh_core::Result<()> {
431/// let x = Tensor::from_vec(vec![1.0, 5.0, 3.0, 9.0, 2.0, 7.0], &[2, 3])?;
432/// let (min_vals, indices) = min_dim(&x, 1, false)?;
433/// // min_vals: [1.0, 2.0] - min of each row
434/// // indices: [0, 1] - position of min in each row
435/// # Ok(())
436/// # }
437/// ```
438pub fn min_dim(tensor: &Tensor, dim: isize, keepdim: bool) -> TorshResult<(Tensor, Tensor<i64>)> {
439 let values = tensor.min_dim(dim as i32, keepdim)?;
440 let indices = tensor.argmin(Some(dim as i32))?;
441 Ok((values, indices))
442}
443
444/// Index of maximum value in a tensor
445pub fn argmax(tensor: &Tensor) -> TorshResult<Tensor<i64>> {
446 tensor.argmax(None)
447}
448
449/// Index of maximum value along specified dimension
450pub fn argmax_dim(tensor: &Tensor, dim: isize, keepdim: bool) -> TorshResult<Tensor<i64>> {
451 let indices = tensor.argmax(Some(dim as i32))?;
452 if keepdim {
453 // Add dimension back for keepdim=true
454 let mut new_shape = indices.shape().dims().to_vec();
455 new_shape.insert(dim as usize, 1);
456 let new_shape_i32: Vec<i32> = new_shape.iter().map(|&x| x as i32).collect();
457 indices.view(&new_shape_i32)
458 } else {
459 Ok(indices)
460 }
461}
462
463/// Index of minimum value in a tensor
464pub fn argmin(tensor: &Tensor) -> TorshResult<Tensor<i64>> {
465 tensor.argmin(None)
466}
467
468/// Index of minimum value along specified dimension
469pub fn argmin_dim(tensor: &Tensor, dim: isize, keepdim: bool) -> TorshResult<Tensor<i64>> {
470 let indices = tensor.argmin(Some(dim as i32))?;
471 if keepdim {
472 // Add dimension back for keepdim=true
473 let mut new_shape = indices.shape().dims().to_vec();
474 new_shape.insert(dim as usize, 1);
475 let new_shape_i32: Vec<i32> = new_shape.iter().map(|&x| x as i32).collect();
476 indices.view(&new_shape_i32)
477 } else {
478 Ok(indices)
479 }
480}
481
482/// Product of all elements in a tensor
483pub fn prod(tensor: &Tensor) -> TorshResult<Tensor> {
484 // Flatten tensor and compute product manually since prod() might not be implemented
485 let flat = tensor.flatten()?;
486 let size = flat.shape().dims()[0];
487 let mut result = 1.0f32;
488 for i in 0..size {
489 result *= flat.get(&[i])?;
490 }
491 Tensor::from_vec(vec![result], &[])
492}
493
494/// Product along specified dimensions
495pub fn prod_dim(_tensor: &Tensor, _dim: &[isize], _keepdim: bool) -> TorshResult<Tensor> {
496 // Simple implementation: return error for now as this requires complex dimension reduction
497 let _ = (_dim, _keepdim); // silence unused warnings
498 Err(TorshError::Other(
499 "prod_dim not yet fully implemented".to_string(),
500 ))
501}
502
503/// Standard deviation of all elements
504pub fn std(tensor: &Tensor, unbiased: bool) -> TorshResult<Tensor> {
505 let mode = if unbiased {
506 StatMode::Sample
507 } else {
508 StatMode::Population
509 };
510 tensor.std(None, false, mode)
511}
512
513/// Standard deviation along specified dimensions
514pub fn std_dim(
515 tensor: &Tensor,
516 dim: &[isize],
517 unbiased: bool,
518 keepdim: bool,
519) -> TorshResult<Tensor> {
520 let usize_dims: Vec<usize> = dim.iter().map(|&d| d as usize).collect();
521 let mode = if unbiased {
522 StatMode::Sample
523 } else {
524 StatMode::Population
525 };
526 tensor.std(Some(&usize_dims), keepdim, mode)
527}
528
529/// Variance of all elements
530pub fn var(tensor: &Tensor, unbiased: bool) -> TorshResult<Tensor> {
531 let mode = if unbiased {
532 StatMode::Sample
533 } else {
534 StatMode::Population
535 };
536 tensor.var(None, false, mode)
537}
538
539/// Variance along specified dimensions
540pub fn var_dim(
541 tensor: &Tensor,
542 dim: &[isize],
543 unbiased: bool,
544 keepdim: bool,
545) -> TorshResult<Tensor> {
546 let usize_dims: Vec<usize> = dim.iter().map(|&d| d as usize).collect();
547 let mode = if unbiased {
548 StatMode::Sample
549 } else {
550 StatMode::Population
551 };
552 tensor.var(Some(&usize_dims), keepdim, mode)
553}
554
555/// L1 norm (sum of absolute values)
556pub fn norm_l1(tensor: &Tensor) -> TorshResult<Tensor> {
557 let abs_tensor = tensor.abs()?;
558 abs_tensor.sum()
559}
560
561/// L2 norm (Euclidean norm)
562pub fn norm_l2(tensor: &Tensor) -> TorshResult<Tensor> {
563 let squared = tensor.square()?;
564 let sum = squared.sum()?;
565 sum.sqrt()
566}
567
568/// P-norm
569pub fn norm_p(tensor: &Tensor, p: f32) -> TorshResult<Tensor> {
570 if p == 1.0 {
571 norm_l1(tensor)
572 } else if p == 2.0 {
573 norm_l2(tensor)
574 } else {
575 let abs_tensor = tensor.abs()?;
576 let powered = abs_tensor.pow_scalar(p)?;
577 let sum = powered.sum()?;
578 sum.pow_scalar(1.0 / p)
579 }
580}
581
582/// Frobenius norm (for matrices)
583pub fn norm_frobenius(tensor: &Tensor) -> TorshResult<Tensor> {
584 norm_l2(tensor)
585}
586
587/// Nuclear norm (sum of singular values)
588///
589/// The nuclear norm is defined as the sum of singular values of a matrix:
590/// ```text
591/// ||A||_* = Σ σ_i(A)
592/// ```
593/// where σ_i are the singular values obtained from SVD: A = UΣV^T
594///
595/// This is also known as the trace norm or Schatten 1-norm.
596/// It's commonly used in:
597/// - Low-rank matrix approximation
598/// - Matrix completion problems
599/// - Compressed sensing
600/// - Regularization for machine learning models
601pub fn norm_nuclear(tensor: &Tensor) -> TorshResult<Tensor> {
602 // Ensure input is 2D
603 if tensor.shape().ndim() != 2 {
604 return Err(TorshError::InvalidArgument(
605 "Nuclear norm requires 2D tensor (matrix)".to_string(),
606 ));
607 }
608
609 // Compute SVD: A = U * S * V^T
610 // The singular values S contain the information we need
611 let (_u, s, _vt) = torsh_linalg::decomposition::svd(tensor, false)?;
612
613 // Nuclear norm is the sum of singular values
614 // The singular values are already sorted in descending order
615 s.sum()
616}
617
618/// Count non-zero elements
619pub fn count_nonzero(tensor: &Tensor) -> TorshResult<Tensor> {
620 let zero_tensor = zeros(tensor.shape().dims())?;
621 let nonzero_mask = tensor.ne(&zero_tensor)?;
622 // Create a tensor of ones with the same shape and sum where mask is true
623 let ones = tensor.ones_like()?;
624 let zeros = tensor.zeros_like()?;
625 let count_tensor = ones.where_tensor(&nonzero_mask, &zeros)?;
626 count_tensor.sum()
627}
628
629/// Count non-zero elements along dimension
630pub fn count_nonzero_dim(tensor: &Tensor, dim: isize) -> TorshResult<Tensor> {
631 let zero_tensor = zeros(tensor.shape().dims())?;
632 let nonzero_mask = tensor.ne(&zero_tensor)?;
633 // Create a tensor of ones with the same shape and sum where mask is true
634 let ones = tensor.ones_like()?;
635 let zeros = tensor.zeros_like()?;
636 let count_tensor = ones.where_tensor(&nonzero_mask, &zeros)?;
637 count_tensor.sum_dim(&[dim as i32], false)
638}
639
640/// Cumulative sum
641pub fn cumsum(tensor: &Tensor, dim: isize) -> TorshResult<Tensor> {
642 tensor.cumsum(dim.try_into().expect("dimension conversion should succeed"))
643}
644
645/// Cumulative product
646pub fn cumprod(tensor: &Tensor, dim: isize) -> TorshResult<Tensor> {
647 tensor.cumprod(dim.try_into().expect("dimension conversion should succeed"))
648}
649
650/// All elements are true (non-zero)
651pub fn all(tensor: &Tensor) -> TorshResult<torsh_tensor::Tensor<bool>> {
652 tensor.all()
653}
654
655/// All elements are true along dimension
656pub fn all_dim(
657 tensor: &Tensor,
658 dim: isize,
659 keepdim: bool,
660) -> TorshResult<torsh_tensor::Tensor<bool>> {
661 tensor.all_dim(
662 dim.try_into().expect("dimension conversion should succeed"),
663 keepdim,
664 )
665}
666
667/// Any element is true (non-zero)
668pub fn any(tensor: &Tensor) -> TorshResult<torsh_tensor::Tensor<bool>> {
669 tensor.any()
670}
671
672/// Any element is true along dimension
673pub fn any_dim(
674 tensor: &Tensor,
675 dim: isize,
676 keepdim: bool,
677) -> TorshResult<torsh_tensor::Tensor<bool>> {
678 tensor.any_dim(
679 dim.try_into().expect("dimension conversion should succeed"),
680 keepdim,
681 )
682}
683
684// ============================================================================
685// Unique Operations
686// ============================================================================
687
688/// Find unique elements in a tensor
689pub fn unique(
690 tensor: &Tensor,
691 sorted: bool,
692 return_inverse: bool,
693 return_counts: bool,
694 dim: Option<isize>,
695) -> TorshResult<UniqueResult> {
696 if let Some(_d) = dim {
697 // Unique along dimension
698 return Err(TorshError::Other(
699 "unique along dimension not yet implemented".to_string(),
700 ));
701 }
702
703 // Flatten tensor and get unique values
704 let flat = tensor.flatten()?;
705 let size = flat.shape().dims()[0];
706
707 let mut unique_map: HashMap<OrderedFloat, usize> = HashMap::new();
708 let mut unique_values = Vec::new();
709 let mut inverse_indices = vec![0; size];
710
711 // Find unique values
712 for i in 0..size {
713 let value = flat.get(&[i])?;
714 let key = OrderedFloat(value);
715
716 match unique_map.get(&key) {
717 Some(&idx) => {
718 if return_inverse {
719 inverse_indices[i] = idx;
720 }
721 }
722 None => {
723 let idx = unique_values.len();
724 unique_values.push(value);
725 unique_map.insert(key, idx);
726 if return_inverse {
727 inverse_indices[i] = idx;
728 }
729 }
730 }
731 }
732
733 // Sort if requested
734 if sorted {
735 let mut indices: Vec<_> = (0..unique_values.len()).collect();
736 indices.sort_by(|&a, &b| {
737 unique_values[a]
738 .partial_cmp(&unique_values[b])
739 .expect("numeric comparison should succeed")
740 });
741
742 // Reorder unique values
743 let sorted_values: Vec<_> = indices.iter().map(|&i| unique_values[i]).collect();
744
745 // Update inverse indices if needed
746 if return_inverse {
747 let mut index_map = vec![0; indices.len()];
748 for (new_idx, &old_idx) in indices.iter().enumerate() {
749 index_map[old_idx] = new_idx;
750 }
751
752 for inv_idx in inverse_indices.iter_mut() {
753 *inv_idx = index_map[*inv_idx];
754 }
755 }
756
757 unique_values = sorted_values;
758 }
759
760 // Create output tensor
761 let output = Tensor::from_vec(unique_values.clone(), &[unique_values.len()])?;
762
763 // Compute counts if requested
764 let counts = if return_counts {
765 let mut count_vec = vec![0; unique_values.len()];
766 if return_inverse {
767 for &idx in inverse_indices.iter() {
768 count_vec[idx] += 1;
769 }
770 } else {
771 // Recompute if we didn't track inverse indices
772 for i in 0..size {
773 let value = flat.get(&[i])?;
774 for (j, &unique_val) in unique_values.iter().enumerate() {
775 if (value - unique_val).abs() < f32::EPSILON {
776 count_vec[j] += 1;
777 break;
778 }
779 }
780 }
781 }
782 let count_data: Vec<f32> = count_vec.into_iter().map(|c| c as f32).collect();
783 Some(Tensor::from_vec(count_data.clone(), &[count_data.len()])?)
784 } else {
785 None
786 };
787
788 // Create inverse tensor if requested
789 let inverse = if return_inverse {
790 let inverse_data: Vec<f32> = inverse_indices.into_iter().map(|i| i as f32).collect();
791 Some(Tensor::from_vec(
792 inverse_data.clone(),
793 &[inverse_data.len()],
794 )?)
795 } else {
796 None
797 };
798
799 Ok(UniqueResult {
800 values: output,
801 inverse,
802 counts,
803 })
804}
805
806/// Find unique consecutive elements
807pub fn unique_consecutive(
808 tensor: &Tensor,
809 return_inverse: bool,
810 return_counts: bool,
811 dim: Option<isize>,
812) -> TorshResult<UniqueResult> {
813 if let Some(_d) = dim {
814 // Unique consecutive along dimension
815 return Err(TorshError::Other(
816 "unique_consecutive along dimension not yet implemented".to_string(),
817 ));
818 }
819
820 // Flatten tensor
821 let flat = tensor.flatten()?;
822 let size = flat.shape().dims()[0];
823
824 if size == 0 {
825 return Ok(UniqueResult {
826 values: zeros(&[0])?,
827 inverse: if return_inverse {
828 Some(zeros(&[0])?)
829 } else {
830 None
831 },
832 counts: if return_counts {
833 Some(zeros(&[0])?)
834 } else {
835 None
836 },
837 });
838 }
839
840 let mut unique_values = Vec::new();
841 let mut inverse_indices = vec![0; size];
842 let mut counts = Vec::new();
843
844 // Process first element
845 let mut current_value = flat.get(&[0])?;
846 unique_values.push(current_value);
847 let mut current_count = 1;
848
849 // Process remaining elements
850 for i in 1..size {
851 let value = flat.get(&[i])?;
852
853 if (value - current_value).abs() < f32::EPSILON {
854 // Same as previous
855 current_count += 1;
856 if return_inverse {
857 inverse_indices[i] = unique_values.len() - 1;
858 }
859 } else {
860 // Different from previous
861 if return_counts {
862 counts.push(current_count);
863 }
864
865 current_value = value;
866 unique_values.push(value);
867 current_count = 1;
868
869 if return_inverse {
870 inverse_indices[i] = unique_values.len() - 1;
871 }
872 }
873 }
874
875 // Don't forget the last group
876 if return_counts {
877 counts.push(current_count);
878 }
879
880 // Create output tensors
881 let output = Tensor::from_vec(unique_values.clone(), &[unique_values.len()])?;
882
883 let counts_tensor = if return_counts {
884 let count_data: Vec<f32> = counts.into_iter().map(|c| c as f32).collect();
885 Some(Tensor::from_vec(count_data.clone(), &[count_data.len()])?)
886 } else {
887 None
888 };
889
890 let inverse_tensor = if return_inverse {
891 let inverse_data: Vec<f32> = inverse_indices.into_iter().map(|i| i as f32).collect();
892 Some(Tensor::from_vec(
893 inverse_data.clone(),
894 &[inverse_data.len()],
895 )?)
896 } else {
897 None
898 };
899
900 Ok(UniqueResult {
901 values: output,
902 inverse: inverse_tensor,
903 counts: counts_tensor,
904 })
905}
906
907/// Result of unique operations
908pub struct UniqueResult {
909 pub values: Tensor,
910 pub inverse: Option<Tensor>,
911 pub counts: Option<Tensor>,
912}
913
914/// Wrapper for f32 that implements Eq and Hash
915#[derive(Debug, Clone, Copy)]
916struct OrderedFloat(f32);
917
918impl PartialEq for OrderedFloat {
919 fn eq(&self, other: &Self) -> bool {
920 (self.0 - other.0).abs() < f32::EPSILON
921 }
922}
923
924impl Eq for OrderedFloat {}
925
926impl std::hash::Hash for OrderedFloat {
927 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
928 // Hash the bits of the float
929 self.0.to_bits().hash(state);
930 }
931}
932
933#[cfg(test)]
934mod tests {
935 use super::*;
936 use torsh_tensor::tensor;
937
938 #[test]
939 fn test_unique() {
940 let tensor = tensor![3.0f32, 1.0, 4.0, 1.0, 5.0, 9.0, 2.0, 6.0, 5.0].unwrap();
941
942 // Test basic unique
943 let result = unique(&tensor, true, false, false, None).unwrap();
944 assert_eq!(result.values.shape().dims()[0], 7); // Should have 7 unique values
945
946 // Test with counts
947 let result = unique(&tensor, true, false, true, None).unwrap();
948 assert!(result.counts.is_some());
949 }
950
951 #[test]
952 fn test_unique_consecutive() {
953 let tensor = tensor![1.0f32, 1.0, 2.0, 2.0, 2.0, 3.0, 1.0, 1.0].unwrap();
954
955 let result = unique_consecutive(&tensor, true, true, None).unwrap();
956
957 // Should have 4 groups: [1,1], [2,2,2], [3], [1,1]
958 assert_eq!(result.values.shape().dims()[0], 4);
959
960 let expected_values = vec![1.0, 2.0, 3.0, 1.0];
961 let expected_counts = vec![2.0, 3.0, 1.0, 2.0];
962
963 // Verify values and counts
964 for i in 0..4 {
965 assert!((result.values.get(&[i]).unwrap() - expected_values[i]).abs() < f32::EPSILON);
966 assert!(
967 (result.counts.as_ref().unwrap().get(&[i]).unwrap() - expected_counts[i]).abs()
968 < f32::EPSILON
969 );
970 }
971 }
972}