train_station/tensor/ops/exp.rs
1//! Exponential operations for tensors
2//!
3//! Provides element-wise exponential function following PyTorch conventions with
4//! comprehensive automatic differentiation support and optimized scalar computation.
5//!
6//! # Key Features
7//!
8//! - **Element-wise Exponential**: `exp()` - Computes e^x for each element (PyTorch `exp()` equivalent)
9//! - **Automatic Differentiation**: Full gradtrack support with efficient gradient computation
10//! - **Optimized Scalar Math**: Uses optimized scalar exponential for accuracy and simplicity
11//! - **Cache Optimization**: Memory access patterns optimized for modern CPUs
12//! - **Zero-copy Operations**: Efficient memory usage where possible
13//! - **Mathematical Accuracy**: High-precision exponential computation
14//!
15//! # Mathematical Properties
16//!
17//! The exponential function e^x has the following properties:
18//! - e^0 = 1
19//! - e^1 ≈ 2.71828 (Euler's number)
20//! - e^(-x) = 1/e^x
21//! - e^(x+y) = e^x * e^y
22//! - Gradient: d/dx(e^x) = e^x
23//!
24//! # Performance Characteristics
25//!
26//! - **Scalar Optimization**: Optimized scalar exponential computation
27//! - **Unrolled Loops**: 4x unrolling for optimal instruction throughput
28//! - **Cache-friendly Access**: Linear memory access patterns
29//! - **Mathematical Accuracy**: High-precision floating-point exponential
30//! - **Gradient Optimization**: Efficient gradtrack with NoGradTrack support
31
32use crate::gradtrack::{is_grad_enabled, GradEngine, GradFn};
33use crate::tensor::core::Tensor;
34
35// Note: exp uses scalar math for accuracy and simplicity; SIMD width-load path was removed due to alignment strictness
36
37impl Tensor {
38 /// Element-wise exponential function.
39 ///
40 /// Computes e^x for each element: `output[i] = e^(self[i])`
41 ///
42 /// # Returns
43 /// A new tensor with the exponential of each element
44 ///
45 /// # Examples
46 ///
47 /// ## Basic Exponential
48 ///
49 /// ```
50 /// use train_station::Tensor;
51 ///
52 /// let a = Tensor::from_slice(&[0.0, 1.0, 2.0], vec![3]).unwrap();
53 /// let b = a.exp();
54 /// assert_eq!(b.shape().dims, vec![3]);
55 /// assert_eq!(b.get(&[0]), 1.0); // e^0 = 1
56 /// assert!((b.get(&[1]) - 2.71828).abs() < 1e-5); // e^1 ≈ 2.71828
57 /// assert!((b.get(&[2]) - 7.38906).abs() < 1e-5); // e^2 ≈ 7.38906
58 /// ```
59 ///
60 /// ## Negative Values
61 ///
62 /// ```
63 /// use train_station::Tensor;
64 ///
65 /// let a = Tensor::from_slice(&[-1.0, 0.0, 1.0], vec![3]).unwrap();
66 /// let b = a.exp();
67 /// assert_eq!(b.shape().dims, vec![3]);
68 /// assert!((b.get(&[0]) - 0.36788).abs() < 1e-5); // e^(-1) ≈ 0.36788
69 /// assert_eq!(b.get(&[1]), 1.0); // e^0 = 1
70 /// assert!((b.get(&[2]) - 2.71828).abs() < 1e-5); // e^1 ≈ 2.71828
71 /// ```
72 #[inline]
73 pub fn exp(&self) -> Tensor {
74 let mut result = self.exp_optimized();
75 if self.requires_grad() && is_grad_enabled() {
76 result.set_requires_grad_internal(true);
77 let grad_fn = GradFn::Exp {
78 saved_output: Box::new(result.clone()),
79 };
80 result.set_grad_fn(grad_fn.clone());
81 GradEngine::register_operation(result.id(), vec![self.id()], grad_fn);
82 }
83 result
84 }
85 /// Internal optimized exponential operation
86 ///
87 /// Performs element-wise exponential computation using optimized scalar math
88 /// for maximum accuracy and performance. This is the core implementation
89 /// used by `exp()`.
90 ///
91 /// # Returns
92 ///
93 /// A new tensor containing the exponential of each element
94 ///
95 /// # Performance Characteristics
96 ///
97 /// - **Scalar Optimization**: Uses optimized scalar exponential for accuracy
98 /// - **Unrolled Loops**: 4x unrolling for optimal instruction throughput
99 /// - **Cache-friendly**: Linear memory access patterns
100 /// - **Mathematical Accuracy**: High-precision floating-point exponential
101 /// - **Zero-sized Handling**: Fast return for empty tensors
102 ///
103 /// # Implementation Details
104 ///
105 /// Uses scalar exponential computation for maximum mathematical accuracy.
106 /// SIMD optimization was removed due to alignment requirements and the
107 /// need for high-precision mathematical operations.
108 #[inline]
109 pub(crate) fn exp_optimized(&self) -> Tensor {
110 let mut output = Tensor::new(self.shape().dims.clone());
111
112 // Fast return for zero-sized tensors
113 if self.size() == 0 {
114 return output;
115 }
116
117 unsafe {
118 let src = self.as_ptr();
119 let dst = output.as_mut_ptr();
120 self.exp_scalar_fallback(src, dst);
121 }
122
123 output
124 }
125
126 /// Optimized scalar exponential fallback
127 ///
128 /// Performs element-wise exponential using optimized scalar operations with
129 /// 4x unrolling for better instruction-level parallelism and cache efficiency.
130 ///
131 /// # Arguments
132 ///
133 /// * `src` - Pointer to source tensor data
134 /// * `dst` - Pointer to output tensor data
135 ///
136 /// # Safety
137 ///
138 /// Requires valid pointers with sufficient memory for the tensor size.
139 /// All pointers must point to valid tensor data.
140 ///
141 /// # Performance Characteristics
142 ///
143 /// - **Unrolling**: 4x unrolling for instruction-level parallelism
144 /// - **Memory Access**: Linear access patterns for cache efficiency
145 /// - **Fallback**: Handles remaining elements with scalar operations
146 /// - **Mathematical Accuracy**: Uses high-precision scalar exponential
147 #[inline]
148 unsafe fn exp_scalar_fallback(&self, src: *const f32, dst: *mut f32) {
149 let size = self.size();
150 let unroll = 4;
151 let mut offset = 0;
152 let unroll_count = size / unroll;
153 for _ in 0..unroll_count {
154 *dst.add(offset) = (*src.add(offset)).exp();
155 *dst.add(offset + 1) = (*src.add(offset + 1)).exp();
156 *dst.add(offset + 2) = (*src.add(offset + 2)).exp();
157 *dst.add(offset + 3) = (*src.add(offset + 3)).exp();
158 offset += unroll;
159 }
160 for i in offset..size {
161 *dst.add(i) = (*src.add(i)).exp();
162 }
163 }
164}
165
166#[cfg(test)]
167mod tests {
168 use super::*;
169
170 #[test]
171 fn test_exp_basic() {
172 let data = [0.0, 1.0, -1.0, 2.0];
173 let x = Tensor::from_slice(&data, vec![2, 2]).unwrap();
174 let y = x.exp_optimized();
175 unsafe {
176 let yd = std::slice::from_raw_parts(y.as_ptr(), y.size());
177 let xd = std::slice::from_raw_parts(x.as_ptr(), x.size());
178 for i in 0..y.size() {
179 let expected = xd[i].exp();
180 assert!((yd[i] - expected).abs() < 1e-6);
181 }
182 }
183 }
184}