rustorch/tensor/
parallel_ops.rs

1use super::Tensor;
2use crate::error::{RusTorchError, RusTorchResult};
3type ParallelResult<T> = RusTorchResult<T>;
4// Parallel operations implementation
5use num_traits::Float;
6use rayon::prelude::*;
7use std::sync::Arc;
8
9/// Parallel batch operations for tensors
10/// テンソルの並列バッチ演算
11impl<T: Float + Send + Sync + Clone + 'static> Tensor<T> {
12    /// Parallel batch matrix multiplication
13    /// 並列バッチ行列乗算
14    pub fn batch_matmul_parallel(&self, other: &Tensor<T>) -> ParallelResult<Tensor<T>> {
15        let self_shape = self.data.shape();
16        let other_shape = other.data.shape();
17
18        if self_shape.len() < 3 || other_shape.len() < 3 {
19            return Err(RusTorchError::parallel("Insufficient dimensions"));
20        }
21
22        let batch_size = self_shape[0];
23        if batch_size != other_shape[0] {
24            return Err(RusTorchError::parallel("Batch size mismatch"));
25        }
26
27        let m = self_shape[1];
28        let k = self_shape[2];
29        let n = other_shape[2];
30
31        if k != other_shape[1] {
32            return Err(RusTorchError::parallel("Matrix dimension mismatch"));
33        }
34
35        let result_shape = vec![batch_size, m, n];
36        let mut result = Self::zeros(&result_shape);
37
38        // Parallel processing over batch dimension
39        let self_data = Arc::new(self.data.clone());
40        let other_data = Arc::new(other.data.clone());
41
42        let results: Vec<_> = (0..batch_size)
43            .into_par_iter()
44            .map(|b| {
45                let mut batch_result = vec![T::zero(); m * n];
46
47                // Extract batch matrices
48                for i in 0..m {
49                    for j in 0..n {
50                        let mut sum = T::zero();
51                        for l in 0..k {
52                            let a_idx = b * m * k + i * k + l;
53                            let b_idx = b * k * n + l * n + j;
54
55                            if let (Some(a_val), Some(b_val)) = (
56                                self_data.as_slice().and_then(|s| s.get(a_idx)),
57                                other_data.as_slice().and_then(|s| s.get(b_idx)),
58                            ) {
59                                sum = sum + *a_val * *b_val;
60                            }
61                        }
62                        batch_result[i * n + j] = sum;
63                    }
64                }
65                batch_result
66            })
67            .collect();
68
69        // Combine results
70        if let Some(result_slice) = result.data.as_slice_mut() {
71            for (b, batch_result) in results.iter().enumerate() {
72                let start_idx = b * m * n;
73                for (i, &val) in batch_result.iter().enumerate() {
74                    if let Some(dest) = result_slice.get_mut(start_idx + i) {
75                        *dest = val;
76                    }
77                }
78            }
79        }
80
81        Ok(result)
82    }
83
84    /// Parallel batch element-wise operations
85    /// 並列バッチ要素ごと演算
86    pub fn batch_add_parallel(&self, other: &Tensor<T>) -> ParallelResult<Tensor<T>> {
87        if self.data.shape() != other.data.shape() {
88            return Err(RusTorchError::parallel("Shape mismatch"));
89        }
90
91        let mut result = Self::zeros(self.data.shape());
92
93        if let (Some(self_slice), Some(other_slice), Some(result_slice)) = (
94            self.data.as_slice(),
95            other.data.as_slice(),
96            result.data.as_slice_mut(),
97        ) {
98            result_slice
99                .par_iter_mut()
100                .zip(self_slice.par_iter())
101                .zip(other_slice.par_iter())
102                .for_each(|((r, &a), &b)| {
103                    *r = a + b;
104                });
105        }
106
107        Ok(result)
108    }
109
110    /// Parallel batch scalar multiplication
111    /// 並列バッチスカラー乗算
112    pub fn batch_mul_scalar_parallel(&self, scalar: T) -> Tensor<T> {
113        let mut result = Self::zeros(self.data.shape());
114
115        if let (Some(self_slice), Some(result_slice)) =
116            (self.data.as_slice(), result.data.as_slice_mut())
117        {
118            result_slice
119                .par_iter_mut()
120                .zip(self_slice.par_iter())
121                .for_each(|(r, &a)| {
122                    *r = a * scalar;
123                });
124        }
125
126        result
127    }
128
129    /// Parallel batch normalization
130    /// 並列バッチ正規化
131    pub fn batch_normalize_parallel(&self, epsilon: T) -> Tensor<T> {
132        let shape = self.data.shape();
133        if shape.len() < 2 {
134            return self.clone();
135        }
136
137        let batch_size = shape[0];
138        let feature_size: usize = shape[1..].iter().product();
139
140        let mut result = Self::zeros(shape);
141
142        if let (Some(self_slice), Some(result_slice)) =
143            (self.data.as_slice(), result.data.as_slice_mut())
144        {
145            // Parallel computation of batch statistics and normalization
146            let batch_results: Vec<_> = (0..batch_size)
147                .into_par_iter()
148                .map(|b| {
149                    let start_idx = b * feature_size;
150                    let end_idx = start_idx + feature_size;
151                    let batch_data = &self_slice[start_idx..end_idx];
152
153                    // Compute mean
154                    let mean = batch_data.iter().fold(T::zero(), |acc, &x| acc + x)
155                        / T::from(feature_size).unwrap();
156
157                    // Compute variance
158                    let variance = batch_data.iter().fold(T::zero(), |acc, &x| {
159                        let diff = x - mean;
160                        acc + diff * diff
161                    }) / T::from(feature_size).unwrap();
162
163                    let std_dev = (variance + epsilon).sqrt();
164
165                    // Normalize
166                    let normalized: Vec<T> =
167                        batch_data.iter().map(|&x| (x - mean) / std_dev).collect();
168
169                    normalized
170                })
171                .collect();
172
173            // Copy results back
174            for (b, normalized) in batch_results.iter().enumerate() {
175                let start_idx = b * feature_size;
176                for (i, &val) in normalized.iter().enumerate() {
177                    if let Some(dest) = result_slice.get_mut(start_idx + i) {
178                        *dest = val;
179                    }
180                }
181            }
182        }
183
184        result
185    }
186
187    /// Parallel batch convolution (simplified 2D)
188    /// 並列バッチ畳み込み(簡略化2D)
189    pub fn batch_conv2d_parallel(
190        &self,
191        kernel: &Tensor<T>,
192        stride: usize,
193        padding: usize,
194    ) -> ParallelResult<Tensor<T>> {
195        let input_shape = self.data.shape();
196        let kernel_shape = kernel.data.shape();
197
198        if input_shape.len() != 4 || kernel_shape.len() != 4 {
199            return Err(RusTorchError::parallel("Insufficient dimensions"));
200        }
201
202        let batch_size = input_shape[0];
203        let in_channels = input_shape[1];
204        let in_height = input_shape[2];
205        let in_width = input_shape[3];
206
207        let out_channels = kernel_shape[0];
208        let kernel_height = kernel_shape[2];
209        let kernel_width = kernel_shape[3];
210
211        if in_channels != kernel_shape[1] {
212            return Err(RusTorchError::parallel("Convolution error"));
213        }
214
215        let out_height = (in_height + 2 * padding - kernel_height) / stride + 1;
216        let out_width = (in_width + 2 * padding - kernel_width) / stride + 1;
217
218        let result_shape = vec![batch_size, out_channels, out_height, out_width];
219        let mut result = Self::zeros(&result_shape);
220
221        // Parallel processing over batch and output channels
222        let self_data = Arc::new(self.data.clone());
223        let kernel_data = Arc::new(kernel.data.clone());
224
225        let batch_channel_pairs: Vec<(usize, usize)> = (0..batch_size)
226            .flat_map(|b| (0..out_channels).map(move |oc| (b, oc)))
227            .collect();
228
229        let results: Vec<_> = batch_channel_pairs
230            .into_par_iter()
231            .map(|(b, oc)| {
232                let mut channel_result = vec![T::zero(); out_height * out_width];
233
234                for oh in 0..out_height {
235                    for ow in 0..out_width {
236                        let mut sum = T::zero();
237
238                        for ic in 0..in_channels {
239                            for kh in 0..kernel_height {
240                                for kw in 0..kernel_width {
241                                    let ih = oh * stride + kh;
242                                    let iw = ow * stride + kw;
243
244                                    if ih >= padding && iw >= padding {
245                                        let ih = ih - padding;
246                                        let iw = iw - padding;
247
248                                        if ih < in_height && iw < in_width {
249                                            let input_idx = b * in_channels * in_height * in_width
250                                                + ic * in_height * in_width
251                                                + ih * in_width
252                                                + iw;
253                                            let kernel_idx =
254                                                oc * in_channels * kernel_height * kernel_width
255                                                    + ic * kernel_height * kernel_width
256                                                    + kh * kernel_width
257                                                    + kw;
258
259                                            if let (Some(input_val), Some(kernel_val)) = (
260                                                self_data.as_slice().and_then(|s| s.get(input_idx)),
261                                                kernel_data
262                                                    .as_slice()
263                                                    .and_then(|s| s.get(kernel_idx)),
264                                            ) {
265                                                sum = sum + *input_val * *kernel_val;
266                                            }
267                                        }
268                                    }
269                                }
270                            }
271                        }
272
273                        channel_result[oh * out_width + ow] = sum;
274                    }
275                }
276
277                (b, oc, channel_result)
278            })
279            .collect();
280
281        // Combine results
282        if let Some(result_slice) = result.data.as_slice_mut() {
283            for (b, oc, channel_result) in results {
284                let start_idx =
285                    b * out_channels * out_height * out_width + oc * out_height * out_width;
286
287                for (i, &val) in channel_result.iter().enumerate() {
288                    if let Some(dest) = result_slice.get_mut(start_idx + i) {
289                        *dest = val;
290                    }
291                }
292            }
293        }
294
295        Ok(result)
296    }
297
298    /// Parallel batch reduction operations
299    /// 並列バッチリダクション演算
300    pub fn batch_sum_parallel(&self, dim: usize) -> ParallelResult<Tensor<T>> {
301        let shape = self.data.shape();
302        if dim >= shape.len() {
303            return Err(RusTorchError::parallel("Dimension error"));
304        }
305
306        let mut result_shape = shape.to_vec();
307        result_shape.remove(dim);
308
309        if result_shape.is_empty() {
310            // Scalar result
311            if let Some(slice) = self.data.as_slice() {
312                let sum = slice
313                    .par_iter()
314                    .fold(|| T::zero(), |acc, &x| acc + x)
315                    .reduce(|| T::zero(), |a, b| a + b);
316                return Ok(Tensor::from_vec(vec![sum], vec![]));
317            }
318        }
319
320        let mut result = Self::zeros(&result_shape);
321
322        // Parallel reduction along specified dimension
323        let _total_elements = shape.iter().product::<usize>();
324        let dim_size = shape[dim];
325        let _stride_before: usize = shape[..dim].iter().product();
326        let stride_after: usize = shape[dim + 1..].iter().product();
327
328        if let Some(self_slice) = self.data.as_slice() {
329            let result_elements = result_shape.iter().product::<usize>();
330
331            let computed_results: Vec<_> = (0..result_elements)
332                .into_par_iter()
333                .map(|result_idx| {
334                    let before_idx = result_idx / stride_after;
335                    let after_idx = result_idx % stride_after;
336
337                    let mut sum = T::zero();
338                    for d in 0..dim_size {
339                        let source_idx =
340                            before_idx * dim_size * stride_after + d * stride_after + after_idx;
341                        if let Some(&val) = self_slice.get(source_idx) {
342                            sum = sum + val;
343                        }
344                    }
345                    (result_idx, sum)
346                })
347                .collect();
348
349            // Copy results back
350            if let Some(result_slice) = result.data.as_slice_mut() {
351                for (idx, val) in computed_results {
352                    if let Some(dest) = result_slice.get_mut(idx) {
353                        *dest = val;
354                    }
355                }
356            }
357        }
358
359        Ok(result)
360    }
361
362    /// Parallel batch mean computation
363    /// 並列バッチ平均計算
364    pub fn batch_mean_parallel(&self, dim: usize) -> ParallelResult<Tensor<T>> {
365        let shape = self.data.shape();
366        if dim >= shape.len() {
367            return Err(RusTorchError::parallel("Dimension error"));
368        }
369
370        let sum_result = self.batch_sum_parallel(dim)?;
371        let dim_size = T::from(shape[dim]).unwrap();
372
373        Ok(sum_result.batch_mul_scalar_parallel(T::one() / dim_size))
374    }
375}
376
377/// Specialized f32 implementations with SIMD integration
378/// SIMD統合を含むf32特殊化実装
379impl Tensor<f32> {
380    /// High-performance parallel batch operations for f32
381    /// f32用高性能並列バッチ演算
382    pub fn batch_simd_add_parallel(&self, other: &Tensor<f32>) -> ParallelResult<Tensor<f32>> {
383        if self.data.shape() != other.data.shape() {
384            return Err(RusTorchError::parallel("Shape mismatch"));
385        }
386
387        let mut result = Self::zeros(self.data.shape());
388
389        if let (Some(self_slice), Some(other_slice), Some(result_slice)) = (
390            self.data.as_slice(),
391            other.data.as_slice(),
392            result.data.as_slice_mut(),
393        ) {
394            // Use chunked parallel processing for better SIMD utilization
395            const CHUNK_SIZE: usize = 1024;
396
397            self_slice
398                .par_chunks(CHUNK_SIZE)
399                .zip(other_slice.par_chunks(CHUNK_SIZE))
400                .zip(result_slice.par_chunks_mut(CHUNK_SIZE))
401                .for_each(|((a_chunk, b_chunk), r_chunk)| {
402                    // Use SIMD operations for each chunk
403                    #[cfg(not(target_arch = "wasm32"))]
404                    {
405                        crate::simd::ops::add_optimized(a_chunk, b_chunk, r_chunk);
406                    }
407                    #[cfg(target_arch = "wasm32")]
408                    {
409                        // Fallback for WASM
410                        for ((a_elem, b_elem), r_elem) in
411                            a_chunk.iter().zip(b_chunk.iter()).zip(r_chunk.iter_mut())
412                        {
413                            *r_elem = *a_elem + *b_elem;
414                        }
415                    }
416                });
417        }
418
419        Ok(result)
420    }
421
422    /// Parallel batch matrix multiplication with SIMD optimization
423    /// SIMD最適化を含む並列バッチ行列乗算
424    pub fn batch_simd_matmul_parallel(&self, other: &Tensor<f32>) -> ParallelResult<Tensor<f32>> {
425        let self_shape = self.data.shape();
426        let other_shape = other.data.shape();
427
428        if self_shape.len() < 3 || other_shape.len() < 3 {
429            return Err(RusTorchError::parallel("Insufficient dimensions"));
430        }
431
432        let batch_size = self_shape[0];
433        let m = self_shape[1];
434        let k = self_shape[2];
435        let n = other_shape[2];
436
437        let result_shape = vec![batch_size, m, n];
438        let mut result = Self::zeros(&result_shape);
439
440        // Parallel processing with SIMD matrix multiplication
441        if let (Some(self_slice), Some(other_slice)) = (self.data.as_slice(), other.data.as_slice())
442        {
443            let batch_results: Vec<_> = (0..batch_size)
444                .into_par_iter()
445                .map(|b| {
446                    let self_batch = &self_slice[b * m * k..(b + 1) * m * k];
447                    let other_batch = &other_slice[b * k * n..(b + 1) * k * n];
448
449                    // Create a temporary result for this batch
450                    let mut batch_result = vec![0.0f32; m * n];
451
452                    // Use SIMD-optimized matrix multiplication
453                    #[cfg(not(target_arch = "wasm32"))]
454                    {
455                        crate::simd::ops::matmul_optimized(
456                            self_batch,
457                            m,
458                            k,
459                            other_batch,
460                            k,
461                            n,
462                            &mut batch_result,
463                        );
464                    }
465                    #[cfg(target_arch = "wasm32")]
466                    {
467                        // Fallback for WASM: simple matrix multiplication
468                        for i in 0..m {
469                            for j in 0..n {
470                                let mut sum = 0.0f32;
471                                for p in 0..k {
472                                    sum += self_batch[i * k + p] * other_batch[p * n + j];
473                                }
474                                batch_result[i * n + j] = sum;
475                            }
476                        }
477                    }
478
479                    batch_result
480                })
481                .collect();
482
483            // Copy results back
484            if let Some(result_slice) = result.data.as_slice_mut() {
485                for (b, batch_result) in batch_results.iter().enumerate() {
486                    let start_idx = b * m * n;
487                    for (i, &val) in batch_result.iter().enumerate() {
488                        if let Some(dest) = result_slice.get_mut(start_idx + i) {
489                            *dest = val;
490                        }
491                    }
492                }
493            }
494        }
495
496        Ok(result)
497    }
498}
499
500#[cfg(test)]
501mod tests {
502    use super::*;
503
504    #[test]
505    fn test_batch_add_parallel() {
506        let a =
507            Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], vec![2, 2, 2]);
508        let b =
509            Tensor::<f32>::from_vec(vec![1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], vec![2, 2, 2]);
510
511        let result = a.batch_add_parallel(&b).unwrap();
512        let expected = vec![2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
513
514        assert_eq!(result.data.as_slice().unwrap(), &expected);
515    }
516
517    #[test]
518    fn test_batch_matmul_parallel() {
519        let a =
520            Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], vec![2, 2, 2]);
521        let b =
522            Tensor::<f32>::from_vec(vec![1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0], vec![2, 2, 2]);
523
524        let result = a.batch_matmul_parallel(&b).unwrap();
525
526        // Should be identity multiplication for each batch
527        assert_eq!(result.size(), vec![2, 2, 2]);
528        assert_eq!(result.data.as_slice().unwrap(), a.data.as_slice().unwrap());
529    }
530
531    #[test]
532    fn test_batch_normalize_parallel() {
533        let a = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], vec![2, 4]);
534
535        let result = a.batch_normalize_parallel(1e-5);
536
537        // Check that each batch is normalized (mean ≈ 0, std ≈ 1)
538        assert_eq!(result.size(), vec![2, 4]);
539
540        if let Some(slice) = result.data.as_slice() {
541            // First batch
542            let batch1_mean: f32 = slice[0..4].iter().sum::<f32>() / 4.0;
543            assert!((batch1_mean).abs() < 1e-5);
544
545            // Second batch
546            let batch2_mean: f32 = slice[4..8].iter().sum::<f32>() / 4.0;
547            assert!((batch2_mean).abs() < 1e-5);
548        }
549    }
550
551    #[test]
552    fn test_batch_sum_parallel() {
553        let a = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]);
554
555        // Sum along dimension 1 (columns)
556        let result = a.batch_sum_parallel(1).unwrap();
557        assert_eq!(result.size(), vec![2]);
558
559        let expected = vec![6.0, 15.0]; // [1+2+3, 4+5+6]
560        assert_eq!(result.data.as_slice().unwrap(), &expected);
561    }
562
563    #[test]
564    fn test_batch_simd_add_parallel() {
565        let size = 1000;
566        let a = Tensor::<f32>::from_vec((0..size).map(|i| i as f32).collect(), vec![10, 100]);
567        let b = Tensor::<f32>::from_vec(vec![1.0; size], vec![10, 100]);
568
569        let result = a.batch_simd_add_parallel(&b).unwrap();
570
571        if let Some(slice) = result.data.as_slice() {
572            for (i, &val) in slice.iter().enumerate() {
573                assert_eq!(val, i as f32 + 1.0);
574            }
575        }
576    }
577
578    #[test]
579    fn test_large_batch_performance() {
580        let batch_size = 100;
581        let feature_size = 1000;
582
583        let a = Tensor::<f32>::from_vec(
584            (0..batch_size * feature_size)
585                .map(|i| (i % 100) as f32)
586                .collect(),
587            vec![batch_size, feature_size],
588        );
589        let b = Tensor::<f32>::from_vec(
590            vec![0.1; batch_size * feature_size],
591            vec![batch_size, feature_size],
592        );
593
594        let result = a.batch_add_parallel(&b).unwrap();
595        assert_eq!(result.size(), vec![batch_size, feature_size]);
596
597        // Verify correctness
598        if let (Some(a_slice), Some(b_slice), Some(result_slice)) =
599            (a.data.as_slice(), b.data.as_slice(), result.data.as_slice())
600        {
601            for i in 0..batch_size * feature_size {
602                assert!((result_slice[i] - (a_slice[i] + b_slice[i])).abs() < 1e-6);
603            }
604        }
605    }
606}