train_station/tensor/ops/
tanh.rs

1//! Hyperbolic tangent activation function
2//!
3//! Provides element-wise hyperbolic tangent activation following PyTorch conventions with
4//! comprehensive GradTrack support and high-precision scalar computation.
5//!
6//! # Key Features
7//!
8//! - **Hyperbolic Tangent**: `tanh()` - Element-wise hyperbolic tangent activation
9//! - **GradTrack Support**: Full automatic differentiation with efficient gradient computation
10//! - **High Precision**: Accurate scalar implementation for mathematical validation
11//! - **Performance Optimization**: 4x unrolled scalar operations for better throughput
12//! - **Numerical Stability**: Robust implementation for extreme input values
13//!
14//! # Mathematical Properties
15//!
16//! The hyperbolic tangent function has the following properties:
17//! - **Range**: Output values are in the range (-1, 1)
18//! - **Symmetry**: tanh(-x) = -tanh(x) (odd function)
19//! - **Asymptotes**: Approaches ±1 as x approaches ±∞
20//! - **Zero**: tanh(0) = 0
21//! - **Gradient**: ∂tanh(x)/∂x = 1 - tanh²(x) = sech²(x)
22//! - **Monotonic**: Strictly increasing function
23//!
24//! # Performance Characteristics
25//!
26//! - **Scalar Implementation**: High-precision scalar computation for mathematical accuracy
27//! - **4x Unrolling**: Optimized scalar operations with instruction-level parallelism
28//! - **Cache-friendly**: Linear memory access patterns
29//! - **Numerical Stability**: Robust handling of extreme input values
30//! - **GradTrack Optimization**: Efficient automatic differentiation with gradient computation
31
32use crate::gradtrack::{is_grad_enabled, GradEngine, GradFn};
33use crate::tensor::core::Tensor;
34
35impl Tensor {
36    /// Element-wise hyperbolic tangent activation
37    ///
38    /// Computes hyperbolic tangent for each element: `output[i] = tanh(self[i])`
39    ///
40    /// The hyperbolic tangent function maps any real number to the range (-1, 1),
41    /// making it useful as an activation function in neural networks.
42    ///
43    /// # Returns
44    ///
45    /// A new tensor with tanh applied to each element, values in range (-1, 1)
46    ///
47    /// # Performance Characteristics
48    ///
49    /// - **High Precision**: Accurate scalar implementation for mathematical validation
50    /// - **4x Unrolling**: Optimized scalar operations with instruction-level parallelism
51    /// - **Cache-friendly**: Linear memory access patterns
52    /// - **Numerical Stability**: Robust handling of extreme input values
53    /// - **GradTrack Support**: Full automatic differentiation with efficient gradient computation
54    ///
55    /// # Mathematical Properties
56    ///
57    /// - **Range**: Output values are in the range (-1, 1)
58    /// - **Symmetry**: tanh(-x) = -tanh(x) (odd function)
59    /// - **Zero**: tanh(0) = 0
60    /// - **Gradient**: ∂tanh(x)/∂x = 1 - tanh²(x) = sech²(x)
61    ///
62    /// # Examples
63    ///
64    /// ## Basic Hyperbolic Tangent
65    ///
66    /// ```
67    /// use train_station::Tensor;
68    ///
69    /// let a = Tensor::from_slice(&[-1.0, 0.0, 1.0], vec![3]).unwrap();
70    /// let b = a.tanh();
71    /// assert_eq!(b.shape().dims, vec![3]);
72    /// assert!((b.get(&[0]) - (-0.7615942)).abs() < 1e-6); // tanh(-1.0)
73    /// assert!((b.get(&[1]) - 0.0).abs() < 1e-6); // tanh(0.0)
74    /// assert!((b.get(&[2]) - 0.7615942).abs() < 1e-6); // tanh(1.0)
75    /// ```
76    ///
77    /// ## Extreme Values
78    ///
79    /// ```
80    /// use train_station::Tensor;
81    ///
82    /// let a = Tensor::from_slice(&[-10.0, 10.0], vec![2]).unwrap();
83    /// let b = a.tanh();
84    /// assert_eq!(b.shape().dims, vec![2]);
85    /// assert!((b.get(&[0]) - (-1.0)).abs() < 1e-6); // tanh(-10.0) ≈ -1
86    /// assert!((b.get(&[1]) - 1.0).abs() < 1e-6); // tanh(10.0) ≈ 1
87    /// ```
88    pub fn tanh(&self) -> Tensor {
89        let mut out = self.tanh_optimized();
90
91        if self.requires_grad() && is_grad_enabled() {
92            out.set_requires_grad_internal(true);
93            let grad_fn = GradFn::Tanh {
94                saved_output: Box::new(out.clone()),
95            };
96            out.set_grad_fn(grad_fn.clone());
97            GradEngine::register_operation(out.id(), vec![self.id()], grad_fn);
98        }
99
100        out
101    }
102
103    /// Internal optimized tanh operation
104    ///
105    /// Performs element-wise hyperbolic tangent computation using high-precision
106    /// scalar implementation for mathematical accuracy and validation.
107    ///
108    /// # Returns
109    ///
110    /// A new tensor with tanh applied to each element
111    ///
112    /// # Performance Characteristics
113    ///
114    /// - **High Precision**: Accurate scalar implementation for mathematical validation
115    /// - **4x Unrolling**: Optimized scalar operations with instruction-level parallelism
116    /// - **Cache-friendly**: Linear memory access patterns
117    /// - **Zero-sized Handling**: Fast return for empty tensors
118    /// - **Numerical Stability**: Robust handling of extreme input values
119    ///
120    /// # Implementation Details
121    ///
122    /// Uses high-precision scalar implementation rather than SIMD approximations
123    /// to ensure mathematical accuracy for validation against reference implementations.
124    /// Implements 4x unrolling for better instruction-level parallelism and cache utilization.
125    #[inline]
126    pub(crate) fn tanh_optimized(&self) -> Tensor {
127        let mut output = Tensor::new(self.shape().dims.clone());
128
129        if self.size() == 0 {
130            return output;
131        }
132
133        unsafe {
134            let src = self.as_ptr();
135            let dst = output.as_mut_ptr();
136
137            // Use scalar implementation for accuracy
138            // SIMD approximations for tanh introduce too much error for validation
139            self.tanh_scalar_optimized(src, dst);
140        }
141
142        output
143    }
144
145    /// Optimized scalar hyperbolic tangent implementation
146    ///
147    /// Performs element-wise hyperbolic tangent computation using optimized scalar
148    /// operations with 4x unrolling for better instruction-level parallelism.
149    ///
150    /// # Arguments
151    ///
152    /// * `src` - Pointer to source tensor data
153    /// * `dst` - Pointer to output tensor data
154    ///
155    /// # Safety
156    ///
157    /// Requires valid pointers with sufficient memory for the tensor size.
158    /// All pointers must point to valid tensor data.
159    ///
160    /// # Performance Characteristics
161    ///
162    /// - **High Precision**: Accurate scalar implementation for mathematical validation
163    /// - **4x Unrolling**: Optimized scalar operations with instruction-level parallelism
164    /// - **Memory Access**: Linear access patterns for cache efficiency
165    /// - **Fallback**: Handles remaining elements with scalar operations
166    /// - **Mathematical Accuracy**: High-precision hyperbolic tangent computation
167    ///
168    /// # Implementation Details
169    ///
170    /// Uses 4x unrolled scalar operations for optimal performance while maintaining
171    /// high mathematical accuracy. Processes elements in groups of 4 to improve
172    /// instruction-level parallelism and reduce loop overhead.
173    #[inline]
174    unsafe fn tanh_scalar_optimized(&self, src: *const f32, dst: *mut f32) {
175        let size = self.size();
176        let unroll_count = size / 4;
177        let mut offset = 0;
178
179        // Unrolled scalar loop for better performance
180        for _ in 0..unroll_count {
181            *dst.add(offset) = (*src.add(offset)).tanh();
182            *dst.add(offset + 1) = (*src.add(offset + 1)).tanh();
183            *dst.add(offset + 2) = (*src.add(offset + 2)).tanh();
184            *dst.add(offset + 3) = (*src.add(offset + 3)).tanh();
185            offset += 4;
186        }
187
188        // Handle remaining elements
189        for i in offset..size {
190            *dst.add(i) = (*src.add(i)).tanh();
191        }
192    }
193}
194
195#[cfg(test)]
196mod tests {
197    use super::*;
198
199    #[test]
200    fn test_tanh_forward_basic() {
201        let x = Tensor::from_slice(&[-1.0, 0.0, 1.0], vec![3]).unwrap();
202        let y = x.tanh();
203        unsafe {
204            assert!((*y.as_ptr() + 0.7615942).abs() < 1e-6);
205            assert!((*y.as_ptr().add(1) - 0.0).abs() < 1e-6);
206            assert!((*y.as_ptr().add(2) - 0.7615942).abs() < 1e-6);
207        }
208    }
209}