train_station/tensor/ops/
pow.rs

1//! Power operations for tensors
2//!
3//! Provides element-wise power functions following PyTorch conventions with
4//! comprehensive automatic differentiation support and SIMD-optimized computation.
5//!
6//! # Key Features
7//!
8//! - **Scalar Power**: `pow_scalar(exponent)` - Raises each element to a scalar power (PyTorch `pow(tensor, scalar)` equivalent)
9//! - **Tensor Power**: `pow_tensor(exponent)` - Element-wise power with tensor exponents (PyTorch `pow(tensor, tensor)` equivalent)
10//! - **Automatic Differentiation**: Full gradtrack support with efficient gradient computation
11//! - **SIMD Optimization**: AVX2-optimized implementation for common cases (x^2, x^0.5)
12//! - **Smart Dispatch**: Optimized paths for common exponents (2.0, 0.5) with scalar fallback for others
13//! - **Cache Optimization**: Memory access patterns optimized for modern CPUs
14//! - **Mathematical Accuracy**: High-precision power computation
15//!
16//! # Mathematical Properties
17//!
18//! The power operations have the following properties:
19//! - **Power Laws**: (x^a)^b = x^(a*b), x^a * x^b = x^(a+b)
20//! - **Special Cases**: x^0 = 1, x^1 = x, x^2 = x*x, x^0.5 = sqrt(x)
21//! - **Domain**: x^a is defined for x > 0 when a is not an integer
22//! - **Gradient**: d/dx(x^a) = a * x^(a-1) for scalar power
23//! - **Gradient**: d/dx(x^y) = y * x^(y-1), d/dy(x^y) = x^y * ln(x) for tensor power
24//!
25//! # Performance Characteristics
26//!
27//! - **SIMD Optimization**: AVX2-optimized for x^2 and x^0.5 with 32-element blocks and 4x unrolling
28//! - **Smart Dispatch**: Fast paths for common exponents (2.0, 0.5) with scalar fallback for others
29//! - **Scalar Fallback**: 4x unrolled scalar implementation for non-SIMD hardware and general exponents
30//! - **Cache-friendly Access**: Linear memory access patterns
31//! - **Branch Prediction**: Optimized conditional logic for modern CPUs
32//! - **Gradient Optimization**: Efficient gradtrack with NoGradTrack support
33
34use crate::gradtrack::{is_grad_enabled, GradEngine, GradFn};
35use crate::tensor::core::Tensor;
36
37// SIMD optimizations for performance-critical operations
38#[cfg(target_arch = "x86_64")]
39use std::arch::x86_64::*;
40
41impl Tensor {
42    /// Raises each element to a scalar power.
43    ///
44    /// Computes element-wise power: `output[i] = self[i]^exponent`
45    ///
46    /// # Arguments
47    /// * `exponent` - The scalar exponent to raise each element to
48    ///
49    /// # Returns
50    /// A new tensor with each element raised to the given power
51    ///
52    /// # Examples
53    ///
54    /// ## Basic Scalar Power
55    ///
56    /// ```
57    /// use train_station::Tensor;
58    ///
59    /// let a = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
60    /// let b = a.pow_scalar(2.0);
61    /// assert_eq!(b.shape().dims, vec![3]);
62    /// assert_eq!(b.get(&[0]), 1.0); // 1.0^2 = 1.0
63    /// assert_eq!(b.get(&[1]), 4.0); // 2.0^2 = 4.0
64    /// assert_eq!(b.get(&[2]), 9.0); // 3.0^2 = 9.0
65    /// ```
66    ///
67    /// ## Square Root (Power 0.5)
68    ///
69    /// ```
70    /// use train_station::Tensor;
71    ///
72    /// let a = Tensor::from_slice(&[1.0, 4.0, 9.0], vec![3]).unwrap();
73    /// let b = a.pow_scalar(0.5);
74    /// assert_eq!(b.shape().dims, vec![3]);
75    /// assert_eq!(b.get(&[0]), 1.0); // sqrt(1.0) = 1.0
76    /// assert_eq!(b.get(&[1]), 2.0); // sqrt(4.0) = 2.0
77    /// assert_eq!(b.get(&[2]), 3.0); // sqrt(9.0) = 3.0
78    /// ```
79    pub fn pow_scalar(&self, exponent: f32) -> Tensor {
80        let mut out = self.pow_scalar_optimized(exponent);
81
82        if self.requires_grad() && is_grad_enabled() {
83            out.set_requires_grad_internal(true);
84            let grad_fn = GradFn::PowScalar {
85                exponent,
86                saved_input: Box::new(self.clone()),
87            };
88            out.set_grad_fn(grad_fn.clone());
89            GradEngine::register_operation(out.id(), vec![self.id()], grad_fn);
90        }
91
92        out
93    }
94
95    /// Internal optimized scalar power operation
96    ///
97    /// Performs element-wise scalar power computation using smart dispatch for common
98    /// exponents and optimized scalar computation. This is the core implementation
99    /// used by `pow_scalar()`.
100    ///
101    /// # Arguments
102    ///
103    /// * `exponent` - The scalar exponent to raise each element to
104    ///
105    /// # Returns
106    ///
107    /// A new tensor containing each element raised to the given power
108    ///
109    /// # Performance Characteristics
110    ///
111    /// - **Smart Dispatch**: Fast paths for common exponents (2.0, 0.5) with SIMD optimization
112    /// - **SIMD Optimization**: AVX2-optimized for x^2 and x^0.5 when available
113    /// - **Scalar Fallback**: 4x unrolled scalar implementation for non-SIMD hardware and general exponents
114    /// - **Cache-friendly**: Linear memory access patterns
115    /// - **Mathematical Accuracy**: High-precision power computation
116    /// - **Zero-sized Handling**: Fast return for empty tensors
117    ///
118    /// # Implementation Details
119    ///
120    /// Uses smart dispatch to optimize common cases:
121    /// - `exponent == 2.0`: Uses SIMD multiplication for x^2
122    /// - `exponent == 0.5`: Uses SIMD square root for x^0.5
123    /// - Other exponents: Uses scalar `powf()` for accuracy
124    #[inline]
125    pub(crate) fn pow_scalar_optimized(&self, exponent: f32) -> Tensor {
126        let mut output = Tensor::new(self.shape().dims.clone());
127
128        if self.size() == 0 {
129            return output;
130        }
131
132        unsafe {
133            let src = self.as_ptr();
134            let dst = output.as_mut_ptr();
135
136            // Handle common cases with SIMD optimizations
137            if exponent == 2.0 {
138                #[cfg(target_arch = "x86_64")]
139                {
140                    if is_x86_feature_detected!("avx2") {
141                        self.pow_square_simd_avx2_optimized(src, dst);
142                        return output;
143                    }
144                }
145                self.pow_square_scalar_optimized(src, dst);
146            } else if exponent == 0.5 {
147                #[cfg(target_arch = "x86_64")]
148                {
149                    if is_x86_feature_detected!("avx2") {
150                        self.pow_sqrt_simd_avx2_optimized(src, dst);
151                        return output;
152                    }
153                }
154                self.pow_sqrt_scalar_optimized(src, dst);
155            } else {
156                // General case - use scalar fallback for accuracy
157                self.pow_general_scalar_optimized(src, dst, exponent);
158            }
159        }
160
161        output
162    }
163
164    /// AVX2-optimized square implementation (x^2)
165    ///
166    /// Performs element-wise squaring using AVX2 SIMD instructions for maximum
167    /// performance on x86_64 architectures with AVX2 support.
168    ///
169    /// # Arguments
170    ///
171    /// * `src` - Pointer to source tensor data
172    /// * `dst` - Pointer to output tensor data
173    ///
174    /// # Safety
175    ///
176    /// Requires valid pointers with sufficient memory for the tensor size.
177    /// All pointers must point to valid tensor data. Requires AVX2 support.
178    ///
179    /// # Performance Characteristics
180    ///
181    /// - **SIMD Processing**: 32 elements per iteration with 4x unrolling
182    /// - **Memory Access**: Linear access patterns for cache efficiency
183    /// - **Vector Operations**: Uses AVX2 multiplication instructions for x^2
184    /// - **Fallback**: Handles remaining elements with scalar operations
185    /// - **Hardware Requirements**: Requires x86_64 with AVX2 support
186    ///
187    /// # Implementation Details
188    ///
189    /// Uses AVX2 vector multiplication to compute x^2 efficiently.
190    /// Implements 4x unrolling for optimal instruction throughput and cache utilization.
191    /// Processes remaining elements with scalar operations for complete coverage.
192    #[cfg(target_arch = "x86_64")]
193    #[inline]
194    #[target_feature(enable = "avx2")]
195    unsafe fn pow_square_simd_avx2_optimized(&self, src: *const f32, dst: *mut f32) {
196        let size = self.size();
197        let simd_count = size / 32; // Process 32 elements per iteration (4x unroll)
198        let mut offset = 0;
199
200        // Unrolled SIMD loop for x^2
201        for _ in 0..simd_count {
202            // Process 4 AVX2 vectors (32 elements) per iteration
203            let src_vec1 = _mm256_loadu_ps(src.add(offset));
204            let square_vec1 = _mm256_mul_ps(src_vec1, src_vec1);
205            _mm256_storeu_ps(dst.add(offset), square_vec1);
206
207            let src_vec2 = _mm256_loadu_ps(src.add(offset + 8));
208            let square_vec2 = _mm256_mul_ps(src_vec2, src_vec2);
209            _mm256_storeu_ps(dst.add(offset + 8), square_vec2);
210
211            let src_vec3 = _mm256_loadu_ps(src.add(offset + 16));
212            let square_vec3 = _mm256_mul_ps(src_vec3, src_vec3);
213            _mm256_storeu_ps(dst.add(offset + 16), square_vec3);
214
215            let src_vec4 = _mm256_loadu_ps(src.add(offset + 24));
216            let square_vec4 = _mm256_mul_ps(src_vec4, src_vec4);
217            _mm256_storeu_ps(dst.add(offset + 24), square_vec4);
218
219            offset += 32;
220        }
221
222        // Handle remaining 8-element blocks
223        let remaining_full_blocks = (size - offset) / 8;
224        for _ in 0..remaining_full_blocks {
225            let src_vec = _mm256_loadu_ps(src.add(offset));
226            let square_vec = _mm256_mul_ps(src_vec, src_vec);
227            _mm256_storeu_ps(dst.add(offset), square_vec);
228            offset += 8;
229        }
230
231        // Handle remaining elements
232        for i in offset..size {
233            let v = *src.add(i);
234            *dst.add(i) = v * v;
235        }
236    }
237
238    /// AVX2-optimized square root implementation (x^0.5)
239    ///
240    /// Performs element-wise square root using AVX2 SIMD instructions for maximum
241    /// performance on x86_64 architectures with AVX2 support.
242    ///
243    /// # Arguments
244    ///
245    /// * `src` - Pointer to source tensor data
246    /// * `dst` - Pointer to output tensor data
247    ///
248    /// # Safety
249    ///
250    /// Requires valid pointers with sufficient memory for the tensor size.
251    /// All pointers must point to valid tensor data. Requires AVX2 support.
252    ///
253    /// # Performance Characteristics
254    ///
255    /// - **SIMD Processing**: 32 elements per iteration with 4x unrolling
256    /// - **Memory Access**: Linear access patterns for cache efficiency
257    /// - **Vector Operations**: Uses AVX2 square root instructions for x^0.5
258    /// - **Fallback**: Handles remaining elements with scalar operations
259    /// - **Hardware Requirements**: Requires x86_64 with AVX2 support
260    ///
261    /// # Implementation Details
262    ///
263    /// Uses AVX2 vector square root instructions to compute x^0.5 efficiently.
264    /// Implements 4x unrolling for optimal instruction throughput and cache utilization.
265    /// Processes remaining elements with scalar operations for complete coverage.
266    #[cfg(target_arch = "x86_64")]
267    #[inline]
268    #[target_feature(enable = "avx2")]
269    unsafe fn pow_sqrt_simd_avx2_optimized(&self, src: *const f32, dst: *mut f32) {
270        let size = self.size();
271        let simd_count = size / 32; // Process 32 elements per iteration (4x unroll)
272        let mut offset = 0;
273
274        // Unrolled SIMD loop for x^0.5 (sqrt)
275        for _ in 0..simd_count {
276            // Process 4 AVX2 vectors (32 elements) per iteration
277            let src_vec1 = _mm256_loadu_ps(src.add(offset));
278            let sqrt_vec1 = _mm256_sqrt_ps(src_vec1);
279            _mm256_storeu_ps(dst.add(offset), sqrt_vec1);
280
281            let src_vec2 = _mm256_loadu_ps(src.add(offset + 8));
282            let sqrt_vec2 = _mm256_sqrt_ps(src_vec2);
283            _mm256_storeu_ps(dst.add(offset + 8), sqrt_vec2);
284
285            let src_vec3 = _mm256_loadu_ps(src.add(offset + 16));
286            let sqrt_vec3 = _mm256_sqrt_ps(src_vec3);
287            _mm256_storeu_ps(dst.add(offset + 16), sqrt_vec3);
288
289            let src_vec4 = _mm256_loadu_ps(src.add(offset + 24));
290            let sqrt_vec4 = _mm256_sqrt_ps(src_vec4);
291            _mm256_storeu_ps(dst.add(offset + 24), sqrt_vec4);
292
293            offset += 32;
294        }
295
296        // Handle remaining 8-element blocks
297        let remaining_full_blocks = (size - offset) / 8;
298        for _ in 0..remaining_full_blocks {
299            let src_vec = _mm256_loadu_ps(src.add(offset));
300            let sqrt_vec = _mm256_sqrt_ps(src_vec);
301            _mm256_storeu_ps(dst.add(offset), sqrt_vec);
302            offset += 8;
303        }
304
305        // Handle remaining elements
306        for i in offset..size {
307            *dst.add(i) = (*src.add(i)).sqrt();
308        }
309    }
310
311    /// Optimized scalar square fallback (x^2)
312    ///
313    /// Performs element-wise squaring using optimized scalar operations with
314    /// 4x unrolling for better instruction-level parallelism and cache efficiency.
315    ///
316    /// # Arguments
317    ///
318    /// * `src` - Pointer to source tensor data
319    /// * `dst` - Pointer to output tensor data
320    ///
321    /// # Safety
322    ///
323    /// Requires valid pointers with sufficient memory for the tensor size.
324    /// All pointers must point to valid tensor data.
325    ///
326    /// # Performance Characteristics
327    ///
328    /// - **Unrolling**: 4x unrolling for instruction-level parallelism
329    /// - **Memory Access**: Linear access patterns for cache efficiency
330    /// - **Fallback**: Handles remaining elements with scalar operations
331    /// - **Cache Optimization**: Optimized for modern CPU cache hierarchies
332    /// - **Mathematical Accuracy**: High-precision scalar multiplication
333    ///
334    /// # Implementation Details
335    ///
336    /// Uses 4x unrolled scalar multiplication for optimal performance on non-SIMD hardware.
337    /// Processes elements in groups of 4 to improve instruction-level parallelism
338    /// and reduce loop overhead.
339    #[inline]
340    unsafe fn pow_square_scalar_optimized(&self, src: *const f32, dst: *mut f32) {
341        let size = self.size();
342        let unroll_count = size / 4;
343        let mut offset = 0;
344
345        // Unrolled scalar loop for x^2
346        for _ in 0..unroll_count {
347            let v1 = *src.add(offset);
348            let v2 = *src.add(offset + 1);
349            let v3 = *src.add(offset + 2);
350            let v4 = *src.add(offset + 3);
351
352            *dst.add(offset) = v1 * v1;
353            *dst.add(offset + 1) = v2 * v2;
354            *dst.add(offset + 2) = v3 * v3;
355            *dst.add(offset + 3) = v4 * v4;
356
357            offset += 4;
358        }
359
360        // Handle remaining elements
361        for i in offset..size {
362            let v = *src.add(i);
363            *dst.add(i) = v * v;
364        }
365    }
366
367    /// Optimized scalar square root fallback (x^0.5)
368    ///
369    /// Performs element-wise square root using optimized scalar operations with
370    /// 4x unrolling for better instruction-level parallelism and cache efficiency.
371    ///
372    /// # Arguments
373    ///
374    /// * `src` - Pointer to source tensor data
375    /// * `dst` - Pointer to output tensor data
376    ///
377    /// # Safety
378    ///
379    /// Requires valid pointers with sufficient memory for the tensor size.
380    /// All pointers must point to valid tensor data.
381    ///
382    /// # Performance Characteristics
383    ///
384    /// - **Unrolling**: 4x unrolling for instruction-level parallelism
385    /// - **Memory Access**: Linear access patterns for cache efficiency
386    /// - **Fallback**: Handles remaining elements with scalar operations
387    /// - **Cache Optimization**: Optimized for modern CPU cache hierarchies
388    /// - **Mathematical Accuracy**: High-precision scalar square root
389    ///
390    /// # Implementation Details
391    ///
392    /// Uses 4x unrolled scalar square root for optimal performance on non-SIMD hardware.
393    /// Processes elements in groups of 4 to improve instruction-level parallelism
394    /// and reduce loop overhead.
395    #[inline]
396    unsafe fn pow_sqrt_scalar_optimized(&self, src: *const f32, dst: *mut f32) {
397        let size = self.size();
398        let unroll_count = size / 4;
399        let mut offset = 0;
400
401        // Unrolled scalar loop for x^0.5
402        for _ in 0..unroll_count {
403            *dst.add(offset) = (*src.add(offset)).sqrt();
404            *dst.add(offset + 1) = (*src.add(offset + 1)).sqrt();
405            *dst.add(offset + 2) = (*src.add(offset + 2)).sqrt();
406            *dst.add(offset + 3) = (*src.add(offset + 3)).sqrt();
407            offset += 4;
408        }
409
410        // Handle remaining elements
411        for i in offset..size {
412            *dst.add(i) = (*src.add(i)).sqrt();
413        }
414    }
415
416    /// Optimized scalar general power fallback (x^exponent)
417    ///
418    /// Performs element-wise power computation using optimized scalar operations with
419    /// 4x unrolling for better instruction-level parallelism and cache efficiency.
420    ///
421    /// # Arguments
422    ///
423    /// * `src` - Pointer to source tensor data
424    /// * `dst` - Pointer to output tensor data
425    /// * `exponent` - The scalar exponent to raise each element to
426    ///
427    /// # Safety
428    ///
429    /// Requires valid pointers with sufficient memory for the tensor size.
430    /// All pointers must point to valid tensor data.
431    ///
432    /// # Performance Characteristics
433    ///
434    /// - **Unrolling**: 4x unrolling for instruction-level parallelism
435    /// - **Memory Access**: Linear access patterns for cache efficiency
436    /// - **Fallback**: Handles remaining elements with scalar operations
437    /// - **Cache Optimization**: Optimized for modern CPU cache hierarchies
438    /// - **Mathematical Accuracy**: High-precision scalar power computation
439    ///
440    /// # Implementation Details
441    ///
442    /// Uses 4x unrolled scalar power for optimal performance on non-SIMD hardware.
443    /// Processes elements in groups of 4 to improve instruction-level parallelism
444    /// and reduce loop overhead. Uses `powf()` for general exponent support.
445    #[inline]
446    unsafe fn pow_general_scalar_optimized(&self, src: *const f32, dst: *mut f32, exponent: f32) {
447        let size = self.size();
448        let unroll_count = size / 4;
449        let mut offset = 0;
450
451        // Unrolled scalar loop for general exponent
452        for _ in 0..unroll_count {
453            *dst.add(offset) = (*src.add(offset)).powf(exponent);
454            *dst.add(offset + 1) = (*src.add(offset + 1)).powf(exponent);
455            *dst.add(offset + 2) = (*src.add(offset + 2)).powf(exponent);
456            *dst.add(offset + 3) = (*src.add(offset + 3)).powf(exponent);
457            offset += 4;
458        }
459
460        // Handle remaining elements
461        for i in offset..size {
462            *dst.add(i) = (*src.add(i)).powf(exponent);
463        }
464    }
465
466    /// Element-wise power with tensor exponents.
467    ///
468    /// Computes element-wise power: `output[i] = self[i]^exponent[i]`
469    ///
470    /// # Arguments
471    /// * `exponent` - Tensor of exponents, must have the same shape as self
472    ///
473    /// # Returns
474    /// A new tensor with each element raised to the corresponding power
475    ///
476    /// # Examples
477    ///
478    /// ## Basic Tensor Power
479    ///
480    /// ```
481    /// use train_station::Tensor;
482    ///
483    /// let base = Tensor::from_slice(&[2.0, 3.0, 4.0], vec![3]).unwrap();
484    /// let exp = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
485    /// let result = base.pow_tensor(&exp);
486    /// assert_eq!(result.shape().dims, vec![3]);
487    /// assert_eq!(result.get(&[0]), 2.0); // 2.0^1.0 = 2.0
488    /// assert_eq!(result.get(&[1]), 9.0); // 3.0^2.0 = 9.0
489    /// assert_eq!(result.get(&[2]), 64.0); // 4.0^3.0 = 64.0
490    /// ```
491    ///
492    /// ## Mixed Exponents
493    ///
494    /// ```
495    /// use train_station::Tensor;
496    ///
497    /// let base = Tensor::from_slice(&[4.0, 9.0, 16.0], vec![3]).unwrap();
498    /// let exp = Tensor::from_slice(&[0.5, 1.0, 2.0], vec![3]).unwrap();
499    /// let result = base.pow_tensor(&exp);
500    /// assert_eq!(result.shape().dims, vec![3]);
501    /// assert_eq!(result.get(&[0]), 2.0); // sqrt(4.0) = 2.0
502    /// assert_eq!(result.get(&[1]), 9.0); // 9.0^1.0 = 9.0
503    /// assert_eq!(result.get(&[2]), 256.0); // 16.0^2.0 = 256.0
504    /// ```
505    ///
506    /// # Panics
507    /// Panics if tensor shapes don't match
508    pub fn pow_tensor(&self, exponent: &Tensor) -> Tensor {
509        assert_eq!(
510            self.shape().dims,
511            exponent.shape().dims,
512            "pow_tensor requires identical shapes"
513        );
514        let mut out = Tensor::new(self.shape().dims.clone());
515        unsafe {
516            let x = self.as_ptr();
517            let a = exponent.as_ptr();
518            let y = out.as_mut_ptr();
519            let n = out.size();
520            for i in 0..n {
521                *y.add(i) = (*x.add(i)).powf(*a.add(i));
522            }
523        }
524
525        if (self.requires_grad() || exponent.requires_grad()) && is_grad_enabled() {
526            let mut result = out.clone();
527            result.set_requires_grad_internal(true);
528            let grad_fn = GradFn::PowTensor {
529                saved_base: Box::new(self.clone()),
530                saved_exponent: Box::new(exponent.clone()),
531            };
532            result.set_grad_fn(grad_fn.clone());
533            let parents = vec![self.id(), exponent.id()];
534            GradEngine::register_operation(result.id(), parents, grad_fn);
535            return result;
536        }
537
538        out
539    }
540}
541
542#[cfg(test)]
543mod tests {
544    use super::*;
545
546    #[test]
547    fn test_pow_scalar_forward() {
548        let x = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
549        let y = x.pow_scalar(2.0);
550        assert_eq!(y.shape().dims, vec![4]);
551        unsafe {
552            assert_eq!(*y.as_ptr(), 1.0);
553            assert_eq!(*y.as_ptr().add(1), 4.0);
554            assert_eq!(*y.as_ptr().add(2), 9.0);
555            assert_eq!(*y.as_ptr().add(3), 16.0);
556        }
557    }
558}