train_station/tensor/ops/leaky_relu.rs
1//! Leaky ReLU activation operations for tensors
2//!
3//! Provides the Leaky ReLU activation function following PyTorch conventions with
4//! comprehensive automatic differentiation support and SIMD-optimized computation.
5//!
6//! # Key Features
7//!
8//! - **Leaky ReLU Activation**: `leaky_relu(negative_slope)` - Computes max(0, x) + negative_slope * min(0, x) (PyTorch `leaky_relu()` equivalent)
9//! - **Automatic Differentiation**: Full gradtrack support with efficient gradient computation
10//! - **SIMD Optimization**: AVX2-optimized implementation for maximum performance
11//! - **Scalar Fallback**: Optimized scalar implementation for non-SIMD hardware
12//! - **Cache Optimization**: Memory access patterns optimized for modern CPUs
13//! - **Mathematical Accuracy**: High-precision activation computation
14//!
15//! # Mathematical Properties
16//!
17//! The Leaky ReLU function f(x) = max(0, x) + negative_slope * min(0, x) has the following properties:
18//! - For x > 0: f(x) = x (identity function)
19//! - For x ≤ 0: f(x) = negative_slope * x (small negative gradient)
20//! - Gradient: f'(x) = 1 for x > 0, f'(x) = negative_slope for x ≤ 0
21//! - Continuous at x = 0: f(0) = 0
22//! - Monotonic: f'(x) > 0 for all x (when negative_slope > 0)
23//!
24//! # Performance Characteristics
25//!
26//! - **SIMD Optimization**: AVX2-optimized with 32-element blocks and 4x unrolling
27//! - **Scalar Fallback**: 4x unrolled scalar implementation for non-SIMD hardware
28//! - **Cache-friendly Access**: Linear memory access patterns
29//! - **Branch Prediction**: Optimized conditional logic for modern CPUs
30//! - **Gradient Optimization**: Efficient gradtrack with NoGradTrack support
31
32use crate::gradtrack::{is_grad_enabled, GradEngine, GradFn};
33use crate::tensor::core::Tensor;
34
35// SIMD optimizations for performance-critical operations
36#[cfg(target_arch = "x86_64")]
37use std::arch::x86_64::*;
38
39impl Tensor {
40 /// Element-wise Leaky ReLU activation.
41 ///
42 /// Applies Leaky ReLU to each element: `output[i] = max(0, x) + negative_slope * min(0, x)`
43 ///
44 /// Unlike standard ReLU, allows a small gradient when the unit is not active.
45 ///
46 /// # Arguments
47 /// * `negative_slope` - Slope for negative values (typically small, e.g., 0.01 or 0.1)
48 ///
49 /// # Returns
50 /// A new tensor with Leaky ReLU applied to each element
51 ///
52 /// # Examples
53 ///
54 /// ## Basic Leaky ReLU
55 ///
56 /// ```
57 /// use train_station::Tensor;
58 ///
59 /// let a = Tensor::from_slice(&[-2.0, -1.0, 0.0, 1.0], vec![4]).unwrap();
60 /// let b = a.leaky_relu(0.1);
61 /// assert_eq!(b.shape().dims, vec![4]);
62 /// assert!((b.get(&[0]) - (-0.2)).abs() < 1e-6); // -2.0 * 0.1 = -0.2
63 /// assert!((b.get(&[1]) - (-0.1)).abs() < 1e-6); // -1.0 * 0.1 = -0.1
64 /// assert_eq!(b.get(&[2]), 0.0); // max(0, 0) = 0
65 /// assert_eq!(b.get(&[3]), 1.0); // max(0, 1) = 1
66 /// ```
67 ///
68 /// ## Different Negative Slopes
69 ///
70 /// ```
71 /// use train_station::Tensor;
72 ///
73 /// let a = Tensor::from_slice(&[-1.0, 0.0, 1.0], vec![3]).unwrap();
74 /// let b = a.leaky_relu(0.01); // Smaller negative slope
75 /// assert_eq!(b.shape().dims, vec![3]);
76 /// assert!((b.get(&[0]) - (-0.01)).abs() < 1e-6); // -1.0 * 0.01 = -0.01
77 /// assert_eq!(b.get(&[1]), 0.0); // max(0, 0) = 0
78 /// assert_eq!(b.get(&[2]), 1.0); // max(0, 1) = 1
79 /// ```
80 pub fn leaky_relu(&self, negative_slope: f32) -> Tensor {
81 let mut out = self.leaky_relu_optimized(negative_slope);
82
83 if self.requires_grad() && is_grad_enabled() {
84 out.set_requires_grad_internal(true);
85 let grad_fn = GradFn::LeakyRelu {
86 negative_slope,
87 saved_input: Box::new(self.clone()),
88 };
89 out.set_grad_fn(grad_fn.clone());
90 GradEngine::register_operation(out.id(), vec![self.id()], grad_fn);
91 }
92
93 out
94 }
95
96 /// Internal optimized Leaky ReLU operation
97 ///
98 /// Performs element-wise Leaky ReLU computation using SIMD optimization when available
99 /// and falling back to optimized scalar computation. This is the core implementation
100 /// used by `leaky_relu()`.
101 ///
102 /// # Arguments
103 ///
104 /// * `negative_slope` - Slope for negative values (typically small, e.g., 0.01 or 0.1)
105 ///
106 /// # Returns
107 ///
108 /// A new tensor containing the Leaky ReLU activation of each element
109 ///
110 /// # Performance Characteristics
111 ///
112 /// - **SIMD Optimization**: AVX2-optimized with 32-element blocks when available
113 /// - **Scalar Fallback**: 4x unrolled scalar implementation for non-SIMD hardware
114 /// - **Cache-friendly**: Linear memory access patterns
115 /// - **Branch Prediction**: Optimized conditional logic for modern CPUs
116 /// - **Zero-sized Handling**: Fast return for empty tensors
117 ///
118 /// # Implementation Details
119 ///
120 /// Automatically selects between SIMD and scalar implementations based on hardware
121 /// capabilities. SIMD implementation processes 32 elements per iteration with 4x
122 /// unrolling for maximum throughput.
123 #[inline]
124 pub(crate) fn leaky_relu_optimized(&self, negative_slope: f32) -> Tensor {
125 let mut output = Tensor::new(self.shape().dims.clone());
126
127 if self.size() == 0 {
128 return output;
129 }
130
131 unsafe {
132 let src = self.as_ptr();
133 let dst = output.as_mut_ptr();
134
135 #[cfg(target_arch = "x86_64")]
136 {
137 if is_x86_feature_detected!("avx2") {
138 self.leaky_relu_simd_avx2_optimized(src, dst, negative_slope);
139 return output;
140 }
141 }
142
143 // Scalar fallback
144 self.leaky_relu_scalar_optimized(src, dst, negative_slope);
145 }
146
147 output
148 }
149
150 /// AVX2-optimized Leaky ReLU implementation
151 ///
152 /// Performs element-wise Leaky ReLU using AVX2 SIMD instructions for maximum
153 /// performance on x86_64 architectures with AVX2 support.
154 ///
155 /// # Arguments
156 ///
157 /// * `src` - Pointer to source tensor data
158 /// * `dst` - Pointer to output tensor data
159 /// * `negative_slope` - Slope for negative values
160 ///
161 /// # Safety
162 ///
163 /// Requires valid pointers with sufficient memory for the tensor size.
164 /// All pointers must point to valid tensor data. Requires AVX2 support.
165 ///
166 /// # Performance Characteristics
167 ///
168 /// - **SIMD Processing**: 32 elements per iteration with 4x unrolling
169 /// - **Memory Access**: Linear access patterns for cache efficiency
170 /// - **Branch Prediction**: Optimized conditional logic using SIMD masks
171 /// - **Fallback**: Handles remaining elements with scalar operations
172 /// - **Hardware Requirements**: Requires x86_64 with AVX2 support
173 ///
174 /// # Implementation Details
175 ///
176 /// Uses AVX2 vector instructions to process 8 elements simultaneously.
177 /// Implements 4x unrolling for optimal instruction throughput and cache utilization.
178 #[cfg(target_arch = "x86_64")]
179 #[inline]
180 #[target_feature(enable = "avx2")]
181 unsafe fn leaky_relu_simd_avx2_optimized(
182 &self,
183 src: *const f32,
184 dst: *mut f32,
185 negative_slope: f32,
186 ) {
187 let size = self.size();
188 let zero_vec = _mm256_setzero_ps();
189 let slope_vec = _mm256_set1_ps(negative_slope);
190 let simd_count = size / 32; // Process 32 elements per iteration (4x unroll)
191 let mut offset = 0;
192
193 // Unrolled SIMD loop for maximum throughput
194 for _ in 0..simd_count {
195 // Process 4 AVX2 vectors (32 elements) per iteration
196 self.leaky_relu_simd_block(src, dst, offset, zero_vec, slope_vec);
197 self.leaky_relu_simd_block(src, dst, offset + 8, zero_vec, slope_vec);
198 self.leaky_relu_simd_block(src, dst, offset + 16, zero_vec, slope_vec);
199 self.leaky_relu_simd_block(src, dst, offset + 24, zero_vec, slope_vec);
200 offset += 32;
201 }
202
203 // Handle remaining 8-element blocks
204 let remaining_full_blocks = (size - offset) / 8;
205 for _ in 0..remaining_full_blocks {
206 self.leaky_relu_simd_block(src, dst, offset, zero_vec, slope_vec);
207 offset += 8;
208 }
209
210 // Handle remaining elements with scalar fallback
211 for i in offset..size {
212 let x = *src.add(i);
213 *dst.add(i) = if x > 0.0 { x } else { negative_slope * x };
214 }
215 }
216
217 /// AVX2 SIMD block processing for Leaky ReLU
218 ///
219 /// Processes a single 8-element block using AVX2 vector instructions.
220 /// This is a helper function for the main SIMD implementation.
221 ///
222 /// # Arguments
223 ///
224 /// * `src` - Pointer to source tensor data
225 /// * `dst` - Pointer to output tensor data
226 /// * `offset` - Offset into the tensor data
227 /// * `zero_vec` - AVX2 vector containing zeros
228 /// * `slope_vec` - AVX2 vector containing the negative slope value
229 ///
230 /// # Safety
231 ///
232 /// Requires valid pointers with sufficient memory for 8 elements starting at offset.
233 /// All pointers must point to valid tensor data. Requires AVX2 support.
234 ///
235 /// # Performance Characteristics
236 ///
237 /// - **SIMD Processing**: Processes 8 elements in a single vector operation
238 /// - **Vector Operations**: Uses AVX2 comparison, multiplication, and blending
239 /// - **Branch-free**: No conditional branches in the SIMD path
240 /// - **Memory Access**: Single load and store operation per block
241 ///
242 /// # Implementation Details
243 ///
244 /// Uses AVX2 vector comparison to create a mask for positive values,
245 /// then blends between the original values and scaled negative values
246 /// based on the comparison result.
247 #[cfg(target_arch = "x86_64")]
248 #[inline]
249 #[target_feature(enable = "avx2")]
250 unsafe fn leaky_relu_simd_block(
251 &self,
252 src: *const f32,
253 dst: *mut f32,
254 offset: usize,
255 zero_vec: __m256,
256 slope_vec: __m256,
257 ) {
258 let src_vec = _mm256_loadu_ps(src.add(offset));
259
260 // Create mask for positive values
261 let pos_mask = _mm256_cmp_ps(src_vec, zero_vec, _CMP_GT_OQ);
262
263 // Compute negative part: negative_slope * x
264 let neg_part = _mm256_mul_ps(src_vec, slope_vec);
265
266 // Blend: use src_vec where positive, neg_part where negative
267 let result = _mm256_blendv_ps(neg_part, src_vec, pos_mask);
268
269 _mm256_storeu_ps(dst.add(offset), result);
270 }
271
272 /// Optimized scalar Leaky ReLU fallback
273 ///
274 /// Performs element-wise Leaky ReLU using optimized scalar operations with
275 /// 4x unrolling for better instruction-level parallelism and cache efficiency.
276 ///
277 /// # Arguments
278 ///
279 /// * `src` - Pointer to source tensor data
280 /// * `dst` - Pointer to output tensor data
281 /// * `negative_slope` - Slope for negative values
282 ///
283 /// # Safety
284 ///
285 /// Requires valid pointers with sufficient memory for the tensor size.
286 /// All pointers must point to valid tensor data.
287 ///
288 /// # Performance Characteristics
289 ///
290 /// - **Unrolling**: 4x unrolling for instruction-level parallelism
291 /// - **Memory Access**: Linear access patterns for cache efficiency
292 /// - **Fallback**: Handles remaining elements with scalar operations
293 /// - **Branch Prediction**: Optimized conditional logic for modern CPUs
294 /// - **Mathematical Accuracy**: High-precision scalar computation
295 ///
296 /// # Implementation Details
297 ///
298 /// Uses 4x unrolled scalar operations for optimal performance on non-SIMD hardware.
299 /// Processes elements in groups of 4 to improve instruction-level parallelism
300 /// and reduce loop overhead.
301 #[inline]
302 unsafe fn leaky_relu_scalar_optimized(
303 &self,
304 src: *const f32,
305 dst: *mut f32,
306 negative_slope: f32,
307 ) {
308 let size = self.size();
309 let unroll_count = size / 4;
310 let mut offset = 0;
311
312 // Unrolled scalar loop for better performance
313 for _ in 0..unroll_count {
314 let x1 = *src.add(offset);
315 let x2 = *src.add(offset + 1);
316 let x3 = *src.add(offset + 2);
317 let x4 = *src.add(offset + 3);
318
319 *dst.add(offset) = if x1 > 0.0 { x1 } else { negative_slope * x1 };
320 *dst.add(offset + 1) = if x2 > 0.0 { x2 } else { negative_slope * x2 };
321 *dst.add(offset + 2) = if x3 > 0.0 { x3 } else { negative_slope * x3 };
322 *dst.add(offset + 3) = if x4 > 0.0 { x4 } else { negative_slope * x4 };
323
324 offset += 4;
325 }
326
327 // Handle remaining elements
328 for i in offset..size {
329 let x = *src.add(i);
330 *dst.add(i) = if x > 0.0 { x } else { negative_slope * x };
331 }
332 }
333}
334
335#[cfg(test)]
336mod tests {
337 use super::*;
338
339 #[test]
340 fn test_leaky_relu_forward_basic() {
341 let x = Tensor::from_slice(&[-2.0, -1.0, 0.0, 1.5], vec![4]).unwrap();
342 let y = x.leaky_relu(0.1);
343 unsafe {
344 assert!((*y.as_ptr() + 0.2).abs() < 1e-6);
345 assert!((*y.as_ptr().add(1) + 0.1).abs() < 1e-6);
346 assert!((*y.as_ptr().add(2) - 0.0).abs() < 1e-6);
347 assert!((*y.as_ptr().add(3) - 1.5).abs() < 1e-6);
348 }
349 }
350}