train_station/tensor/ops/
relu.rs

1//! ReLU activation function
2//!
3//! Provides the Rectified Linear Unit activation function following PyTorch conventions with
4//! comprehensive automatic differentiation support and SIMD-optimized computation.
5//!
6//! # Key Features
7//!
8//! - **ReLU Activation**: `relu()` - Computes max(0, x) for each element (PyTorch `relu()` equivalent)
9//! - **Automatic Differentiation**: Full gradtrack support with efficient gradient computation
10//! - **SIMD Optimization**: AVX2-optimized implementation for maximum performance
11//! - **Cache Optimization**: Memory access patterns optimized for modern CPUs
12//! - **Mathematical Accuracy**: High-precision activation computation
13//! - **Branch Prediction**: Optimized conditional logic for modern CPUs
14//!
15//! # Mathematical Properties
16//!
17//! The ReLU activation function has the following properties:
18//! - **Definition**: f(x) = max(0, x)
19//! - **Range**: [0, ∞) - outputs are always non-negative
20//! - **Monotonicity**: Strictly increasing for x > 0
21//! - **Continuity**: Continuous everywhere, differentiable everywhere except at x = 0
22//! - **Gradient**: f'(x) = 1 if x > 0, f'(x) = 0 if x ≤ 0
23//! - **Sparsity**: Produces sparse activations (many zeros) for negative inputs
24//!
25//! # Performance Characteristics
26//!
27//! - **SIMD Optimization**: AVX2-optimized with 32-element blocks and 4x unrolling
28//! - **Scalar Fallback**: 4x unrolled scalar implementation for non-SIMD hardware
29//! - **Cache-friendly Access**: Linear memory access patterns
30//! - **Branch Prediction**: Optimized conditional logic for modern CPUs
31//! - **Gradient Optimization**: Efficient gradtrack with NoGradTrack support
32
33use crate::gradtrack::{is_grad_enabled, GradEngine, GradFn};
34use crate::tensor::core::Tensor;
35
36// SIMD optimizations for performance-critical operations
37#[cfg(target_arch = "x86_64")]
38use std::arch::x86_64::*;
39
40impl Tensor {
41    /// Element-wise ReLU (Rectified Linear Unit) activation.
42    ///
43    /// Applies ReLU to each element: `output[i] = max(0, self[i])`
44    ///
45    /// # Returns
46    /// A new tensor with ReLU applied to each element
47    ///
48    /// # Examples
49    ///
50    /// ## Basic ReLU Activation
51    ///
52    /// ```
53    /// use train_station::Tensor;
54    ///
55    /// let a = Tensor::from_slice(&[-1.0, 0.0, 2.5], vec![3]).unwrap();
56    /// let b = a.relu();
57    /// assert_eq!(b.shape().dims, vec![3]);
58    /// assert_eq!(b.get(&[0]), 0.0); // max(0, -1.0) = 0.0
59    /// assert_eq!(b.get(&[1]), 0.0); // max(0, 0.0) = 0.0
60    /// assert_eq!(b.get(&[2]), 2.5); // max(0, 2.5) = 2.5
61    /// ```
62    ///
63    /// ## Mixed Positive and Negative Values
64    ///
65    /// ```
66    /// use train_station::Tensor;
67    ///
68    /// let a = Tensor::from_slice(&[-5.0, -0.1, 0.0, 0.1, 5.0], vec![5]).unwrap();
69    /// let b = a.relu();
70    /// assert_eq!(b.shape().dims, vec![5]);
71    /// assert_eq!(b.get(&[0]), 0.0); // max(0, -5.0) = 0.0
72    /// assert_eq!(b.get(&[1]), 0.0); // max(0, -0.1) = 0.0
73    /// assert_eq!(b.get(&[2]), 0.0); // max(0, 0.0) = 0.0
74    /// assert_eq!(b.get(&[3]), 0.1); // max(0, 0.1) = 0.1
75    /// assert_eq!(b.get(&[4]), 5.0); // max(0, 5.0) = 5.0
76    /// ```
77    #[track_caller]
78    pub fn relu(&self) -> Tensor {
79        let mut out = self.relu_optimized();
80
81        if self.requires_grad() && is_grad_enabled() {
82            out.set_requires_grad_internal(true);
83            let grad_fn = GradFn::Relu {
84                saved_input: Box::new(self.clone()),
85            };
86            out.set_grad_fn(grad_fn.clone());
87            GradEngine::register_operation(out.id(), vec![self.id()], grad_fn);
88        }
89
90        out
91    }
92
93    /// Internal optimized ReLU operation
94    ///
95    /// Performs element-wise ReLU activation using SIMD optimization when available
96    /// and falling back to optimized scalar computation. This is the core implementation
97    /// used by `relu()`.
98    ///
99    /// # Returns
100    ///
101    /// A new tensor containing the ReLU activation of each element
102    ///
103    /// # Performance Characteristics
104    ///
105    /// - **SIMD Optimization**: AVX2-optimized with 32-element blocks when available
106    /// - **Scalar Fallback**: 4x unrolled scalar implementation for non-SIMD hardware
107    /// - **Cache-friendly**: Linear memory access patterns
108    /// - **Branch Prediction**: Optimized conditional logic for modern CPUs
109    /// - **Mathematical Accuracy**: High-precision activation computation
110    /// - **Zero-sized Handling**: Fast return for empty tensors
111    ///
112    /// # Implementation Details
113    ///
114    /// Automatically selects between SIMD and scalar implementations based on hardware
115    /// capabilities. SIMD implementation uses AVX2 vector max operations for optimal
116    /// performance. Scalar implementation uses 4x unrolling for better instruction-level
117    /// parallelism.
118    #[inline]
119    pub(crate) fn relu_optimized(&self) -> Tensor {
120        let mut output = Tensor::new(self.shape().dims.clone());
121
122        if self.size() == 0 {
123            return output;
124        }
125
126        unsafe {
127            let src = self.as_ptr();
128            let dst = output.as_mut_ptr();
129
130            #[cfg(target_arch = "x86_64")]
131            {
132                if is_x86_feature_detected!("avx2") {
133                    self.relu_simd_avx2_optimized(src, dst);
134                    return output;
135                }
136            }
137
138            // Scalar fallback
139            self.relu_scalar_optimized(src, dst);
140        }
141
142        output
143    }
144
145    /// AVX2-optimized ReLU implementation
146    ///
147    /// Performs element-wise ReLU activation using AVX2 SIMD instructions for maximum
148    /// performance on x86_64 architectures with AVX2 support.
149    ///
150    /// # Arguments
151    ///
152    /// * `src` - Pointer to source tensor data
153    /// * `dst` - Pointer to output tensor data
154    ///
155    /// # Safety
156    ///
157    /// Requires valid pointers with sufficient memory for the tensor size.
158    /// All pointers must point to valid tensor data. Requires AVX2 support.
159    ///
160    /// # Performance Characteristics
161    ///
162    /// - **SIMD Processing**: 32 elements per iteration with 4x unrolling
163    /// - **Memory Access**: Linear access patterns for cache efficiency
164    /// - **Vector Operations**: Uses AVX2 max instructions for ReLU computation
165    /// - **Fallback**: Handles remaining elements with scalar operations
166    /// - **Hardware Requirements**: Requires x86_64 with AVX2 support
167    ///
168    /// # Implementation Details
169    ///
170    /// Uses AVX2 vector max operations to compute max(0, x) efficiently.
171    /// Implements 4x unrolling for optimal instruction throughput and cache utilization.
172    /// Processes remaining elements with scalar operations for complete coverage.
173    #[cfg(target_arch = "x86_64")]
174    #[inline]
175    #[target_feature(enable = "avx2")]
176    unsafe fn relu_simd_avx2_optimized(&self, src: *const f32, dst: *mut f32) {
177        let size = self.size();
178        let zero_vec = _mm256_setzero_ps();
179        let simd_count = size / 32; // Process 32 elements per iteration (4x unroll)
180        let mut offset = 0;
181
182        // Unrolled SIMD loop for maximum throughput
183        for _ in 0..simd_count {
184            // Process 4 AVX2 vectors (32 elements) per iteration
185            let src_vec1 = _mm256_loadu_ps(src.add(offset));
186            let relu_vec1 = _mm256_max_ps(src_vec1, zero_vec);
187            _mm256_storeu_ps(dst.add(offset), relu_vec1);
188
189            let src_vec2 = _mm256_loadu_ps(src.add(offset + 8));
190            let relu_vec2 = _mm256_max_ps(src_vec2, zero_vec);
191            _mm256_storeu_ps(dst.add(offset + 8), relu_vec2);
192
193            let src_vec3 = _mm256_loadu_ps(src.add(offset + 16));
194            let relu_vec3 = _mm256_max_ps(src_vec3, zero_vec);
195            _mm256_storeu_ps(dst.add(offset + 16), relu_vec3);
196
197            let src_vec4 = _mm256_loadu_ps(src.add(offset + 24));
198            let relu_vec4 = _mm256_max_ps(src_vec4, zero_vec);
199            _mm256_storeu_ps(dst.add(offset + 24), relu_vec4);
200
201            offset += 32;
202        }
203
204        // Handle remaining 8-element blocks
205        let remaining_full_blocks = (size - offset) / 8;
206        for _ in 0..remaining_full_blocks {
207            let src_vec = _mm256_loadu_ps(src.add(offset));
208            let relu_vec = _mm256_max_ps(src_vec, zero_vec);
209            _mm256_storeu_ps(dst.add(offset), relu_vec);
210            offset += 8;
211        }
212
213        // Handle remaining elements with scalar fallback
214        for i in offset..size {
215            let v = *src.add(i);
216            *dst.add(i) = if v > 0.0 { v } else { 0.0 };
217        }
218    }
219
220    /// Optimized scalar ReLU fallback
221    ///
222    /// Performs element-wise ReLU activation using optimized scalar operations with
223    /// 4x unrolling for better instruction-level parallelism and cache efficiency.
224    ///
225    /// # Arguments
226    ///
227    /// * `src` - Pointer to source tensor data
228    /// * `dst` - Pointer to output tensor data
229    ///
230    /// # Safety
231    ///
232    /// Requires valid pointers with sufficient memory for the tensor size.
233    /// All pointers must point to valid tensor data.
234    ///
235    /// # Performance Characteristics
236    ///
237    /// - **Unrolling**: 4x unrolling for instruction-level parallelism
238    /// - **Memory Access**: Linear access patterns for cache efficiency
239    /// - **Fallback**: Handles remaining elements with scalar operations
240    /// - **Cache Optimization**: Optimized for modern CPU cache hierarchies
241    /// - **Mathematical Accuracy**: High-precision scalar ReLU computation
242    ///
243    /// # Implementation Details
244    ///
245    /// Uses 4x unrolled scalar operations for optimal performance on non-SIMD hardware.
246    /// Processes elements in groups of 4 to improve instruction-level parallelism
247    /// and reduce loop overhead.
248    #[inline]
249    unsafe fn relu_scalar_optimized(&self, src: *const f32, dst: *mut f32) {
250        let size = self.size();
251        let unroll_count = size / 4;
252        let mut offset = 0;
253
254        // Unrolled scalar loop for better performance
255        for _ in 0..unroll_count {
256            let v1 = *src.add(offset);
257            let v2 = *src.add(offset + 1);
258            let v3 = *src.add(offset + 2);
259            let v4 = *src.add(offset + 3);
260
261            *dst.add(offset) = if v1 > 0.0 { v1 } else { 0.0 };
262            *dst.add(offset + 1) = if v2 > 0.0 { v2 } else { 0.0 };
263            *dst.add(offset + 2) = if v3 > 0.0 { v3 } else { 0.0 };
264            *dst.add(offset + 3) = if v4 > 0.0 { v4 } else { 0.0 };
265
266            offset += 4;
267        }
268
269        // Handle remaining elements
270        for i in offset..size {
271            let v = *src.add(i);
272            *dst.add(i) = if v > 0.0 { v } else { 0.0 };
273        }
274    }
275}
276
277#[cfg(test)]
278mod tests {
279    use super::*;
280
281    #[test]
282    fn test_relu_forward_basic() {
283        let x = Tensor::from_slice(&[-1.0, 0.0, 2.5], vec![3]).unwrap();
284        let y = x.relu();
285        unsafe {
286            assert_eq!(*y.as_ptr(), 0.0);
287            assert_eq!(*y.as_ptr().add(1), 0.0);
288            assert_eq!(*y.as_ptr().add(2), 2.5);
289        }
290    }
291}