train_station/tensor/ops/
leaky_relu.rs

1//! Leaky ReLU activation operations for tensors
2//!
3//! Provides the Leaky ReLU activation function following PyTorch conventions with
4//! comprehensive automatic differentiation support and SIMD-optimized computation.
5//!
6//! # Key Features
7//!
8//! - **Leaky ReLU Activation**: `leaky_relu(negative_slope)` - Computes max(0, x) + negative_slope * min(0, x) (PyTorch `leaky_relu()` equivalent)
9//! - **Automatic Differentiation**: Full gradtrack support with efficient gradient computation
10//! - **SIMD Optimization**: AVX2-optimized implementation for maximum performance
11//! - **Scalar Fallback**: Optimized scalar implementation for non-SIMD hardware
12//! - **Cache Optimization**: Memory access patterns optimized for modern CPUs
13//! - **Mathematical Accuracy**: High-precision activation computation
14//!
15//! # Mathematical Properties
16//!
17//! The Leaky ReLU function f(x) = max(0, x) + negative_slope * min(0, x) has the following properties:
18//! - For x > 0: f(x) = x (identity function)
19//! - For x ≤ 0: f(x) = negative_slope * x (small negative gradient)
20//! - Gradient: f'(x) = 1 for x > 0, f'(x) = negative_slope for x ≤ 0
21//! - Continuous at x = 0: f(0) = 0
22//! - Monotonic: f'(x) > 0 for all x (when negative_slope > 0)
23//!
24//! # Performance Characteristics
25//!
26//! - **SIMD Optimization**: AVX2-optimized with 32-element blocks and 4x unrolling
27//! - **Scalar Fallback**: 4x unrolled scalar implementation for non-SIMD hardware
28//! - **Cache-friendly Access**: Linear memory access patterns
29//! - **Branch Prediction**: Optimized conditional logic for modern CPUs
30//! - **Gradient Optimization**: Efficient gradtrack with NoGradTrack support
31
32use crate::gradtrack::{is_grad_enabled, GradEngine, GradFn};
33use crate::tensor::core::Tensor;
34
35// SIMD optimizations for performance-critical operations
36#[cfg(target_arch = "x86_64")]
37use std::arch::x86_64::*;
38
39impl Tensor {
40    /// Element-wise Leaky ReLU activation.
41    ///
42    /// Applies Leaky ReLU to each element: `output[i] = max(0, x) + negative_slope * min(0, x)`
43    ///
44    /// Unlike standard ReLU, allows a small gradient when the unit is not active.
45    ///
46    /// # Arguments
47    /// * `negative_slope` - Slope for negative values (typically small, e.g., 0.01 or 0.1)
48    ///
49    /// # Returns
50    /// A new tensor with Leaky ReLU applied to each element
51    ///
52    /// # Examples
53    ///
54    /// ## Basic Leaky ReLU
55    ///
56    /// ```
57    /// use train_station::Tensor;
58    ///
59    /// let a = Tensor::from_slice(&[-2.0, -1.0, 0.0, 1.0], vec![4]).unwrap();
60    /// let b = a.leaky_relu(0.1);
61    /// assert_eq!(b.shape().dims, vec![4]);
62    /// assert!((b.get(&[0]) - (-0.2)).abs() < 1e-6); // -2.0 * 0.1 = -0.2
63    /// assert!((b.get(&[1]) - (-0.1)).abs() < 1e-6); // -1.0 * 0.1 = -0.1
64    /// assert_eq!(b.get(&[2]), 0.0); // max(0, 0) = 0
65    /// assert_eq!(b.get(&[3]), 1.0); // max(0, 1) = 1
66    /// ```
67    ///
68    /// ## Different Negative Slopes
69    ///
70    /// ```
71    /// use train_station::Tensor;
72    ///
73    /// let a = Tensor::from_slice(&[-1.0, 0.0, 1.0], vec![3]).unwrap();
74    /// let b = a.leaky_relu(0.01); // Smaller negative slope
75    /// assert_eq!(b.shape().dims, vec![3]);
76    /// assert!((b.get(&[0]) - (-0.01)).abs() < 1e-6); // -1.0 * 0.01 = -0.01
77    /// assert_eq!(b.get(&[1]), 0.0); // max(0, 0) = 0
78    /// assert_eq!(b.get(&[2]), 1.0); // max(0, 1) = 1
79    /// ```
80    pub fn leaky_relu(&self, negative_slope: f32) -> Tensor {
81        let mut out = self.leaky_relu_optimized(negative_slope);
82
83        if self.requires_grad() && is_grad_enabled() {
84            out.set_requires_grad_internal(true);
85            let grad_fn = GradFn::LeakyRelu {
86                negative_slope,
87                saved_input: Box::new(self.clone()),
88            };
89            out.set_grad_fn(grad_fn.clone());
90            GradEngine::register_operation(out.id(), vec![self.id()], grad_fn);
91        }
92
93        out
94    }
95
96    /// Internal optimized Leaky ReLU operation
97    ///
98    /// Performs element-wise Leaky ReLU computation using SIMD optimization when available
99    /// and falling back to optimized scalar computation. This is the core implementation
100    /// used by `leaky_relu()`.
101    ///
102    /// # Arguments
103    ///
104    /// * `negative_slope` - Slope for negative values (typically small, e.g., 0.01 or 0.1)
105    ///
106    /// # Returns
107    ///
108    /// A new tensor containing the Leaky ReLU activation of each element
109    ///
110    /// # Performance Characteristics
111    ///
112    /// - **SIMD Optimization**: AVX2-optimized with 32-element blocks when available
113    /// - **Scalar Fallback**: 4x unrolled scalar implementation for non-SIMD hardware
114    /// - **Cache-friendly**: Linear memory access patterns
115    /// - **Branch Prediction**: Optimized conditional logic for modern CPUs
116    /// - **Zero-sized Handling**: Fast return for empty tensors
117    ///
118    /// # Implementation Details
119    ///
120    /// Automatically selects between SIMD and scalar implementations based on hardware
121    /// capabilities. SIMD implementation processes 32 elements per iteration with 4x
122    /// unrolling for maximum throughput.
123    #[inline]
124    pub(crate) fn leaky_relu_optimized(&self, negative_slope: f32) -> Tensor {
125        let mut output = Tensor::new(self.shape().dims.clone());
126
127        if self.size() == 0 {
128            return output;
129        }
130
131        unsafe {
132            let src = self.as_ptr();
133            let dst = output.as_mut_ptr();
134
135            #[cfg(target_arch = "x86_64")]
136            {
137                if is_x86_feature_detected!("avx2") {
138                    self.leaky_relu_simd_avx2_optimized(src, dst, negative_slope);
139                    return output;
140                }
141            }
142
143            // Scalar fallback
144            self.leaky_relu_scalar_optimized(src, dst, negative_slope);
145        }
146
147        output
148    }
149
150    /// AVX2-optimized Leaky ReLU implementation
151    ///
152    /// Performs element-wise Leaky ReLU using AVX2 SIMD instructions for maximum
153    /// performance on x86_64 architectures with AVX2 support.
154    ///
155    /// # Arguments
156    ///
157    /// * `src` - Pointer to source tensor data
158    /// * `dst` - Pointer to output tensor data
159    /// * `negative_slope` - Slope for negative values
160    ///
161    /// # Safety
162    ///
163    /// Requires valid pointers with sufficient memory for the tensor size.
164    /// All pointers must point to valid tensor data. Requires AVX2 support.
165    ///
166    /// # Performance Characteristics
167    ///
168    /// - **SIMD Processing**: 32 elements per iteration with 4x unrolling
169    /// - **Memory Access**: Linear access patterns for cache efficiency
170    /// - **Branch Prediction**: Optimized conditional logic using SIMD masks
171    /// - **Fallback**: Handles remaining elements with scalar operations
172    /// - **Hardware Requirements**: Requires x86_64 with AVX2 support
173    ///
174    /// # Implementation Details
175    ///
176    /// Uses AVX2 vector instructions to process 8 elements simultaneously.
177    /// Implements 4x unrolling for optimal instruction throughput and cache utilization.
178    #[cfg(target_arch = "x86_64")]
179    #[inline]
180    #[target_feature(enable = "avx2")]
181    unsafe fn leaky_relu_simd_avx2_optimized(
182        &self,
183        src: *const f32,
184        dst: *mut f32,
185        negative_slope: f32,
186    ) {
187        let size = self.size();
188        let zero_vec = _mm256_setzero_ps();
189        let slope_vec = _mm256_set1_ps(negative_slope);
190        let simd_count = size / 32; // Process 32 elements per iteration (4x unroll)
191        let mut offset = 0;
192
193        // Unrolled SIMD loop for maximum throughput
194        for _ in 0..simd_count {
195            // Process 4 AVX2 vectors (32 elements) per iteration
196            self.leaky_relu_simd_block(src, dst, offset, zero_vec, slope_vec);
197            self.leaky_relu_simd_block(src, dst, offset + 8, zero_vec, slope_vec);
198            self.leaky_relu_simd_block(src, dst, offset + 16, zero_vec, slope_vec);
199            self.leaky_relu_simd_block(src, dst, offset + 24, zero_vec, slope_vec);
200            offset += 32;
201        }
202
203        // Handle remaining 8-element blocks
204        let remaining_full_blocks = (size - offset) / 8;
205        for _ in 0..remaining_full_blocks {
206            self.leaky_relu_simd_block(src, dst, offset, zero_vec, slope_vec);
207            offset += 8;
208        }
209
210        // Handle remaining elements with scalar fallback
211        for i in offset..size {
212            let x = *src.add(i);
213            *dst.add(i) = if x > 0.0 { x } else { negative_slope * x };
214        }
215    }
216
217    /// AVX2 SIMD block processing for Leaky ReLU
218    ///
219    /// Processes a single 8-element block using AVX2 vector instructions.
220    /// This is a helper function for the main SIMD implementation.
221    ///
222    /// # Arguments
223    ///
224    /// * `src` - Pointer to source tensor data
225    /// * `dst` - Pointer to output tensor data
226    /// * `offset` - Offset into the tensor data
227    /// * `zero_vec` - AVX2 vector containing zeros
228    /// * `slope_vec` - AVX2 vector containing the negative slope value
229    ///
230    /// # Safety
231    ///
232    /// Requires valid pointers with sufficient memory for 8 elements starting at offset.
233    /// All pointers must point to valid tensor data. Requires AVX2 support.
234    ///
235    /// # Performance Characteristics
236    ///
237    /// - **SIMD Processing**: Processes 8 elements in a single vector operation
238    /// - **Vector Operations**: Uses AVX2 comparison, multiplication, and blending
239    /// - **Branch-free**: No conditional branches in the SIMD path
240    /// - **Memory Access**: Single load and store operation per block
241    ///
242    /// # Implementation Details
243    ///
244    /// Uses AVX2 vector comparison to create a mask for positive values,
245    /// then blends between the original values and scaled negative values
246    /// based on the comparison result.
247    #[cfg(target_arch = "x86_64")]
248    #[inline]
249    #[target_feature(enable = "avx2")]
250    unsafe fn leaky_relu_simd_block(
251        &self,
252        src: *const f32,
253        dst: *mut f32,
254        offset: usize,
255        zero_vec: __m256,
256        slope_vec: __m256,
257    ) {
258        let src_vec = _mm256_loadu_ps(src.add(offset));
259
260        // Create mask for positive values
261        let pos_mask = _mm256_cmp_ps(src_vec, zero_vec, _CMP_GT_OQ);
262
263        // Compute negative part: negative_slope * x
264        let neg_part = _mm256_mul_ps(src_vec, slope_vec);
265
266        // Blend: use src_vec where positive, neg_part where negative
267        let result = _mm256_blendv_ps(neg_part, src_vec, pos_mask);
268
269        _mm256_storeu_ps(dst.add(offset), result);
270    }
271
272    /// Optimized scalar Leaky ReLU fallback
273    ///
274    /// Performs element-wise Leaky ReLU using optimized scalar operations with
275    /// 4x unrolling for better instruction-level parallelism and cache efficiency.
276    ///
277    /// # Arguments
278    ///
279    /// * `src` - Pointer to source tensor data
280    /// * `dst` - Pointer to output tensor data
281    /// * `negative_slope` - Slope for negative values
282    ///
283    /// # Safety
284    ///
285    /// Requires valid pointers with sufficient memory for the tensor size.
286    /// All pointers must point to valid tensor data.
287    ///
288    /// # Performance Characteristics
289    ///
290    /// - **Unrolling**: 4x unrolling for instruction-level parallelism
291    /// - **Memory Access**: Linear access patterns for cache efficiency
292    /// - **Fallback**: Handles remaining elements with scalar operations
293    /// - **Branch Prediction**: Optimized conditional logic for modern CPUs
294    /// - **Mathematical Accuracy**: High-precision scalar computation
295    ///
296    /// # Implementation Details
297    ///
298    /// Uses 4x unrolled scalar operations for optimal performance on non-SIMD hardware.
299    /// Processes elements in groups of 4 to improve instruction-level parallelism
300    /// and reduce loop overhead.
301    #[inline]
302    unsafe fn leaky_relu_scalar_optimized(
303        &self,
304        src: *const f32,
305        dst: *mut f32,
306        negative_slope: f32,
307    ) {
308        let size = self.size();
309        let unroll_count = size / 4;
310        let mut offset = 0;
311
312        // Unrolled scalar loop for better performance
313        for _ in 0..unroll_count {
314            let x1 = *src.add(offset);
315            let x2 = *src.add(offset + 1);
316            let x3 = *src.add(offset + 2);
317            let x4 = *src.add(offset + 3);
318
319            *dst.add(offset) = if x1 > 0.0 { x1 } else { negative_slope * x1 };
320            *dst.add(offset + 1) = if x2 > 0.0 { x2 } else { negative_slope * x2 };
321            *dst.add(offset + 2) = if x3 > 0.0 { x3 } else { negative_slope * x3 };
322            *dst.add(offset + 3) = if x4 > 0.0 { x4 } else { negative_slope * x4 };
323
324            offset += 4;
325        }
326
327        // Handle remaining elements
328        for i in offset..size {
329            let x = *src.add(i);
330            *dst.add(i) = if x > 0.0 { x } else { negative_slope * x };
331        }
332    }
333}
334
335#[cfg(test)]
336mod tests {
337    use super::*;
338
339    #[test]
340    fn test_leaky_relu_forward_basic() {
341        let x = Tensor::from_slice(&[-2.0, -1.0, 0.0, 1.5], vec![4]).unwrap();
342        let y = x.leaky_relu(0.1);
343        unsafe {
344            assert!((*y.as_ptr() + 0.2).abs() < 1e-6);
345            assert!((*y.as_ptr().add(1) + 0.1).abs() < 1e-6);
346            assert!((*y.as_ptr().add(2) - 0.0).abs() < 1e-6);
347            assert!((*y.as_ptr().add(3) - 1.5).abs() < 1e-6);
348        }
349    }
350}