train_station/tensor/ops/relu.rs
1//! ReLU activation function
2//!
3//! Provides the Rectified Linear Unit activation function following PyTorch conventions with
4//! comprehensive automatic differentiation support and SIMD-optimized computation.
5//!
6//! # Key Features
7//!
8//! - **ReLU Activation**: `relu()` - Computes max(0, x) for each element (PyTorch `relu()` equivalent)
9//! - **Automatic Differentiation**: Full gradtrack support with efficient gradient computation
10//! - **SIMD Optimization**: AVX2-optimized implementation for maximum performance
11//! - **Cache Optimization**: Memory access patterns optimized for modern CPUs
12//! - **Mathematical Accuracy**: High-precision activation computation
13//! - **Branch Prediction**: Optimized conditional logic for modern CPUs
14//!
15//! # Mathematical Properties
16//!
17//! The ReLU activation function has the following properties:
18//! - **Definition**: f(x) = max(0, x)
19//! - **Range**: [0, ∞) - outputs are always non-negative
20//! - **Monotonicity**: Strictly increasing for x > 0
21//! - **Continuity**: Continuous everywhere, differentiable everywhere except at x = 0
22//! - **Gradient**: f'(x) = 1 if x > 0, f'(x) = 0 if x ≤ 0
23//! - **Sparsity**: Produces sparse activations (many zeros) for negative inputs
24//!
25//! # Performance Characteristics
26//!
27//! - **SIMD Optimization**: AVX2-optimized with 32-element blocks and 4x unrolling
28//! - **Scalar Fallback**: 4x unrolled scalar implementation for non-SIMD hardware
29//! - **Cache-friendly Access**: Linear memory access patterns
30//! - **Branch Prediction**: Optimized conditional logic for modern CPUs
31//! - **Gradient Optimization**: Efficient gradtrack with NoGradTrack support
32
33use crate::gradtrack::{is_grad_enabled, GradEngine, GradFn};
34use crate::tensor::core::Tensor;
35
36// SIMD optimizations for performance-critical operations
37#[cfg(target_arch = "x86_64")]
38use std::arch::x86_64::*;
39
40impl Tensor {
41 /// Element-wise ReLU (Rectified Linear Unit) activation.
42 ///
43 /// Applies ReLU to each element: `output[i] = max(0, self[i])`
44 ///
45 /// # Returns
46 /// A new tensor with ReLU applied to each element
47 ///
48 /// # Examples
49 ///
50 /// ## Basic ReLU Activation
51 ///
52 /// ```
53 /// use train_station::Tensor;
54 ///
55 /// let a = Tensor::from_slice(&[-1.0, 0.0, 2.5], vec![3]).unwrap();
56 /// let b = a.relu();
57 /// assert_eq!(b.shape().dims, vec![3]);
58 /// assert_eq!(b.get(&[0]), 0.0); // max(0, -1.0) = 0.0
59 /// assert_eq!(b.get(&[1]), 0.0); // max(0, 0.0) = 0.0
60 /// assert_eq!(b.get(&[2]), 2.5); // max(0, 2.5) = 2.5
61 /// ```
62 ///
63 /// ## Mixed Positive and Negative Values
64 ///
65 /// ```
66 /// use train_station::Tensor;
67 ///
68 /// let a = Tensor::from_slice(&[-5.0, -0.1, 0.0, 0.1, 5.0], vec![5]).unwrap();
69 /// let b = a.relu();
70 /// assert_eq!(b.shape().dims, vec![5]);
71 /// assert_eq!(b.get(&[0]), 0.0); // max(0, -5.0) = 0.0
72 /// assert_eq!(b.get(&[1]), 0.0); // max(0, -0.1) = 0.0
73 /// assert_eq!(b.get(&[2]), 0.0); // max(0, 0.0) = 0.0
74 /// assert_eq!(b.get(&[3]), 0.1); // max(0, 0.1) = 0.1
75 /// assert_eq!(b.get(&[4]), 5.0); // max(0, 5.0) = 5.0
76 /// ```
77 pub fn relu(&self) -> Tensor {
78 let mut out = self.relu_optimized();
79
80 if self.requires_grad() && is_grad_enabled() {
81 out.set_requires_grad_internal(true);
82 let grad_fn = GradFn::Relu {
83 saved_input: Box::new(self.clone()),
84 };
85 out.set_grad_fn(grad_fn.clone());
86 GradEngine::register_operation(out.id(), vec![self.id()], grad_fn);
87 }
88
89 out
90 }
91
92 /// Internal optimized ReLU operation
93 ///
94 /// Performs element-wise ReLU activation using SIMD optimization when available
95 /// and falling back to optimized scalar computation. This is the core implementation
96 /// used by `relu()`.
97 ///
98 /// # Returns
99 ///
100 /// A new tensor containing the ReLU activation of each element
101 ///
102 /// # Performance Characteristics
103 ///
104 /// - **SIMD Optimization**: AVX2-optimized with 32-element blocks when available
105 /// - **Scalar Fallback**: 4x unrolled scalar implementation for non-SIMD hardware
106 /// - **Cache-friendly**: Linear memory access patterns
107 /// - **Branch Prediction**: Optimized conditional logic for modern CPUs
108 /// - **Mathematical Accuracy**: High-precision activation computation
109 /// - **Zero-sized Handling**: Fast return for empty tensors
110 ///
111 /// # Implementation Details
112 ///
113 /// Automatically selects between SIMD and scalar implementations based on hardware
114 /// capabilities. SIMD implementation uses AVX2 vector max operations for optimal
115 /// performance. Scalar implementation uses 4x unrolling for better instruction-level
116 /// parallelism.
117 #[inline]
118 pub(crate) fn relu_optimized(&self) -> Tensor {
119 let mut output = Tensor::new(self.shape().dims.clone());
120
121 if self.size() == 0 {
122 return output;
123 }
124
125 unsafe {
126 let src = self.as_ptr();
127 let dst = output.as_mut_ptr();
128
129 #[cfg(target_arch = "x86_64")]
130 {
131 if is_x86_feature_detected!("avx2") {
132 self.relu_simd_avx2_optimized(src, dst);
133 return output;
134 }
135 }
136
137 // Scalar fallback
138 self.relu_scalar_optimized(src, dst);
139 }
140
141 output
142 }
143
144 /// AVX2-optimized ReLU implementation
145 ///
146 /// Performs element-wise ReLU activation using AVX2 SIMD instructions for maximum
147 /// performance on x86_64 architectures with AVX2 support.
148 ///
149 /// # Arguments
150 ///
151 /// * `src` - Pointer to source tensor data
152 /// * `dst` - Pointer to output tensor data
153 ///
154 /// # Safety
155 ///
156 /// Requires valid pointers with sufficient memory for the tensor size.
157 /// All pointers must point to valid tensor data. Requires AVX2 support.
158 ///
159 /// # Performance Characteristics
160 ///
161 /// - **SIMD Processing**: 32 elements per iteration with 4x unrolling
162 /// - **Memory Access**: Linear access patterns for cache efficiency
163 /// - **Vector Operations**: Uses AVX2 max instructions for ReLU computation
164 /// - **Fallback**: Handles remaining elements with scalar operations
165 /// - **Hardware Requirements**: Requires x86_64 with AVX2 support
166 ///
167 /// # Implementation Details
168 ///
169 /// Uses AVX2 vector max operations to compute max(0, x) efficiently.
170 /// Implements 4x unrolling for optimal instruction throughput and cache utilization.
171 /// Processes remaining elements with scalar operations for complete coverage.
172 #[cfg(target_arch = "x86_64")]
173 #[inline]
174 #[target_feature(enable = "avx2")]
175 unsafe fn relu_simd_avx2_optimized(&self, src: *const f32, dst: *mut f32) {
176 let size = self.size();
177 let zero_vec = _mm256_setzero_ps();
178 let simd_count = size / 32; // Process 32 elements per iteration (4x unroll)
179 let mut offset = 0;
180
181 // Unrolled SIMD loop for maximum throughput
182 for _ in 0..simd_count {
183 // Process 4 AVX2 vectors (32 elements) per iteration
184 let src_vec1 = _mm256_loadu_ps(src.add(offset));
185 let relu_vec1 = _mm256_max_ps(src_vec1, zero_vec);
186 _mm256_storeu_ps(dst.add(offset), relu_vec1);
187
188 let src_vec2 = _mm256_loadu_ps(src.add(offset + 8));
189 let relu_vec2 = _mm256_max_ps(src_vec2, zero_vec);
190 _mm256_storeu_ps(dst.add(offset + 8), relu_vec2);
191
192 let src_vec3 = _mm256_loadu_ps(src.add(offset + 16));
193 let relu_vec3 = _mm256_max_ps(src_vec3, zero_vec);
194 _mm256_storeu_ps(dst.add(offset + 16), relu_vec3);
195
196 let src_vec4 = _mm256_loadu_ps(src.add(offset + 24));
197 let relu_vec4 = _mm256_max_ps(src_vec4, zero_vec);
198 _mm256_storeu_ps(dst.add(offset + 24), relu_vec4);
199
200 offset += 32;
201 }
202
203 // Handle remaining 8-element blocks
204 let remaining_full_blocks = (size - offset) / 8;
205 for _ in 0..remaining_full_blocks {
206 let src_vec = _mm256_loadu_ps(src.add(offset));
207 let relu_vec = _mm256_max_ps(src_vec, zero_vec);
208 _mm256_storeu_ps(dst.add(offset), relu_vec);
209 offset += 8;
210 }
211
212 // Handle remaining elements with scalar fallback
213 for i in offset..size {
214 let v = *src.add(i);
215 *dst.add(i) = if v > 0.0 { v } else { 0.0 };
216 }
217 }
218
219 /// Optimized scalar ReLU fallback
220 ///
221 /// Performs element-wise ReLU activation using optimized scalar operations with
222 /// 4x unrolling for better instruction-level parallelism and cache efficiency.
223 ///
224 /// # Arguments
225 ///
226 /// * `src` - Pointer to source tensor data
227 /// * `dst` - Pointer to output tensor data
228 ///
229 /// # Safety
230 ///
231 /// Requires valid pointers with sufficient memory for the tensor size.
232 /// All pointers must point to valid tensor data.
233 ///
234 /// # Performance Characteristics
235 ///
236 /// - **Unrolling**: 4x unrolling for instruction-level parallelism
237 /// - **Memory Access**: Linear access patterns for cache efficiency
238 /// - **Fallback**: Handles remaining elements with scalar operations
239 /// - **Cache Optimization**: Optimized for modern CPU cache hierarchies
240 /// - **Mathematical Accuracy**: High-precision scalar ReLU computation
241 ///
242 /// # Implementation Details
243 ///
244 /// Uses 4x unrolled scalar operations for optimal performance on non-SIMD hardware.
245 /// Processes elements in groups of 4 to improve instruction-level parallelism
246 /// and reduce loop overhead.
247 #[inline]
248 unsafe fn relu_scalar_optimized(&self, src: *const f32, dst: *mut f32) {
249 let size = self.size();
250 let unroll_count = size / 4;
251 let mut offset = 0;
252
253 // Unrolled scalar loop for better performance
254 for _ in 0..unroll_count {
255 let v1 = *src.add(offset);
256 let v2 = *src.add(offset + 1);
257 let v3 = *src.add(offset + 2);
258 let v4 = *src.add(offset + 3);
259
260 *dst.add(offset) = if v1 > 0.0 { v1 } else { 0.0 };
261 *dst.add(offset + 1) = if v2 > 0.0 { v2 } else { 0.0 };
262 *dst.add(offset + 2) = if v3 > 0.0 { v3 } else { 0.0 };
263 *dst.add(offset + 3) = if v4 > 0.0 { v4 } else { 0.0 };
264
265 offset += 4;
266 }
267
268 // Handle remaining elements
269 for i in offset..size {
270 let v = *src.add(i);
271 *dst.add(i) = if v > 0.0 { v } else { 0.0 };
272 }
273 }
274}
275
276#[cfg(test)]
277mod tests {
278 use super::*;
279
280 #[test]
281 fn test_relu_forward_basic() {
282 let x = Tensor::from_slice(&[-1.0, 0.0, 2.5], vec![3]).unwrap();
283 let y = x.relu();
284 unsafe {
285 assert_eq!(*y.as_ptr(), 0.0);
286 assert_eq!(*y.as_ptr().add(1), 0.0);
287 assert_eq!(*y.as_ptr().add(2), 2.5);
288 }
289 }
290}