train_station/tensor/ops/pow.rs
1//! Power operations for tensors
2//!
3//! Provides element-wise power functions following PyTorch conventions with
4//! comprehensive automatic differentiation support and SIMD-optimized computation.
5//!
6//! # Key Features
7//!
8//! - **Scalar Power**: `pow_scalar(exponent)` - Raises each element to a scalar power (PyTorch `pow(tensor, scalar)` equivalent)
9//! - **Tensor Power**: `pow_tensor(exponent)` - Element-wise power with tensor exponents (PyTorch `pow(tensor, tensor)` equivalent)
10//! - **Automatic Differentiation**: Full gradtrack support with efficient gradient computation
11//! - **SIMD Optimization**: AVX2-optimized implementation for common cases (x^2, x^0.5)
12//! - **Smart Dispatch**: Optimized paths for common exponents (2.0, 0.5) with scalar fallback for others
13//! - **Cache Optimization**: Memory access patterns optimized for modern CPUs
14//! - **Mathematical Accuracy**: High-precision power computation
15//!
16//! # Mathematical Properties
17//!
18//! The power operations have the following properties:
19//! - **Power Laws**: (x^a)^b = x^(a*b), x^a * x^b = x^(a+b)
20//! - **Special Cases**: x^0 = 1, x^1 = x, x^2 = x*x, x^0.5 = sqrt(x)
21//! - **Domain**: x^a is defined for x > 0 when a is not an integer
22//! - **Gradient**: d/dx(x^a) = a * x^(a-1) for scalar power
23//! - **Gradient**: d/dx(x^y) = y * x^(y-1), d/dy(x^y) = x^y * ln(x) for tensor power
24//!
25//! # Performance Characteristics
26//!
27//! - **SIMD Optimization**: AVX2-optimized for x^2 and x^0.5 with 32-element blocks and 4x unrolling
28//! - **Smart Dispatch**: Fast paths for common exponents (2.0, 0.5) with scalar fallback for others
29//! - **Scalar Fallback**: 4x unrolled scalar implementation for non-SIMD hardware and general exponents
30//! - **Cache-friendly Access**: Linear memory access patterns
31//! - **Branch Prediction**: Optimized conditional logic for modern CPUs
32//! - **Gradient Optimization**: Efficient gradtrack with NoGradTrack support
33
34use crate::gradtrack::{is_grad_enabled, GradEngine, GradFn};
35use crate::tensor::core::Tensor;
36
37// SIMD optimizations for performance-critical operations
38#[cfg(target_arch = "x86_64")]
39use std::arch::x86_64::*;
40
41impl Tensor {
42 /// Raises each element to a scalar power.
43 ///
44 /// Computes element-wise power: `output[i] = self[i]^exponent`
45 ///
46 /// # Arguments
47 /// * `exponent` - The scalar exponent to raise each element to
48 ///
49 /// # Returns
50 /// A new tensor with each element raised to the given power
51 ///
52 /// # Examples
53 ///
54 /// ## Basic Scalar Power
55 ///
56 /// ```
57 /// use train_station::Tensor;
58 ///
59 /// let a = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
60 /// let b = a.pow_scalar(2.0);
61 /// assert_eq!(b.shape().dims, vec![3]);
62 /// assert_eq!(b.get(&[0]), 1.0); // 1.0^2 = 1.0
63 /// assert_eq!(b.get(&[1]), 4.0); // 2.0^2 = 4.0
64 /// assert_eq!(b.get(&[2]), 9.0); // 3.0^2 = 9.0
65 /// ```
66 ///
67 /// ## Square Root (Power 0.5)
68 ///
69 /// ```
70 /// use train_station::Tensor;
71 ///
72 /// let a = Tensor::from_slice(&[1.0, 4.0, 9.0], vec![3]).unwrap();
73 /// let b = a.pow_scalar(0.5);
74 /// assert_eq!(b.shape().dims, vec![3]);
75 /// assert_eq!(b.get(&[0]), 1.0); // sqrt(1.0) = 1.0
76 /// assert_eq!(b.get(&[1]), 2.0); // sqrt(4.0) = 2.0
77 /// assert_eq!(b.get(&[2]), 3.0); // sqrt(9.0) = 3.0
78 /// ```
79 #[track_caller]
80 pub fn pow_scalar(&self, exponent: f32) -> Tensor {
81 let mut out = self.pow_scalar_optimized(exponent);
82
83 if self.requires_grad() && is_grad_enabled() {
84 out.set_requires_grad_internal(true);
85 let grad_fn = GradFn::PowScalar {
86 exponent,
87 saved_input: Box::new(self.clone()),
88 };
89 out.set_grad_fn(grad_fn.clone());
90 GradEngine::register_operation(out.id(), vec![self.id()], grad_fn);
91 }
92
93 out
94 }
95
96 /// Internal optimized scalar power operation
97 ///
98 /// Performs element-wise scalar power computation using smart dispatch for common
99 /// exponents and optimized scalar computation. This is the core implementation
100 /// used by `pow_scalar()`.
101 ///
102 /// # Arguments
103 ///
104 /// * `exponent` - The scalar exponent to raise each element to
105 ///
106 /// # Returns
107 ///
108 /// A new tensor containing each element raised to the given power
109 ///
110 /// # Performance Characteristics
111 ///
112 /// - **Smart Dispatch**: Fast paths for common exponents (2.0, 0.5) with SIMD optimization
113 /// - **SIMD Optimization**: AVX2-optimized for x^2 and x^0.5 when available
114 /// - **Scalar Fallback**: 4x unrolled scalar implementation for non-SIMD hardware and general exponents
115 /// - **Cache-friendly**: Linear memory access patterns
116 /// - **Mathematical Accuracy**: High-precision power computation
117 /// - **Zero-sized Handling**: Fast return for empty tensors
118 ///
119 /// # Implementation Details
120 ///
121 /// Uses smart dispatch to optimize common cases:
122 /// - `exponent == 2.0`: Uses SIMD multiplication for x^2
123 /// - `exponent == 0.5`: Uses SIMD square root for x^0.5
124 /// - Other exponents: Uses scalar `powf()` for accuracy
125 #[inline]
126 pub(crate) fn pow_scalar_optimized(&self, exponent: f32) -> Tensor {
127 let mut output = Tensor::new(self.shape().dims.clone());
128
129 if self.size() == 0 {
130 return output;
131 }
132
133 unsafe {
134 let src = self.as_ptr();
135 let dst = output.as_mut_ptr();
136
137 // Handle common cases with SIMD optimizations
138 if exponent == 2.0 {
139 #[cfg(target_arch = "x86_64")]
140 {
141 if is_x86_feature_detected!("avx2") {
142 self.pow_square_simd_avx2_optimized(src, dst);
143 return output;
144 }
145 }
146 self.pow_square_scalar_optimized(src, dst);
147 } else if exponent == 0.5 {
148 #[cfg(target_arch = "x86_64")]
149 {
150 if is_x86_feature_detected!("avx2") {
151 self.pow_sqrt_simd_avx2_optimized(src, dst);
152 return output;
153 }
154 }
155 self.pow_sqrt_scalar_optimized(src, dst);
156 } else {
157 // General case - use scalar fallback for accuracy
158 self.pow_general_scalar_optimized(src, dst, exponent);
159 }
160 }
161
162 output
163 }
164
165 /// AVX2-optimized square implementation (x^2)
166 ///
167 /// Performs element-wise squaring using AVX2 SIMD instructions for maximum
168 /// performance on x86_64 architectures with AVX2 support.
169 ///
170 /// # Arguments
171 ///
172 /// * `src` - Pointer to source tensor data
173 /// * `dst` - Pointer to output tensor data
174 ///
175 /// # Safety
176 ///
177 /// Requires valid pointers with sufficient memory for the tensor size.
178 /// All pointers must point to valid tensor data. Requires AVX2 support.
179 ///
180 /// # Performance Characteristics
181 ///
182 /// - **SIMD Processing**: 32 elements per iteration with 4x unrolling
183 /// - **Memory Access**: Linear access patterns for cache efficiency
184 /// - **Vector Operations**: Uses AVX2 multiplication instructions for x^2
185 /// - **Fallback**: Handles remaining elements with scalar operations
186 /// - **Hardware Requirements**: Requires x86_64 with AVX2 support
187 ///
188 /// # Implementation Details
189 ///
190 /// Uses AVX2 vector multiplication to compute x^2 efficiently.
191 /// Implements 4x unrolling for optimal instruction throughput and cache utilization.
192 /// Processes remaining elements with scalar operations for complete coverage.
193 #[cfg(target_arch = "x86_64")]
194 #[inline]
195 #[target_feature(enable = "avx2")]
196 unsafe fn pow_square_simd_avx2_optimized(&self, src: *const f32, dst: *mut f32) {
197 let size = self.size();
198 let simd_count = size / 32; // Process 32 elements per iteration (4x unroll)
199 let mut offset = 0;
200
201 // Unrolled SIMD loop for x^2
202 for _ in 0..simd_count {
203 // Process 4 AVX2 vectors (32 elements) per iteration
204 let src_vec1 = _mm256_loadu_ps(src.add(offset));
205 let square_vec1 = _mm256_mul_ps(src_vec1, src_vec1);
206 _mm256_storeu_ps(dst.add(offset), square_vec1);
207
208 let src_vec2 = _mm256_loadu_ps(src.add(offset + 8));
209 let square_vec2 = _mm256_mul_ps(src_vec2, src_vec2);
210 _mm256_storeu_ps(dst.add(offset + 8), square_vec2);
211
212 let src_vec3 = _mm256_loadu_ps(src.add(offset + 16));
213 let square_vec3 = _mm256_mul_ps(src_vec3, src_vec3);
214 _mm256_storeu_ps(dst.add(offset + 16), square_vec3);
215
216 let src_vec4 = _mm256_loadu_ps(src.add(offset + 24));
217 let square_vec4 = _mm256_mul_ps(src_vec4, src_vec4);
218 _mm256_storeu_ps(dst.add(offset + 24), square_vec4);
219
220 offset += 32;
221 }
222
223 // Handle remaining 8-element blocks
224 let remaining_full_blocks = (size - offset) / 8;
225 for _ in 0..remaining_full_blocks {
226 let src_vec = _mm256_loadu_ps(src.add(offset));
227 let square_vec = _mm256_mul_ps(src_vec, src_vec);
228 _mm256_storeu_ps(dst.add(offset), square_vec);
229 offset += 8;
230 }
231
232 // Handle remaining elements
233 for i in offset..size {
234 let v = *src.add(i);
235 *dst.add(i) = v * v;
236 }
237 }
238
239 /// AVX2-optimized square root implementation (x^0.5)
240 ///
241 /// Performs element-wise square root using AVX2 SIMD instructions for maximum
242 /// performance on x86_64 architectures with AVX2 support.
243 ///
244 /// # Arguments
245 ///
246 /// * `src` - Pointer to source tensor data
247 /// * `dst` - Pointer to output tensor data
248 ///
249 /// # Safety
250 ///
251 /// Requires valid pointers with sufficient memory for the tensor size.
252 /// All pointers must point to valid tensor data. Requires AVX2 support.
253 ///
254 /// # Performance Characteristics
255 ///
256 /// - **SIMD Processing**: 32 elements per iteration with 4x unrolling
257 /// - **Memory Access**: Linear access patterns for cache efficiency
258 /// - **Vector Operations**: Uses AVX2 square root instructions for x^0.5
259 /// - **Fallback**: Handles remaining elements with scalar operations
260 /// - **Hardware Requirements**: Requires x86_64 with AVX2 support
261 ///
262 /// # Implementation Details
263 ///
264 /// Uses AVX2 vector square root instructions to compute x^0.5 efficiently.
265 /// Implements 4x unrolling for optimal instruction throughput and cache utilization.
266 /// Processes remaining elements with scalar operations for complete coverage.
267 #[cfg(target_arch = "x86_64")]
268 #[inline]
269 #[target_feature(enable = "avx2")]
270 unsafe fn pow_sqrt_simd_avx2_optimized(&self, src: *const f32, dst: *mut f32) {
271 let size = self.size();
272 let simd_count = size / 32; // Process 32 elements per iteration (4x unroll)
273 let mut offset = 0;
274
275 // Unrolled SIMD loop for x^0.5 (sqrt)
276 for _ in 0..simd_count {
277 // Process 4 AVX2 vectors (32 elements) per iteration
278 let src_vec1 = _mm256_loadu_ps(src.add(offset));
279 let sqrt_vec1 = _mm256_sqrt_ps(src_vec1);
280 _mm256_storeu_ps(dst.add(offset), sqrt_vec1);
281
282 let src_vec2 = _mm256_loadu_ps(src.add(offset + 8));
283 let sqrt_vec2 = _mm256_sqrt_ps(src_vec2);
284 _mm256_storeu_ps(dst.add(offset + 8), sqrt_vec2);
285
286 let src_vec3 = _mm256_loadu_ps(src.add(offset + 16));
287 let sqrt_vec3 = _mm256_sqrt_ps(src_vec3);
288 _mm256_storeu_ps(dst.add(offset + 16), sqrt_vec3);
289
290 let src_vec4 = _mm256_loadu_ps(src.add(offset + 24));
291 let sqrt_vec4 = _mm256_sqrt_ps(src_vec4);
292 _mm256_storeu_ps(dst.add(offset + 24), sqrt_vec4);
293
294 offset += 32;
295 }
296
297 // Handle remaining 8-element blocks
298 let remaining_full_blocks = (size - offset) / 8;
299 for _ in 0..remaining_full_blocks {
300 let src_vec = _mm256_loadu_ps(src.add(offset));
301 let sqrt_vec = _mm256_sqrt_ps(src_vec);
302 _mm256_storeu_ps(dst.add(offset), sqrt_vec);
303 offset += 8;
304 }
305
306 // Handle remaining elements
307 for i in offset..size {
308 *dst.add(i) = (*src.add(i)).sqrt();
309 }
310 }
311
312 /// Optimized scalar square fallback (x^2)
313 ///
314 /// Performs element-wise squaring using optimized scalar operations with
315 /// 4x unrolling for better instruction-level parallelism and cache efficiency.
316 ///
317 /// # Arguments
318 ///
319 /// * `src` - Pointer to source tensor data
320 /// * `dst` - Pointer to output tensor data
321 ///
322 /// # Safety
323 ///
324 /// Requires valid pointers with sufficient memory for the tensor size.
325 /// All pointers must point to valid tensor data.
326 ///
327 /// # Performance Characteristics
328 ///
329 /// - **Unrolling**: 4x unrolling for instruction-level parallelism
330 /// - **Memory Access**: Linear access patterns for cache efficiency
331 /// - **Fallback**: Handles remaining elements with scalar operations
332 /// - **Cache Optimization**: Optimized for modern CPU cache hierarchies
333 /// - **Mathematical Accuracy**: High-precision scalar multiplication
334 ///
335 /// # Implementation Details
336 ///
337 /// Uses 4x unrolled scalar multiplication for optimal performance on non-SIMD hardware.
338 /// Processes elements in groups of 4 to improve instruction-level parallelism
339 /// and reduce loop overhead.
340 #[inline]
341 unsafe fn pow_square_scalar_optimized(&self, src: *const f32, dst: *mut f32) {
342 let size = self.size();
343 let unroll_count = size / 4;
344 let mut offset = 0;
345
346 // Unrolled scalar loop for x^2
347 for _ in 0..unroll_count {
348 let v1 = *src.add(offset);
349 let v2 = *src.add(offset + 1);
350 let v3 = *src.add(offset + 2);
351 let v4 = *src.add(offset + 3);
352
353 *dst.add(offset) = v1 * v1;
354 *dst.add(offset + 1) = v2 * v2;
355 *dst.add(offset + 2) = v3 * v3;
356 *dst.add(offset + 3) = v4 * v4;
357
358 offset += 4;
359 }
360
361 // Handle remaining elements
362 for i in offset..size {
363 let v = *src.add(i);
364 *dst.add(i) = v * v;
365 }
366 }
367
368 /// Optimized scalar square root fallback (x^0.5)
369 ///
370 /// Performs element-wise square root using optimized scalar operations with
371 /// 4x unrolling for better instruction-level parallelism and cache efficiency.
372 ///
373 /// # Arguments
374 ///
375 /// * `src` - Pointer to source tensor data
376 /// * `dst` - Pointer to output tensor data
377 ///
378 /// # Safety
379 ///
380 /// Requires valid pointers with sufficient memory for the tensor size.
381 /// All pointers must point to valid tensor data.
382 ///
383 /// # Performance Characteristics
384 ///
385 /// - **Unrolling**: 4x unrolling for instruction-level parallelism
386 /// - **Memory Access**: Linear access patterns for cache efficiency
387 /// - **Fallback**: Handles remaining elements with scalar operations
388 /// - **Cache Optimization**: Optimized for modern CPU cache hierarchies
389 /// - **Mathematical Accuracy**: High-precision scalar square root
390 ///
391 /// # Implementation Details
392 ///
393 /// Uses 4x unrolled scalar square root for optimal performance on non-SIMD hardware.
394 /// Processes elements in groups of 4 to improve instruction-level parallelism
395 /// and reduce loop overhead.
396 #[inline]
397 unsafe fn pow_sqrt_scalar_optimized(&self, src: *const f32, dst: *mut f32) {
398 let size = self.size();
399 let unroll_count = size / 4;
400 let mut offset = 0;
401
402 // Unrolled scalar loop for x^0.5
403 for _ in 0..unroll_count {
404 *dst.add(offset) = (*src.add(offset)).sqrt();
405 *dst.add(offset + 1) = (*src.add(offset + 1)).sqrt();
406 *dst.add(offset + 2) = (*src.add(offset + 2)).sqrt();
407 *dst.add(offset + 3) = (*src.add(offset + 3)).sqrt();
408 offset += 4;
409 }
410
411 // Handle remaining elements
412 for i in offset..size {
413 *dst.add(i) = (*src.add(i)).sqrt();
414 }
415 }
416
417 /// Optimized scalar general power fallback (x^exponent)
418 ///
419 /// Performs element-wise power computation using optimized scalar operations with
420 /// 4x unrolling for better instruction-level parallelism and cache efficiency.
421 ///
422 /// # Arguments
423 ///
424 /// * `src` - Pointer to source tensor data
425 /// * `dst` - Pointer to output tensor data
426 /// * `exponent` - The scalar exponent to raise each element to
427 ///
428 /// # Safety
429 ///
430 /// Requires valid pointers with sufficient memory for the tensor size.
431 /// All pointers must point to valid tensor data.
432 ///
433 /// # Performance Characteristics
434 ///
435 /// - **Unrolling**: 4x unrolling for instruction-level parallelism
436 /// - **Memory Access**: Linear access patterns for cache efficiency
437 /// - **Fallback**: Handles remaining elements with scalar operations
438 /// - **Cache Optimization**: Optimized for modern CPU cache hierarchies
439 /// - **Mathematical Accuracy**: High-precision scalar power computation
440 ///
441 /// # Implementation Details
442 ///
443 /// Uses 4x unrolled scalar power for optimal performance on non-SIMD hardware.
444 /// Processes elements in groups of 4 to improve instruction-level parallelism
445 /// and reduce loop overhead. Uses `powf()` for general exponent support.
446 #[inline]
447 unsafe fn pow_general_scalar_optimized(&self, src: *const f32, dst: *mut f32, exponent: f32) {
448 let size = self.size();
449 let unroll_count = size / 4;
450 let mut offset = 0;
451
452 // Unrolled scalar loop for general exponent
453 for _ in 0..unroll_count {
454 *dst.add(offset) = (*src.add(offset)).powf(exponent);
455 *dst.add(offset + 1) = (*src.add(offset + 1)).powf(exponent);
456 *dst.add(offset + 2) = (*src.add(offset + 2)).powf(exponent);
457 *dst.add(offset + 3) = (*src.add(offset + 3)).powf(exponent);
458 offset += 4;
459 }
460
461 // Handle remaining elements
462 for i in offset..size {
463 *dst.add(i) = (*src.add(i)).powf(exponent);
464 }
465 }
466
467 /// Element-wise power with tensor exponents.
468 ///
469 /// Computes element-wise power: `output[i] = self[i]^exponent[i]`
470 ///
471 /// # Arguments
472 /// * `exponent` - Tensor of exponents, must have the same shape as self
473 ///
474 /// # Returns
475 /// A new tensor with each element raised to the corresponding power
476 ///
477 /// # Examples
478 ///
479 /// ## Basic Tensor Power
480 ///
481 /// ```
482 /// use train_station::Tensor;
483 ///
484 /// let base = Tensor::from_slice(&[2.0, 3.0, 4.0], vec![3]).unwrap();
485 /// let exp = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
486 /// let result = base.pow_tensor(&exp);
487 /// assert_eq!(result.shape().dims, vec![3]);
488 /// assert_eq!(result.get(&[0]), 2.0); // 2.0^1.0 = 2.0
489 /// assert_eq!(result.get(&[1]), 9.0); // 3.0^2.0 = 9.0
490 /// assert_eq!(result.get(&[2]), 64.0); // 4.0^3.0 = 64.0
491 /// ```
492 ///
493 /// ## Mixed Exponents
494 ///
495 /// ```
496 /// use train_station::Tensor;
497 ///
498 /// let base = Tensor::from_slice(&[4.0, 9.0, 16.0], vec![3]).unwrap();
499 /// let exp = Tensor::from_slice(&[0.5, 1.0, 2.0], vec![3]).unwrap();
500 /// let result = base.pow_tensor(&exp);
501 /// assert_eq!(result.shape().dims, vec![3]);
502 /// assert_eq!(result.get(&[0]), 2.0); // sqrt(4.0) = 2.0
503 /// assert_eq!(result.get(&[1]), 9.0); // 9.0^1.0 = 9.0
504 /// assert_eq!(result.get(&[2]), 256.0); // 16.0^2.0 = 256.0
505 /// ```
506 ///
507 /// # Panics
508 /// Panics if tensor shapes don't match
509 #[track_caller]
510 pub fn pow_tensor(&self, exponent: &Tensor) -> Tensor {
511 assert_eq!(
512 self.shape().dims,
513 exponent.shape().dims,
514 "pow_tensor requires identical shapes"
515 );
516 let mut out = Tensor::new(self.shape().dims.clone());
517 unsafe {
518 let x = self.as_ptr();
519 let a = exponent.as_ptr();
520 let y = out.as_mut_ptr();
521 let n = out.size();
522 for i in 0..n {
523 *y.add(i) = (*x.add(i)).powf(*a.add(i));
524 }
525 }
526
527 if (self.requires_grad() || exponent.requires_grad()) && is_grad_enabled() {
528 let mut result = out.clone();
529 result.set_requires_grad_internal(true);
530 let grad_fn = GradFn::PowTensor {
531 saved_base: Box::new(self.clone()),
532 saved_exponent: Box::new(exponent.clone()),
533 };
534 result.set_grad_fn(grad_fn.clone());
535 let parents = vec![self.id(), exponent.id()];
536 GradEngine::register_operation(result.id(), parents, grad_fn);
537 return result;
538 }
539
540 out
541 }
542}
543
544#[cfg(test)]
545mod tests {
546 use super::*;
547
548 #[test]
549 fn test_pow_scalar_forward() {
550 let x = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
551 let y = x.pow_scalar(2.0);
552 assert_eq!(y.shape().dims, vec![4]);
553 unsafe {
554 assert_eq!(*y.as_ptr(), 1.0);
555 assert_eq!(*y.as_ptr().add(1), 4.0);
556 assert_eq!(*y.as_ptr().add(2), 9.0);
557 assert_eq!(*y.as_ptr().add(3), 16.0);
558 }
559 }
560}