train_station/tensor/ops/tanh.rs
1//! Hyperbolic tangent activation function
2//!
3//! Provides element-wise hyperbolic tangent activation following PyTorch conventions with
4//! comprehensive GradTrack support and high-precision scalar computation.
5//!
6//! # Key Features
7//!
8//! - **Hyperbolic Tangent**: `tanh()` - Element-wise hyperbolic tangent activation
9//! - **GradTrack Support**: Full automatic differentiation with efficient gradient computation
10//! - **High Precision**: Accurate scalar implementation for mathematical validation
11//! - **Performance Optimization**: 4x unrolled scalar operations for better throughput
12//! - **Numerical Stability**: Robust implementation for extreme input values
13//!
14//! # Mathematical Properties
15//!
16//! The hyperbolic tangent function has the following properties:
17//! - **Range**: Output values are in the range (-1, 1)
18//! - **Symmetry**: tanh(-x) = -tanh(x) (odd function)
19//! - **Asymptotes**: Approaches ±1 as x approaches ±∞
20//! - **Zero**: tanh(0) = 0
21//! - **Gradient**: ∂tanh(x)/∂x = 1 - tanh²(x) = sech²(x)
22//! - **Monotonic**: Strictly increasing function
23//!
24//! # Performance Characteristics
25//!
26//! - **Scalar Implementation**: High-precision scalar computation for mathematical accuracy
27//! - **4x Unrolling**: Optimized scalar operations with instruction-level parallelism
28//! - **Cache-friendly**: Linear memory access patterns
29//! - **Numerical Stability**: Robust handling of extreme input values
30//! - **GradTrack Optimization**: Efficient automatic differentiation with gradient computation
31
32use crate::gradtrack::{is_grad_enabled, GradEngine, GradFn};
33use crate::tensor::core::Tensor;
34
35impl Tensor {
36 /// Element-wise hyperbolic tangent activation
37 ///
38 /// Computes hyperbolic tangent for each element: `output[i] = tanh(self[i])`
39 ///
40 /// The hyperbolic tangent function maps any real number to the range (-1, 1),
41 /// making it useful as an activation function in neural networks.
42 ///
43 /// # Returns
44 ///
45 /// A new tensor with tanh applied to each element, values in range (-1, 1)
46 ///
47 /// # Performance Characteristics
48 ///
49 /// - **High Precision**: Accurate scalar implementation for mathematical validation
50 /// - **4x Unrolling**: Optimized scalar operations with instruction-level parallelism
51 /// - **Cache-friendly**: Linear memory access patterns
52 /// - **Numerical Stability**: Robust handling of extreme input values
53 /// - **GradTrack Support**: Full automatic differentiation with efficient gradient computation
54 ///
55 /// # Mathematical Properties
56 ///
57 /// - **Range**: Output values are in the range (-1, 1)
58 /// - **Symmetry**: tanh(-x) = -tanh(x) (odd function)
59 /// - **Zero**: tanh(0) = 0
60 /// - **Gradient**: ∂tanh(x)/∂x = 1 - tanh²(x) = sech²(x)
61 ///
62 /// # Examples
63 ///
64 /// ## Basic Hyperbolic Tangent
65 ///
66 /// ```
67 /// use train_station::Tensor;
68 ///
69 /// let a = Tensor::from_slice(&[-1.0, 0.0, 1.0], vec![3]).unwrap();
70 /// let b = a.tanh();
71 /// assert_eq!(b.shape().dims, vec![3]);
72 /// assert!((b.get(&[0]) - (-0.7615942)).abs() < 1e-6); // tanh(-1.0)
73 /// assert!((b.get(&[1]) - 0.0).abs() < 1e-6); // tanh(0.0)
74 /// assert!((b.get(&[2]) - 0.7615942).abs() < 1e-6); // tanh(1.0)
75 /// ```
76 ///
77 /// ## Extreme Values
78 ///
79 /// ```
80 /// use train_station::Tensor;
81 ///
82 /// let a = Tensor::from_slice(&[-10.0, 10.0], vec![2]).unwrap();
83 /// let b = a.tanh();
84 /// assert_eq!(b.shape().dims, vec![2]);
85 /// assert!((b.get(&[0]) - (-1.0)).abs() < 1e-6); // tanh(-10.0) ≈ -1
86 /// assert!((b.get(&[1]) - 1.0).abs() < 1e-6); // tanh(10.0) ≈ 1
87 /// ```
88 #[track_caller]
89 pub fn tanh(&self) -> Tensor {
90 let mut out = self.tanh_optimized();
91
92 if self.requires_grad() && is_grad_enabled() {
93 out.set_requires_grad_internal(true);
94 let grad_fn = GradFn::Tanh {
95 saved_output: Box::new(out.clone()),
96 };
97 out.set_grad_fn(grad_fn.clone());
98 GradEngine::register_operation(out.id(), vec![self.id()], grad_fn);
99 }
100
101 out
102 }
103
104 /// Internal optimized tanh operation
105 ///
106 /// Performs element-wise hyperbolic tangent computation using high-precision
107 /// scalar implementation for mathematical accuracy and validation.
108 ///
109 /// # Returns
110 ///
111 /// A new tensor with tanh applied to each element
112 ///
113 /// # Performance Characteristics
114 ///
115 /// - **High Precision**: Accurate scalar implementation for mathematical validation
116 /// - **4x Unrolling**: Optimized scalar operations with instruction-level parallelism
117 /// - **Cache-friendly**: Linear memory access patterns
118 /// - **Zero-sized Handling**: Fast return for empty tensors
119 /// - **Numerical Stability**: Robust handling of extreme input values
120 ///
121 /// # Implementation Details
122 ///
123 /// Uses high-precision scalar implementation rather than SIMD approximations
124 /// to ensure mathematical accuracy for validation against reference implementations.
125 /// Implements 4x unrolling for better instruction-level parallelism and cache utilization.
126 #[inline]
127 pub(crate) fn tanh_optimized(&self) -> Tensor {
128 let mut output = Tensor::new(self.shape().dims.clone());
129
130 if self.size() == 0 {
131 return output;
132 }
133
134 unsafe {
135 let src = self.as_ptr();
136 let dst = output.as_mut_ptr();
137
138 // Use scalar implementation for accuracy
139 // SIMD approximations for tanh introduce too much error for validation
140 self.tanh_scalar_optimized(src, dst);
141 }
142
143 output
144 }
145
146 /// Optimized scalar hyperbolic tangent implementation
147 ///
148 /// Performs element-wise hyperbolic tangent computation using optimized scalar
149 /// operations with 4x unrolling for better instruction-level parallelism.
150 ///
151 /// # Arguments
152 ///
153 /// * `src` - Pointer to source tensor data
154 /// * `dst` - Pointer to output tensor data
155 ///
156 /// # Safety
157 ///
158 /// Requires valid pointers with sufficient memory for the tensor size.
159 /// All pointers must point to valid tensor data.
160 ///
161 /// # Performance Characteristics
162 ///
163 /// - **High Precision**: Accurate scalar implementation for mathematical validation
164 /// - **4x Unrolling**: Optimized scalar operations with instruction-level parallelism
165 /// - **Memory Access**: Linear access patterns for cache efficiency
166 /// - **Fallback**: Handles remaining elements with scalar operations
167 /// - **Mathematical Accuracy**: High-precision hyperbolic tangent computation
168 ///
169 /// # Implementation Details
170 ///
171 /// Uses 4x unrolled scalar operations for optimal performance while maintaining
172 /// high mathematical accuracy. Processes elements in groups of 4 to improve
173 /// instruction-level parallelism and reduce loop overhead.
174 #[inline]
175 unsafe fn tanh_scalar_optimized(&self, src: *const f32, dst: *mut f32) {
176 let size = self.size();
177 let unroll_count = size / 4;
178 let mut offset = 0;
179
180 // Unrolled scalar loop for better performance
181 for _ in 0..unroll_count {
182 *dst.add(offset) = (*src.add(offset)).tanh();
183 *dst.add(offset + 1) = (*src.add(offset + 1)).tanh();
184 *dst.add(offset + 2) = (*src.add(offset + 2)).tanh();
185 *dst.add(offset + 3) = (*src.add(offset + 3)).tanh();
186 offset += 4;
187 }
188
189 // Handle remaining elements
190 for i in offset..size {
191 *dst.add(i) = (*src.add(i)).tanh();
192 }
193 }
194}
195
196#[cfg(test)]
197mod tests {
198 use super::*;
199
200 #[test]
201 fn test_tanh_forward_basic() {
202 let x = Tensor::from_slice(&[-1.0, 0.0, 1.0], vec![3]).unwrap();
203 let y = x.tanh();
204 unsafe {
205 assert!((*y.as_ptr() + 0.7615942).abs() < 1e-6);
206 assert!((*y.as_ptr().add(1) - 0.0).abs() < 1e-6);
207 assert!((*y.as_ptr().add(2) - 0.7615942).abs() < 1e-6);
208 }
209 }
210}