train_station/tensor/ops/
softmax.rs

1//! Softmax activation function
2//!
3//! Provides the softmax activation function following PyTorch conventions with
4//! comprehensive GradTrack support and numerically stable computation.
5//!
6//! # Key Features
7//!
8//! - **Softmax Activation**: `softmax(dim)` - Computes softmax along specified dimension (PyTorch `softmax()` equivalent)
9//! - **GradTrack Support**: Full automatic differentiation with efficient gradient computation
10//! - **Numerical Stability**: Avoids overflow using max subtraction technique
11//! - **Mathematical Accuracy**: High-precision softmax computation
12//! - **Dimension Flexibility**: Supports softmax along any dimension
13//! - **Probability Output**: Values sum to 1 along the specified dimension
14//!
15//! # Mathematical Properties
16//!
17//! The softmax activation function has the following properties:
18//! - **Definition**: softmax(x_i) = exp(x_i - max(x)) / sum(exp(x_j - max(x)))
19//! - **Range**: (0, 1) - outputs are always positive and sum to 1
20//! - **Numerical Stability**: Subtracts max value to prevent overflow
21//! - **Monotonicity**: Preserves relative ordering of input values
22//! - **Continuity**: Continuous and differentiable everywhere
23//! - **Gradient**: Complex gradient computation involving the softmax output
24//! - **Probability Interpretation**: Outputs can be interpreted as probabilities
25//!
26//! # Performance Characteristics
27//!
28//! - **Numerical Stability**: Avoids overflow using max subtraction technique
29//! - **Scalar Implementation**: Optimized scalar computation for mathematical accuracy
30//! - **Cache-friendly Access**: Optimized memory access patterns for dimension operations
31//! - **Mathematical Accuracy**: High-precision exponential and division operations
32//! - **GradTrack Optimization**: Efficient automatic differentiation with NoGradTrack support
33
34use crate::gradtrack::{is_grad_enabled, GradEngine, GradFn};
35use crate::tensor::core::Tensor;
36
37impl Tensor {
38    /// Computes softmax activation along the specified dimension
39    ///
40    /// Applies the softmax function along dimension `dim`, transforming values into
41    /// probabilities that sum to 1 along that dimension. Uses numerically stable
42    /// computation to avoid overflow: `softmax(x_i) = exp(x_i - max(x)) / sum(exp(x_j - max(x)))`
43    ///
44    /// # Arguments
45    ///
46    /// * `dim` - Dimension along which to compute softmax (0-based indexing)
47    ///
48    /// # Returns
49    ///
50    /// A new tensor with softmax applied along the specified dimension.
51    /// Values are in range (0, 1) and sum to 1 along `dim`.
52    ///
53    /// # Performance Characteristics
54    ///
55    /// - **Numerical Stability**: Avoids overflow using max subtraction technique
56    /// - **Scalar Implementation**: Optimized scalar computation for mathematical accuracy
57    /// - **Cache-friendly**: Optimized memory access patterns for dimension operations
58    /// - **Mathematical Accuracy**: High-precision exponential and division operations
59    /// - **GradTrack Support**: Full automatic differentiation with efficient gradient computation
60    ///
61    /// # Implementation Details
62    ///
63    /// Uses a numerically stable three-pass algorithm:
64    /// 1. **Max Computation**: Find the maximum value along the specified dimension
65    /// 2. **Exponential Sum**: Compute exp(x - max) and sum the results
66    /// 3. **Normalization**: Divide each exp(x - max) by the sum to get probabilities
67    ///
68    /// This approach prevents overflow by subtracting the maximum value before
69    /// computing exponentials, ensuring numerical stability for any input range.
70    ///
71    /// # Examples
72    ///
73    /// ## Basic Softmax Activation
74    ///
75    /// ```
76    /// use train_station::Tensor;
77    ///
78    /// let a = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
79    /// let b = a.softmax(0);
80    /// assert_eq!(b.shape().dims, vec![3]);
81    ///
82    /// // Verify probabilities sum to 1
83    /// let sum = b.get(&[0]) + b.get(&[1]) + b.get(&[2]);
84    /// assert!((sum - 1.0).abs() < 1e-6);
85    ///
86    /// // Verify relative ordering is preserved
87    /// assert!(b.get(&[0]) < b.get(&[1]));
88    /// assert!(b.get(&[1]) < b.get(&[2]));
89    /// ```
90    ///
91    /// ## 2D Softmax Along Different Dimensions
92    ///
93    /// ```
94    /// use train_station::Tensor;
95    ///
96    /// let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
97    /// let b = a.softmax(0); // Softmax along first dimension
98    /// assert_eq!(b.shape().dims, vec![2, 2]);
99    ///
100    /// // Each column should sum to 1
101    /// let col1_sum = b.get(&[0, 0]) + b.get(&[1, 0]);
102    /// let col2_sum = b.get(&[0, 1]) + b.get(&[1, 1]);
103    /// assert!((col1_sum - 1.0).abs() < 1e-6);
104    /// assert!((col2_sum - 1.0).abs() < 1e-6);
105    /// ```
106    ///
107    /// # Panics
108    /// - Panics if `dim` is out of bounds for the tensor's rank
109    /// - Panics if the dimension size is 0
110    pub fn softmax(&self, dim: usize) -> Tensor {
111        let rank = self.shape().rank();
112        assert!(
113            dim < rank,
114            "softmax dim {} out of bounds for rank {}",
115            dim,
116            rank
117        );
118        let dims = self.shape().dims.clone();
119        let reduce = dims[dim];
120        assert!(reduce > 0, "cannot softmax over empty dimension");
121
122        let inner: usize = dims[dim + 1..].iter().product();
123        let outer: usize = dims[..dim].iter().product();
124
125        let mut out = Tensor::new(dims.clone());
126        unsafe {
127            let xptr = self.as_ptr();
128            let yptr = out.as_mut_ptr();
129            // For each slice along `dim`, find max then compute exp and sum, then normalize
130            for o in 0..outer {
131                for i in 0..inner {
132                    // 1) max
133                    let mut maxv = f32::NEG_INFINITY;
134                    for j in 0..reduce {
135                        let off = o * (reduce * inner) + j * inner + i;
136                        let v = *xptr.add(off);
137                        if v > maxv {
138                            maxv = v;
139                        }
140                    }
141                    // 2) exp sum
142                    let mut sum = 0.0f32;
143                    for j in 0..reduce {
144                        let off = o * (reduce * inner) + j * inner + i;
145                        let e = (*xptr.add(off) - maxv).exp();
146                        *yptr.add(off) = e;
147                        sum += e;
148                    }
149                    // 3) normalize
150                    let inv = 1.0f32 / sum;
151                    for j in 0..reduce {
152                        let off = o * (reduce * inner) + j * inner + i;
153                        *yptr.add(off) *= inv;
154                    }
155                }
156            }
157        }
158
159        if self.requires_grad() && is_grad_enabled() {
160            let mut result = out.clone();
161            result.set_requires_grad_internal(true);
162            let grad_fn = GradFn::Softmax {
163                dim,
164                saved_output: Box::new(out.clone()),
165            };
166            result.set_grad_fn(grad_fn.clone());
167            GradEngine::register_operation(result.id(), vec![self.id()], grad_fn);
168            return result;
169        }
170
171        out
172    }
173}
174
175#[cfg(test)]
176mod tests {
177    use super::*;
178
179    #[test]
180    fn test_softmax_forward_basic() {
181        let x = Tensor::from_slice(&[0.0, 1.0, 2.0], vec![3]).unwrap();
182        let y = x.softmax(0);
183        let s = y.sum();
184        unsafe {
185            assert!((*s.as_ptr() - 1.0).abs() < 1e-6);
186        }
187    }
188}