train_station/tensor/ops/pow.rs
1//! Power operations for tensors
2//!
3//! Provides element-wise power functions following PyTorch conventions with
4//! comprehensive automatic differentiation support and SIMD-optimized computation.
5//!
6//! # Key Features
7//!
8//! - **Scalar Power**: `pow_scalar(exponent)` - Raises each element to a scalar power (PyTorch `pow(tensor, scalar)` equivalent)
9//! - **Tensor Power**: `pow_tensor(exponent)` - Element-wise power with tensor exponents (PyTorch `pow(tensor, tensor)` equivalent)
10//! - **Automatic Differentiation**: Full gradtrack support with efficient gradient computation
11//! - **SIMD Optimization**: AVX2-optimized implementation for common cases (x^2, x^0.5)
12//! - **Smart Dispatch**: Optimized paths for common exponents (2.0, 0.5) with scalar fallback for others
13//! - **Cache Optimization**: Memory access patterns optimized for modern CPUs
14//! - **Mathematical Accuracy**: High-precision power computation
15//!
16//! # Mathematical Properties
17//!
18//! The power operations have the following properties:
19//! - **Power Laws**: (x^a)^b = x^(a*b), x^a * x^b = x^(a+b)
20//! - **Special Cases**: x^0 = 1, x^1 = x, x^2 = x*x, x^0.5 = sqrt(x)
21//! - **Domain**: x^a is defined for x > 0 when a is not an integer
22//! - **Gradient**: d/dx(x^a) = a * x^(a-1) for scalar power
23//! - **Gradient**: d/dx(x^y) = y * x^(y-1), d/dy(x^y) = x^y * ln(x) for tensor power
24//!
25//! # Performance Characteristics
26//!
27//! - **SIMD Optimization**: AVX2-optimized for x^2 and x^0.5 with 32-element blocks and 4x unrolling
28//! - **Smart Dispatch**: Fast paths for common exponents (2.0, 0.5) with scalar fallback for others
29//! - **Scalar Fallback**: 4x unrolled scalar implementation for non-SIMD hardware and general exponents
30//! - **Cache-friendly Access**: Linear memory access patterns
31//! - **Branch Prediction**: Optimized conditional logic for modern CPUs
32//! - **Gradient Optimization**: Efficient gradtrack with NoGradTrack support
33
34use crate::gradtrack::{is_grad_enabled, GradEngine, GradFn};
35use crate::tensor::core::Tensor;
36
37// SIMD optimizations for performance-critical operations
38#[cfg(target_arch = "x86_64")]
39use std::arch::x86_64::*;
40
41impl Tensor {
42 /// Raises each element to a scalar power.
43 ///
44 /// Computes element-wise power: `output[i] = self[i]^exponent`
45 ///
46 /// # Arguments
47 /// * `exponent` - The scalar exponent to raise each element to
48 ///
49 /// # Returns
50 /// A new tensor with each element raised to the given power
51 ///
52 /// # Examples
53 ///
54 /// ## Basic Scalar Power
55 ///
56 /// ```
57 /// use train_station::Tensor;
58 ///
59 /// let a = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
60 /// let b = a.pow_scalar(2.0);
61 /// assert_eq!(b.shape().dims, vec![3]);
62 /// assert_eq!(b.get(&[0]), 1.0); // 1.0^2 = 1.0
63 /// assert_eq!(b.get(&[1]), 4.0); // 2.0^2 = 4.0
64 /// assert_eq!(b.get(&[2]), 9.0); // 3.0^2 = 9.0
65 /// ```
66 ///
67 /// ## Square Root (Power 0.5)
68 ///
69 /// ```
70 /// use train_station::Tensor;
71 ///
72 /// let a = Tensor::from_slice(&[1.0, 4.0, 9.0], vec![3]).unwrap();
73 /// let b = a.pow_scalar(0.5);
74 /// assert_eq!(b.shape().dims, vec![3]);
75 /// assert_eq!(b.get(&[0]), 1.0); // sqrt(1.0) = 1.0
76 /// assert_eq!(b.get(&[1]), 2.0); // sqrt(4.0) = 2.0
77 /// assert_eq!(b.get(&[2]), 3.0); // sqrt(9.0) = 3.0
78 /// ```
79 pub fn pow_scalar(&self, exponent: f32) -> Tensor {
80 let mut out = self.pow_scalar_optimized(exponent);
81
82 if self.requires_grad() && is_grad_enabled() {
83 out.set_requires_grad_internal(true);
84 let grad_fn = GradFn::PowScalar {
85 exponent,
86 saved_input: Box::new(self.clone()),
87 };
88 out.set_grad_fn(grad_fn.clone());
89 GradEngine::register_operation(out.id(), vec![self.id()], grad_fn);
90 }
91
92 out
93 }
94
95 /// Internal optimized scalar power operation
96 ///
97 /// Performs element-wise scalar power computation using smart dispatch for common
98 /// exponents and optimized scalar computation. This is the core implementation
99 /// used by `pow_scalar()`.
100 ///
101 /// # Arguments
102 ///
103 /// * `exponent` - The scalar exponent to raise each element to
104 ///
105 /// # Returns
106 ///
107 /// A new tensor containing each element raised to the given power
108 ///
109 /// # Performance Characteristics
110 ///
111 /// - **Smart Dispatch**: Fast paths for common exponents (2.0, 0.5) with SIMD optimization
112 /// - **SIMD Optimization**: AVX2-optimized for x^2 and x^0.5 when available
113 /// - **Scalar Fallback**: 4x unrolled scalar implementation for non-SIMD hardware and general exponents
114 /// - **Cache-friendly**: Linear memory access patterns
115 /// - **Mathematical Accuracy**: High-precision power computation
116 /// - **Zero-sized Handling**: Fast return for empty tensors
117 ///
118 /// # Implementation Details
119 ///
120 /// Uses smart dispatch to optimize common cases:
121 /// - `exponent == 2.0`: Uses SIMD multiplication for x^2
122 /// - `exponent == 0.5`: Uses SIMD square root for x^0.5
123 /// - Other exponents: Uses scalar `powf()` for accuracy
124 #[inline]
125 pub(crate) fn pow_scalar_optimized(&self, exponent: f32) -> Tensor {
126 let mut output = Tensor::new(self.shape().dims.clone());
127
128 if self.size() == 0 {
129 return output;
130 }
131
132 unsafe {
133 let src = self.as_ptr();
134 let dst = output.as_mut_ptr();
135
136 // Handle common cases with SIMD optimizations
137 if exponent == 2.0 {
138 #[cfg(target_arch = "x86_64")]
139 {
140 if is_x86_feature_detected!("avx2") {
141 self.pow_square_simd_avx2_optimized(src, dst);
142 return output;
143 }
144 }
145 self.pow_square_scalar_optimized(src, dst);
146 } else if exponent == 0.5 {
147 #[cfg(target_arch = "x86_64")]
148 {
149 if is_x86_feature_detected!("avx2") {
150 self.pow_sqrt_simd_avx2_optimized(src, dst);
151 return output;
152 }
153 }
154 self.pow_sqrt_scalar_optimized(src, dst);
155 } else {
156 // General case - use scalar fallback for accuracy
157 self.pow_general_scalar_optimized(src, dst, exponent);
158 }
159 }
160
161 output
162 }
163
164 /// AVX2-optimized square implementation (x^2)
165 ///
166 /// Performs element-wise squaring using AVX2 SIMD instructions for maximum
167 /// performance on x86_64 architectures with AVX2 support.
168 ///
169 /// # Arguments
170 ///
171 /// * `src` - Pointer to source tensor data
172 /// * `dst` - Pointer to output tensor data
173 ///
174 /// # Safety
175 ///
176 /// Requires valid pointers with sufficient memory for the tensor size.
177 /// All pointers must point to valid tensor data. Requires AVX2 support.
178 ///
179 /// # Performance Characteristics
180 ///
181 /// - **SIMD Processing**: 32 elements per iteration with 4x unrolling
182 /// - **Memory Access**: Linear access patterns for cache efficiency
183 /// - **Vector Operations**: Uses AVX2 multiplication instructions for x^2
184 /// - **Fallback**: Handles remaining elements with scalar operations
185 /// - **Hardware Requirements**: Requires x86_64 with AVX2 support
186 ///
187 /// # Implementation Details
188 ///
189 /// Uses AVX2 vector multiplication to compute x^2 efficiently.
190 /// Implements 4x unrolling for optimal instruction throughput and cache utilization.
191 /// Processes remaining elements with scalar operations for complete coverage.
192 #[cfg(target_arch = "x86_64")]
193 #[inline]
194 #[target_feature(enable = "avx2")]
195 unsafe fn pow_square_simd_avx2_optimized(&self, src: *const f32, dst: *mut f32) {
196 let size = self.size();
197 let simd_count = size / 32; // Process 32 elements per iteration (4x unroll)
198 let mut offset = 0;
199
200 // Unrolled SIMD loop for x^2
201 for _ in 0..simd_count {
202 // Process 4 AVX2 vectors (32 elements) per iteration
203 let src_vec1 = _mm256_loadu_ps(src.add(offset));
204 let square_vec1 = _mm256_mul_ps(src_vec1, src_vec1);
205 _mm256_storeu_ps(dst.add(offset), square_vec1);
206
207 let src_vec2 = _mm256_loadu_ps(src.add(offset + 8));
208 let square_vec2 = _mm256_mul_ps(src_vec2, src_vec2);
209 _mm256_storeu_ps(dst.add(offset + 8), square_vec2);
210
211 let src_vec3 = _mm256_loadu_ps(src.add(offset + 16));
212 let square_vec3 = _mm256_mul_ps(src_vec3, src_vec3);
213 _mm256_storeu_ps(dst.add(offset + 16), square_vec3);
214
215 let src_vec4 = _mm256_loadu_ps(src.add(offset + 24));
216 let square_vec4 = _mm256_mul_ps(src_vec4, src_vec4);
217 _mm256_storeu_ps(dst.add(offset + 24), square_vec4);
218
219 offset += 32;
220 }
221
222 // Handle remaining 8-element blocks
223 let remaining_full_blocks = (size - offset) / 8;
224 for _ in 0..remaining_full_blocks {
225 let src_vec = _mm256_loadu_ps(src.add(offset));
226 let square_vec = _mm256_mul_ps(src_vec, src_vec);
227 _mm256_storeu_ps(dst.add(offset), square_vec);
228 offset += 8;
229 }
230
231 // Handle remaining elements
232 for i in offset..size {
233 let v = *src.add(i);
234 *dst.add(i) = v * v;
235 }
236 }
237
238 /// AVX2-optimized square root implementation (x^0.5)
239 ///
240 /// Performs element-wise square root using AVX2 SIMD instructions for maximum
241 /// performance on x86_64 architectures with AVX2 support.
242 ///
243 /// # Arguments
244 ///
245 /// * `src` - Pointer to source tensor data
246 /// * `dst` - Pointer to output tensor data
247 ///
248 /// # Safety
249 ///
250 /// Requires valid pointers with sufficient memory for the tensor size.
251 /// All pointers must point to valid tensor data. Requires AVX2 support.
252 ///
253 /// # Performance Characteristics
254 ///
255 /// - **SIMD Processing**: 32 elements per iteration with 4x unrolling
256 /// - **Memory Access**: Linear access patterns for cache efficiency
257 /// - **Vector Operations**: Uses AVX2 square root instructions for x^0.5
258 /// - **Fallback**: Handles remaining elements with scalar operations
259 /// - **Hardware Requirements**: Requires x86_64 with AVX2 support
260 ///
261 /// # Implementation Details
262 ///
263 /// Uses AVX2 vector square root instructions to compute x^0.5 efficiently.
264 /// Implements 4x unrolling for optimal instruction throughput and cache utilization.
265 /// Processes remaining elements with scalar operations for complete coverage.
266 #[cfg(target_arch = "x86_64")]
267 #[inline]
268 #[target_feature(enable = "avx2")]
269 unsafe fn pow_sqrt_simd_avx2_optimized(&self, src: *const f32, dst: *mut f32) {
270 let size = self.size();
271 let simd_count = size / 32; // Process 32 elements per iteration (4x unroll)
272 let mut offset = 0;
273
274 // Unrolled SIMD loop for x^0.5 (sqrt)
275 for _ in 0..simd_count {
276 // Process 4 AVX2 vectors (32 elements) per iteration
277 let src_vec1 = _mm256_loadu_ps(src.add(offset));
278 let sqrt_vec1 = _mm256_sqrt_ps(src_vec1);
279 _mm256_storeu_ps(dst.add(offset), sqrt_vec1);
280
281 let src_vec2 = _mm256_loadu_ps(src.add(offset + 8));
282 let sqrt_vec2 = _mm256_sqrt_ps(src_vec2);
283 _mm256_storeu_ps(dst.add(offset + 8), sqrt_vec2);
284
285 let src_vec3 = _mm256_loadu_ps(src.add(offset + 16));
286 let sqrt_vec3 = _mm256_sqrt_ps(src_vec3);
287 _mm256_storeu_ps(dst.add(offset + 16), sqrt_vec3);
288
289 let src_vec4 = _mm256_loadu_ps(src.add(offset + 24));
290 let sqrt_vec4 = _mm256_sqrt_ps(src_vec4);
291 _mm256_storeu_ps(dst.add(offset + 24), sqrt_vec4);
292
293 offset += 32;
294 }
295
296 // Handle remaining 8-element blocks
297 let remaining_full_blocks = (size - offset) / 8;
298 for _ in 0..remaining_full_blocks {
299 let src_vec = _mm256_loadu_ps(src.add(offset));
300 let sqrt_vec = _mm256_sqrt_ps(src_vec);
301 _mm256_storeu_ps(dst.add(offset), sqrt_vec);
302 offset += 8;
303 }
304
305 // Handle remaining elements
306 for i in offset..size {
307 *dst.add(i) = (*src.add(i)).sqrt();
308 }
309 }
310
311 /// Optimized scalar square fallback (x^2)
312 ///
313 /// Performs element-wise squaring using optimized scalar operations with
314 /// 4x unrolling for better instruction-level parallelism and cache efficiency.
315 ///
316 /// # Arguments
317 ///
318 /// * `src` - Pointer to source tensor data
319 /// * `dst` - Pointer to output tensor data
320 ///
321 /// # Safety
322 ///
323 /// Requires valid pointers with sufficient memory for the tensor size.
324 /// All pointers must point to valid tensor data.
325 ///
326 /// # Performance Characteristics
327 ///
328 /// - **Unrolling**: 4x unrolling for instruction-level parallelism
329 /// - **Memory Access**: Linear access patterns for cache efficiency
330 /// - **Fallback**: Handles remaining elements with scalar operations
331 /// - **Cache Optimization**: Optimized for modern CPU cache hierarchies
332 /// - **Mathematical Accuracy**: High-precision scalar multiplication
333 ///
334 /// # Implementation Details
335 ///
336 /// Uses 4x unrolled scalar multiplication for optimal performance on non-SIMD hardware.
337 /// Processes elements in groups of 4 to improve instruction-level parallelism
338 /// and reduce loop overhead.
339 #[inline]
340 unsafe fn pow_square_scalar_optimized(&self, src: *const f32, dst: *mut f32) {
341 let size = self.size();
342 let unroll_count = size / 4;
343 let mut offset = 0;
344
345 // Unrolled scalar loop for x^2
346 for _ in 0..unroll_count {
347 let v1 = *src.add(offset);
348 let v2 = *src.add(offset + 1);
349 let v3 = *src.add(offset + 2);
350 let v4 = *src.add(offset + 3);
351
352 *dst.add(offset) = v1 * v1;
353 *dst.add(offset + 1) = v2 * v2;
354 *dst.add(offset + 2) = v3 * v3;
355 *dst.add(offset + 3) = v4 * v4;
356
357 offset += 4;
358 }
359
360 // Handle remaining elements
361 for i in offset..size {
362 let v = *src.add(i);
363 *dst.add(i) = v * v;
364 }
365 }
366
367 /// Optimized scalar square root fallback (x^0.5)
368 ///
369 /// Performs element-wise square root using optimized scalar operations with
370 /// 4x unrolling for better instruction-level parallelism and cache efficiency.
371 ///
372 /// # Arguments
373 ///
374 /// * `src` - Pointer to source tensor data
375 /// * `dst` - Pointer to output tensor data
376 ///
377 /// # Safety
378 ///
379 /// Requires valid pointers with sufficient memory for the tensor size.
380 /// All pointers must point to valid tensor data.
381 ///
382 /// # Performance Characteristics
383 ///
384 /// - **Unrolling**: 4x unrolling for instruction-level parallelism
385 /// - **Memory Access**: Linear access patterns for cache efficiency
386 /// - **Fallback**: Handles remaining elements with scalar operations
387 /// - **Cache Optimization**: Optimized for modern CPU cache hierarchies
388 /// - **Mathematical Accuracy**: High-precision scalar square root
389 ///
390 /// # Implementation Details
391 ///
392 /// Uses 4x unrolled scalar square root for optimal performance on non-SIMD hardware.
393 /// Processes elements in groups of 4 to improve instruction-level parallelism
394 /// and reduce loop overhead.
395 #[inline]
396 unsafe fn pow_sqrt_scalar_optimized(&self, src: *const f32, dst: *mut f32) {
397 let size = self.size();
398 let unroll_count = size / 4;
399 let mut offset = 0;
400
401 // Unrolled scalar loop for x^0.5
402 for _ in 0..unroll_count {
403 *dst.add(offset) = (*src.add(offset)).sqrt();
404 *dst.add(offset + 1) = (*src.add(offset + 1)).sqrt();
405 *dst.add(offset + 2) = (*src.add(offset + 2)).sqrt();
406 *dst.add(offset + 3) = (*src.add(offset + 3)).sqrt();
407 offset += 4;
408 }
409
410 // Handle remaining elements
411 for i in offset..size {
412 *dst.add(i) = (*src.add(i)).sqrt();
413 }
414 }
415
416 /// Optimized scalar general power fallback (x^exponent)
417 ///
418 /// Performs element-wise power computation using optimized scalar operations with
419 /// 4x unrolling for better instruction-level parallelism and cache efficiency.
420 ///
421 /// # Arguments
422 ///
423 /// * `src` - Pointer to source tensor data
424 /// * `dst` - Pointer to output tensor data
425 /// * `exponent` - The scalar exponent to raise each element to
426 ///
427 /// # Safety
428 ///
429 /// Requires valid pointers with sufficient memory for the tensor size.
430 /// All pointers must point to valid tensor data.
431 ///
432 /// # Performance Characteristics
433 ///
434 /// - **Unrolling**: 4x unrolling for instruction-level parallelism
435 /// - **Memory Access**: Linear access patterns for cache efficiency
436 /// - **Fallback**: Handles remaining elements with scalar operations
437 /// - **Cache Optimization**: Optimized for modern CPU cache hierarchies
438 /// - **Mathematical Accuracy**: High-precision scalar power computation
439 ///
440 /// # Implementation Details
441 ///
442 /// Uses 4x unrolled scalar power for optimal performance on non-SIMD hardware.
443 /// Processes elements in groups of 4 to improve instruction-level parallelism
444 /// and reduce loop overhead. Uses `powf()` for general exponent support.
445 #[inline]
446 unsafe fn pow_general_scalar_optimized(&self, src: *const f32, dst: *mut f32, exponent: f32) {
447 let size = self.size();
448 let unroll_count = size / 4;
449 let mut offset = 0;
450
451 // Unrolled scalar loop for general exponent
452 for _ in 0..unroll_count {
453 *dst.add(offset) = (*src.add(offset)).powf(exponent);
454 *dst.add(offset + 1) = (*src.add(offset + 1)).powf(exponent);
455 *dst.add(offset + 2) = (*src.add(offset + 2)).powf(exponent);
456 *dst.add(offset + 3) = (*src.add(offset + 3)).powf(exponent);
457 offset += 4;
458 }
459
460 // Handle remaining elements
461 for i in offset..size {
462 *dst.add(i) = (*src.add(i)).powf(exponent);
463 }
464 }
465
466 /// Element-wise power with tensor exponents.
467 ///
468 /// Computes element-wise power: `output[i] = self[i]^exponent[i]`
469 ///
470 /// # Arguments
471 /// * `exponent` - Tensor of exponents, must have the same shape as self
472 ///
473 /// # Returns
474 /// A new tensor with each element raised to the corresponding power
475 ///
476 /// # Examples
477 ///
478 /// ## Basic Tensor Power
479 ///
480 /// ```
481 /// use train_station::Tensor;
482 ///
483 /// let base = Tensor::from_slice(&[2.0, 3.0, 4.0], vec![3]).unwrap();
484 /// let exp = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
485 /// let result = base.pow_tensor(&exp);
486 /// assert_eq!(result.shape().dims, vec![3]);
487 /// assert_eq!(result.get(&[0]), 2.0); // 2.0^1.0 = 2.0
488 /// assert_eq!(result.get(&[1]), 9.0); // 3.0^2.0 = 9.0
489 /// assert_eq!(result.get(&[2]), 64.0); // 4.0^3.0 = 64.0
490 /// ```
491 ///
492 /// ## Mixed Exponents
493 ///
494 /// ```
495 /// use train_station::Tensor;
496 ///
497 /// let base = Tensor::from_slice(&[4.0, 9.0, 16.0], vec![3]).unwrap();
498 /// let exp = Tensor::from_slice(&[0.5, 1.0, 2.0], vec![3]).unwrap();
499 /// let result = base.pow_tensor(&exp);
500 /// assert_eq!(result.shape().dims, vec![3]);
501 /// assert_eq!(result.get(&[0]), 2.0); // sqrt(4.0) = 2.0
502 /// assert_eq!(result.get(&[1]), 9.0); // 9.0^1.0 = 9.0
503 /// assert_eq!(result.get(&[2]), 256.0); // 16.0^2.0 = 256.0
504 /// ```
505 ///
506 /// # Panics
507 /// Panics if tensor shapes don't match
508 pub fn pow_tensor(&self, exponent: &Tensor) -> Tensor {
509 assert_eq!(
510 self.shape().dims,
511 exponent.shape().dims,
512 "pow_tensor requires identical shapes"
513 );
514 let mut out = Tensor::new(self.shape().dims.clone());
515 unsafe {
516 let x = self.as_ptr();
517 let a = exponent.as_ptr();
518 let y = out.as_mut_ptr();
519 let n = out.size();
520 for i in 0..n {
521 *y.add(i) = (*x.add(i)).powf(*a.add(i));
522 }
523 }
524
525 if (self.requires_grad() || exponent.requires_grad()) && is_grad_enabled() {
526 let mut result = out.clone();
527 result.set_requires_grad_internal(true);
528 let grad_fn = GradFn::PowTensor {
529 saved_base: Box::new(self.clone()),
530 saved_exponent: Box::new(exponent.clone()),
531 };
532 result.set_grad_fn(grad_fn.clone());
533 let parents = vec![self.id(), exponent.id()];
534 GradEngine::register_operation(result.id(), parents, grad_fn);
535 return result;
536 }
537
538 out
539 }
540}
541
542#[cfg(test)]
543mod tests {
544 use super::*;
545
546 #[test]
547 fn test_pow_scalar_forward() {
548 let x = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
549 let y = x.pow_scalar(2.0);
550 assert_eq!(y.shape().dims, vec![4]);
551 unsafe {
552 assert_eq!(*y.as_ptr(), 1.0);
553 assert_eq!(*y.as_ptr().add(1), 4.0);
554 assert_eq!(*y.as_ptr().add(2), 9.0);
555 assert_eq!(*y.as_ptr().add(3), 16.0);
556 }
557 }
558}