train_station/tensor/ops/
tanh.rs

1//! Hyperbolic tangent activation function
2//!
3//! Provides element-wise hyperbolic tangent activation following PyTorch conventions with
4//! comprehensive GradTrack support and high-precision scalar computation.
5//!
6//! # Key Features
7//!
8//! - **Hyperbolic Tangent**: `tanh()` - Element-wise hyperbolic tangent activation
9//! - **GradTrack Support**: Full automatic differentiation with efficient gradient computation
10//! - **High Precision**: Accurate scalar implementation for mathematical validation
11//! - **Performance Optimization**: 4x unrolled scalar operations for better throughput
12//! - **Numerical Stability**: Robust implementation for extreme input values
13//!
14//! # Mathematical Properties
15//!
16//! The hyperbolic tangent function has the following properties:
17//! - **Range**: Output values are in the range (-1, 1)
18//! - **Symmetry**: tanh(-x) = -tanh(x) (odd function)
19//! - **Asymptotes**: Approaches ±1 as x approaches ±∞
20//! - **Zero**: tanh(0) = 0
21//! - **Gradient**: ∂tanh(x)/∂x = 1 - tanh²(x) = sech²(x)
22//! - **Monotonic**: Strictly increasing function
23//!
24//! # Performance Characteristics
25//!
26//! - **Scalar Implementation**: High-precision scalar computation for mathematical accuracy
27//! - **4x Unrolling**: Optimized scalar operations with instruction-level parallelism
28//! - **Cache-friendly**: Linear memory access patterns
29//! - **Numerical Stability**: Robust handling of extreme input values
30//! - **GradTrack Optimization**: Efficient automatic differentiation with gradient computation
31
32use crate::gradtrack::{is_grad_enabled, GradEngine, GradFn};
33use crate::tensor::core::Tensor;
34
35impl Tensor {
36    /// Element-wise hyperbolic tangent activation
37    ///
38    /// Computes hyperbolic tangent for each element: `output[i] = tanh(self[i])`
39    ///
40    /// The hyperbolic tangent function maps any real number to the range (-1, 1),
41    /// making it useful as an activation function in neural networks.
42    ///
43    /// # Returns
44    ///
45    /// A new tensor with tanh applied to each element, values in range (-1, 1)
46    ///
47    /// # Performance Characteristics
48    ///
49    /// - **High Precision**: Accurate scalar implementation for mathematical validation
50    /// - **4x Unrolling**: Optimized scalar operations with instruction-level parallelism
51    /// - **Cache-friendly**: Linear memory access patterns
52    /// - **Numerical Stability**: Robust handling of extreme input values
53    /// - **GradTrack Support**: Full automatic differentiation with efficient gradient computation
54    ///
55    /// # Mathematical Properties
56    ///
57    /// - **Range**: Output values are in the range (-1, 1)
58    /// - **Symmetry**: tanh(-x) = -tanh(x) (odd function)
59    /// - **Zero**: tanh(0) = 0
60    /// - **Gradient**: ∂tanh(x)/∂x = 1 - tanh²(x) = sech²(x)
61    ///
62    /// # Examples
63    ///
64    /// ## Basic Hyperbolic Tangent
65    ///
66    /// ```
67    /// use train_station::Tensor;
68    ///
69    /// let a = Tensor::from_slice(&[-1.0, 0.0, 1.0], vec![3]).unwrap();
70    /// let b = a.tanh();
71    /// assert_eq!(b.shape().dims, vec![3]);
72    /// assert!((b.get(&[0]) - (-0.7615942)).abs() < 1e-6); // tanh(-1.0)
73    /// assert!((b.get(&[1]) - 0.0).abs() < 1e-6); // tanh(0.0)
74    /// assert!((b.get(&[2]) - 0.7615942).abs() < 1e-6); // tanh(1.0)
75    /// ```
76    ///
77    /// ## Extreme Values
78    ///
79    /// ```
80    /// use train_station::Tensor;
81    ///
82    /// let a = Tensor::from_slice(&[-10.0, 10.0], vec![2]).unwrap();
83    /// let b = a.tanh();
84    /// assert_eq!(b.shape().dims, vec![2]);
85    /// assert!((b.get(&[0]) - (-1.0)).abs() < 1e-6); // tanh(-10.0) ≈ -1
86    /// assert!((b.get(&[1]) - 1.0).abs() < 1e-6); // tanh(10.0) ≈ 1
87    /// ```
88    #[track_caller]
89    pub fn tanh(&self) -> Tensor {
90        let mut out = self.tanh_optimized();
91
92        if self.requires_grad() && is_grad_enabled() {
93            out.set_requires_grad_internal(true);
94            let grad_fn = GradFn::Tanh {
95                saved_output: Box::new(out.clone()),
96            };
97            out.set_grad_fn(grad_fn.clone());
98            GradEngine::register_operation(out.id(), vec![self.id()], grad_fn);
99        }
100
101        out
102    }
103
104    /// Internal optimized tanh operation
105    ///
106    /// Performs element-wise hyperbolic tangent computation using high-precision
107    /// scalar implementation for mathematical accuracy and validation.
108    ///
109    /// # Returns
110    ///
111    /// A new tensor with tanh applied to each element
112    ///
113    /// # Performance Characteristics
114    ///
115    /// - **High Precision**: Accurate scalar implementation for mathematical validation
116    /// - **4x Unrolling**: Optimized scalar operations with instruction-level parallelism
117    /// - **Cache-friendly**: Linear memory access patterns
118    /// - **Zero-sized Handling**: Fast return for empty tensors
119    /// - **Numerical Stability**: Robust handling of extreme input values
120    ///
121    /// # Implementation Details
122    ///
123    /// Uses high-precision scalar implementation rather than SIMD approximations
124    /// to ensure mathematical accuracy for validation against reference implementations.
125    /// Implements 4x unrolling for better instruction-level parallelism and cache utilization.
126    #[inline]
127    pub(crate) fn tanh_optimized(&self) -> Tensor {
128        let mut output = Tensor::new(self.shape().dims.clone());
129
130        if self.size() == 0 {
131            return output;
132        }
133
134        unsafe {
135            let src = self.as_ptr();
136            let dst = output.as_mut_ptr();
137
138            // Use scalar implementation for accuracy
139            // SIMD approximations for tanh introduce too much error for validation
140            self.tanh_scalar_optimized(src, dst);
141        }
142
143        output
144    }
145
146    /// Optimized scalar hyperbolic tangent implementation
147    ///
148    /// Performs element-wise hyperbolic tangent computation using optimized scalar
149    /// operations with 4x unrolling for better instruction-level parallelism.
150    ///
151    /// # Arguments
152    ///
153    /// * `src` - Pointer to source tensor data
154    /// * `dst` - Pointer to output tensor data
155    ///
156    /// # Safety
157    ///
158    /// Requires valid pointers with sufficient memory for the tensor size.
159    /// All pointers must point to valid tensor data.
160    ///
161    /// # Performance Characteristics
162    ///
163    /// - **High Precision**: Accurate scalar implementation for mathematical validation
164    /// - **4x Unrolling**: Optimized scalar operations with instruction-level parallelism
165    /// - **Memory Access**: Linear access patterns for cache efficiency
166    /// - **Fallback**: Handles remaining elements with scalar operations
167    /// - **Mathematical Accuracy**: High-precision hyperbolic tangent computation
168    ///
169    /// # Implementation Details
170    ///
171    /// Uses 4x unrolled scalar operations for optimal performance while maintaining
172    /// high mathematical accuracy. Processes elements in groups of 4 to improve
173    /// instruction-level parallelism and reduce loop overhead.
174    #[inline]
175    unsafe fn tanh_scalar_optimized(&self, src: *const f32, dst: *mut f32) {
176        let size = self.size();
177        let unroll_count = size / 4;
178        let mut offset = 0;
179
180        // Unrolled scalar loop for better performance
181        for _ in 0..unroll_count {
182            *dst.add(offset) = (*src.add(offset)).tanh();
183            *dst.add(offset + 1) = (*src.add(offset + 1)).tanh();
184            *dst.add(offset + 2) = (*src.add(offset + 2)).tanh();
185            *dst.add(offset + 3) = (*src.add(offset + 3)).tanh();
186            offset += 4;
187        }
188
189        // Handle remaining elements
190        for i in offset..size {
191            *dst.add(i) = (*src.add(i)).tanh();
192        }
193    }
194}
195
196#[cfg(test)]
197mod tests {
198    use super::*;
199
200    #[test]
201    fn test_tanh_forward_basic() {
202        let x = Tensor::from_slice(&[-1.0, 0.0, 1.0], vec![3]).unwrap();
203        let y = x.tanh();
204        unsafe {
205            assert!((*y.as_ptr() + 0.7615942).abs() < 1e-6);
206            assert!((*y.as_ptr().add(1) - 0.0).abs() < 1e-6);
207            assert!((*y.as_ptr().add(2) - 0.7615942).abs() < 1e-6);
208        }
209    }
210}