train_station/tensor/ops/
softmax.rs

1//! Softmax activation function
2//!
3//! Provides the softmax activation function following PyTorch conventions with
4//! comprehensive GradTrack support and numerically stable computation.
5//!
6//! # Key Features
7//!
8//! - **Softmax Activation**: `softmax(dim)` - Computes softmax along specified dimension (PyTorch `softmax()` equivalent)
9//! - **GradTrack Support**: Full automatic differentiation with efficient gradient computation
10//! - **Numerical Stability**: Avoids overflow using max subtraction technique
11//! - **Mathematical Accuracy**: High-precision softmax computation
12//! - **Dimension Flexibility**: Supports softmax along any dimension
13//! - **Probability Output**: Values sum to 1 along the specified dimension
14//!
15//! # Mathematical Properties
16//!
17//! The softmax activation function has the following properties:
18//! - **Definition**: softmax(x_i) = exp(x_i - max(x)) / sum(exp(x_j - max(x)))
19//! - **Range**: (0, 1) - outputs are always positive and sum to 1
20//! - **Numerical Stability**: Subtracts max value to prevent overflow
21//! - **Monotonicity**: Preserves relative ordering of input values
22//! - **Continuity**: Continuous and differentiable everywhere
23//! - **Gradient**: Complex gradient computation involving the softmax output
24//! - **Probability Interpretation**: Outputs can be interpreted as probabilities
25//!
26//! # Performance Characteristics
27//!
28//! - **Numerical Stability**: Avoids overflow using max subtraction technique
29//! - **Scalar Implementation**: Optimized scalar computation for mathematical accuracy
30//! - **Cache-friendly Access**: Optimized memory access patterns for dimension operations
31//! - **Mathematical Accuracy**: High-precision exponential and division operations
32//! - **GradTrack Optimization**: Efficient automatic differentiation with NoGradTrack support
33
34use crate::gradtrack::{is_grad_enabled, GradEngine, GradFn};
35use crate::tensor::core::Tensor;
36
37impl Tensor {
38    /// Computes softmax activation along the specified dimension
39    ///
40    /// Applies the softmax function along dimension `dim`, transforming values into
41    /// probabilities that sum to 1 along that dimension. Uses numerically stable
42    /// computation to avoid overflow: `softmax(x_i) = exp(x_i - max(x)) / sum(exp(x_j - max(x)))`
43    ///
44    /// # Arguments
45    ///
46    /// * `dim` - Dimension along which to compute softmax (0-based indexing)
47    ///
48    /// # Returns
49    ///
50    /// A new tensor with softmax applied along the specified dimension.
51    /// Values are in range (0, 1) and sum to 1 along `dim`.
52    ///
53    /// # Performance Characteristics
54    ///
55    /// - **Numerical Stability**: Avoids overflow using max subtraction technique
56    /// - **Scalar Implementation**: Optimized scalar computation for mathematical accuracy
57    /// - **Cache-friendly**: Optimized memory access patterns for dimension operations
58    /// - **Mathematical Accuracy**: High-precision exponential and division operations
59    /// - **GradTrack Support**: Full automatic differentiation with efficient gradient computation
60    ///
61    /// # Implementation Details
62    ///
63    /// Uses a numerically stable three-pass algorithm:
64    /// 1. **Max Computation**: Find the maximum value along the specified dimension
65    /// 2. **Exponential Sum**: Compute exp(x - max) and sum the results
66    /// 3. **Normalization**: Divide each exp(x - max) by the sum to get probabilities
67    ///
68    /// This approach prevents overflow by subtracting the maximum value before
69    /// computing exponentials, ensuring numerical stability for any input range.
70    ///
71    /// # Examples
72    ///
73    /// ## Basic Softmax Activation
74    ///
75    /// ```
76    /// use train_station::Tensor;
77    ///
78    /// let a = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
79    /// let b = a.softmax(0);
80    /// assert_eq!(b.shape().dims(), vec![3]);
81    ///
82    /// // Verify probabilities sum to 1
83    /// let sum = b.get(&[0]) + b.get(&[1]) + b.get(&[2]);
84    /// assert!((sum - 1.0).abs() < 1e-6);
85    ///
86    /// // Verify relative ordering is preserved
87    /// assert!(b.get(&[0]) < b.get(&[1]));
88    /// assert!(b.get(&[1]) < b.get(&[2]));
89    /// ```
90    ///
91    /// ## 2D Softmax Along Different Dimensions
92    ///
93    /// ```
94    /// use train_station::Tensor;
95    ///
96    /// let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
97    /// let b = a.softmax(0); // Softmax along first dimension
98    /// assert_eq!(b.shape().dims(), vec![2, 2]);
99    ///
100    /// // Each column should sum to 1
101    /// let col1_sum = b.get(&[0, 0]) + b.get(&[1, 0]);
102    /// let col2_sum = b.get(&[0, 1]) + b.get(&[1, 1]);
103    /// assert!((col1_sum - 1.0).abs() < 1e-6);
104    /// assert!((col2_sum - 1.0).abs() < 1e-6);
105    /// ```
106    ///
107    /// # Panics
108    /// - Panics if `dim` is out of bounds for the tensor's rank
109    /// - Panics if the dimension size is 0
110    #[track_caller]
111    pub fn softmax(&self, dim: usize) -> Tensor {
112        let rank = self.shape().rank();
113        assert!(
114            dim < rank,
115            "softmax dim {} out of bounds for rank {}",
116            dim,
117            rank
118        );
119        let dims = self.shape().dims().to_vec();
120        let reduce = dims[dim];
121        assert!(reduce > 0, "cannot softmax over empty dimension");
122
123        let inner: usize = dims[dim + 1..].iter().product();
124        let outer: usize = dims[..dim].iter().product();
125
126        let mut out = Tensor::new(dims.clone());
127        unsafe {
128            let xptr = self.as_ptr();
129            let yptr = out.as_mut_ptr();
130            // For each slice along `dim`, find max then compute exp and sum, then normalize
131            for o in 0..outer {
132                for i in 0..inner {
133                    // 1) max
134                    let mut maxv = f32::NEG_INFINITY;
135                    for j in 0..reduce {
136                        let off = o * (reduce * inner) + j * inner + i;
137                        let v = *xptr.add(off);
138                        if v > maxv {
139                            maxv = v;
140                        }
141                    }
142                    // 2) exp sum
143                    let mut sum = 0.0f32;
144                    for j in 0..reduce {
145                        let off = o * (reduce * inner) + j * inner + i;
146                        let e = (*xptr.add(off) - maxv).exp();
147                        *yptr.add(off) = e;
148                        sum += e;
149                    }
150                    // 3) normalize
151                    let inv = 1.0f32 / sum;
152                    for j in 0..reduce {
153                        let off = o * (reduce * inner) + j * inner + i;
154                        *yptr.add(off) *= inv;
155                    }
156                }
157            }
158        }
159
160        if self.requires_grad() && is_grad_enabled() {
161            let mut result = out.clone();
162            result.set_requires_grad_internal(true);
163            let grad_fn = GradFn::Softmax {
164                dim,
165                saved_output: Box::new(out.clone()),
166            };
167            result.set_grad_fn(grad_fn.clone());
168            GradEngine::register_operation(result.id(), vec![self.id()], grad_fn);
169            return result;
170        }
171
172        out
173    }
174}
175
176#[cfg(test)]
177mod tests {
178    use super::*;
179
180    #[test]
181    fn test_softmax_forward_basic() {
182        let x = Tensor::from_slice(&[0.0, 1.0, 2.0], vec![3]).unwrap();
183        let y = x.softmax(0);
184        let s = y.sum();
185        unsafe {
186            assert!((*s.as_ptr() - 1.0).abs() < 1e-6);
187        }
188    }
189}