train_station/tensor/ops/
exp.rs

1//! Exponential operations for tensors
2//!
3//! Provides element-wise exponential function following PyTorch conventions with
4//! comprehensive automatic differentiation support and optimized scalar computation.
5//!
6//! # Key Features
7//!
8//! - **Element-wise Exponential**: `exp()` - Computes e^x for each element (PyTorch `exp()` equivalent)
9//! - **Automatic Differentiation**: Full gradtrack support with efficient gradient computation
10//! - **Optimized Scalar Math**: Uses optimized scalar exponential for accuracy and simplicity
11//! - **Cache Optimization**: Memory access patterns optimized for modern CPUs
12//! - **Zero-copy Operations**: Efficient memory usage where possible
13//! - **Mathematical Accuracy**: High-precision exponential computation
14//!
15//! # Mathematical Properties
16//!
17//! The exponential function e^x has the following properties:
18//! - e^0 = 1
19//! - e^1 ≈ 2.71828 (Euler's number)
20//! - e^(-x) = 1/e^x
21//! - e^(x+y) = e^x * e^y
22//! - Gradient: d/dx(e^x) = e^x
23//!
24//! # Performance Characteristics
25//!
26//! - **Scalar Optimization**: Optimized scalar exponential computation
27//! - **Unrolled Loops**: 4x unrolling for optimal instruction throughput
28//! - **Cache-friendly Access**: Linear memory access patterns
29//! - **Mathematical Accuracy**: High-precision floating-point exponential
30//! - **Gradient Optimization**: Efficient gradtrack with NoGradTrack support
31
32use crate::gradtrack::{is_grad_enabled, GradEngine, GradFn};
33use crate::tensor::core::Tensor;
34
35// Note: exp uses scalar math for accuracy and simplicity; SIMD width-load path was removed due to alignment strictness
36
37impl Tensor {
38    /// Element-wise exponential function.
39    ///
40    /// Computes e^x for each element: `output[i] = e^(self[i])`
41    ///
42    /// # Returns
43    /// A new tensor with the exponential of each element
44    ///
45    /// # Examples
46    ///
47    /// ## Basic Exponential
48    ///
49    /// ```
50    /// use train_station::Tensor;
51    ///
52    /// let a = Tensor::from_slice(&[0.0, 1.0, 2.0], vec![3]).unwrap();
53    /// let b = a.exp();
54    /// assert_eq!(b.shape().dims, vec![3]);
55    /// assert_eq!(b.get(&[0]), 1.0); // e^0 = 1
56    /// assert!((b.get(&[1]) - 2.71828).abs() < 1e-5); // e^1 ≈ 2.71828
57    /// assert!((b.get(&[2]) - 7.38906).abs() < 1e-5); // e^2 ≈ 7.38906
58    /// ```
59    ///
60    /// ## Negative Values
61    ///
62    /// ```
63    /// use train_station::Tensor;
64    ///
65    /// let a = Tensor::from_slice(&[-1.0, 0.0, 1.0], vec![3]).unwrap();
66    /// let b = a.exp();
67    /// assert_eq!(b.shape().dims, vec![3]);
68    /// assert!((b.get(&[0]) - 0.36788).abs() < 1e-5); // e^(-1) ≈ 0.36788
69    /// assert_eq!(b.get(&[1]), 1.0); // e^0 = 1
70    /// assert!((b.get(&[2]) - 2.71828).abs() < 1e-5); // e^1 ≈ 2.71828
71    /// ```
72    #[inline]
73    pub fn exp(&self) -> Tensor {
74        let mut result = self.exp_optimized();
75        if self.requires_grad() && is_grad_enabled() {
76            result.set_requires_grad_internal(true);
77            let grad_fn = GradFn::Exp {
78                saved_output: Box::new(result.clone()),
79            };
80            result.set_grad_fn(grad_fn.clone());
81            GradEngine::register_operation(result.id(), vec![self.id()], grad_fn);
82        }
83        result
84    }
85    /// Internal optimized exponential operation
86    ///
87    /// Performs element-wise exponential computation using optimized scalar math
88    /// for maximum accuracy and performance. This is the core implementation
89    /// used by `exp()`.
90    ///
91    /// # Returns
92    ///
93    /// A new tensor containing the exponential of each element
94    ///
95    /// # Performance Characteristics
96    ///
97    /// - **Scalar Optimization**: Uses optimized scalar exponential for accuracy
98    /// - **Unrolled Loops**: 4x unrolling for optimal instruction throughput
99    /// - **Cache-friendly**: Linear memory access patterns
100    /// - **Mathematical Accuracy**: High-precision floating-point exponential
101    /// - **Zero-sized Handling**: Fast return for empty tensors
102    ///
103    /// # Implementation Details
104    ///
105    /// Uses scalar exponential computation for maximum mathematical accuracy.
106    /// SIMD optimization was removed due to alignment requirements and the
107    /// need for high-precision mathematical operations.
108    #[inline]
109    pub(crate) fn exp_optimized(&self) -> Tensor {
110        let mut output = Tensor::new(self.shape().dims.clone());
111
112        // Fast return for zero-sized tensors
113        if self.size() == 0 {
114            return output;
115        }
116
117        unsafe {
118            let src = self.as_ptr();
119            let dst = output.as_mut_ptr();
120            self.exp_scalar_fallback(src, dst);
121        }
122
123        output
124    }
125
126    /// Optimized scalar exponential fallback
127    ///
128    /// Performs element-wise exponential using optimized scalar operations with
129    /// 4x unrolling for better instruction-level parallelism and cache efficiency.
130    ///
131    /// # Arguments
132    ///
133    /// * `src` - Pointer to source tensor data
134    /// * `dst` - Pointer to output tensor data
135    ///
136    /// # Safety
137    ///
138    /// Requires valid pointers with sufficient memory for the tensor size.
139    /// All pointers must point to valid tensor data.
140    ///
141    /// # Performance Characteristics
142    ///
143    /// - **Unrolling**: 4x unrolling for instruction-level parallelism
144    /// - **Memory Access**: Linear access patterns for cache efficiency
145    /// - **Fallback**: Handles remaining elements with scalar operations
146    /// - **Mathematical Accuracy**: Uses high-precision scalar exponential
147    #[inline]
148    unsafe fn exp_scalar_fallback(&self, src: *const f32, dst: *mut f32) {
149        let size = self.size();
150        let unroll = 4;
151        let mut offset = 0;
152        let unroll_count = size / unroll;
153        for _ in 0..unroll_count {
154            *dst.add(offset) = (*src.add(offset)).exp();
155            *dst.add(offset + 1) = (*src.add(offset + 1)).exp();
156            *dst.add(offset + 2) = (*src.add(offset + 2)).exp();
157            *dst.add(offset + 3) = (*src.add(offset + 3)).exp();
158            offset += unroll;
159        }
160        for i in offset..size {
161            *dst.add(i) = (*src.add(i)).exp();
162        }
163    }
164}
165
166#[cfg(test)]
167mod tests {
168    use super::*;
169
170    #[test]
171    fn test_exp_basic() {
172        let data = [0.0, 1.0, -1.0, 2.0];
173        let x = Tensor::from_slice(&data, vec![2, 2]).unwrap();
174        let y = x.exp_optimized();
175        unsafe {
176            let yd = std::slice::from_raw_parts(y.as_ptr(), y.size());
177            let xd = std::slice::from_raw_parts(x.as_ptr(), x.size());
178            for i in 0..y.size() {
179                let expected = xd[i].exp();
180                assert!((yd[i] - expected).abs() < 1e-6);
181            }
182        }
183    }
184}