train_station/tensor/ops/
relu.rs

1//! ReLU activation function
2//!
3//! Provides the Rectified Linear Unit activation function following PyTorch conventions with
4//! comprehensive automatic differentiation support and SIMD-optimized computation.
5//!
6//! # Key Features
7//!
8//! - **ReLU Activation**: `relu()` - Computes max(0, x) for each element (PyTorch `relu()` equivalent)
9//! - **Automatic Differentiation**: Full gradtrack support with efficient gradient computation
10//! - **SIMD Optimization**: AVX2-optimized implementation for maximum performance
11//! - **Cache Optimization**: Memory access patterns optimized for modern CPUs
12//! - **Mathematical Accuracy**: High-precision activation computation
13//! - **Branch Prediction**: Optimized conditional logic for modern CPUs
14//!
15//! # Mathematical Properties
16//!
17//! The ReLU activation function has the following properties:
18//! - **Definition**: f(x) = max(0, x)
19//! - **Range**: [0, ∞) - outputs are always non-negative
20//! - **Monotonicity**: Strictly increasing for x > 0
21//! - **Continuity**: Continuous everywhere, differentiable everywhere except at x = 0
22//! - **Gradient**: f'(x) = 1 if x > 0, f'(x) = 0 if x ≤ 0
23//! - **Sparsity**: Produces sparse activations (many zeros) for negative inputs
24//!
25//! # Performance Characteristics
26//!
27//! - **SIMD Optimization**: AVX2-optimized with 32-element blocks and 4x unrolling
28//! - **Scalar Fallback**: 4x unrolled scalar implementation for non-SIMD hardware
29//! - **Cache-friendly Access**: Linear memory access patterns
30//! - **Branch Prediction**: Optimized conditional logic for modern CPUs
31//! - **Gradient Optimization**: Efficient gradtrack with NoGradTrack support
32
33use crate::gradtrack::{is_grad_enabled, GradEngine, GradFn};
34use crate::tensor::core::Tensor;
35
36// SIMD optimizations for performance-critical operations
37#[cfg(target_arch = "x86_64")]
38use std::arch::x86_64::*;
39
40impl Tensor {
41    /// Element-wise ReLU (Rectified Linear Unit) activation.
42    ///
43    /// Applies ReLU to each element: `output[i] = max(0, self[i])`
44    ///
45    /// # Returns
46    /// A new tensor with ReLU applied to each element
47    ///
48    /// # Examples
49    ///
50    /// ## Basic ReLU Activation
51    ///
52    /// ```
53    /// use train_station::Tensor;
54    ///
55    /// let a = Tensor::from_slice(&[-1.0, 0.0, 2.5], vec![3]).unwrap();
56    /// let b = a.relu();
57    /// assert_eq!(b.shape().dims, vec![3]);
58    /// assert_eq!(b.get(&[0]), 0.0); // max(0, -1.0) = 0.0
59    /// assert_eq!(b.get(&[1]), 0.0); // max(0, 0.0) = 0.0
60    /// assert_eq!(b.get(&[2]), 2.5); // max(0, 2.5) = 2.5
61    /// ```
62    ///
63    /// ## Mixed Positive and Negative Values
64    ///
65    /// ```
66    /// use train_station::Tensor;
67    ///
68    /// let a = Tensor::from_slice(&[-5.0, -0.1, 0.0, 0.1, 5.0], vec![5]).unwrap();
69    /// let b = a.relu();
70    /// assert_eq!(b.shape().dims, vec![5]);
71    /// assert_eq!(b.get(&[0]), 0.0); // max(0, -5.0) = 0.0
72    /// assert_eq!(b.get(&[1]), 0.0); // max(0, -0.1) = 0.0
73    /// assert_eq!(b.get(&[2]), 0.0); // max(0, 0.0) = 0.0
74    /// assert_eq!(b.get(&[3]), 0.1); // max(0, 0.1) = 0.1
75    /// assert_eq!(b.get(&[4]), 5.0); // max(0, 5.0) = 5.0
76    /// ```
77    pub fn relu(&self) -> Tensor {
78        let mut out = self.relu_optimized();
79
80        if self.requires_grad() && is_grad_enabled() {
81            out.set_requires_grad_internal(true);
82            let grad_fn = GradFn::Relu {
83                saved_input: Box::new(self.clone()),
84            };
85            out.set_grad_fn(grad_fn.clone());
86            GradEngine::register_operation(out.id(), vec![self.id()], grad_fn);
87        }
88
89        out
90    }
91
92    /// Internal optimized ReLU operation
93    ///
94    /// Performs element-wise ReLU activation using SIMD optimization when available
95    /// and falling back to optimized scalar computation. This is the core implementation
96    /// used by `relu()`.
97    ///
98    /// # Returns
99    ///
100    /// A new tensor containing the ReLU activation of each element
101    ///
102    /// # Performance Characteristics
103    ///
104    /// - **SIMD Optimization**: AVX2-optimized with 32-element blocks when available
105    /// - **Scalar Fallback**: 4x unrolled scalar implementation for non-SIMD hardware
106    /// - **Cache-friendly**: Linear memory access patterns
107    /// - **Branch Prediction**: Optimized conditional logic for modern CPUs
108    /// - **Mathematical Accuracy**: High-precision activation computation
109    /// - **Zero-sized Handling**: Fast return for empty tensors
110    ///
111    /// # Implementation Details
112    ///
113    /// Automatically selects between SIMD and scalar implementations based on hardware
114    /// capabilities. SIMD implementation uses AVX2 vector max operations for optimal
115    /// performance. Scalar implementation uses 4x unrolling for better instruction-level
116    /// parallelism.
117    #[inline]
118    pub(crate) fn relu_optimized(&self) -> Tensor {
119        let mut output = Tensor::new(self.shape().dims.clone());
120
121        if self.size() == 0 {
122            return output;
123        }
124
125        unsafe {
126            let src = self.as_ptr();
127            let dst = output.as_mut_ptr();
128
129            #[cfg(target_arch = "x86_64")]
130            {
131                if is_x86_feature_detected!("avx2") {
132                    self.relu_simd_avx2_optimized(src, dst);
133                    return output;
134                }
135            }
136
137            // Scalar fallback
138            self.relu_scalar_optimized(src, dst);
139        }
140
141        output
142    }
143
144    /// AVX2-optimized ReLU implementation
145    ///
146    /// Performs element-wise ReLU activation using AVX2 SIMD instructions for maximum
147    /// performance on x86_64 architectures with AVX2 support.
148    ///
149    /// # Arguments
150    ///
151    /// * `src` - Pointer to source tensor data
152    /// * `dst` - Pointer to output tensor data
153    ///
154    /// # Safety
155    ///
156    /// Requires valid pointers with sufficient memory for the tensor size.
157    /// All pointers must point to valid tensor data. Requires AVX2 support.
158    ///
159    /// # Performance Characteristics
160    ///
161    /// - **SIMD Processing**: 32 elements per iteration with 4x unrolling
162    /// - **Memory Access**: Linear access patterns for cache efficiency
163    /// - **Vector Operations**: Uses AVX2 max instructions for ReLU computation
164    /// - **Fallback**: Handles remaining elements with scalar operations
165    /// - **Hardware Requirements**: Requires x86_64 with AVX2 support
166    ///
167    /// # Implementation Details
168    ///
169    /// Uses AVX2 vector max operations to compute max(0, x) efficiently.
170    /// Implements 4x unrolling for optimal instruction throughput and cache utilization.
171    /// Processes remaining elements with scalar operations for complete coverage.
172    #[cfg(target_arch = "x86_64")]
173    #[inline]
174    #[target_feature(enable = "avx2")]
175    unsafe fn relu_simd_avx2_optimized(&self, src: *const f32, dst: *mut f32) {
176        let size = self.size();
177        let zero_vec = _mm256_setzero_ps();
178        let simd_count = size / 32; // Process 32 elements per iteration (4x unroll)
179        let mut offset = 0;
180
181        // Unrolled SIMD loop for maximum throughput
182        for _ in 0..simd_count {
183            // Process 4 AVX2 vectors (32 elements) per iteration
184            let src_vec1 = _mm256_loadu_ps(src.add(offset));
185            let relu_vec1 = _mm256_max_ps(src_vec1, zero_vec);
186            _mm256_storeu_ps(dst.add(offset), relu_vec1);
187
188            let src_vec2 = _mm256_loadu_ps(src.add(offset + 8));
189            let relu_vec2 = _mm256_max_ps(src_vec2, zero_vec);
190            _mm256_storeu_ps(dst.add(offset + 8), relu_vec2);
191
192            let src_vec3 = _mm256_loadu_ps(src.add(offset + 16));
193            let relu_vec3 = _mm256_max_ps(src_vec3, zero_vec);
194            _mm256_storeu_ps(dst.add(offset + 16), relu_vec3);
195
196            let src_vec4 = _mm256_loadu_ps(src.add(offset + 24));
197            let relu_vec4 = _mm256_max_ps(src_vec4, zero_vec);
198            _mm256_storeu_ps(dst.add(offset + 24), relu_vec4);
199
200            offset += 32;
201        }
202
203        // Handle remaining 8-element blocks
204        let remaining_full_blocks = (size - offset) / 8;
205        for _ in 0..remaining_full_blocks {
206            let src_vec = _mm256_loadu_ps(src.add(offset));
207            let relu_vec = _mm256_max_ps(src_vec, zero_vec);
208            _mm256_storeu_ps(dst.add(offset), relu_vec);
209            offset += 8;
210        }
211
212        // Handle remaining elements with scalar fallback
213        for i in offset..size {
214            let v = *src.add(i);
215            *dst.add(i) = if v > 0.0 { v } else { 0.0 };
216        }
217    }
218
219    /// Optimized scalar ReLU fallback
220    ///
221    /// Performs element-wise ReLU activation using optimized scalar operations with
222    /// 4x unrolling for better instruction-level parallelism and cache efficiency.
223    ///
224    /// # Arguments
225    ///
226    /// * `src` - Pointer to source tensor data
227    /// * `dst` - Pointer to output tensor data
228    ///
229    /// # Safety
230    ///
231    /// Requires valid pointers with sufficient memory for the tensor size.
232    /// All pointers must point to valid tensor data.
233    ///
234    /// # Performance Characteristics
235    ///
236    /// - **Unrolling**: 4x unrolling for instruction-level parallelism
237    /// - **Memory Access**: Linear access patterns for cache efficiency
238    /// - **Fallback**: Handles remaining elements with scalar operations
239    /// - **Cache Optimization**: Optimized for modern CPU cache hierarchies
240    /// - **Mathematical Accuracy**: High-precision scalar ReLU computation
241    ///
242    /// # Implementation Details
243    ///
244    /// Uses 4x unrolled scalar operations for optimal performance on non-SIMD hardware.
245    /// Processes elements in groups of 4 to improve instruction-level parallelism
246    /// and reduce loop overhead.
247    #[inline]
248    unsafe fn relu_scalar_optimized(&self, src: *const f32, dst: *mut f32) {
249        let size = self.size();
250        let unroll_count = size / 4;
251        let mut offset = 0;
252
253        // Unrolled scalar loop for better performance
254        for _ in 0..unroll_count {
255            let v1 = *src.add(offset);
256            let v2 = *src.add(offset + 1);
257            let v3 = *src.add(offset + 2);
258            let v4 = *src.add(offset + 3);
259
260            *dst.add(offset) = if v1 > 0.0 { v1 } else { 0.0 };
261            *dst.add(offset + 1) = if v2 > 0.0 { v2 } else { 0.0 };
262            *dst.add(offset + 2) = if v3 > 0.0 { v3 } else { 0.0 };
263            *dst.add(offset + 3) = if v4 > 0.0 { v4 } else { 0.0 };
264
265            offset += 4;
266        }
267
268        // Handle remaining elements
269        for i in offset..size {
270            let v = *src.add(i);
271            *dst.add(i) = if v > 0.0 { v } else { 0.0 };
272        }
273    }
274}
275
276#[cfg(test)]
277mod tests {
278    use super::*;
279
280    #[test]
281    fn test_relu_forward_basic() {
282        let x = Tensor::from_slice(&[-1.0, 0.0, 2.5], vec![3]).unwrap();
283        let y = x.relu();
284        unsafe {
285            assert_eq!(*y.as_ptr(), 0.0);
286            assert_eq!(*y.as_ptr().add(1), 0.0);
287            assert_eq!(*y.as_ptr().add(2), 2.5);
288        }
289    }
290}