train_station/tensor/ops/sqrt.rs
1//! Square root operation for tensors
2//!
3//! Provides element-wise square root following PyTorch conventions with
4//! comprehensive GradTrack support and SIMD-optimized computation.
5//!
6//! # Key Features
7//!
8//! - **Square Root**: `sqrt()` - Computes square root for each element (PyTorch `sqrt()` equivalent)
9//! - **GradTrack Support**: Full automatic differentiation with efficient gradient computation
10//! - **SIMD Optimization**: AVX2-optimized implementation for maximum performance
11//! - **Mathematical Accuracy**: High-precision square root computation
12//! - **Domain Validation**: Handles negative values appropriately
13//! - **Performance Optimization**: 4x unrolled SIMD operations with scalar fallback
14//!
15//! # Mathematical Properties
16//!
17//! The square root function has the following properties:
18//! - **Definition**: f(x) = √x
19//! - **Domain**: [0, ∞) - defined for non-negative real numbers
20//! - **Range**: [0, ∞) - outputs are always non-negative
21//! - **Monotonicity**: Strictly increasing function
22//! - **Continuity**: Continuous on its domain
23//! - **Gradient**: f'(x) = 0.5 / √x for x > 0
24//! - **Special Cases**: f(0) = 0, f(1) = 1
25//!
26//! # Performance Characteristics
27//!
28//! - **SIMD Optimization**: AVX2-optimized with 32-element blocks and 4x unrolling
29//! - **Scalar Fallback**: 4x unrolled scalar implementation for non-SIMD hardware
30//! - **Cache-friendly Access**: Linear memory access patterns
31//! - **Mathematical Accuracy**: High-precision square root computation
32//! - **GradTrack Optimization**: Efficient automatic differentiation with NoGradTrack support
33
34use crate::gradtrack::{is_grad_enabled, GradEngine, GradFn};
35use crate::tensor::core::Tensor;
36
37// SIMD optimizations for performance-critical operations
38#[cfg(target_arch = "x86_64")]
39use std::arch::x86_64::*;
40
41impl Tensor {
42 /// Element-wise square root
43 ///
44 /// Computes the square root for each element: `output[i] = sqrt(self[i])`
45 ///
46 /// Uses SIMD optimization when available for maximum performance, with automatic
47 /// fallback to optimized scalar computation for non-SIMD hardware.
48 ///
49 /// # Returns
50 ///
51 /// A new tensor with the square root of each element
52 ///
53 /// # Performance Characteristics
54 ///
55 /// - **SIMD Optimization**: AVX2-optimized with 32-element blocks and 4x unrolling
56 /// - **Scalar Fallback**: 4x unrolled scalar implementation for non-SIMD hardware
57 /// - **Cache-friendly**: Linear memory access patterns
58 /// - **Mathematical Accuracy**: High-precision square root computation
59 /// - **GradTrack Support**: Full automatic differentiation with efficient gradient computation
60 ///
61 /// # Implementation Details
62 ///
63 /// Automatically selects between SIMD and scalar implementations based on hardware
64 /// capabilities. SIMD implementation uses AVX2 vector square root operations for optimal
65 /// performance. Scalar implementation uses 4x unrolling for better instruction-level
66 /// parallelism.
67 ///
68 /// # Examples
69 ///
70 /// ## Basic Square Root
71 ///
72 /// ```
73 /// use train_station::Tensor;
74 ///
75 /// let a = Tensor::from_slice(&[1.0, 4.0, 9.0], vec![3]).unwrap();
76 /// let b = a.sqrt();
77 /// assert_eq!(b.shape().dims, vec![3]);
78 /// assert_eq!(b.get(&[0]), 1.0); // sqrt(1.0) = 1.0
79 /// assert_eq!(b.get(&[1]), 2.0); // sqrt(4.0) = 2.0
80 /// assert_eq!(b.get(&[2]), 3.0); // sqrt(9.0) = 3.0
81 /// ```
82 ///
83 /// ## Zero and Special Values
84 ///
85 /// ```
86 /// use train_station::Tensor;
87 ///
88 /// let a = Tensor::from_slice(&[0.0, 1.0, 16.0], vec![3]).unwrap();
89 /// let b = a.sqrt();
90 /// assert_eq!(b.shape().dims, vec![3]);
91 /// assert_eq!(b.get(&[0]), 0.0); // sqrt(0.0) = 0.0
92 /// assert_eq!(b.get(&[1]), 1.0); // sqrt(1.0) = 1.0
93 /// assert_eq!(b.get(&[2]), 4.0); // sqrt(16.0) = 4.0
94 /// ```
95 ///
96 /// # Note
97 /// Results are undefined for negative values (may produce NaN)
98 #[inline]
99 pub fn sqrt(&self) -> Tensor {
100 let mut result = self.sqrt_optimized();
101 if self.requires_grad() && is_grad_enabled() {
102 result.set_requires_grad_internal(true);
103 let grad_fn = GradFn::Sqrt {
104 saved_output: Box::new(result.clone()),
105 };
106 result.set_grad_fn(grad_fn.clone());
107 GradEngine::register_operation(result.id(), vec![self.id()], grad_fn);
108 }
109 result
110 }
111 /// Internal optimized square root operation
112 ///
113 /// Performs element-wise square root using SIMD optimization when available
114 /// and falling back to optimized scalar computation. This is the core implementation
115 /// used by `sqrt()`.
116 ///
117 /// # Returns
118 ///
119 /// A new tensor containing the square root of each element
120 ///
121 /// # Performance Characteristics
122 ///
123 /// - **SIMD Optimization**: AVX2-optimized with 32-element blocks when available
124 /// - **Scalar Fallback**: 4x unrolled scalar implementation for non-SIMD hardware
125 /// - **Cache-friendly**: Linear memory access patterns
126 /// - **Mathematical Accuracy**: High-precision square root computation
127 /// - **Zero-sized Handling**: Fast return for empty tensors
128 ///
129 /// # Implementation Details
130 ///
131 /// Automatically selects between SIMD and scalar implementations based on hardware
132 /// capabilities. SIMD implementation uses AVX2 vector square root operations for optimal
133 /// performance. Scalar implementation uses 4x unrolling for better instruction-level
134 /// parallelism.
135 #[inline]
136 pub(crate) fn sqrt_optimized(&self) -> Tensor {
137 let mut output = Tensor::new(self.shape().dims.clone());
138
139 if self.size() == 0 {
140 return output;
141 }
142
143 unsafe {
144 let src = self.as_ptr();
145 let dst = output.as_mut_ptr();
146
147 #[cfg(target_arch = "x86_64")]
148 {
149 if is_x86_feature_detected!("avx2") {
150 self.sqrt_simd_avx2_optimized(src, dst);
151 return output;
152 }
153 }
154
155 // Scalar fallback
156 self.sqrt_scalar_optimized(src, dst);
157 }
158
159 output
160 }
161
162 /// AVX2-optimized square root implementation
163 ///
164 /// Performs element-wise square root using AVX2 SIMD instructions for maximum
165 /// performance on x86_64 architectures with AVX2 support.
166 ///
167 /// # Arguments
168 ///
169 /// * `src` - Pointer to source tensor data
170 /// * `dst` - Pointer to output tensor data
171 ///
172 /// # Safety
173 ///
174 /// Requires valid pointers with sufficient memory for the tensor size.
175 /// All pointers must point to valid tensor data. Requires AVX2 support.
176 ///
177 /// # Performance Characteristics
178 ///
179 /// - **SIMD Processing**: 32 elements per iteration with 4x unrolling
180 /// - **Memory Access**: Linear access patterns for cache efficiency
181 /// - **Vector Operations**: Uses AVX2 sqrt instructions for square root computation
182 /// - **Fallback**: Handles remaining elements with scalar operations
183 /// - **Hardware Requirements**: Requires x86_64 with AVX2 support
184 ///
185 /// # Implementation Details
186 ///
187 /// Uses AVX2 vector square root operations to compute sqrt(x) efficiently.
188 /// Implements 4x unrolling for optimal instruction throughput and cache utilization.
189 /// Processes remaining elements with scalar operations for complete coverage.
190 #[cfg(target_arch = "x86_64")]
191 #[inline]
192 #[target_feature(enable = "avx2")]
193 unsafe fn sqrt_simd_avx2_optimized(&self, src: *const f32, dst: *mut f32) {
194 let size = self.size();
195 let simd_count = size / 32; // Process 32 elements per iteration (4x unroll)
196 let mut offset = 0;
197
198 // Unrolled SIMD loop for maximum throughput
199 for _ in 0..simd_count {
200 // Process 4 AVX2 vectors (32 elements) per iteration
201 let src_vec1 = _mm256_loadu_ps(src.add(offset));
202 let sqrt_vec1 = _mm256_sqrt_ps(src_vec1);
203 _mm256_storeu_ps(dst.add(offset), sqrt_vec1);
204
205 let src_vec2 = _mm256_loadu_ps(src.add(offset + 8));
206 let sqrt_vec2 = _mm256_sqrt_ps(src_vec2);
207 _mm256_storeu_ps(dst.add(offset + 8), sqrt_vec2);
208
209 let src_vec3 = _mm256_loadu_ps(src.add(offset + 16));
210 let sqrt_vec3 = _mm256_sqrt_ps(src_vec3);
211 _mm256_storeu_ps(dst.add(offset + 16), sqrt_vec3);
212
213 let src_vec4 = _mm256_loadu_ps(src.add(offset + 24));
214 let sqrt_vec4 = _mm256_sqrt_ps(src_vec4);
215 _mm256_storeu_ps(dst.add(offset + 24), sqrt_vec4);
216
217 offset += 32;
218 }
219
220 // Handle remaining 8-element blocks
221 let remaining_full_blocks = (size - offset) / 8;
222 for _ in 0..remaining_full_blocks {
223 let src_vec = _mm256_loadu_ps(src.add(offset));
224 let sqrt_vec = _mm256_sqrt_ps(src_vec);
225 _mm256_storeu_ps(dst.add(offset), sqrt_vec);
226 offset += 8;
227 }
228
229 // Handle remaining elements with scalar fallback
230 for i in offset..size {
231 *dst.add(i) = (*src.add(i)).sqrt();
232 }
233 }
234
235 /// Optimized scalar square root fallback
236 ///
237 /// Performs element-wise square root using optimized scalar operations with
238 /// 4x unrolling for better instruction-level parallelism and cache efficiency.
239 ///
240 /// # Arguments
241 ///
242 /// * `src` - Pointer to source tensor data
243 /// * `dst` - Pointer to output tensor data
244 ///
245 /// # Safety
246 ///
247 /// Requires valid pointers with sufficient memory for the tensor size.
248 /// All pointers must point to valid tensor data.
249 ///
250 /// # Performance Characteristics
251 ///
252 /// - **Unrolling**: 4x unrolling for instruction-level parallelism
253 /// - **Memory Access**: Linear access patterns for cache efficiency
254 /// - **Fallback**: Handles remaining elements with scalar operations
255 /// - **Cache Optimization**: Optimized for modern CPU cache hierarchies
256 /// - **Mathematical Accuracy**: High-precision scalar square root computation
257 ///
258 /// # Implementation Details
259 ///
260 /// Uses 4x unrolled scalar operations for optimal performance on non-SIMD hardware.
261 /// Processes elements in groups of 4 to improve instruction-level parallelism
262 /// and reduce loop overhead.
263 #[inline]
264 unsafe fn sqrt_scalar_optimized(&self, src: *const f32, dst: *mut f32) {
265 let size = self.size();
266 let unroll_count = size / 4;
267 let mut offset = 0;
268
269 // Unrolled scalar loop for better performance
270 for _ in 0..unroll_count {
271 *dst.add(offset) = (*src.add(offset)).sqrt();
272 *dst.add(offset + 1) = (*src.add(offset + 1)).sqrt();
273 *dst.add(offset + 2) = (*src.add(offset + 2)).sqrt();
274 *dst.add(offset + 3) = (*src.add(offset + 3)).sqrt();
275 offset += 4;
276 }
277
278 // Handle remaining elements
279 for i in offset..size {
280 *dst.add(i) = (*src.add(i)).sqrt();
281 }
282 }
283}
284
285#[cfg(test)]
286mod tests {
287 use super::*;
288
289 #[test]
290 fn test_sqrt_basic() {
291 let x = Tensor::from_slice(&[0.0, 1.0, 4.0, 9.0], vec![2, 2]).unwrap();
292 let y = x.sqrt_optimized();
293 unsafe {
294 let yd = std::slice::from_raw_parts(y.as_ptr(), y.size());
295 assert!((yd[0] - 0.0).abs() < 1e-6);
296 assert!((yd[1] - 1.0).abs() < 1e-6);
297 assert!((yd[2] - 2.0).abs() < 1e-6);
298 assert!((yd[3] - 3.0).abs() < 1e-6);
299 }
300 }
301}