train_station/tensor/ops/softmax.rs
1//! Softmax activation function
2//!
3//! Provides the softmax activation function following PyTorch conventions with
4//! comprehensive GradTrack support and numerically stable computation.
5//!
6//! # Key Features
7//!
8//! - **Softmax Activation**: `softmax(dim)` - Computes softmax along specified dimension (PyTorch `softmax()` equivalent)
9//! - **GradTrack Support**: Full automatic differentiation with efficient gradient computation
10//! - **Numerical Stability**: Avoids overflow using max subtraction technique
11//! - **Mathematical Accuracy**: High-precision softmax computation
12//! - **Dimension Flexibility**: Supports softmax along any dimension
13//! - **Probability Output**: Values sum to 1 along the specified dimension
14//!
15//! # Mathematical Properties
16//!
17//! The softmax activation function has the following properties:
18//! - **Definition**: softmax(x_i) = exp(x_i - max(x)) / sum(exp(x_j - max(x)))
19//! - **Range**: (0, 1) - outputs are always positive and sum to 1
20//! - **Numerical Stability**: Subtracts max value to prevent overflow
21//! - **Monotonicity**: Preserves relative ordering of input values
22//! - **Continuity**: Continuous and differentiable everywhere
23//! - **Gradient**: Complex gradient computation involving the softmax output
24//! - **Probability Interpretation**: Outputs can be interpreted as probabilities
25//!
26//! # Performance Characteristics
27//!
28//! - **Numerical Stability**: Avoids overflow using max subtraction technique
29//! - **Scalar Implementation**: Optimized scalar computation for mathematical accuracy
30//! - **Cache-friendly Access**: Optimized memory access patterns for dimension operations
31//! - **Mathematical Accuracy**: High-precision exponential and division operations
32//! - **GradTrack Optimization**: Efficient automatic differentiation with NoGradTrack support
33
34use crate::gradtrack::{is_grad_enabled, GradEngine, GradFn};
35use crate::tensor::core::Tensor;
36
37impl Tensor {
38 /// Computes softmax activation along the specified dimension
39 ///
40 /// Applies the softmax function along dimension `dim`, transforming values into
41 /// probabilities that sum to 1 along that dimension. Uses numerically stable
42 /// computation to avoid overflow: `softmax(x_i) = exp(x_i - max(x)) / sum(exp(x_j - max(x)))`
43 ///
44 /// # Arguments
45 ///
46 /// * `dim` - Dimension along which to compute softmax (0-based indexing)
47 ///
48 /// # Returns
49 ///
50 /// A new tensor with softmax applied along the specified dimension.
51 /// Values are in range (0, 1) and sum to 1 along `dim`.
52 ///
53 /// # Performance Characteristics
54 ///
55 /// - **Numerical Stability**: Avoids overflow using max subtraction technique
56 /// - **Scalar Implementation**: Optimized scalar computation for mathematical accuracy
57 /// - **Cache-friendly**: Optimized memory access patterns for dimension operations
58 /// - **Mathematical Accuracy**: High-precision exponential and division operations
59 /// - **GradTrack Support**: Full automatic differentiation with efficient gradient computation
60 ///
61 /// # Implementation Details
62 ///
63 /// Uses a numerically stable three-pass algorithm:
64 /// 1. **Max Computation**: Find the maximum value along the specified dimension
65 /// 2. **Exponential Sum**: Compute exp(x - max) and sum the results
66 /// 3. **Normalization**: Divide each exp(x - max) by the sum to get probabilities
67 ///
68 /// This approach prevents overflow by subtracting the maximum value before
69 /// computing exponentials, ensuring numerical stability for any input range.
70 ///
71 /// # Examples
72 ///
73 /// ## Basic Softmax Activation
74 ///
75 /// ```
76 /// use train_station::Tensor;
77 ///
78 /// let a = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
79 /// let b = a.softmax(0);
80 /// assert_eq!(b.shape().dims, vec![3]);
81 ///
82 /// // Verify probabilities sum to 1
83 /// let sum = b.get(&[0]) + b.get(&[1]) + b.get(&[2]);
84 /// assert!((sum - 1.0).abs() < 1e-6);
85 ///
86 /// // Verify relative ordering is preserved
87 /// assert!(b.get(&[0]) < b.get(&[1]));
88 /// assert!(b.get(&[1]) < b.get(&[2]));
89 /// ```
90 ///
91 /// ## 2D Softmax Along Different Dimensions
92 ///
93 /// ```
94 /// use train_station::Tensor;
95 ///
96 /// let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
97 /// let b = a.softmax(0); // Softmax along first dimension
98 /// assert_eq!(b.shape().dims, vec![2, 2]);
99 ///
100 /// // Each column should sum to 1
101 /// let col1_sum = b.get(&[0, 0]) + b.get(&[1, 0]);
102 /// let col2_sum = b.get(&[0, 1]) + b.get(&[1, 1]);
103 /// assert!((col1_sum - 1.0).abs() < 1e-6);
104 /// assert!((col2_sum - 1.0).abs() < 1e-6);
105 /// ```
106 ///
107 /// # Panics
108 /// - Panics if `dim` is out of bounds for the tensor's rank
109 /// - Panics if the dimension size is 0
110 #[track_caller]
111 pub fn softmax(&self, dim: usize) -> Tensor {
112 let rank = self.shape().rank();
113 assert!(
114 dim < rank,
115 "softmax dim {} out of bounds for rank {}",
116 dim,
117 rank
118 );
119 let dims = self.shape().dims.clone();
120 let reduce = dims[dim];
121 assert!(reduce > 0, "cannot softmax over empty dimension");
122
123 let inner: usize = dims[dim + 1..].iter().product();
124 let outer: usize = dims[..dim].iter().product();
125
126 let mut out = Tensor::new(dims.clone());
127 unsafe {
128 let xptr = self.as_ptr();
129 let yptr = out.as_mut_ptr();
130 // For each slice along `dim`, find max then compute exp and sum, then normalize
131 for o in 0..outer {
132 for i in 0..inner {
133 // 1) max
134 let mut maxv = f32::NEG_INFINITY;
135 for j in 0..reduce {
136 let off = o * (reduce * inner) + j * inner + i;
137 let v = *xptr.add(off);
138 if v > maxv {
139 maxv = v;
140 }
141 }
142 // 2) exp sum
143 let mut sum = 0.0f32;
144 for j in 0..reduce {
145 let off = o * (reduce * inner) + j * inner + i;
146 let e = (*xptr.add(off) - maxv).exp();
147 *yptr.add(off) = e;
148 sum += e;
149 }
150 // 3) normalize
151 let inv = 1.0f32 / sum;
152 for j in 0..reduce {
153 let off = o * (reduce * inner) + j * inner + i;
154 *yptr.add(off) *= inv;
155 }
156 }
157 }
158 }
159
160 if self.requires_grad() && is_grad_enabled() {
161 let mut result = out.clone();
162 result.set_requires_grad_internal(true);
163 let grad_fn = GradFn::Softmax {
164 dim,
165 saved_output: Box::new(out.clone()),
166 };
167 result.set_grad_fn(grad_fn.clone());
168 GradEngine::register_operation(result.id(), vec![self.id()], grad_fn);
169 return result;
170 }
171
172 out
173 }
174}
175
176#[cfg(test)]
177mod tests {
178 use super::*;
179
180 #[test]
181 fn test_softmax_forward_basic() {
182 let x = Tensor::from_slice(&[0.0, 1.0, 2.0], vec![3]).unwrap();
183 let y = x.softmax(0);
184 let s = y.sum();
185 unsafe {
186 assert!((*s.as_ptr() - 1.0).abs() < 1e-6);
187 }
188 }
189}