train_station/tensor/ops/exp.rs
1//! Exponential operations for tensors
2//!
3//! Provides element-wise exponential function following PyTorch conventions with
4//! comprehensive automatic differentiation support and optimized scalar computation.
5//!
6//! # Key Features
7//!
8//! - **Element-wise Exponential**: `exp()` - Computes e^x for each element (PyTorch `exp()` equivalent)
9//! - **Automatic Differentiation**: Full gradtrack support with efficient gradient computation
10//! - **Optimized Scalar Math**: Uses optimized scalar exponential for accuracy and simplicity
11//! - **Cache Optimization**: Memory access patterns optimized for modern CPUs
12//! - **Zero-copy Operations**: Efficient memory usage where possible
13//! - **Mathematical Accuracy**: High-precision exponential computation
14//!
15//! # Mathematical Properties
16//!
17//! The exponential function e^x has the following properties:
18//! - e^0 = 1
19//! - e^1 ≈ 2.71828 (Euler's number)
20//! - e^(-x) = 1/e^x
21//! - e^(x+y) = e^x * e^y
22//! - Gradient: d/dx(e^x) = e^x
23//!
24//! # Performance Characteristics
25//!
26//! - **Scalar Optimization**: Optimized scalar exponential computation
27//! - **Unrolled Loops**: 4x unrolling for optimal instruction throughput
28//! - **Cache-friendly Access**: Linear memory access patterns
29//! - **Mathematical Accuracy**: High-precision floating-point exponential
30//! - **Gradient Optimization**: Efficient gradtrack with NoGradTrack support
31
32use crate::gradtrack::{is_grad_enabled, GradEngine, GradFn};
33use crate::tensor::core::Tensor;
34
35// Note: exp uses scalar math for accuracy and simplicity; SIMD width-load path was removed due to alignment strictness
36
37impl Tensor {
38 /// Element-wise exponential function.
39 ///
40 /// Computes e^x for each element: `output[i] = e^(self[i])`
41 ///
42 /// # Returns
43 /// A new tensor with the exponential of each element
44 ///
45 /// # Examples
46 ///
47 /// ## Basic Exponential
48 ///
49 /// ```
50 /// use train_station::Tensor;
51 ///
52 /// let a = Tensor::from_slice(&[0.0, 1.0, 2.0], vec![3]).unwrap();
53 /// let b = a.exp();
54 /// assert_eq!(b.shape().dims, vec![3]);
55 /// assert_eq!(b.get(&[0]), 1.0); // e^0 = 1
56 /// assert!((b.get(&[1]) - 2.71828).abs() < 1e-5); // e^1 ≈ 2.71828
57 /// assert!((b.get(&[2]) - 7.38906).abs() < 1e-5); // e^2 ≈ 7.38906
58 /// ```
59 ///
60 /// ## Negative Values
61 ///
62 /// ```
63 /// use train_station::Tensor;
64 ///
65 /// let a = Tensor::from_slice(&[-1.0, 0.0, 1.0], vec![3]).unwrap();
66 /// let b = a.exp();
67 /// assert_eq!(b.shape().dims, vec![3]);
68 /// assert!((b.get(&[0]) - 0.36788).abs() < 1e-5); // e^(-1) ≈ 0.36788
69 /// assert_eq!(b.get(&[1]), 1.0); // e^0 = 1
70 /// assert!((b.get(&[2]) - 2.71828).abs() < 1e-5); // e^1 ≈ 2.71828
71 /// ```
72 #[inline]
73 #[track_caller]
74 pub fn exp(&self) -> Tensor {
75 let mut result = self.exp_optimized();
76 if self.requires_grad() && is_grad_enabled() {
77 result.set_requires_grad_internal(true);
78 let grad_fn = GradFn::Exp {
79 saved_output: Box::new(result.clone()),
80 };
81 result.set_grad_fn(grad_fn.clone());
82 GradEngine::register_operation(result.id(), vec![self.id()], grad_fn);
83 }
84 result
85 }
86 /// Internal optimized exponential operation
87 ///
88 /// Performs element-wise exponential computation using optimized scalar math
89 /// for maximum accuracy and performance. This is the core implementation
90 /// used by `exp()`.
91 ///
92 /// # Returns
93 ///
94 /// A new tensor containing the exponential of each element
95 ///
96 /// # Performance Characteristics
97 ///
98 /// - **Scalar Optimization**: Uses optimized scalar exponential for accuracy
99 /// - **Unrolled Loops**: 4x unrolling for optimal instruction throughput
100 /// - **Cache-friendly**: Linear memory access patterns
101 /// - **Mathematical Accuracy**: High-precision floating-point exponential
102 /// - **Zero-sized Handling**: Fast return for empty tensors
103 ///
104 /// # Implementation Details
105 ///
106 /// Uses scalar exponential computation for maximum mathematical accuracy.
107 /// SIMD optimization was removed due to alignment requirements and the
108 /// need for high-precision mathematical operations.
109 #[inline]
110 pub(crate) fn exp_optimized(&self) -> Tensor {
111 let mut output = Tensor::new(self.shape().dims.clone());
112
113 // Fast return for zero-sized tensors
114 if self.size() == 0 {
115 return output;
116 }
117
118 unsafe {
119 let src = self.as_ptr();
120 let dst = output.as_mut_ptr();
121 self.exp_scalar_fallback(src, dst);
122 }
123
124 output
125 }
126
127 /// Optimized scalar exponential fallback
128 ///
129 /// Performs element-wise exponential using optimized scalar operations with
130 /// 4x unrolling for better instruction-level parallelism and cache efficiency.
131 ///
132 /// # Arguments
133 ///
134 /// * `src` - Pointer to source tensor data
135 /// * `dst` - Pointer to output tensor data
136 ///
137 /// # Safety
138 ///
139 /// Requires valid pointers with sufficient memory for the tensor size.
140 /// All pointers must point to valid tensor data.
141 ///
142 /// # Performance Characteristics
143 ///
144 /// - **Unrolling**: 4x unrolling for instruction-level parallelism
145 /// - **Memory Access**: Linear access patterns for cache efficiency
146 /// - **Fallback**: Handles remaining elements with scalar operations
147 /// - **Mathematical Accuracy**: Uses high-precision scalar exponential
148 #[inline]
149 unsafe fn exp_scalar_fallback(&self, src: *const f32, dst: *mut f32) {
150 let size = self.size();
151 let unroll = 4;
152 let mut offset = 0;
153 let unroll_count = size / unroll;
154 for _ in 0..unroll_count {
155 *dst.add(offset) = (*src.add(offset)).exp();
156 *dst.add(offset + 1) = (*src.add(offset + 1)).exp();
157 *dst.add(offset + 2) = (*src.add(offset + 2)).exp();
158 *dst.add(offset + 3) = (*src.add(offset + 3)).exp();
159 offset += unroll;
160 }
161 for i in offset..size {
162 *dst.add(i) = (*src.add(i)).exp();
163 }
164 }
165}
166
167#[cfg(test)]
168mod tests {
169 use super::*;
170
171 #[test]
172 fn test_exp_basic() {
173 let data = [0.0, 1.0, -1.0, 2.0];
174 let x = Tensor::from_slice(&data, vec![2, 2]).unwrap();
175 let y = x.exp_optimized();
176 unsafe {
177 let yd = std::slice::from_raw_parts(y.as_ptr(), y.size());
178 let xd = std::slice::from_raw_parts(x.as_ptr(), x.size());
179 for i in 0..y.size() {
180 let expected = xd[i].exp();
181 assert!((yd[i] - expected).abs() < 1e-6);
182 }
183 }
184 }
185}