train_station/tensor/ops/
leaky_relu.rs

1//! Leaky ReLU activation operations for tensors
2//!
3//! Provides the Leaky ReLU activation function following PyTorch conventions with
4//! comprehensive automatic differentiation support and SIMD-optimized computation.
5//!
6//! # Key Features
7//!
8//! - **Leaky ReLU Activation**: `leaky_relu(negative_slope)` - Computes max(0, x) + negative_slope * min(0, x) (PyTorch `leaky_relu()` equivalent)
9//! - **Automatic Differentiation**: Full gradtrack support with efficient gradient computation
10//! - **SIMD Optimization**: AVX2-optimized implementation for maximum performance
11//! - **Scalar Fallback**: Optimized scalar implementation for non-SIMD hardware
12//! - **Cache Optimization**: Memory access patterns optimized for modern CPUs
13//! - **Mathematical Accuracy**: High-precision activation computation
14//!
15//! # Mathematical Properties
16//!
17//! The Leaky ReLU function f(x) = max(0, x) + negative_slope * min(0, x) has the following properties:
18//! - For x > 0: f(x) = x (identity function)
19//! - For x ≤ 0: f(x) = negative_slope * x (small negative gradient)
20//! - Gradient: f'(x) = 1 for x > 0, f'(x) = negative_slope for x ≤ 0
21//! - Continuous at x = 0: f(0) = 0
22//! - Monotonic: f'(x) > 0 for all x (when negative_slope > 0)
23//!
24//! # Performance Characteristics
25//!
26//! - **SIMD Optimization**: AVX2-optimized with 32-element blocks and 4x unrolling
27//! - **Scalar Fallback**: 4x unrolled scalar implementation for non-SIMD hardware
28//! - **Cache-friendly Access**: Linear memory access patterns
29//! - **Branch Prediction**: Optimized conditional logic for modern CPUs
30//! - **Gradient Optimization**: Efficient gradtrack with NoGradTrack support
31
32use crate::gradtrack::{is_grad_enabled, GradEngine, GradFn};
33use crate::tensor::core::Tensor;
34
35// SIMD optimizations for performance-critical operations
36#[cfg(target_arch = "x86_64")]
37use std::arch::x86_64::*;
38
39impl Tensor {
40    /// Element-wise Leaky ReLU activation.
41    ///
42    /// Applies Leaky ReLU to each element: `output[i] = max(0, x) + negative_slope * min(0, x)`
43    ///
44    /// Unlike standard ReLU, allows a small gradient when the unit is not active.
45    ///
46    /// # Arguments
47    /// * `negative_slope` - Slope for negative values (typically small, e.g., 0.01 or 0.1)
48    ///
49    /// # Returns
50    /// A new tensor with Leaky ReLU applied to each element
51    ///
52    /// # Examples
53    ///
54    /// ## Basic Leaky ReLU
55    ///
56    /// ```
57    /// use train_station::Tensor;
58    ///
59    /// let a = Tensor::from_slice(&[-2.0, -1.0, 0.0, 1.0], vec![4]).unwrap();
60    /// let b = a.leaky_relu(0.1);
61    /// assert_eq!(b.shape().dims, vec![4]);
62    /// assert!((b.get(&[0]) - (-0.2)).abs() < 1e-6); // -2.0 * 0.1 = -0.2
63    /// assert!((b.get(&[1]) - (-0.1)).abs() < 1e-6); // -1.0 * 0.1 = -0.1
64    /// assert_eq!(b.get(&[2]), 0.0); // max(0, 0) = 0
65    /// assert_eq!(b.get(&[3]), 1.0); // max(0, 1) = 1
66    /// ```
67    ///
68    /// ## Different Negative Slopes
69    ///
70    /// ```
71    /// use train_station::Tensor;
72    ///
73    /// let a = Tensor::from_slice(&[-1.0, 0.0, 1.0], vec![3]).unwrap();
74    /// let b = a.leaky_relu(0.01); // Smaller negative slope
75    /// assert_eq!(b.shape().dims, vec![3]);
76    /// assert!((b.get(&[0]) - (-0.01)).abs() < 1e-6); // -1.0 * 0.01 = -0.01
77    /// assert_eq!(b.get(&[1]), 0.0); // max(0, 0) = 0
78    /// assert_eq!(b.get(&[2]), 1.0); // max(0, 1) = 1
79    /// ```
80    #[track_caller]
81    pub fn leaky_relu(&self, negative_slope: f32) -> Tensor {
82        let mut out = self.leaky_relu_optimized(negative_slope);
83
84        if self.requires_grad() && is_grad_enabled() {
85            out.set_requires_grad_internal(true);
86            let grad_fn = GradFn::LeakyRelu {
87                negative_slope,
88                saved_input: Box::new(self.clone()),
89            };
90            out.set_grad_fn(grad_fn.clone());
91            GradEngine::register_operation(out.id(), vec![self.id()], grad_fn);
92        }
93
94        out
95    }
96
97    /// Internal optimized Leaky ReLU operation
98    ///
99    /// Performs element-wise Leaky ReLU computation using SIMD optimization when available
100    /// and falling back to optimized scalar computation. This is the core implementation
101    /// used by `leaky_relu()`.
102    ///
103    /// # Arguments
104    ///
105    /// * `negative_slope` - Slope for negative values (typically small, e.g., 0.01 or 0.1)
106    ///
107    /// # Returns
108    ///
109    /// A new tensor containing the Leaky ReLU activation of each element
110    ///
111    /// # Performance Characteristics
112    ///
113    /// - **SIMD Optimization**: AVX2-optimized with 32-element blocks when available
114    /// - **Scalar Fallback**: 4x unrolled scalar implementation for non-SIMD hardware
115    /// - **Cache-friendly**: Linear memory access patterns
116    /// - **Branch Prediction**: Optimized conditional logic for modern CPUs
117    /// - **Zero-sized Handling**: Fast return for empty tensors
118    ///
119    /// # Implementation Details
120    ///
121    /// Automatically selects between SIMD and scalar implementations based on hardware
122    /// capabilities. SIMD implementation processes 32 elements per iteration with 4x
123    /// unrolling for maximum throughput.
124    #[inline]
125    pub(crate) fn leaky_relu_optimized(&self, negative_slope: f32) -> Tensor {
126        let mut output = Tensor::new(self.shape().dims.clone());
127
128        if self.size() == 0 {
129            return output;
130        }
131
132        unsafe {
133            let src = self.as_ptr();
134            let dst = output.as_mut_ptr();
135
136            #[cfg(target_arch = "x86_64")]
137            {
138                if is_x86_feature_detected!("avx2") {
139                    self.leaky_relu_simd_avx2_optimized(src, dst, negative_slope);
140                    return output;
141                }
142            }
143
144            // Scalar fallback
145            self.leaky_relu_scalar_optimized(src, dst, negative_slope);
146        }
147
148        output
149    }
150
151    /// AVX2-optimized Leaky ReLU implementation
152    ///
153    /// Performs element-wise Leaky ReLU using AVX2 SIMD instructions for maximum
154    /// performance on x86_64 architectures with AVX2 support.
155    ///
156    /// # Arguments
157    ///
158    /// * `src` - Pointer to source tensor data
159    /// * `dst` - Pointer to output tensor data
160    /// * `negative_slope` - Slope for negative values
161    ///
162    /// # Safety
163    ///
164    /// Requires valid pointers with sufficient memory for the tensor size.
165    /// All pointers must point to valid tensor data. Requires AVX2 support.
166    ///
167    /// # Performance Characteristics
168    ///
169    /// - **SIMD Processing**: 32 elements per iteration with 4x unrolling
170    /// - **Memory Access**: Linear access patterns for cache efficiency
171    /// - **Branch Prediction**: Optimized conditional logic using SIMD masks
172    /// - **Fallback**: Handles remaining elements with scalar operations
173    /// - **Hardware Requirements**: Requires x86_64 with AVX2 support
174    ///
175    /// # Implementation Details
176    ///
177    /// Uses AVX2 vector instructions to process 8 elements simultaneously.
178    /// Implements 4x unrolling for optimal instruction throughput and cache utilization.
179    #[cfg(target_arch = "x86_64")]
180    #[inline]
181    #[target_feature(enable = "avx2")]
182    unsafe fn leaky_relu_simd_avx2_optimized(
183        &self,
184        src: *const f32,
185        dst: *mut f32,
186        negative_slope: f32,
187    ) {
188        let size = self.size();
189        let zero_vec = _mm256_setzero_ps();
190        let slope_vec = _mm256_set1_ps(negative_slope);
191        let simd_count = size / 32; // Process 32 elements per iteration (4x unroll)
192        let mut offset = 0;
193
194        // Unrolled SIMD loop for maximum throughput
195        for _ in 0..simd_count {
196            // Process 4 AVX2 vectors (32 elements) per iteration
197            self.leaky_relu_simd_block(src, dst, offset, zero_vec, slope_vec);
198            self.leaky_relu_simd_block(src, dst, offset + 8, zero_vec, slope_vec);
199            self.leaky_relu_simd_block(src, dst, offset + 16, zero_vec, slope_vec);
200            self.leaky_relu_simd_block(src, dst, offset + 24, zero_vec, slope_vec);
201            offset += 32;
202        }
203
204        // Handle remaining 8-element blocks
205        let remaining_full_blocks = (size - offset) / 8;
206        for _ in 0..remaining_full_blocks {
207            self.leaky_relu_simd_block(src, dst, offset, zero_vec, slope_vec);
208            offset += 8;
209        }
210
211        // Handle remaining elements with scalar fallback
212        for i in offset..size {
213            let x = *src.add(i);
214            *dst.add(i) = if x > 0.0 { x } else { negative_slope * x };
215        }
216    }
217
218    /// AVX2 SIMD block processing for Leaky ReLU
219    ///
220    /// Processes a single 8-element block using AVX2 vector instructions.
221    /// This is a helper function for the main SIMD implementation.
222    ///
223    /// # Arguments
224    ///
225    /// * `src` - Pointer to source tensor data
226    /// * `dst` - Pointer to output tensor data
227    /// * `offset` - Offset into the tensor data
228    /// * `zero_vec` - AVX2 vector containing zeros
229    /// * `slope_vec` - AVX2 vector containing the negative slope value
230    ///
231    /// # Safety
232    ///
233    /// Requires valid pointers with sufficient memory for 8 elements starting at offset.
234    /// All pointers must point to valid tensor data. Requires AVX2 support.
235    ///
236    /// # Performance Characteristics
237    ///
238    /// - **SIMD Processing**: Processes 8 elements in a single vector operation
239    /// - **Vector Operations**: Uses AVX2 comparison, multiplication, and blending
240    /// - **Branch-free**: No conditional branches in the SIMD path
241    /// - **Memory Access**: Single load and store operation per block
242    ///
243    /// # Implementation Details
244    ///
245    /// Uses AVX2 vector comparison to create a mask for positive values,
246    /// then blends between the original values and scaled negative values
247    /// based on the comparison result.
248    #[cfg(target_arch = "x86_64")]
249    #[inline]
250    #[target_feature(enable = "avx2")]
251    unsafe fn leaky_relu_simd_block(
252        &self,
253        src: *const f32,
254        dst: *mut f32,
255        offset: usize,
256        zero_vec: __m256,
257        slope_vec: __m256,
258    ) {
259        let src_vec = _mm256_loadu_ps(src.add(offset));
260
261        // Create mask for positive values
262        let pos_mask = _mm256_cmp_ps(src_vec, zero_vec, _CMP_GT_OQ);
263
264        // Compute negative part: negative_slope * x
265        let neg_part = _mm256_mul_ps(src_vec, slope_vec);
266
267        // Blend: use src_vec where positive, neg_part where negative
268        let result = _mm256_blendv_ps(neg_part, src_vec, pos_mask);
269
270        _mm256_storeu_ps(dst.add(offset), result);
271    }
272
273    /// Optimized scalar Leaky ReLU fallback
274    ///
275    /// Performs element-wise Leaky ReLU using optimized scalar operations with
276    /// 4x unrolling for better instruction-level parallelism and cache efficiency.
277    ///
278    /// # Arguments
279    ///
280    /// * `src` - Pointer to source tensor data
281    /// * `dst` - Pointer to output tensor data
282    /// * `negative_slope` - Slope for negative values
283    ///
284    /// # Safety
285    ///
286    /// Requires valid pointers with sufficient memory for the tensor size.
287    /// All pointers must point to valid tensor data.
288    ///
289    /// # Performance Characteristics
290    ///
291    /// - **Unrolling**: 4x unrolling for instruction-level parallelism
292    /// - **Memory Access**: Linear access patterns for cache efficiency
293    /// - **Fallback**: Handles remaining elements with scalar operations
294    /// - **Branch Prediction**: Optimized conditional logic for modern CPUs
295    /// - **Mathematical Accuracy**: High-precision scalar computation
296    ///
297    /// # Implementation Details
298    ///
299    /// Uses 4x unrolled scalar operations for optimal performance on non-SIMD hardware.
300    /// Processes elements in groups of 4 to improve instruction-level parallelism
301    /// and reduce loop overhead.
302    #[inline]
303    unsafe fn leaky_relu_scalar_optimized(
304        &self,
305        src: *const f32,
306        dst: *mut f32,
307        negative_slope: f32,
308    ) {
309        let size = self.size();
310        let unroll_count = size / 4;
311        let mut offset = 0;
312
313        // Unrolled scalar loop for better performance
314        for _ in 0..unroll_count {
315            let x1 = *src.add(offset);
316            let x2 = *src.add(offset + 1);
317            let x3 = *src.add(offset + 2);
318            let x4 = *src.add(offset + 3);
319
320            *dst.add(offset) = if x1 > 0.0 { x1 } else { negative_slope * x1 };
321            *dst.add(offset + 1) = if x2 > 0.0 { x2 } else { negative_slope * x2 };
322            *dst.add(offset + 2) = if x3 > 0.0 { x3 } else { negative_slope * x3 };
323            *dst.add(offset + 3) = if x4 > 0.0 { x4 } else { negative_slope * x4 };
324
325            offset += 4;
326        }
327
328        // Handle remaining elements
329        for i in offset..size {
330            let x = *src.add(i);
331            *dst.add(i) = if x > 0.0 { x } else { negative_slope * x };
332        }
333    }
334}
335
336#[cfg(test)]
337mod tests {
338    use super::*;
339
340    #[test]
341    fn test_leaky_relu_forward_basic() {
342        let x = Tensor::from_slice(&[-2.0, -1.0, 0.0, 1.5], vec![4]).unwrap();
343        let y = x.leaky_relu(0.1);
344        unsafe {
345            assert!((*y.as_ptr() + 0.2).abs() < 1e-6);
346            assert!((*y.as_ptr().add(1) + 0.1).abs() < 1e-6);
347            assert!((*y.as_ptr().add(2) - 0.0).abs() < 1e-6);
348            assert!((*y.as_ptr().add(3) - 1.5).abs() < 1e-6);
349        }
350    }
351}