train_station/tensor/ops/
pow.rs

1//! Power operations for tensors
2//!
3//! Provides element-wise power functions following PyTorch conventions with
4//! comprehensive automatic differentiation support and SIMD-optimized computation.
5//!
6//! # Key Features
7//!
8//! - **Scalar Power**: `pow_scalar(exponent)` - Raises each element to a scalar power (PyTorch `pow(tensor, scalar)` equivalent)
9//! - **Tensor Power**: `pow_tensor(exponent)` - Element-wise power with tensor exponents (PyTorch `pow(tensor, tensor)` equivalent)
10//! - **Automatic Differentiation**: Full gradtrack support with efficient gradient computation
11//! - **SIMD Optimization**: AVX2-optimized implementation for common cases (x^2, x^0.5)
12//! - **Smart Dispatch**: Optimized paths for common exponents (2.0, 0.5) with scalar fallback for others
13//! - **Cache Optimization**: Memory access patterns optimized for modern CPUs
14//! - **Mathematical Accuracy**: High-precision power computation
15//!
16//! # Mathematical Properties
17//!
18//! The power operations have the following properties:
19//! - **Power Laws**: (x^a)^b = x^(a*b), x^a * x^b = x^(a+b)
20//! - **Special Cases**: x^0 = 1, x^1 = x, x^2 = x*x, x^0.5 = sqrt(x)
21//! - **Domain**: x^a is defined for x > 0 when a is not an integer
22//! - **Gradient**: d/dx(x^a) = a * x^(a-1) for scalar power
23//! - **Gradient**: d/dx(x^y) = y * x^(y-1), d/dy(x^y) = x^y * ln(x) for tensor power
24//!
25//! # Performance Characteristics
26//!
27//! - **SIMD Optimization**: AVX2-optimized for x^2 and x^0.5 with 32-element blocks and 4x unrolling
28//! - **Smart Dispatch**: Fast paths for common exponents (2.0, 0.5) with scalar fallback for others
29//! - **Scalar Fallback**: 4x unrolled scalar implementation for non-SIMD hardware and general exponents
30//! - **Cache-friendly Access**: Linear memory access patterns
31//! - **Branch Prediction**: Optimized conditional logic for modern CPUs
32//! - **Gradient Optimization**: Efficient gradtrack with NoGradTrack support
33
34use crate::gradtrack::{is_grad_enabled, GradEngine, GradFn};
35use crate::tensor::core::Tensor;
36
37// SIMD optimizations for performance-critical operations
38#[cfg(target_arch = "x86_64")]
39use std::arch::x86_64::*;
40
41impl Tensor {
42    /// Raises each element to a scalar power.
43    ///
44    /// Computes element-wise power: `output[i] = self[i]^exponent`
45    ///
46    /// # Arguments
47    /// * `exponent` - The scalar exponent to raise each element to
48    ///
49    /// # Returns
50    /// A new tensor with each element raised to the given power
51    ///
52    /// # Examples
53    ///
54    /// ## Basic Scalar Power
55    ///
56    /// ```
57    /// use train_station::Tensor;
58    ///
59    /// let a = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
60    /// let b = a.pow_scalar(2.0);
61    /// assert_eq!(b.shape().dims, vec![3]);
62    /// assert_eq!(b.get(&[0]), 1.0); // 1.0^2 = 1.0
63    /// assert_eq!(b.get(&[1]), 4.0); // 2.0^2 = 4.0
64    /// assert_eq!(b.get(&[2]), 9.0); // 3.0^2 = 9.0
65    /// ```
66    ///
67    /// ## Square Root (Power 0.5)
68    ///
69    /// ```
70    /// use train_station::Tensor;
71    ///
72    /// let a = Tensor::from_slice(&[1.0, 4.0, 9.0], vec![3]).unwrap();
73    /// let b = a.pow_scalar(0.5);
74    /// assert_eq!(b.shape().dims, vec![3]);
75    /// assert_eq!(b.get(&[0]), 1.0); // sqrt(1.0) = 1.0
76    /// assert_eq!(b.get(&[1]), 2.0); // sqrt(4.0) = 2.0
77    /// assert_eq!(b.get(&[2]), 3.0); // sqrt(9.0) = 3.0
78    /// ```
79    #[track_caller]
80    pub fn pow_scalar(&self, exponent: f32) -> Tensor {
81        let mut out = self.pow_scalar_optimized(exponent);
82
83        if self.requires_grad() && is_grad_enabled() {
84            out.set_requires_grad_internal(true);
85            let grad_fn = GradFn::PowScalar {
86                exponent,
87                saved_input: Box::new(self.clone()),
88            };
89            out.set_grad_fn(grad_fn.clone());
90            GradEngine::register_operation(out.id(), vec![self.id()], grad_fn);
91        }
92
93        out
94    }
95
96    /// Internal optimized scalar power operation
97    ///
98    /// Performs element-wise scalar power computation using smart dispatch for common
99    /// exponents and optimized scalar computation. This is the core implementation
100    /// used by `pow_scalar()`.
101    ///
102    /// # Arguments
103    ///
104    /// * `exponent` - The scalar exponent to raise each element to
105    ///
106    /// # Returns
107    ///
108    /// A new tensor containing each element raised to the given power
109    ///
110    /// # Performance Characteristics
111    ///
112    /// - **Smart Dispatch**: Fast paths for common exponents (2.0, 0.5) with SIMD optimization
113    /// - **SIMD Optimization**: AVX2-optimized for x^2 and x^0.5 when available
114    /// - **Scalar Fallback**: 4x unrolled scalar implementation for non-SIMD hardware and general exponents
115    /// - **Cache-friendly**: Linear memory access patterns
116    /// - **Mathematical Accuracy**: High-precision power computation
117    /// - **Zero-sized Handling**: Fast return for empty tensors
118    ///
119    /// # Implementation Details
120    ///
121    /// Uses smart dispatch to optimize common cases:
122    /// - `exponent == 2.0`: Uses SIMD multiplication for x^2
123    /// - `exponent == 0.5`: Uses SIMD square root for x^0.5
124    /// - Other exponents: Uses scalar `powf()` for accuracy
125    #[inline]
126    pub(crate) fn pow_scalar_optimized(&self, exponent: f32) -> Tensor {
127        let mut output = Tensor::new(self.shape().dims.clone());
128
129        if self.size() == 0 {
130            return output;
131        }
132
133        unsafe {
134            let src = self.as_ptr();
135            let dst = output.as_mut_ptr();
136
137            // Handle common cases with SIMD optimizations
138            if exponent == 2.0 {
139                #[cfg(target_arch = "x86_64")]
140                {
141                    if is_x86_feature_detected!("avx2") {
142                        self.pow_square_simd_avx2_optimized(src, dst);
143                        return output;
144                    }
145                }
146                self.pow_square_scalar_optimized(src, dst);
147            } else if exponent == 0.5 {
148                #[cfg(target_arch = "x86_64")]
149                {
150                    if is_x86_feature_detected!("avx2") {
151                        self.pow_sqrt_simd_avx2_optimized(src, dst);
152                        return output;
153                    }
154                }
155                self.pow_sqrt_scalar_optimized(src, dst);
156            } else {
157                // General case - use scalar fallback for accuracy
158                self.pow_general_scalar_optimized(src, dst, exponent);
159            }
160        }
161
162        output
163    }
164
165    /// AVX2-optimized square implementation (x^2)
166    ///
167    /// Performs element-wise squaring using AVX2 SIMD instructions for maximum
168    /// performance on x86_64 architectures with AVX2 support.
169    ///
170    /// # Arguments
171    ///
172    /// * `src` - Pointer to source tensor data
173    /// * `dst` - Pointer to output tensor data
174    ///
175    /// # Safety
176    ///
177    /// Requires valid pointers with sufficient memory for the tensor size.
178    /// All pointers must point to valid tensor data. Requires AVX2 support.
179    ///
180    /// # Performance Characteristics
181    ///
182    /// - **SIMD Processing**: 32 elements per iteration with 4x unrolling
183    /// - **Memory Access**: Linear access patterns for cache efficiency
184    /// - **Vector Operations**: Uses AVX2 multiplication instructions for x^2
185    /// - **Fallback**: Handles remaining elements with scalar operations
186    /// - **Hardware Requirements**: Requires x86_64 with AVX2 support
187    ///
188    /// # Implementation Details
189    ///
190    /// Uses AVX2 vector multiplication to compute x^2 efficiently.
191    /// Implements 4x unrolling for optimal instruction throughput and cache utilization.
192    /// Processes remaining elements with scalar operations for complete coverage.
193    #[cfg(target_arch = "x86_64")]
194    #[inline]
195    #[target_feature(enable = "avx2")]
196    unsafe fn pow_square_simd_avx2_optimized(&self, src: *const f32, dst: *mut f32) {
197        let size = self.size();
198        let simd_count = size / 32; // Process 32 elements per iteration (4x unroll)
199        let mut offset = 0;
200
201        // Unrolled SIMD loop for x^2
202        for _ in 0..simd_count {
203            // Process 4 AVX2 vectors (32 elements) per iteration
204            let src_vec1 = _mm256_loadu_ps(src.add(offset));
205            let square_vec1 = _mm256_mul_ps(src_vec1, src_vec1);
206            _mm256_storeu_ps(dst.add(offset), square_vec1);
207
208            let src_vec2 = _mm256_loadu_ps(src.add(offset + 8));
209            let square_vec2 = _mm256_mul_ps(src_vec2, src_vec2);
210            _mm256_storeu_ps(dst.add(offset + 8), square_vec2);
211
212            let src_vec3 = _mm256_loadu_ps(src.add(offset + 16));
213            let square_vec3 = _mm256_mul_ps(src_vec3, src_vec3);
214            _mm256_storeu_ps(dst.add(offset + 16), square_vec3);
215
216            let src_vec4 = _mm256_loadu_ps(src.add(offset + 24));
217            let square_vec4 = _mm256_mul_ps(src_vec4, src_vec4);
218            _mm256_storeu_ps(dst.add(offset + 24), square_vec4);
219
220            offset += 32;
221        }
222
223        // Handle remaining 8-element blocks
224        let remaining_full_blocks = (size - offset) / 8;
225        for _ in 0..remaining_full_blocks {
226            let src_vec = _mm256_loadu_ps(src.add(offset));
227            let square_vec = _mm256_mul_ps(src_vec, src_vec);
228            _mm256_storeu_ps(dst.add(offset), square_vec);
229            offset += 8;
230        }
231
232        // Handle remaining elements
233        for i in offset..size {
234            let v = *src.add(i);
235            *dst.add(i) = v * v;
236        }
237    }
238
239    /// AVX2-optimized square root implementation (x^0.5)
240    ///
241    /// Performs element-wise square root using AVX2 SIMD instructions for maximum
242    /// performance on x86_64 architectures with AVX2 support.
243    ///
244    /// # Arguments
245    ///
246    /// * `src` - Pointer to source tensor data
247    /// * `dst` - Pointer to output tensor data
248    ///
249    /// # Safety
250    ///
251    /// Requires valid pointers with sufficient memory for the tensor size.
252    /// All pointers must point to valid tensor data. Requires AVX2 support.
253    ///
254    /// # Performance Characteristics
255    ///
256    /// - **SIMD Processing**: 32 elements per iteration with 4x unrolling
257    /// - **Memory Access**: Linear access patterns for cache efficiency
258    /// - **Vector Operations**: Uses AVX2 square root instructions for x^0.5
259    /// - **Fallback**: Handles remaining elements with scalar operations
260    /// - **Hardware Requirements**: Requires x86_64 with AVX2 support
261    ///
262    /// # Implementation Details
263    ///
264    /// Uses AVX2 vector square root instructions to compute x^0.5 efficiently.
265    /// Implements 4x unrolling for optimal instruction throughput and cache utilization.
266    /// Processes remaining elements with scalar operations for complete coverage.
267    #[cfg(target_arch = "x86_64")]
268    #[inline]
269    #[target_feature(enable = "avx2")]
270    unsafe fn pow_sqrt_simd_avx2_optimized(&self, src: *const f32, dst: *mut f32) {
271        let size = self.size();
272        let simd_count = size / 32; // Process 32 elements per iteration (4x unroll)
273        let mut offset = 0;
274
275        // Unrolled SIMD loop for x^0.5 (sqrt)
276        for _ in 0..simd_count {
277            // Process 4 AVX2 vectors (32 elements) per iteration
278            let src_vec1 = _mm256_loadu_ps(src.add(offset));
279            let sqrt_vec1 = _mm256_sqrt_ps(src_vec1);
280            _mm256_storeu_ps(dst.add(offset), sqrt_vec1);
281
282            let src_vec2 = _mm256_loadu_ps(src.add(offset + 8));
283            let sqrt_vec2 = _mm256_sqrt_ps(src_vec2);
284            _mm256_storeu_ps(dst.add(offset + 8), sqrt_vec2);
285
286            let src_vec3 = _mm256_loadu_ps(src.add(offset + 16));
287            let sqrt_vec3 = _mm256_sqrt_ps(src_vec3);
288            _mm256_storeu_ps(dst.add(offset + 16), sqrt_vec3);
289
290            let src_vec4 = _mm256_loadu_ps(src.add(offset + 24));
291            let sqrt_vec4 = _mm256_sqrt_ps(src_vec4);
292            _mm256_storeu_ps(dst.add(offset + 24), sqrt_vec4);
293
294            offset += 32;
295        }
296
297        // Handle remaining 8-element blocks
298        let remaining_full_blocks = (size - offset) / 8;
299        for _ in 0..remaining_full_blocks {
300            let src_vec = _mm256_loadu_ps(src.add(offset));
301            let sqrt_vec = _mm256_sqrt_ps(src_vec);
302            _mm256_storeu_ps(dst.add(offset), sqrt_vec);
303            offset += 8;
304        }
305
306        // Handle remaining elements
307        for i in offset..size {
308            *dst.add(i) = (*src.add(i)).sqrt();
309        }
310    }
311
312    /// Optimized scalar square fallback (x^2)
313    ///
314    /// Performs element-wise squaring using optimized scalar operations with
315    /// 4x unrolling for better instruction-level parallelism and cache efficiency.
316    ///
317    /// # Arguments
318    ///
319    /// * `src` - Pointer to source tensor data
320    /// * `dst` - Pointer to output tensor data
321    ///
322    /// # Safety
323    ///
324    /// Requires valid pointers with sufficient memory for the tensor size.
325    /// All pointers must point to valid tensor data.
326    ///
327    /// # Performance Characteristics
328    ///
329    /// - **Unrolling**: 4x unrolling for instruction-level parallelism
330    /// - **Memory Access**: Linear access patterns for cache efficiency
331    /// - **Fallback**: Handles remaining elements with scalar operations
332    /// - **Cache Optimization**: Optimized for modern CPU cache hierarchies
333    /// - **Mathematical Accuracy**: High-precision scalar multiplication
334    ///
335    /// # Implementation Details
336    ///
337    /// Uses 4x unrolled scalar multiplication for optimal performance on non-SIMD hardware.
338    /// Processes elements in groups of 4 to improve instruction-level parallelism
339    /// and reduce loop overhead.
340    #[inline]
341    unsafe fn pow_square_scalar_optimized(&self, src: *const f32, dst: *mut f32) {
342        let size = self.size();
343        let unroll_count = size / 4;
344        let mut offset = 0;
345
346        // Unrolled scalar loop for x^2
347        for _ in 0..unroll_count {
348            let v1 = *src.add(offset);
349            let v2 = *src.add(offset + 1);
350            let v3 = *src.add(offset + 2);
351            let v4 = *src.add(offset + 3);
352
353            *dst.add(offset) = v1 * v1;
354            *dst.add(offset + 1) = v2 * v2;
355            *dst.add(offset + 2) = v3 * v3;
356            *dst.add(offset + 3) = v4 * v4;
357
358            offset += 4;
359        }
360
361        // Handle remaining elements
362        for i in offset..size {
363            let v = *src.add(i);
364            *dst.add(i) = v * v;
365        }
366    }
367
368    /// Optimized scalar square root fallback (x^0.5)
369    ///
370    /// Performs element-wise square root using optimized scalar operations with
371    /// 4x unrolling for better instruction-level parallelism and cache efficiency.
372    ///
373    /// # Arguments
374    ///
375    /// * `src` - Pointer to source tensor data
376    /// * `dst` - Pointer to output tensor data
377    ///
378    /// # Safety
379    ///
380    /// Requires valid pointers with sufficient memory for the tensor size.
381    /// All pointers must point to valid tensor data.
382    ///
383    /// # Performance Characteristics
384    ///
385    /// - **Unrolling**: 4x unrolling for instruction-level parallelism
386    /// - **Memory Access**: Linear access patterns for cache efficiency
387    /// - **Fallback**: Handles remaining elements with scalar operations
388    /// - **Cache Optimization**: Optimized for modern CPU cache hierarchies
389    /// - **Mathematical Accuracy**: High-precision scalar square root
390    ///
391    /// # Implementation Details
392    ///
393    /// Uses 4x unrolled scalar square root for optimal performance on non-SIMD hardware.
394    /// Processes elements in groups of 4 to improve instruction-level parallelism
395    /// and reduce loop overhead.
396    #[inline]
397    unsafe fn pow_sqrt_scalar_optimized(&self, src: *const f32, dst: *mut f32) {
398        let size = self.size();
399        let unroll_count = size / 4;
400        let mut offset = 0;
401
402        // Unrolled scalar loop for x^0.5
403        for _ in 0..unroll_count {
404            *dst.add(offset) = (*src.add(offset)).sqrt();
405            *dst.add(offset + 1) = (*src.add(offset + 1)).sqrt();
406            *dst.add(offset + 2) = (*src.add(offset + 2)).sqrt();
407            *dst.add(offset + 3) = (*src.add(offset + 3)).sqrt();
408            offset += 4;
409        }
410
411        // Handle remaining elements
412        for i in offset..size {
413            *dst.add(i) = (*src.add(i)).sqrt();
414        }
415    }
416
417    /// Optimized scalar general power fallback (x^exponent)
418    ///
419    /// Performs element-wise power computation using optimized scalar operations with
420    /// 4x unrolling for better instruction-level parallelism and cache efficiency.
421    ///
422    /// # Arguments
423    ///
424    /// * `src` - Pointer to source tensor data
425    /// * `dst` - Pointer to output tensor data
426    /// * `exponent` - The scalar exponent to raise each element to
427    ///
428    /// # Safety
429    ///
430    /// Requires valid pointers with sufficient memory for the tensor size.
431    /// All pointers must point to valid tensor data.
432    ///
433    /// # Performance Characteristics
434    ///
435    /// - **Unrolling**: 4x unrolling for instruction-level parallelism
436    /// - **Memory Access**: Linear access patterns for cache efficiency
437    /// - **Fallback**: Handles remaining elements with scalar operations
438    /// - **Cache Optimization**: Optimized for modern CPU cache hierarchies
439    /// - **Mathematical Accuracy**: High-precision scalar power computation
440    ///
441    /// # Implementation Details
442    ///
443    /// Uses 4x unrolled scalar power for optimal performance on non-SIMD hardware.
444    /// Processes elements in groups of 4 to improve instruction-level parallelism
445    /// and reduce loop overhead. Uses `powf()` for general exponent support.
446    #[inline]
447    unsafe fn pow_general_scalar_optimized(&self, src: *const f32, dst: *mut f32, exponent: f32) {
448        let size = self.size();
449        let unroll_count = size / 4;
450        let mut offset = 0;
451
452        // Unrolled scalar loop for general exponent
453        for _ in 0..unroll_count {
454            *dst.add(offset) = (*src.add(offset)).powf(exponent);
455            *dst.add(offset + 1) = (*src.add(offset + 1)).powf(exponent);
456            *dst.add(offset + 2) = (*src.add(offset + 2)).powf(exponent);
457            *dst.add(offset + 3) = (*src.add(offset + 3)).powf(exponent);
458            offset += 4;
459        }
460
461        // Handle remaining elements
462        for i in offset..size {
463            *dst.add(i) = (*src.add(i)).powf(exponent);
464        }
465    }
466
467    /// Element-wise power with tensor exponents.
468    ///
469    /// Computes element-wise power: `output[i] = self[i]^exponent[i]`
470    ///
471    /// # Arguments
472    /// * `exponent` - Tensor of exponents, must have the same shape as self
473    ///
474    /// # Returns
475    /// A new tensor with each element raised to the corresponding power
476    ///
477    /// # Examples
478    ///
479    /// ## Basic Tensor Power
480    ///
481    /// ```
482    /// use train_station::Tensor;
483    ///
484    /// let base = Tensor::from_slice(&[2.0, 3.0, 4.0], vec![3]).unwrap();
485    /// let exp = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
486    /// let result = base.pow_tensor(&exp);
487    /// assert_eq!(result.shape().dims, vec![3]);
488    /// assert_eq!(result.get(&[0]), 2.0); // 2.0^1.0 = 2.0
489    /// assert_eq!(result.get(&[1]), 9.0); // 3.0^2.0 = 9.0
490    /// assert_eq!(result.get(&[2]), 64.0); // 4.0^3.0 = 64.0
491    /// ```
492    ///
493    /// ## Mixed Exponents
494    ///
495    /// ```
496    /// use train_station::Tensor;
497    ///
498    /// let base = Tensor::from_slice(&[4.0, 9.0, 16.0], vec![3]).unwrap();
499    /// let exp = Tensor::from_slice(&[0.5, 1.0, 2.0], vec![3]).unwrap();
500    /// let result = base.pow_tensor(&exp);
501    /// assert_eq!(result.shape().dims, vec![3]);
502    /// assert_eq!(result.get(&[0]), 2.0); // sqrt(4.0) = 2.0
503    /// assert_eq!(result.get(&[1]), 9.0); // 9.0^1.0 = 9.0
504    /// assert_eq!(result.get(&[2]), 256.0); // 16.0^2.0 = 256.0
505    /// ```
506    ///
507    /// # Panics
508    /// Panics if tensor shapes don't match
509    #[track_caller]
510    pub fn pow_tensor(&self, exponent: &Tensor) -> Tensor {
511        assert_eq!(
512            self.shape().dims,
513            exponent.shape().dims,
514            "pow_tensor requires identical shapes"
515        );
516        let mut out = Tensor::new(self.shape().dims.clone());
517        unsafe {
518            let x = self.as_ptr();
519            let a = exponent.as_ptr();
520            let y = out.as_mut_ptr();
521            let n = out.size();
522            for i in 0..n {
523                *y.add(i) = (*x.add(i)).powf(*a.add(i));
524            }
525        }
526
527        if (self.requires_grad() || exponent.requires_grad()) && is_grad_enabled() {
528            let mut result = out.clone();
529            result.set_requires_grad_internal(true);
530            let grad_fn = GradFn::PowTensor {
531                saved_base: Box::new(self.clone()),
532                saved_exponent: Box::new(exponent.clone()),
533            };
534            result.set_grad_fn(grad_fn.clone());
535            let parents = vec![self.id(), exponent.id()];
536            GradEngine::register_operation(result.id(), parents, grad_fn);
537            return result;
538        }
539
540        out
541    }
542}
543
544#[cfg(test)]
545mod tests {
546    use super::*;
547
548    #[test]
549    fn test_pow_scalar_forward() {
550        let x = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
551        let y = x.pow_scalar(2.0);
552        assert_eq!(y.shape().dims, vec![4]);
553        unsafe {
554            assert_eq!(*y.as_ptr(), 1.0);
555            assert_eq!(*y.as_ptr().add(1), 4.0);
556            assert_eq!(*y.as_ptr().add(2), 9.0);
557            assert_eq!(*y.as_ptr().add(3), 16.0);
558        }
559    }
560}