train_station/tensor/ops/relu.rs
1//! ReLU activation function
2//!
3//! Provides the Rectified Linear Unit activation function following PyTorch conventions with
4//! comprehensive automatic differentiation support and SIMD-optimized computation.
5//!
6//! # Key Features
7//!
8//! - **ReLU Activation**: `relu()` - Computes max(0, x) for each element (PyTorch `relu()` equivalent)
9//! - **Automatic Differentiation**: Full gradtrack support with efficient gradient computation
10//! - **SIMD Optimization**: AVX2-optimized implementation for maximum performance
11//! - **Cache Optimization**: Memory access patterns optimized for modern CPUs
12//! - **Mathematical Accuracy**: High-precision activation computation
13//! - **Branch Prediction**: Optimized conditional logic for modern CPUs
14//!
15//! # Mathematical Properties
16//!
17//! The ReLU activation function has the following properties:
18//! - **Definition**: f(x) = max(0, x)
19//! - **Range**: [0, ∞) - outputs are always non-negative
20//! - **Monotonicity**: Strictly increasing for x > 0
21//! - **Continuity**: Continuous everywhere, differentiable everywhere except at x = 0
22//! - **Gradient**: f'(x) = 1 if x > 0, f'(x) = 0 if x ≤ 0
23//! - **Sparsity**: Produces sparse activations (many zeros) for negative inputs
24//!
25//! # Performance Characteristics
26//!
27//! - **SIMD Optimization**: AVX2-optimized with 32-element blocks and 4x unrolling
28//! - **Scalar Fallback**: 4x unrolled scalar implementation for non-SIMD hardware
29//! - **Cache-friendly Access**: Linear memory access patterns
30//! - **Branch Prediction**: Optimized conditional logic for modern CPUs
31//! - **Gradient Optimization**: Efficient gradtrack with NoGradTrack support
32
33use crate::gradtrack::{is_grad_enabled, GradEngine, GradFn};
34use crate::tensor::core::Tensor;
35
36// SIMD optimizations for performance-critical operations
37#[cfg(target_arch = "x86_64")]
38use std::arch::x86_64::*;
39
40impl Tensor {
41 /// Element-wise ReLU (Rectified Linear Unit) activation.
42 ///
43 /// Applies ReLU to each element: `output[i] = max(0, self[i])`
44 ///
45 /// # Returns
46 /// A new tensor with ReLU applied to each element
47 ///
48 /// # Examples
49 ///
50 /// ## Basic ReLU Activation
51 ///
52 /// ```
53 /// use train_station::Tensor;
54 ///
55 /// let a = Tensor::from_slice(&[-1.0, 0.0, 2.5], vec![3]).unwrap();
56 /// let b = a.relu();
57 /// assert_eq!(b.shape().dims, vec![3]);
58 /// assert_eq!(b.get(&[0]), 0.0); // max(0, -1.0) = 0.0
59 /// assert_eq!(b.get(&[1]), 0.0); // max(0, 0.0) = 0.0
60 /// assert_eq!(b.get(&[2]), 2.5); // max(0, 2.5) = 2.5
61 /// ```
62 ///
63 /// ## Mixed Positive and Negative Values
64 ///
65 /// ```
66 /// use train_station::Tensor;
67 ///
68 /// let a = Tensor::from_slice(&[-5.0, -0.1, 0.0, 0.1, 5.0], vec![5]).unwrap();
69 /// let b = a.relu();
70 /// assert_eq!(b.shape().dims, vec![5]);
71 /// assert_eq!(b.get(&[0]), 0.0); // max(0, -5.0) = 0.0
72 /// assert_eq!(b.get(&[1]), 0.0); // max(0, -0.1) = 0.0
73 /// assert_eq!(b.get(&[2]), 0.0); // max(0, 0.0) = 0.0
74 /// assert_eq!(b.get(&[3]), 0.1); // max(0, 0.1) = 0.1
75 /// assert_eq!(b.get(&[4]), 5.0); // max(0, 5.0) = 5.0
76 /// ```
77 #[track_caller]
78 pub fn relu(&self) -> Tensor {
79 let mut out = self.relu_optimized();
80
81 if self.requires_grad() && is_grad_enabled() {
82 out.set_requires_grad_internal(true);
83 let grad_fn = GradFn::Relu {
84 saved_input: Box::new(self.clone()),
85 };
86 out.set_grad_fn(grad_fn.clone());
87 GradEngine::register_operation(out.id(), vec![self.id()], grad_fn);
88 }
89
90 out
91 }
92
93 /// Internal optimized ReLU operation
94 ///
95 /// Performs element-wise ReLU activation using SIMD optimization when available
96 /// and falling back to optimized scalar computation. This is the core implementation
97 /// used by `relu()`.
98 ///
99 /// # Returns
100 ///
101 /// A new tensor containing the ReLU activation of each element
102 ///
103 /// # Performance Characteristics
104 ///
105 /// - **SIMD Optimization**: AVX2-optimized with 32-element blocks when available
106 /// - **Scalar Fallback**: 4x unrolled scalar implementation for non-SIMD hardware
107 /// - **Cache-friendly**: Linear memory access patterns
108 /// - **Branch Prediction**: Optimized conditional logic for modern CPUs
109 /// - **Mathematical Accuracy**: High-precision activation computation
110 /// - **Zero-sized Handling**: Fast return for empty tensors
111 ///
112 /// # Implementation Details
113 ///
114 /// Automatically selects between SIMD and scalar implementations based on hardware
115 /// capabilities. SIMD implementation uses AVX2 vector max operations for optimal
116 /// performance. Scalar implementation uses 4x unrolling for better instruction-level
117 /// parallelism.
118 #[inline]
119 pub(crate) fn relu_optimized(&self) -> Tensor {
120 let mut output = Tensor::new(self.shape().dims.clone());
121
122 if self.size() == 0 {
123 return output;
124 }
125
126 unsafe {
127 let src = self.as_ptr();
128 let dst = output.as_mut_ptr();
129
130 #[cfg(target_arch = "x86_64")]
131 {
132 if is_x86_feature_detected!("avx2") {
133 self.relu_simd_avx2_optimized(src, dst);
134 return output;
135 }
136 }
137
138 // Scalar fallback
139 self.relu_scalar_optimized(src, dst);
140 }
141
142 output
143 }
144
145 /// AVX2-optimized ReLU implementation
146 ///
147 /// Performs element-wise ReLU activation using AVX2 SIMD instructions for maximum
148 /// performance on x86_64 architectures with AVX2 support.
149 ///
150 /// # Arguments
151 ///
152 /// * `src` - Pointer to source tensor data
153 /// * `dst` - Pointer to output tensor data
154 ///
155 /// # Safety
156 ///
157 /// Requires valid pointers with sufficient memory for the tensor size.
158 /// All pointers must point to valid tensor data. Requires AVX2 support.
159 ///
160 /// # Performance Characteristics
161 ///
162 /// - **SIMD Processing**: 32 elements per iteration with 4x unrolling
163 /// - **Memory Access**: Linear access patterns for cache efficiency
164 /// - **Vector Operations**: Uses AVX2 max instructions for ReLU computation
165 /// - **Fallback**: Handles remaining elements with scalar operations
166 /// - **Hardware Requirements**: Requires x86_64 with AVX2 support
167 ///
168 /// # Implementation Details
169 ///
170 /// Uses AVX2 vector max operations to compute max(0, x) efficiently.
171 /// Implements 4x unrolling for optimal instruction throughput and cache utilization.
172 /// Processes remaining elements with scalar operations for complete coverage.
173 #[cfg(target_arch = "x86_64")]
174 #[inline]
175 #[target_feature(enable = "avx2")]
176 unsafe fn relu_simd_avx2_optimized(&self, src: *const f32, dst: *mut f32) {
177 let size = self.size();
178 let zero_vec = _mm256_setzero_ps();
179 let simd_count = size / 32; // Process 32 elements per iteration (4x unroll)
180 let mut offset = 0;
181
182 // Unrolled SIMD loop for maximum throughput
183 for _ in 0..simd_count {
184 // Process 4 AVX2 vectors (32 elements) per iteration
185 let src_vec1 = _mm256_loadu_ps(src.add(offset));
186 let relu_vec1 = _mm256_max_ps(src_vec1, zero_vec);
187 _mm256_storeu_ps(dst.add(offset), relu_vec1);
188
189 let src_vec2 = _mm256_loadu_ps(src.add(offset + 8));
190 let relu_vec2 = _mm256_max_ps(src_vec2, zero_vec);
191 _mm256_storeu_ps(dst.add(offset + 8), relu_vec2);
192
193 let src_vec3 = _mm256_loadu_ps(src.add(offset + 16));
194 let relu_vec3 = _mm256_max_ps(src_vec3, zero_vec);
195 _mm256_storeu_ps(dst.add(offset + 16), relu_vec3);
196
197 let src_vec4 = _mm256_loadu_ps(src.add(offset + 24));
198 let relu_vec4 = _mm256_max_ps(src_vec4, zero_vec);
199 _mm256_storeu_ps(dst.add(offset + 24), relu_vec4);
200
201 offset += 32;
202 }
203
204 // Handle remaining 8-element blocks
205 let remaining_full_blocks = (size - offset) / 8;
206 for _ in 0..remaining_full_blocks {
207 let src_vec = _mm256_loadu_ps(src.add(offset));
208 let relu_vec = _mm256_max_ps(src_vec, zero_vec);
209 _mm256_storeu_ps(dst.add(offset), relu_vec);
210 offset += 8;
211 }
212
213 // Handle remaining elements with scalar fallback
214 for i in offset..size {
215 let v = *src.add(i);
216 *dst.add(i) = if v > 0.0 { v } else { 0.0 };
217 }
218 }
219
220 /// Optimized scalar ReLU fallback
221 ///
222 /// Performs element-wise ReLU activation using optimized scalar operations with
223 /// 4x unrolling for better instruction-level parallelism and cache efficiency.
224 ///
225 /// # Arguments
226 ///
227 /// * `src` - Pointer to source tensor data
228 /// * `dst` - Pointer to output tensor data
229 ///
230 /// # Safety
231 ///
232 /// Requires valid pointers with sufficient memory for the tensor size.
233 /// All pointers must point to valid tensor data.
234 ///
235 /// # Performance Characteristics
236 ///
237 /// - **Unrolling**: 4x unrolling for instruction-level parallelism
238 /// - **Memory Access**: Linear access patterns for cache efficiency
239 /// - **Fallback**: Handles remaining elements with scalar operations
240 /// - **Cache Optimization**: Optimized for modern CPU cache hierarchies
241 /// - **Mathematical Accuracy**: High-precision scalar ReLU computation
242 ///
243 /// # Implementation Details
244 ///
245 /// Uses 4x unrolled scalar operations for optimal performance on non-SIMD hardware.
246 /// Processes elements in groups of 4 to improve instruction-level parallelism
247 /// and reduce loop overhead.
248 #[inline]
249 unsafe fn relu_scalar_optimized(&self, src: *const f32, dst: *mut f32) {
250 let size = self.size();
251 let unroll_count = size / 4;
252 let mut offset = 0;
253
254 // Unrolled scalar loop for better performance
255 for _ in 0..unroll_count {
256 let v1 = *src.add(offset);
257 let v2 = *src.add(offset + 1);
258 let v3 = *src.add(offset + 2);
259 let v4 = *src.add(offset + 3);
260
261 *dst.add(offset) = if v1 > 0.0 { v1 } else { 0.0 };
262 *dst.add(offset + 1) = if v2 > 0.0 { v2 } else { 0.0 };
263 *dst.add(offset + 2) = if v3 > 0.0 { v3 } else { 0.0 };
264 *dst.add(offset + 3) = if v4 > 0.0 { v4 } else { 0.0 };
265
266 offset += 4;
267 }
268
269 // Handle remaining elements
270 for i in offset..size {
271 let v = *src.add(i);
272 *dst.add(i) = if v > 0.0 { v } else { 0.0 };
273 }
274 }
275}
276
277#[cfg(test)]
278mod tests {
279 use super::*;
280
281 #[test]
282 fn test_relu_forward_basic() {
283 let x = Tensor::from_slice(&[-1.0, 0.0, 2.5], vec![3]).unwrap();
284 let y = x.relu();
285 unsafe {
286 assert_eq!(*y.as_ptr(), 0.0);
287 assert_eq!(*y.as_ptr().add(1), 0.0);
288 assert_eq!(*y.as_ptr().add(2), 2.5);
289 }
290 }
291}