train_station/tensor/ops/sqrt.rs
1//! Square root operation for tensors
2//!
3//! Provides element-wise square root following PyTorch conventions with
4//! comprehensive GradTrack support and SIMD-optimized computation.
5//!
6//! # Key Features
7//!
8//! - **Square Root**: `sqrt()` - Computes square root for each element (PyTorch `sqrt()` equivalent)
9//! - **GradTrack Support**: Full automatic differentiation with efficient gradient computation
10//! - **SIMD Optimization**: AVX2-optimized implementation for maximum performance
11//! - **Mathematical Accuracy**: High-precision square root computation
12//! - **Domain Validation**: Handles negative values appropriately
13//! - **Performance Optimization**: 4x unrolled SIMD operations with scalar fallback
14//!
15//! # Mathematical Properties
16//!
17//! The square root function has the following properties:
18//! - **Definition**: f(x) = √x
19//! - **Domain**: [0, ∞) - defined for non-negative real numbers
20//! - **Range**: [0, ∞) - outputs are always non-negative
21//! - **Monotonicity**: Strictly increasing function
22//! - **Continuity**: Continuous on its domain
23//! - **Gradient**: f'(x) = 0.5 / √x for x > 0
24//! - **Special Cases**: f(0) = 0, f(1) = 1
25//!
26//! # Performance Characteristics
27//!
28//! - **SIMD Optimization**: AVX2-optimized with 32-element blocks and 4x unrolling
29//! - **Scalar Fallback**: 4x unrolled scalar implementation for non-SIMD hardware
30//! - **Cache-friendly Access**: Linear memory access patterns
31//! - **Mathematical Accuracy**: High-precision square root computation
32//! - **GradTrack Optimization**: Efficient automatic differentiation with NoGradTrack support
33
34use crate::gradtrack::{is_grad_enabled, GradEngine, GradFn};
35use crate::tensor::core::Tensor;
36
37// SIMD optimizations for performance-critical operations
38#[cfg(target_arch = "x86_64")]
39use std::arch::x86_64::*;
40
41impl Tensor {
42 /// Element-wise square root
43 ///
44 /// Computes the square root for each element: `output[i] = sqrt(self[i])`
45 ///
46 /// Uses SIMD optimization when available for maximum performance, with automatic
47 /// fallback to optimized scalar computation for non-SIMD hardware.
48 ///
49 /// # Returns
50 ///
51 /// A new tensor with the square root of each element
52 ///
53 /// # Performance Characteristics
54 ///
55 /// - **SIMD Optimization**: AVX2-optimized with 32-element blocks and 4x unrolling
56 /// - **Scalar Fallback**: 4x unrolled scalar implementation for non-SIMD hardware
57 /// - **Cache-friendly**: Linear memory access patterns
58 /// - **Mathematical Accuracy**: High-precision square root computation
59 /// - **GradTrack Support**: Full automatic differentiation with efficient gradient computation
60 ///
61 /// # Implementation Details
62 ///
63 /// Automatically selects between SIMD and scalar implementations based on hardware
64 /// capabilities. SIMD implementation uses AVX2 vector square root operations for optimal
65 /// performance. Scalar implementation uses 4x unrolling for better instruction-level
66 /// parallelism.
67 ///
68 /// # Examples
69 ///
70 /// ## Basic Square Root
71 ///
72 /// ```
73 /// use train_station::Tensor;
74 ///
75 /// let a = Tensor::from_slice(&[1.0, 4.0, 9.0], vec![3]).unwrap();
76 /// let b = a.sqrt();
77 /// assert_eq!(b.shape().dims(), vec![3]);
78 /// assert_eq!(b.get(&[0]), 1.0); // sqrt(1.0) = 1.0
79 /// assert_eq!(b.get(&[1]), 2.0); // sqrt(4.0) = 2.0
80 /// assert_eq!(b.get(&[2]), 3.0); // sqrt(9.0) = 3.0
81 /// ```
82 ///
83 /// ## Zero and Special Values
84 ///
85 /// ```
86 /// use train_station::Tensor;
87 ///
88 /// let a = Tensor::from_slice(&[0.0, 1.0, 16.0], vec![3]).unwrap();
89 /// let b = a.sqrt();
90 /// assert_eq!(b.shape().dims(), vec![3]);
91 /// assert_eq!(b.get(&[0]), 0.0); // sqrt(0.0) = 0.0
92 /// assert_eq!(b.get(&[1]), 1.0); // sqrt(1.0) = 1.0
93 /// assert_eq!(b.get(&[2]), 4.0); // sqrt(16.0) = 4.0
94 /// ```
95 ///
96 /// # Note
97 /// Results are undefined for negative values (may produce NaN)
98 #[inline]
99 #[track_caller]
100 pub fn sqrt(&self) -> Tensor {
101 let mut result = self.sqrt_optimized();
102 if self.requires_grad() && is_grad_enabled() {
103 result.set_requires_grad_internal(true);
104 let grad_fn = GradFn::Sqrt {
105 saved_output: Box::new(result.clone()),
106 };
107 result.set_grad_fn(grad_fn.clone());
108 GradEngine::register_operation(result.id(), vec![self.id()], grad_fn);
109 }
110 result
111 }
112 /// Internal optimized square root operation
113 ///
114 /// Performs element-wise square root using SIMD optimization when available
115 /// and falling back to optimized scalar computation. This is the core implementation
116 /// used by `sqrt()`.
117 ///
118 /// # Returns
119 ///
120 /// A new tensor containing the square root of each element
121 ///
122 /// # Performance Characteristics
123 ///
124 /// - **SIMD Optimization**: AVX2-optimized with 32-element blocks when available
125 /// - **Scalar Fallback**: 4x unrolled scalar implementation for non-SIMD hardware
126 /// - **Cache-friendly**: Linear memory access patterns
127 /// - **Mathematical Accuracy**: High-precision square root computation
128 /// - **Zero-sized Handling**: Fast return for empty tensors
129 ///
130 /// # Implementation Details
131 ///
132 /// Automatically selects between SIMD and scalar implementations based on hardware
133 /// capabilities. SIMD implementation uses AVX2 vector square root operations for optimal
134 /// performance. Scalar implementation uses 4x unrolling for better instruction-level
135 /// parallelism.
136 #[inline]
137 pub(crate) fn sqrt_optimized(&self) -> Tensor {
138 let mut output = Tensor::new(self.shape().dims().to_vec());
139
140 if self.size() == 0 {
141 return output;
142 }
143
144 unsafe {
145 let src = self.as_ptr();
146 let dst = output.as_mut_ptr();
147
148 #[cfg(target_arch = "x86_64")]
149 {
150 if is_x86_feature_detected!("avx2") {
151 self.sqrt_simd_avx2_optimized(src, dst);
152 return output;
153 }
154 }
155
156 // Scalar fallback
157 self.sqrt_scalar_optimized(src, dst);
158 }
159
160 output
161 }
162
163 /// AVX2-optimized square root implementation
164 ///
165 /// Performs element-wise square root using AVX2 SIMD instructions for maximum
166 /// performance on x86_64 architectures with AVX2 support.
167 ///
168 /// # Arguments
169 ///
170 /// * `src` - Pointer to source tensor data
171 /// * `dst` - Pointer to output tensor data
172 ///
173 /// # Safety
174 ///
175 /// Requires valid pointers with sufficient memory for the tensor size.
176 /// All pointers must point to valid tensor data. Requires AVX2 support.
177 ///
178 /// # Performance Characteristics
179 ///
180 /// - **SIMD Processing**: 32 elements per iteration with 4x unrolling
181 /// - **Memory Access**: Linear access patterns for cache efficiency
182 /// - **Vector Operations**: Uses AVX2 sqrt instructions for square root computation
183 /// - **Fallback**: Handles remaining elements with scalar operations
184 /// - **Hardware Requirements**: Requires x86_64 with AVX2 support
185 ///
186 /// # Implementation Details
187 ///
188 /// Uses AVX2 vector square root operations to compute sqrt(x) efficiently.
189 /// Implements 4x unrolling for optimal instruction throughput and cache utilization.
190 /// Processes remaining elements with scalar operations for complete coverage.
191 #[cfg(target_arch = "x86_64")]
192 #[inline]
193 #[target_feature(enable = "avx2")]
194 unsafe fn sqrt_simd_avx2_optimized(&self, src: *const f32, dst: *mut f32) {
195 let size = self.size();
196 let simd_count = size / 32; // Process 32 elements per iteration (4x unroll)
197 let mut offset = 0;
198
199 // Unrolled SIMD loop for maximum throughput
200 for _ in 0..simd_count {
201 // Process 4 AVX2 vectors (32 elements) per iteration
202 let src_vec1 = _mm256_loadu_ps(src.add(offset));
203 let sqrt_vec1 = _mm256_sqrt_ps(src_vec1);
204 _mm256_storeu_ps(dst.add(offset), sqrt_vec1);
205
206 let src_vec2 = _mm256_loadu_ps(src.add(offset + 8));
207 let sqrt_vec2 = _mm256_sqrt_ps(src_vec2);
208 _mm256_storeu_ps(dst.add(offset + 8), sqrt_vec2);
209
210 let src_vec3 = _mm256_loadu_ps(src.add(offset + 16));
211 let sqrt_vec3 = _mm256_sqrt_ps(src_vec3);
212 _mm256_storeu_ps(dst.add(offset + 16), sqrt_vec3);
213
214 let src_vec4 = _mm256_loadu_ps(src.add(offset + 24));
215 let sqrt_vec4 = _mm256_sqrt_ps(src_vec4);
216 _mm256_storeu_ps(dst.add(offset + 24), sqrt_vec4);
217
218 offset += 32;
219 }
220
221 // Handle remaining 8-element blocks
222 let remaining_full_blocks = (size - offset) / 8;
223 for _ in 0..remaining_full_blocks {
224 let src_vec = _mm256_loadu_ps(src.add(offset));
225 let sqrt_vec = _mm256_sqrt_ps(src_vec);
226 _mm256_storeu_ps(dst.add(offset), sqrt_vec);
227 offset += 8;
228 }
229
230 // Handle remaining elements with scalar fallback
231 for i in offset..size {
232 *dst.add(i) = (*src.add(i)).sqrt();
233 }
234 }
235
236 /// Optimized scalar square root fallback
237 ///
238 /// Performs element-wise square root using optimized scalar operations with
239 /// 4x unrolling for better instruction-level parallelism and cache efficiency.
240 ///
241 /// # Arguments
242 ///
243 /// * `src` - Pointer to source tensor data
244 /// * `dst` - Pointer to output tensor data
245 ///
246 /// # Safety
247 ///
248 /// Requires valid pointers with sufficient memory for the tensor size.
249 /// All pointers must point to valid tensor data.
250 ///
251 /// # Performance Characteristics
252 ///
253 /// - **Unrolling**: 4x unrolling for instruction-level parallelism
254 /// - **Memory Access**: Linear access patterns for cache efficiency
255 /// - **Fallback**: Handles remaining elements with scalar operations
256 /// - **Cache Optimization**: Optimized for modern CPU cache hierarchies
257 /// - **Mathematical Accuracy**: High-precision scalar square root computation
258 ///
259 /// # Implementation Details
260 ///
261 /// Uses 4x unrolled scalar operations for optimal performance on non-SIMD hardware.
262 /// Processes elements in groups of 4 to improve instruction-level parallelism
263 /// and reduce loop overhead.
264 #[inline]
265 unsafe fn sqrt_scalar_optimized(&self, src: *const f32, dst: *mut f32) {
266 let size = self.size();
267 let unroll_count = size / 4;
268 let mut offset = 0;
269
270 // Unrolled scalar loop for better performance
271 for _ in 0..unroll_count {
272 *dst.add(offset) = (*src.add(offset)).sqrt();
273 *dst.add(offset + 1) = (*src.add(offset + 1)).sqrt();
274 *dst.add(offset + 2) = (*src.add(offset + 2)).sqrt();
275 *dst.add(offset + 3) = (*src.add(offset + 3)).sqrt();
276 offset += 4;
277 }
278
279 // Handle remaining elements
280 for i in offset..size {
281 *dst.add(i) = (*src.add(i)).sqrt();
282 }
283 }
284}
285
286#[cfg(test)]
287mod tests {
288 use super::*;
289
290 #[test]
291 fn test_sqrt_basic() {
292 let x = Tensor::from_slice(&[0.0, 1.0, 4.0, 9.0], vec![2, 2]).unwrap();
293 let y = x.sqrt_optimized();
294 unsafe {
295 let yd = std::slice::from_raw_parts(y.as_ptr(), y.size());
296 assert!((yd[0] - 0.0).abs() < 1e-6);
297 assert!((yd[1] - 1.0).abs() < 1e-6);
298 assert!((yd[2] - 2.0).abs() < 1e-6);
299 assert!((yd[3] - 3.0).abs() < 1e-6);
300 }
301 }
302}