1use super::Tensor;
2use crate::error::{RusTorchError, RusTorchResult};
3type ParallelResult<T> = RusTorchResult<T>;
4use num_traits::Float;
6use rayon::prelude::*;
7use std::sync::Arc;
8
9impl<T: Float + Send + Sync + Clone + 'static> Tensor<T> {
12 pub fn batch_matmul_parallel(&self, other: &Tensor<T>) -> ParallelResult<Tensor<T>> {
15 let self_shape = self.data.shape();
16 let other_shape = other.data.shape();
17
18 if self_shape.len() < 3 || other_shape.len() < 3 {
19 return Err(RusTorchError::parallel("Insufficient dimensions"));
20 }
21
22 let batch_size = self_shape[0];
23 if batch_size != other_shape[0] {
24 return Err(RusTorchError::parallel("Batch size mismatch"));
25 }
26
27 let m = self_shape[1];
28 let k = self_shape[2];
29 let n = other_shape[2];
30
31 if k != other_shape[1] {
32 return Err(RusTorchError::parallel("Matrix dimension mismatch"));
33 }
34
35 let result_shape = vec![batch_size, m, n];
36 let mut result = Self::zeros(&result_shape);
37
38 let self_data = Arc::new(self.data.clone());
40 let other_data = Arc::new(other.data.clone());
41
42 let results: Vec<_> = (0..batch_size)
43 .into_par_iter()
44 .map(|b| {
45 let mut batch_result = vec![T::zero(); m * n];
46
47 for i in 0..m {
49 for j in 0..n {
50 let mut sum = T::zero();
51 for l in 0..k {
52 let a_idx = b * m * k + i * k + l;
53 let b_idx = b * k * n + l * n + j;
54
55 if let (Some(a_val), Some(b_val)) = (
56 self_data.as_slice().and_then(|s| s.get(a_idx)),
57 other_data.as_slice().and_then(|s| s.get(b_idx)),
58 ) {
59 sum = sum + *a_val * *b_val;
60 }
61 }
62 batch_result[i * n + j] = sum;
63 }
64 }
65 batch_result
66 })
67 .collect();
68
69 if let Some(result_slice) = result.data.as_slice_mut() {
71 for (b, batch_result) in results.iter().enumerate() {
72 let start_idx = b * m * n;
73 for (i, &val) in batch_result.iter().enumerate() {
74 if let Some(dest) = result_slice.get_mut(start_idx + i) {
75 *dest = val;
76 }
77 }
78 }
79 }
80
81 Ok(result)
82 }
83
84 pub fn batch_add_parallel(&self, other: &Tensor<T>) -> ParallelResult<Tensor<T>> {
87 if self.data.shape() != other.data.shape() {
88 return Err(RusTorchError::parallel("Shape mismatch"));
89 }
90
91 let mut result = Self::zeros(self.data.shape());
92
93 if let (Some(self_slice), Some(other_slice), Some(result_slice)) = (
94 self.data.as_slice(),
95 other.data.as_slice(),
96 result.data.as_slice_mut(),
97 ) {
98 result_slice
99 .par_iter_mut()
100 .zip(self_slice.par_iter())
101 .zip(other_slice.par_iter())
102 .for_each(|((r, &a), &b)| {
103 *r = a + b;
104 });
105 }
106
107 Ok(result)
108 }
109
110 pub fn batch_mul_scalar_parallel(&self, scalar: T) -> Tensor<T> {
113 let mut result = Self::zeros(self.data.shape());
114
115 if let (Some(self_slice), Some(result_slice)) =
116 (self.data.as_slice(), result.data.as_slice_mut())
117 {
118 result_slice
119 .par_iter_mut()
120 .zip(self_slice.par_iter())
121 .for_each(|(r, &a)| {
122 *r = a * scalar;
123 });
124 }
125
126 result
127 }
128
129 pub fn batch_normalize_parallel(&self, epsilon: T) -> Tensor<T> {
132 let shape = self.data.shape();
133 if shape.len() < 2 {
134 return self.clone();
135 }
136
137 let batch_size = shape[0];
138 let feature_size: usize = shape[1..].iter().product();
139
140 let mut result = Self::zeros(shape);
141
142 if let (Some(self_slice), Some(result_slice)) =
143 (self.data.as_slice(), result.data.as_slice_mut())
144 {
145 let batch_results: Vec<_> = (0..batch_size)
147 .into_par_iter()
148 .map(|b| {
149 let start_idx = b * feature_size;
150 let end_idx = start_idx + feature_size;
151 let batch_data = &self_slice[start_idx..end_idx];
152
153 let mean = batch_data.iter().fold(T::zero(), |acc, &x| acc + x)
155 / T::from(feature_size).unwrap();
156
157 let variance = batch_data.iter().fold(T::zero(), |acc, &x| {
159 let diff = x - mean;
160 acc + diff * diff
161 }) / T::from(feature_size).unwrap();
162
163 let std_dev = (variance + epsilon).sqrt();
164
165 let normalized: Vec<T> =
167 batch_data.iter().map(|&x| (x - mean) / std_dev).collect();
168
169 normalized
170 })
171 .collect();
172
173 for (b, normalized) in batch_results.iter().enumerate() {
175 let start_idx = b * feature_size;
176 for (i, &val) in normalized.iter().enumerate() {
177 if let Some(dest) = result_slice.get_mut(start_idx + i) {
178 *dest = val;
179 }
180 }
181 }
182 }
183
184 result
185 }
186
187 pub fn batch_conv2d_parallel(
190 &self,
191 kernel: &Tensor<T>,
192 stride: usize,
193 padding: usize,
194 ) -> ParallelResult<Tensor<T>> {
195 let input_shape = self.data.shape();
196 let kernel_shape = kernel.data.shape();
197
198 if input_shape.len() != 4 || kernel_shape.len() != 4 {
199 return Err(RusTorchError::parallel("Insufficient dimensions"));
200 }
201
202 let batch_size = input_shape[0];
203 let in_channels = input_shape[1];
204 let in_height = input_shape[2];
205 let in_width = input_shape[3];
206
207 let out_channels = kernel_shape[0];
208 let kernel_height = kernel_shape[2];
209 let kernel_width = kernel_shape[3];
210
211 if in_channels != kernel_shape[1] {
212 return Err(RusTorchError::parallel("Convolution error"));
213 }
214
215 let out_height = (in_height + 2 * padding - kernel_height) / stride + 1;
216 let out_width = (in_width + 2 * padding - kernel_width) / stride + 1;
217
218 let result_shape = vec![batch_size, out_channels, out_height, out_width];
219 let mut result = Self::zeros(&result_shape);
220
221 let self_data = Arc::new(self.data.clone());
223 let kernel_data = Arc::new(kernel.data.clone());
224
225 let batch_channel_pairs: Vec<(usize, usize)> = (0..batch_size)
226 .flat_map(|b| (0..out_channels).map(move |oc| (b, oc)))
227 .collect();
228
229 let results: Vec<_> = batch_channel_pairs
230 .into_par_iter()
231 .map(|(b, oc)| {
232 let mut channel_result = vec![T::zero(); out_height * out_width];
233
234 for oh in 0..out_height {
235 for ow in 0..out_width {
236 let mut sum = T::zero();
237
238 for ic in 0..in_channels {
239 for kh in 0..kernel_height {
240 for kw in 0..kernel_width {
241 let ih = oh * stride + kh;
242 let iw = ow * stride + kw;
243
244 if ih >= padding && iw >= padding {
245 let ih = ih - padding;
246 let iw = iw - padding;
247
248 if ih < in_height && iw < in_width {
249 let input_idx = b * in_channels * in_height * in_width
250 + ic * in_height * in_width
251 + ih * in_width
252 + iw;
253 let kernel_idx =
254 oc * in_channels * kernel_height * kernel_width
255 + ic * kernel_height * kernel_width
256 + kh * kernel_width
257 + kw;
258
259 if let (Some(input_val), Some(kernel_val)) = (
260 self_data.as_slice().and_then(|s| s.get(input_idx)),
261 kernel_data
262 .as_slice()
263 .and_then(|s| s.get(kernel_idx)),
264 ) {
265 sum = sum + *input_val * *kernel_val;
266 }
267 }
268 }
269 }
270 }
271 }
272
273 channel_result[oh * out_width + ow] = sum;
274 }
275 }
276
277 (b, oc, channel_result)
278 })
279 .collect();
280
281 if let Some(result_slice) = result.data.as_slice_mut() {
283 for (b, oc, channel_result) in results {
284 let start_idx =
285 b * out_channels * out_height * out_width + oc * out_height * out_width;
286
287 for (i, &val) in channel_result.iter().enumerate() {
288 if let Some(dest) = result_slice.get_mut(start_idx + i) {
289 *dest = val;
290 }
291 }
292 }
293 }
294
295 Ok(result)
296 }
297
298 pub fn batch_sum_parallel(&self, dim: usize) -> ParallelResult<Tensor<T>> {
301 let shape = self.data.shape();
302 if dim >= shape.len() {
303 return Err(RusTorchError::parallel("Dimension error"));
304 }
305
306 let mut result_shape = shape.to_vec();
307 result_shape.remove(dim);
308
309 if result_shape.is_empty() {
310 if let Some(slice) = self.data.as_slice() {
312 let sum = slice
313 .par_iter()
314 .fold(|| T::zero(), |acc, &x| acc + x)
315 .reduce(|| T::zero(), |a, b| a + b);
316 return Ok(Tensor::from_vec(vec![sum], vec![]));
317 }
318 }
319
320 let mut result = Self::zeros(&result_shape);
321
322 let _total_elements = shape.iter().product::<usize>();
324 let dim_size = shape[dim];
325 let _stride_before: usize = shape[..dim].iter().product();
326 let stride_after: usize = shape[dim + 1..].iter().product();
327
328 if let Some(self_slice) = self.data.as_slice() {
329 let result_elements = result_shape.iter().product::<usize>();
330
331 let computed_results: Vec<_> = (0..result_elements)
332 .into_par_iter()
333 .map(|result_idx| {
334 let before_idx = result_idx / stride_after;
335 let after_idx = result_idx % stride_after;
336
337 let mut sum = T::zero();
338 for d in 0..dim_size {
339 let source_idx =
340 before_idx * dim_size * stride_after + d * stride_after + after_idx;
341 if let Some(&val) = self_slice.get(source_idx) {
342 sum = sum + val;
343 }
344 }
345 (result_idx, sum)
346 })
347 .collect();
348
349 if let Some(result_slice) = result.data.as_slice_mut() {
351 for (idx, val) in computed_results {
352 if let Some(dest) = result_slice.get_mut(idx) {
353 *dest = val;
354 }
355 }
356 }
357 }
358
359 Ok(result)
360 }
361
362 pub fn batch_mean_parallel(&self, dim: usize) -> ParallelResult<Tensor<T>> {
365 let shape = self.data.shape();
366 if dim >= shape.len() {
367 return Err(RusTorchError::parallel("Dimension error"));
368 }
369
370 let sum_result = self.batch_sum_parallel(dim)?;
371 let dim_size = T::from(shape[dim]).unwrap();
372
373 Ok(sum_result.batch_mul_scalar_parallel(T::one() / dim_size))
374 }
375}
376
377impl Tensor<f32> {
380 pub fn batch_simd_add_parallel(&self, other: &Tensor<f32>) -> ParallelResult<Tensor<f32>> {
383 if self.data.shape() != other.data.shape() {
384 return Err(RusTorchError::parallel("Shape mismatch"));
385 }
386
387 let mut result = Self::zeros(self.data.shape());
388
389 if let (Some(self_slice), Some(other_slice), Some(result_slice)) = (
390 self.data.as_slice(),
391 other.data.as_slice(),
392 result.data.as_slice_mut(),
393 ) {
394 const CHUNK_SIZE: usize = 1024;
396
397 self_slice
398 .par_chunks(CHUNK_SIZE)
399 .zip(other_slice.par_chunks(CHUNK_SIZE))
400 .zip(result_slice.par_chunks_mut(CHUNK_SIZE))
401 .for_each(|((a_chunk, b_chunk), r_chunk)| {
402 #[cfg(not(target_arch = "wasm32"))]
404 {
405 crate::simd::ops::add_optimized(a_chunk, b_chunk, r_chunk);
406 }
407 #[cfg(target_arch = "wasm32")]
408 {
409 for ((a_elem, b_elem), r_elem) in
411 a_chunk.iter().zip(b_chunk.iter()).zip(r_chunk.iter_mut())
412 {
413 *r_elem = *a_elem + *b_elem;
414 }
415 }
416 });
417 }
418
419 Ok(result)
420 }
421
422 pub fn batch_simd_matmul_parallel(&self, other: &Tensor<f32>) -> ParallelResult<Tensor<f32>> {
425 let self_shape = self.data.shape();
426 let other_shape = other.data.shape();
427
428 if self_shape.len() < 3 || other_shape.len() < 3 {
429 return Err(RusTorchError::parallel("Insufficient dimensions"));
430 }
431
432 let batch_size = self_shape[0];
433 let m = self_shape[1];
434 let k = self_shape[2];
435 let n = other_shape[2];
436
437 let result_shape = vec![batch_size, m, n];
438 let mut result = Self::zeros(&result_shape);
439
440 if let (Some(self_slice), Some(other_slice)) = (self.data.as_slice(), other.data.as_slice())
442 {
443 let batch_results: Vec<_> = (0..batch_size)
444 .into_par_iter()
445 .map(|b| {
446 let self_batch = &self_slice[b * m * k..(b + 1) * m * k];
447 let other_batch = &other_slice[b * k * n..(b + 1) * k * n];
448
449 let mut batch_result = vec![0.0f32; m * n];
451
452 #[cfg(not(target_arch = "wasm32"))]
454 {
455 crate::simd::ops::matmul_optimized(
456 self_batch,
457 m,
458 k,
459 other_batch,
460 k,
461 n,
462 &mut batch_result,
463 );
464 }
465 #[cfg(target_arch = "wasm32")]
466 {
467 for i in 0..m {
469 for j in 0..n {
470 let mut sum = 0.0f32;
471 for p in 0..k {
472 sum += self_batch[i * k + p] * other_batch[p * n + j];
473 }
474 batch_result[i * n + j] = sum;
475 }
476 }
477 }
478
479 batch_result
480 })
481 .collect();
482
483 if let Some(result_slice) = result.data.as_slice_mut() {
485 for (b, batch_result) in batch_results.iter().enumerate() {
486 let start_idx = b * m * n;
487 for (i, &val) in batch_result.iter().enumerate() {
488 if let Some(dest) = result_slice.get_mut(start_idx + i) {
489 *dest = val;
490 }
491 }
492 }
493 }
494 }
495
496 Ok(result)
497 }
498}
499
500#[cfg(test)]
501mod tests {
502 use super::*;
503
504 #[test]
505 fn test_batch_add_parallel() {
506 let a =
507 Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], vec![2, 2, 2]);
508 let b =
509 Tensor::<f32>::from_vec(vec![1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], vec![2, 2, 2]);
510
511 let result = a.batch_add_parallel(&b).unwrap();
512 let expected = vec![2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
513
514 assert_eq!(result.data.as_slice().unwrap(), &expected);
515 }
516
517 #[test]
518 fn test_batch_matmul_parallel() {
519 let a =
520 Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], vec![2, 2, 2]);
521 let b =
522 Tensor::<f32>::from_vec(vec![1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0], vec![2, 2, 2]);
523
524 let result = a.batch_matmul_parallel(&b).unwrap();
525
526 assert_eq!(result.size(), vec![2, 2, 2]);
528 assert_eq!(result.data.as_slice().unwrap(), a.data.as_slice().unwrap());
529 }
530
531 #[test]
532 fn test_batch_normalize_parallel() {
533 let a = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], vec![2, 4]);
534
535 let result = a.batch_normalize_parallel(1e-5);
536
537 assert_eq!(result.size(), vec![2, 4]);
539
540 if let Some(slice) = result.data.as_slice() {
541 let batch1_mean: f32 = slice[0..4].iter().sum::<f32>() / 4.0;
543 assert!((batch1_mean).abs() < 1e-5);
544
545 let batch2_mean: f32 = slice[4..8].iter().sum::<f32>() / 4.0;
547 assert!((batch2_mean).abs() < 1e-5);
548 }
549 }
550
551 #[test]
552 fn test_batch_sum_parallel() {
553 let a = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]);
554
555 let result = a.batch_sum_parallel(1).unwrap();
557 assert_eq!(result.size(), vec![2]);
558
559 let expected = vec![6.0, 15.0]; assert_eq!(result.data.as_slice().unwrap(), &expected);
561 }
562
563 #[test]
564 fn test_batch_simd_add_parallel() {
565 let size = 1000;
566 let a = Tensor::<f32>::from_vec((0..size).map(|i| i as f32).collect(), vec![10, 100]);
567 let b = Tensor::<f32>::from_vec(vec![1.0; size], vec![10, 100]);
568
569 let result = a.batch_simd_add_parallel(&b).unwrap();
570
571 if let Some(slice) = result.data.as_slice() {
572 for (i, &val) in slice.iter().enumerate() {
573 assert_eq!(val, i as f32 + 1.0);
574 }
575 }
576 }
577
578 #[test]
579 fn test_large_batch_performance() {
580 let batch_size = 100;
581 let feature_size = 1000;
582
583 let a = Tensor::<f32>::from_vec(
584 (0..batch_size * feature_size)
585 .map(|i| (i % 100) as f32)
586 .collect(),
587 vec![batch_size, feature_size],
588 );
589 let b = Tensor::<f32>::from_vec(
590 vec![0.1; batch_size * feature_size],
591 vec![batch_size, feature_size],
592 );
593
594 let result = a.batch_add_parallel(&b).unwrap();
595 assert_eq!(result.size(), vec![batch_size, feature_size]);
596
597 if let (Some(a_slice), Some(b_slice), Some(result_slice)) =
599 (a.data.as_slice(), b.data.as_slice(), result.data.as_slice())
600 {
601 for i in 0..batch_size * feature_size {
602 assert!((result_slice[i] - (a_slice[i] + b_slice[i])).abs() < 1e-6);
603 }
604 }
605 }
606}