train_station/tensor/ops/
sqrt.rs

1//! Square root operation for tensors
2//!
3//! Provides element-wise square root following PyTorch conventions with
4//! comprehensive GradTrack support and SIMD-optimized computation.
5//!
6//! # Key Features
7//!
8//! - **Square Root**: `sqrt()` - Computes square root for each element (PyTorch `sqrt()` equivalent)
9//! - **GradTrack Support**: Full automatic differentiation with efficient gradient computation
10//! - **SIMD Optimization**: AVX2-optimized implementation for maximum performance
11//! - **Mathematical Accuracy**: High-precision square root computation
12//! - **Domain Validation**: Handles negative values appropriately
13//! - **Performance Optimization**: 4x unrolled SIMD operations with scalar fallback
14//!
15//! # Mathematical Properties
16//!
17//! The square root function has the following properties:
18//! - **Definition**: f(x) = √x
19//! - **Domain**: [0, ∞) - defined for non-negative real numbers
20//! - **Range**: [0, ∞) - outputs are always non-negative
21//! - **Monotonicity**: Strictly increasing function
22//! - **Continuity**: Continuous on its domain
23//! - **Gradient**: f'(x) = 0.5 / √x for x > 0
24//! - **Special Cases**: f(0) = 0, f(1) = 1
25//!
26//! # Performance Characteristics
27//!
28//! - **SIMD Optimization**: AVX2-optimized with 32-element blocks and 4x unrolling
29//! - **Scalar Fallback**: 4x unrolled scalar implementation for non-SIMD hardware
30//! - **Cache-friendly Access**: Linear memory access patterns
31//! - **Mathematical Accuracy**: High-precision square root computation
32//! - **GradTrack Optimization**: Efficient automatic differentiation with NoGradTrack support
33
34use crate::gradtrack::{is_grad_enabled, GradEngine, GradFn};
35use crate::tensor::core::Tensor;
36
37// SIMD optimizations for performance-critical operations
38#[cfg(target_arch = "x86_64")]
39use std::arch::x86_64::*;
40
41impl Tensor {
42    /// Element-wise square root
43    ///
44    /// Computes the square root for each element: `output[i] = sqrt(self[i])`
45    ///
46    /// Uses SIMD optimization when available for maximum performance, with automatic
47    /// fallback to optimized scalar computation for non-SIMD hardware.
48    ///
49    /// # Returns
50    ///
51    /// A new tensor with the square root of each element
52    ///
53    /// # Performance Characteristics
54    ///
55    /// - **SIMD Optimization**: AVX2-optimized with 32-element blocks and 4x unrolling
56    /// - **Scalar Fallback**: 4x unrolled scalar implementation for non-SIMD hardware
57    /// - **Cache-friendly**: Linear memory access patterns
58    /// - **Mathematical Accuracy**: High-precision square root computation
59    /// - **GradTrack Support**: Full automatic differentiation with efficient gradient computation
60    ///
61    /// # Implementation Details
62    ///
63    /// Automatically selects between SIMD and scalar implementations based on hardware
64    /// capabilities. SIMD implementation uses AVX2 vector square root operations for optimal
65    /// performance. Scalar implementation uses 4x unrolling for better instruction-level
66    /// parallelism.
67    ///
68    /// # Examples
69    ///
70    /// ## Basic Square Root
71    ///
72    /// ```
73    /// use train_station::Tensor;
74    ///
75    /// let a = Tensor::from_slice(&[1.0, 4.0, 9.0], vec![3]).unwrap();
76    /// let b = a.sqrt();
77    /// assert_eq!(b.shape().dims, vec![3]);
78    /// assert_eq!(b.get(&[0]), 1.0); // sqrt(1.0) = 1.0
79    /// assert_eq!(b.get(&[1]), 2.0); // sqrt(4.0) = 2.0
80    /// assert_eq!(b.get(&[2]), 3.0); // sqrt(9.0) = 3.0
81    /// ```
82    ///
83    /// ## Zero and Special Values
84    ///
85    /// ```
86    /// use train_station::Tensor;
87    ///
88    /// let a = Tensor::from_slice(&[0.0, 1.0, 16.0], vec![3]).unwrap();
89    /// let b = a.sqrt();
90    /// assert_eq!(b.shape().dims, vec![3]);
91    /// assert_eq!(b.get(&[0]), 0.0); // sqrt(0.0) = 0.0
92    /// assert_eq!(b.get(&[1]), 1.0); // sqrt(1.0) = 1.0
93    /// assert_eq!(b.get(&[2]), 4.0); // sqrt(16.0) = 4.0
94    /// ```
95    ///
96    /// # Note
97    /// Results are undefined for negative values (may produce NaN)
98    #[inline]
99    pub fn sqrt(&self) -> Tensor {
100        let mut result = self.sqrt_optimized();
101        if self.requires_grad() && is_grad_enabled() {
102            result.set_requires_grad_internal(true);
103            let grad_fn = GradFn::Sqrt {
104                saved_output: Box::new(result.clone()),
105            };
106            result.set_grad_fn(grad_fn.clone());
107            GradEngine::register_operation(result.id(), vec![self.id()], grad_fn);
108        }
109        result
110    }
111    /// Internal optimized square root operation
112    ///
113    /// Performs element-wise square root using SIMD optimization when available
114    /// and falling back to optimized scalar computation. This is the core implementation
115    /// used by `sqrt()`.
116    ///
117    /// # Returns
118    ///
119    /// A new tensor containing the square root of each element
120    ///
121    /// # Performance Characteristics
122    ///
123    /// - **SIMD Optimization**: AVX2-optimized with 32-element blocks when available
124    /// - **Scalar Fallback**: 4x unrolled scalar implementation for non-SIMD hardware
125    /// - **Cache-friendly**: Linear memory access patterns
126    /// - **Mathematical Accuracy**: High-precision square root computation
127    /// - **Zero-sized Handling**: Fast return for empty tensors
128    ///
129    /// # Implementation Details
130    ///
131    /// Automatically selects between SIMD and scalar implementations based on hardware
132    /// capabilities. SIMD implementation uses AVX2 vector square root operations for optimal
133    /// performance. Scalar implementation uses 4x unrolling for better instruction-level
134    /// parallelism.
135    #[inline]
136    pub(crate) fn sqrt_optimized(&self) -> Tensor {
137        let mut output = Tensor::new(self.shape().dims.clone());
138
139        if self.size() == 0 {
140            return output;
141        }
142
143        unsafe {
144            let src = self.as_ptr();
145            let dst = output.as_mut_ptr();
146
147            #[cfg(target_arch = "x86_64")]
148            {
149                if is_x86_feature_detected!("avx2") {
150                    self.sqrt_simd_avx2_optimized(src, dst);
151                    return output;
152                }
153            }
154
155            // Scalar fallback
156            self.sqrt_scalar_optimized(src, dst);
157        }
158
159        output
160    }
161
162    /// AVX2-optimized square root implementation
163    ///
164    /// Performs element-wise square root using AVX2 SIMD instructions for maximum
165    /// performance on x86_64 architectures with AVX2 support.
166    ///
167    /// # Arguments
168    ///
169    /// * `src` - Pointer to source tensor data
170    /// * `dst` - Pointer to output tensor data
171    ///
172    /// # Safety
173    ///
174    /// Requires valid pointers with sufficient memory for the tensor size.
175    /// All pointers must point to valid tensor data. Requires AVX2 support.
176    ///
177    /// # Performance Characteristics
178    ///
179    /// - **SIMD Processing**: 32 elements per iteration with 4x unrolling
180    /// - **Memory Access**: Linear access patterns for cache efficiency
181    /// - **Vector Operations**: Uses AVX2 sqrt instructions for square root computation
182    /// - **Fallback**: Handles remaining elements with scalar operations
183    /// - **Hardware Requirements**: Requires x86_64 with AVX2 support
184    ///
185    /// # Implementation Details
186    ///
187    /// Uses AVX2 vector square root operations to compute sqrt(x) efficiently.
188    /// Implements 4x unrolling for optimal instruction throughput and cache utilization.
189    /// Processes remaining elements with scalar operations for complete coverage.
190    #[cfg(target_arch = "x86_64")]
191    #[inline]
192    #[target_feature(enable = "avx2")]
193    unsafe fn sqrt_simd_avx2_optimized(&self, src: *const f32, dst: *mut f32) {
194        let size = self.size();
195        let simd_count = size / 32; // Process 32 elements per iteration (4x unroll)
196        let mut offset = 0;
197
198        // Unrolled SIMD loop for maximum throughput
199        for _ in 0..simd_count {
200            // Process 4 AVX2 vectors (32 elements) per iteration
201            let src_vec1 = _mm256_loadu_ps(src.add(offset));
202            let sqrt_vec1 = _mm256_sqrt_ps(src_vec1);
203            _mm256_storeu_ps(dst.add(offset), sqrt_vec1);
204
205            let src_vec2 = _mm256_loadu_ps(src.add(offset + 8));
206            let sqrt_vec2 = _mm256_sqrt_ps(src_vec2);
207            _mm256_storeu_ps(dst.add(offset + 8), sqrt_vec2);
208
209            let src_vec3 = _mm256_loadu_ps(src.add(offset + 16));
210            let sqrt_vec3 = _mm256_sqrt_ps(src_vec3);
211            _mm256_storeu_ps(dst.add(offset + 16), sqrt_vec3);
212
213            let src_vec4 = _mm256_loadu_ps(src.add(offset + 24));
214            let sqrt_vec4 = _mm256_sqrt_ps(src_vec4);
215            _mm256_storeu_ps(dst.add(offset + 24), sqrt_vec4);
216
217            offset += 32;
218        }
219
220        // Handle remaining 8-element blocks
221        let remaining_full_blocks = (size - offset) / 8;
222        for _ in 0..remaining_full_blocks {
223            let src_vec = _mm256_loadu_ps(src.add(offset));
224            let sqrt_vec = _mm256_sqrt_ps(src_vec);
225            _mm256_storeu_ps(dst.add(offset), sqrt_vec);
226            offset += 8;
227        }
228
229        // Handle remaining elements with scalar fallback
230        for i in offset..size {
231            *dst.add(i) = (*src.add(i)).sqrt();
232        }
233    }
234
235    /// Optimized scalar square root fallback
236    ///
237    /// Performs element-wise square root using optimized scalar operations with
238    /// 4x unrolling for better instruction-level parallelism and cache efficiency.
239    ///
240    /// # Arguments
241    ///
242    /// * `src` - Pointer to source tensor data
243    /// * `dst` - Pointer to output tensor data
244    ///
245    /// # Safety
246    ///
247    /// Requires valid pointers with sufficient memory for the tensor size.
248    /// All pointers must point to valid tensor data.
249    ///
250    /// # Performance Characteristics
251    ///
252    /// - **Unrolling**: 4x unrolling for instruction-level parallelism
253    /// - **Memory Access**: Linear access patterns for cache efficiency
254    /// - **Fallback**: Handles remaining elements with scalar operations
255    /// - **Cache Optimization**: Optimized for modern CPU cache hierarchies
256    /// - **Mathematical Accuracy**: High-precision scalar square root computation
257    ///
258    /// # Implementation Details
259    ///
260    /// Uses 4x unrolled scalar operations for optimal performance on non-SIMD hardware.
261    /// Processes elements in groups of 4 to improve instruction-level parallelism
262    /// and reduce loop overhead.
263    #[inline]
264    unsafe fn sqrt_scalar_optimized(&self, src: *const f32, dst: *mut f32) {
265        let size = self.size();
266        let unroll_count = size / 4;
267        let mut offset = 0;
268
269        // Unrolled scalar loop for better performance
270        for _ in 0..unroll_count {
271            *dst.add(offset) = (*src.add(offset)).sqrt();
272            *dst.add(offset + 1) = (*src.add(offset + 1)).sqrt();
273            *dst.add(offset + 2) = (*src.add(offset + 2)).sqrt();
274            *dst.add(offset + 3) = (*src.add(offset + 3)).sqrt();
275            offset += 4;
276        }
277
278        // Handle remaining elements
279        for i in offset..size {
280            *dst.add(i) = (*src.add(i)).sqrt();
281        }
282    }
283}
284
285#[cfg(test)]
286mod tests {
287    use super::*;
288
289    #[test]
290    fn test_sqrt_basic() {
291        let x = Tensor::from_slice(&[0.0, 1.0, 4.0, 9.0], vec![2, 2]).unwrap();
292        let y = x.sqrt_optimized();
293        unsafe {
294            let yd = std::slice::from_raw_parts(y.as_ptr(), y.size());
295            assert!((yd[0] - 0.0).abs() < 1e-6);
296            assert!((yd[1] - 1.0).abs() < 1e-6);
297            assert!((yd[2] - 2.0).abs() < 1e-6);
298            assert!((yd[3] - 3.0).abs() < 1e-6);
299        }
300    }
301}