train_station/tensor/ops/leaky_relu.rs
1//! Leaky ReLU activation operations for tensors
2//!
3//! Provides the Leaky ReLU activation function following PyTorch conventions with
4//! comprehensive automatic differentiation support and SIMD-optimized computation.
5//!
6//! # Key Features
7//!
8//! - **Leaky ReLU Activation**: `leaky_relu(negative_slope)` - Computes max(0, x) + negative_slope * min(0, x) (PyTorch `leaky_relu()` equivalent)
9//! - **Automatic Differentiation**: Full gradtrack support with efficient gradient computation
10//! - **SIMD Optimization**: AVX2-optimized implementation for maximum performance
11//! - **Scalar Fallback**: Optimized scalar implementation for non-SIMD hardware
12//! - **Cache Optimization**: Memory access patterns optimized for modern CPUs
13//! - **Mathematical Accuracy**: High-precision activation computation
14//!
15//! # Mathematical Properties
16//!
17//! The Leaky ReLU function f(x) = max(0, x) + negative_slope * min(0, x) has the following properties:
18//! - For x > 0: f(x) = x (identity function)
19//! - For x ≤ 0: f(x) = negative_slope * x (small negative gradient)
20//! - Gradient: f'(x) = 1 for x > 0, f'(x) = negative_slope for x ≤ 0
21//! - Continuous at x = 0: f(0) = 0
22//! - Monotonic: f'(x) > 0 for all x (when negative_slope > 0)
23//!
24//! # Performance Characteristics
25//!
26//! - **SIMD Optimization**: AVX2-optimized with 32-element blocks and 4x unrolling
27//! - **Scalar Fallback**: 4x unrolled scalar implementation for non-SIMD hardware
28//! - **Cache-friendly Access**: Linear memory access patterns
29//! - **Branch Prediction**: Optimized conditional logic for modern CPUs
30//! - **Gradient Optimization**: Efficient gradtrack with NoGradTrack support
31
32use crate::gradtrack::{is_grad_enabled, GradEngine, GradFn};
33use crate::tensor::core::Tensor;
34
35// SIMD optimizations for performance-critical operations
36#[cfg(target_arch = "x86_64")]
37use std::arch::x86_64::*;
38
39impl Tensor {
40 /// Element-wise Leaky ReLU activation.
41 ///
42 /// Applies Leaky ReLU to each element: `output[i] = max(0, x) + negative_slope * min(0, x)`
43 ///
44 /// Unlike standard ReLU, allows a small gradient when the unit is not active.
45 ///
46 /// # Arguments
47 /// * `negative_slope` - Slope for negative values (typically small, e.g., 0.01 or 0.1)
48 ///
49 /// # Returns
50 /// A new tensor with Leaky ReLU applied to each element
51 ///
52 /// # Examples
53 ///
54 /// ## Basic Leaky ReLU
55 ///
56 /// ```
57 /// use train_station::Tensor;
58 ///
59 /// let a = Tensor::from_slice(&[-2.0, -1.0, 0.0, 1.0], vec![4]).unwrap();
60 /// let b = a.leaky_relu(0.1);
61 /// assert_eq!(b.shape().dims, vec![4]);
62 /// assert!((b.get(&[0]) - (-0.2)).abs() < 1e-6); // -2.0 * 0.1 = -0.2
63 /// assert!((b.get(&[1]) - (-0.1)).abs() < 1e-6); // -1.0 * 0.1 = -0.1
64 /// assert_eq!(b.get(&[2]), 0.0); // max(0, 0) = 0
65 /// assert_eq!(b.get(&[3]), 1.0); // max(0, 1) = 1
66 /// ```
67 ///
68 /// ## Different Negative Slopes
69 ///
70 /// ```
71 /// use train_station::Tensor;
72 ///
73 /// let a = Tensor::from_slice(&[-1.0, 0.0, 1.0], vec![3]).unwrap();
74 /// let b = a.leaky_relu(0.01); // Smaller negative slope
75 /// assert_eq!(b.shape().dims, vec![3]);
76 /// assert!((b.get(&[0]) - (-0.01)).abs() < 1e-6); // -1.0 * 0.01 = -0.01
77 /// assert_eq!(b.get(&[1]), 0.0); // max(0, 0) = 0
78 /// assert_eq!(b.get(&[2]), 1.0); // max(0, 1) = 1
79 /// ```
80 #[track_caller]
81 pub fn leaky_relu(&self, negative_slope: f32) -> Tensor {
82 let mut out = self.leaky_relu_optimized(negative_slope);
83
84 if self.requires_grad() && is_grad_enabled() {
85 out.set_requires_grad_internal(true);
86 let grad_fn = GradFn::LeakyRelu {
87 negative_slope,
88 saved_input: Box::new(self.clone()),
89 };
90 out.set_grad_fn(grad_fn.clone());
91 GradEngine::register_operation(out.id(), vec![self.id()], grad_fn);
92 }
93
94 out
95 }
96
97 /// Internal optimized Leaky ReLU operation
98 ///
99 /// Performs element-wise Leaky ReLU computation using SIMD optimization when available
100 /// and falling back to optimized scalar computation. This is the core implementation
101 /// used by `leaky_relu()`.
102 ///
103 /// # Arguments
104 ///
105 /// * `negative_slope` - Slope for negative values (typically small, e.g., 0.01 or 0.1)
106 ///
107 /// # Returns
108 ///
109 /// A new tensor containing the Leaky ReLU activation of each element
110 ///
111 /// # Performance Characteristics
112 ///
113 /// - **SIMD Optimization**: AVX2-optimized with 32-element blocks when available
114 /// - **Scalar Fallback**: 4x unrolled scalar implementation for non-SIMD hardware
115 /// - **Cache-friendly**: Linear memory access patterns
116 /// - **Branch Prediction**: Optimized conditional logic for modern CPUs
117 /// - **Zero-sized Handling**: Fast return for empty tensors
118 ///
119 /// # Implementation Details
120 ///
121 /// Automatically selects between SIMD and scalar implementations based on hardware
122 /// capabilities. SIMD implementation processes 32 elements per iteration with 4x
123 /// unrolling for maximum throughput.
124 #[inline]
125 pub(crate) fn leaky_relu_optimized(&self, negative_slope: f32) -> Tensor {
126 let mut output = Tensor::new(self.shape().dims.clone());
127
128 if self.size() == 0 {
129 return output;
130 }
131
132 unsafe {
133 let src = self.as_ptr();
134 let dst = output.as_mut_ptr();
135
136 #[cfg(target_arch = "x86_64")]
137 {
138 if is_x86_feature_detected!("avx2") {
139 self.leaky_relu_simd_avx2_optimized(src, dst, negative_slope);
140 return output;
141 }
142 }
143
144 // Scalar fallback
145 self.leaky_relu_scalar_optimized(src, dst, negative_slope);
146 }
147
148 output
149 }
150
151 /// AVX2-optimized Leaky ReLU implementation
152 ///
153 /// Performs element-wise Leaky ReLU using AVX2 SIMD instructions for maximum
154 /// performance on x86_64 architectures with AVX2 support.
155 ///
156 /// # Arguments
157 ///
158 /// * `src` - Pointer to source tensor data
159 /// * `dst` - Pointer to output tensor data
160 /// * `negative_slope` - Slope for negative values
161 ///
162 /// # Safety
163 ///
164 /// Requires valid pointers with sufficient memory for the tensor size.
165 /// All pointers must point to valid tensor data. Requires AVX2 support.
166 ///
167 /// # Performance Characteristics
168 ///
169 /// - **SIMD Processing**: 32 elements per iteration with 4x unrolling
170 /// - **Memory Access**: Linear access patterns for cache efficiency
171 /// - **Branch Prediction**: Optimized conditional logic using SIMD masks
172 /// - **Fallback**: Handles remaining elements with scalar operations
173 /// - **Hardware Requirements**: Requires x86_64 with AVX2 support
174 ///
175 /// # Implementation Details
176 ///
177 /// Uses AVX2 vector instructions to process 8 elements simultaneously.
178 /// Implements 4x unrolling for optimal instruction throughput and cache utilization.
179 #[cfg(target_arch = "x86_64")]
180 #[inline]
181 #[target_feature(enable = "avx2")]
182 unsafe fn leaky_relu_simd_avx2_optimized(
183 &self,
184 src: *const f32,
185 dst: *mut f32,
186 negative_slope: f32,
187 ) {
188 let size = self.size();
189 let zero_vec = _mm256_setzero_ps();
190 let slope_vec = _mm256_set1_ps(negative_slope);
191 let simd_count = size / 32; // Process 32 elements per iteration (4x unroll)
192 let mut offset = 0;
193
194 // Unrolled SIMD loop for maximum throughput
195 for _ in 0..simd_count {
196 // Process 4 AVX2 vectors (32 elements) per iteration
197 self.leaky_relu_simd_block(src, dst, offset, zero_vec, slope_vec);
198 self.leaky_relu_simd_block(src, dst, offset + 8, zero_vec, slope_vec);
199 self.leaky_relu_simd_block(src, dst, offset + 16, zero_vec, slope_vec);
200 self.leaky_relu_simd_block(src, dst, offset + 24, zero_vec, slope_vec);
201 offset += 32;
202 }
203
204 // Handle remaining 8-element blocks
205 let remaining_full_blocks = (size - offset) / 8;
206 for _ in 0..remaining_full_blocks {
207 self.leaky_relu_simd_block(src, dst, offset, zero_vec, slope_vec);
208 offset += 8;
209 }
210
211 // Handle remaining elements with scalar fallback
212 for i in offset..size {
213 let x = *src.add(i);
214 *dst.add(i) = if x > 0.0 { x } else { negative_slope * x };
215 }
216 }
217
218 /// AVX2 SIMD block processing for Leaky ReLU
219 ///
220 /// Processes a single 8-element block using AVX2 vector instructions.
221 /// This is a helper function for the main SIMD implementation.
222 ///
223 /// # Arguments
224 ///
225 /// * `src` - Pointer to source tensor data
226 /// * `dst` - Pointer to output tensor data
227 /// * `offset` - Offset into the tensor data
228 /// * `zero_vec` - AVX2 vector containing zeros
229 /// * `slope_vec` - AVX2 vector containing the negative slope value
230 ///
231 /// # Safety
232 ///
233 /// Requires valid pointers with sufficient memory for 8 elements starting at offset.
234 /// All pointers must point to valid tensor data. Requires AVX2 support.
235 ///
236 /// # Performance Characteristics
237 ///
238 /// - **SIMD Processing**: Processes 8 elements in a single vector operation
239 /// - **Vector Operations**: Uses AVX2 comparison, multiplication, and blending
240 /// - **Branch-free**: No conditional branches in the SIMD path
241 /// - **Memory Access**: Single load and store operation per block
242 ///
243 /// # Implementation Details
244 ///
245 /// Uses AVX2 vector comparison to create a mask for positive values,
246 /// then blends between the original values and scaled negative values
247 /// based on the comparison result.
248 #[cfg(target_arch = "x86_64")]
249 #[inline]
250 #[target_feature(enable = "avx2")]
251 unsafe fn leaky_relu_simd_block(
252 &self,
253 src: *const f32,
254 dst: *mut f32,
255 offset: usize,
256 zero_vec: __m256,
257 slope_vec: __m256,
258 ) {
259 let src_vec = _mm256_loadu_ps(src.add(offset));
260
261 // Create mask for positive values
262 let pos_mask = _mm256_cmp_ps(src_vec, zero_vec, _CMP_GT_OQ);
263
264 // Compute negative part: negative_slope * x
265 let neg_part = _mm256_mul_ps(src_vec, slope_vec);
266
267 // Blend: use src_vec where positive, neg_part where negative
268 let result = _mm256_blendv_ps(neg_part, src_vec, pos_mask);
269
270 _mm256_storeu_ps(dst.add(offset), result);
271 }
272
273 /// Optimized scalar Leaky ReLU fallback
274 ///
275 /// Performs element-wise Leaky ReLU using optimized scalar operations with
276 /// 4x unrolling for better instruction-level parallelism and cache efficiency.
277 ///
278 /// # Arguments
279 ///
280 /// * `src` - Pointer to source tensor data
281 /// * `dst` - Pointer to output tensor data
282 /// * `negative_slope` - Slope for negative values
283 ///
284 /// # Safety
285 ///
286 /// Requires valid pointers with sufficient memory for the tensor size.
287 /// All pointers must point to valid tensor data.
288 ///
289 /// # Performance Characteristics
290 ///
291 /// - **Unrolling**: 4x unrolling for instruction-level parallelism
292 /// - **Memory Access**: Linear access patterns for cache efficiency
293 /// - **Fallback**: Handles remaining elements with scalar operations
294 /// - **Branch Prediction**: Optimized conditional logic for modern CPUs
295 /// - **Mathematical Accuracy**: High-precision scalar computation
296 ///
297 /// # Implementation Details
298 ///
299 /// Uses 4x unrolled scalar operations for optimal performance on non-SIMD hardware.
300 /// Processes elements in groups of 4 to improve instruction-level parallelism
301 /// and reduce loop overhead.
302 #[inline]
303 unsafe fn leaky_relu_scalar_optimized(
304 &self,
305 src: *const f32,
306 dst: *mut f32,
307 negative_slope: f32,
308 ) {
309 let size = self.size();
310 let unroll_count = size / 4;
311 let mut offset = 0;
312
313 // Unrolled scalar loop for better performance
314 for _ in 0..unroll_count {
315 let x1 = *src.add(offset);
316 let x2 = *src.add(offset + 1);
317 let x3 = *src.add(offset + 2);
318 let x4 = *src.add(offset + 3);
319
320 *dst.add(offset) = if x1 > 0.0 { x1 } else { negative_slope * x1 };
321 *dst.add(offset + 1) = if x2 > 0.0 { x2 } else { negative_slope * x2 };
322 *dst.add(offset + 2) = if x3 > 0.0 { x3 } else { negative_slope * x3 };
323 *dst.add(offset + 3) = if x4 > 0.0 { x4 } else { negative_slope * x4 };
324
325 offset += 4;
326 }
327
328 // Handle remaining elements
329 for i in offset..size {
330 let x = *src.add(i);
331 *dst.add(i) = if x > 0.0 { x } else { negative_slope * x };
332 }
333 }
334}
335
336#[cfg(test)]
337mod tests {
338 use super::*;
339
340 #[test]
341 fn test_leaky_relu_forward_basic() {
342 let x = Tensor::from_slice(&[-2.0, -1.0, 0.0, 1.5], vec![4]).unwrap();
343 let y = x.leaky_relu(0.1);
344 unsafe {
345 assert!((*y.as_ptr() + 0.2).abs() < 1e-6);
346 assert!((*y.as_ptr().add(1) + 0.1).abs() < 1e-6);
347 assert!((*y.as_ptr().add(2) - 0.0).abs() < 1e-6);
348 assert!((*y.as_ptr().add(3) - 1.5).abs() < 1e-6);
349 }
350 }
351}