1use crate::error::{LinalgError, LinalgResult};
7use scirs2_core::ndarray::{Array, ArrayBase, Data, Dimension, Ix3, IxDyn};
8use scirs2_core::numeric::{Float, NumAssign};
9use std::fmt::Debug;
10use std::iter::Sum;
11
12pub trait BroadcastExt<A> {
14 fn broadcast_compatible<D2>(&self, other: &ArrayBase<D2, impl Dimension>) -> bool
16 where
17 D2: Data<Elem = A>;
18
19 fn broadcastshape<D2>(&self, other: &ArrayBase<D2, impl Dimension>) -> Option<Vec<usize>>
21 where
22 D2: Data<Elem = A>;
23}
24
25impl<A, S, D> BroadcastExt<A> for ArrayBase<S, D>
26where
27 S: Data<Elem = A>,
28 D: Dimension,
29{
30 fn broadcast_compatible<D2>(&self, other: &ArrayBase<D2, impl Dimension>) -> bool
31 where
32 D2: Data<Elem = A>,
33 {
34 let shape1 = self.shape();
35 let shape2 = other.shape();
36 let ndim1 = shape1.len();
37 let ndim2 = shape2.len();
38
39 let mut i = ndim1;
41 let mut j = ndim2;
42
43 while i > 0 && j > 0 {
44 i -= 1;
45 j -= 1;
46
47 let dim1 = shape1[i];
48 let dim2 = shape2[j];
49
50 if dim1 != dim2 && dim1 != 1 && dim2 != 1 {
52 return false;
53 }
54 }
55
56 true
57 }
58
59 fn broadcastshape<D2>(&self, other: &ArrayBase<D2, impl Dimension>) -> Option<Vec<usize>>
60 where
61 D2: Data<Elem = A>,
62 {
63 if !self.broadcast_compatible(other) {
64 return None;
65 }
66
67 let shape1 = self.shape();
68 let shape2 = other.shape();
69 let ndim1 = shape1.len();
70 let ndim2 = shape2.len();
71 let max_ndim = ndim1.max(ndim2);
72
73 let mut broadcastshape = vec![0; max_ndim];
74
75 let mut i = ndim1;
77 let mut j = ndim2;
78 let mut k = max_ndim;
79
80 while k > 0 {
81 k -= 1;
82
83 let dim1 = if i > 0 {
84 i -= 1;
85 shape1[i]
86 } else {
87 1
88 };
89
90 let dim2 = if j > 0 {
91 j -= 1;
92 shape2[j]
93 } else {
94 1
95 };
96
97 broadcastshape[k] = dim1.max(dim2);
98 }
99
100 Some(broadcastshape)
101 }
102}
103
104#[allow(dead_code)]
110pub fn broadcast_matmul_3d<A>(
111 a: &ArrayBase<impl Data<Elem = A>, Ix3>,
112 b: &ArrayBase<impl Data<Elem = A>, Ix3>,
113) -> LinalgResult<Array<A, Ix3>>
114where
115 A: Float + NumAssign + Sum + Debug + 'static,
116{
117 let ashape = a.shape();
118 let bshape = b.shape();
119
120 let a_cols = ashape[2];
122 let b_rows = bshape[1];
123
124 if a_cols != b_rows {
125 return Err(LinalgError::DimensionError(format!(
126 "Matrix dimensions don't match for multiplication: ({}, {}) x ({}, {})",
127 ashape[1], a_cols, b_rows, bshape[2]
128 )));
129 }
130
131 let batchsize = ashape[0].max(bshape[0]);
133
134 if ashape[0] != bshape[0] && ashape[0] != 1 && bshape[0] != 1 {
136 return Err(LinalgError::DimensionError(
137 "Batch dimensions must be compatible for broadcasting".to_string(),
138 ));
139 }
140
141 let a_rows = ashape[1];
143 let b_cols = bshape[2];
144 let outputshape = [batchsize, a_rows, b_cols];
145
146 let mut output = Array::zeros(outputshape);
148
149 for i in 0..batchsize {
151 let a_idx = if ashape[0] == 1 { 0 } else { i };
152 let b_idx = if bshape[0] == 1 { 0 } else { i };
153
154 let a_mat = a.index_axis(scirs2_core::ndarray::Axis(0), a_idx);
155 let b_mat = b.index_axis(scirs2_core::ndarray::Axis(0), b_idx);
156 let mut out_mat = output.index_axis_mut(scirs2_core::ndarray::Axis(0), i);
157
158 scirs2_core::ndarray::linalg::general_mat_mul(
160 A::one(),
161 &a_mat,
162 &b_mat,
163 A::one(),
164 &mut out_mat,
165 );
166 }
167
168 Ok(output)
169}
170
171#[allow(dead_code)]
177pub fn broadcast_matmul<A>(
178 a: &ArrayBase<impl Data<Elem = A>, IxDyn>,
179 b: &ArrayBase<impl Data<Elem = A>, IxDyn>,
180) -> LinalgResult<Array<A, IxDyn>>
181where
182 A: Float + NumAssign + Sum + Debug + 'static,
183{
184 if a.ndim() < 2 || b.ndim() < 2 {
186 return Err(LinalgError::DimensionError(
187 "Arrays must have at least 2 dimensions for matrix multiplication".to_string(),
188 ));
189 }
190
191 let ashape = a.shape();
192 let bshape = b.shape();
193
194 let a_cols = ashape[ashape.len() - 1];
196 let b_rows = bshape[bshape.len() - 2];
197
198 if a_cols != b_rows {
199 return Err(LinalgError::DimensionError(format!(
200 "Matrix dimensions don't match for multiplication: (..., {a_cols}) x ({b_rows}, ...)"
201 )));
202 }
203
204 let a_batchshape = &ashape[..ashape.len() - 2];
206 let b_batchshape = &bshape[..bshape.len() - 2];
207
208 let batchshape = if a_batchshape == b_batchshape {
210 a_batchshape.to_vec()
211 } else {
212 return Err(LinalgError::DimensionError(
214 "Batch dimensions must match exactly (full broadcasting not yet implemented)"
215 .to_string(),
216 ));
217 };
218
219 let a_rows = ashape[ashape.len() - 2];
221 let b_cols = bshape[bshape.len() - 1];
222 let mut outputshape = batchshape;
223 outputshape.push(a_rows);
224 outputshape.push(b_cols);
225
226 let mut output = Array::zeros(IxDyn(&outputshape));
228
229 let n_batch = output.len() / (a_rows * b_cols);
231
232 for i in 0..n_batch {
235 let mut a_slice = Array2::zeros((a_rows, a_cols));
237 let mut b_slice = Array2::zeros((b_rows, b_cols));
238 let mut out_slice = Array2::zeros((a_rows, b_cols));
239
240 let a_start = i * a_rows * a_cols;
242 let b_start = i * b_rows * b_cols;
243 let out_start = i * a_rows * b_cols;
244
245 for r in 0..a_rows {
246 for c in 0..a_cols {
247 let flat_idx = a_start + r * a_cols + c;
248 let nd_idx: Vec<usize> = {
249 let mut idx = vec![0; a.ndim()];
250 let mut remaining = flat_idx;
251 for dim in (0..a.ndim()).rev() {
252 idx[dim] = remaining % ashape[dim];
253 remaining /= ashape[dim];
254 }
255 idx
256 };
257 a_slice[[r, c]] = a[nd_idx.as_slice()];
258 }
259 }
260
261 for r in 0..b_rows {
262 for c in 0..b_cols {
263 let flat_idx = b_start + r * b_cols + c;
264 let nd_idx: Vec<usize> = {
265 let mut idx = vec![0; b.ndim()];
266 let mut remaining = flat_idx;
267 for dim in (0..b.ndim()).rev() {
268 idx[dim] = remaining % bshape[dim];
269 remaining /= bshape[dim];
270 }
271 idx
272 };
273 b_slice[[r, c]] = b[nd_idx.as_slice()];
274 }
275 }
276
277 scirs2_core::ndarray::linalg::general_mat_mul(
279 A::one(),
280 &a_slice.view(),
281 &b_slice.view(),
282 A::one(),
283 &mut out_slice,
284 );
285
286 for r in 0..a_rows {
288 for c in 0..b_cols {
289 let flat_idx = out_start + r * b_cols + c;
290 let nd_idx: Vec<usize> = {
291 let mut idx = vec![0; output.ndim()];
292 let mut remaining = flat_idx;
293 for dim in (0..output.ndim()).rev() {
294 idx[dim] = remaining % outputshape[dim];
295 remaining /= outputshape[dim];
296 }
297 idx
298 };
299 output[nd_idx.as_slice()] = out_slice[[r, c]];
300 }
301 }
302 }
303
304 Ok(output)
305}
306
307#[allow(dead_code)]
309pub fn broadcast_matvec<A>(
310 a: &ArrayBase<impl Data<Elem = A>, IxDyn>,
311 x: &ArrayBase<impl Data<Elem = A>, IxDyn>,
312) -> LinalgResult<Array<A, IxDyn>>
313where
314 A: Float + NumAssign + Sum + Debug + 'static,
315{
316 if a.ndim() < 2 || x.ndim() < 1 {
318 return Err(LinalgError::DimensionError(
319 "Matrix must have at least 2 dimensions and vector at least 1".to_string(),
320 ));
321 }
322
323 let ashape = a.shape();
324 let xshape = x.shape();
325
326 let a_cols = ashape[ashape.len() - 1];
328 let x_len = xshape[xshape.len() - 1];
329
330 if a_cols != x_len {
331 return Err(LinalgError::DimensionError(format!(
332 "Matrix and vector dimensions don't match: (..., {a_cols}) x ({x_len})"
333 )));
334 }
335
336 let a_batchshape = &ashape[..ashape.len() - 2];
338 let x_batchshape = &xshape[..xshape.len() - 1];
339
340 let batchshape = if a_batchshape == x_batchshape {
342 a_batchshape.to_vec()
343 } else {
344 return Err(LinalgError::DimensionError(
346 "Batch dimensions must match exactly (full broadcasting not yet implemented)"
347 .to_string(),
348 ));
349 };
350
351 let a_rows = ashape[ashape.len() - 2];
353 let mut outputshape = batchshape;
354 outputshape.push(a_rows);
355
356 let mut output = Array::zeros(IxDyn(&outputshape));
358
359 let n_batch = output.len() / a_rows;
361
362 for i in 0..n_batch {
364 let mut a_slice = Array2::zeros((a_rows, a_cols));
366 let mut x_slice = Array1::zeros(x_len);
367 let mut y_slice = Array1::zeros(a_rows);
368
369 let a_start = i * a_rows * a_cols;
371 let x_start = i * x_len;
372 let y_start = i * a_rows;
373
374 for r in 0..a_rows {
375 for c in 0..a_cols {
376 let flat_idx = a_start + r * a_cols + c;
377 let nd_idx: Vec<usize> = {
378 let mut idx = vec![0; a.ndim()];
379 let mut remaining = flat_idx;
380 for dim in (0..a.ndim()).rev() {
381 idx[dim] = remaining % ashape[dim];
382 remaining /= ashape[dim];
383 }
384 idx
385 };
386 a_slice[[r, c]] = a[nd_idx.as_slice()];
387 }
388 }
389
390 for j in 0..x_len {
391 let flat_idx = x_start + j;
392 let nd_idx: Vec<usize> = {
393 let mut idx = vec![0; x.ndim()];
394 let mut remaining = flat_idx;
395 for dim in (0..x.ndim()).rev() {
396 idx[dim] = remaining % xshape[dim];
397 remaining /= xshape[dim];
398 }
399 idx
400 };
401 x_slice[j] = x[nd_idx.as_slice()];
402 }
403
404 scirs2_core::ndarray::linalg::general_mat_vec_mul(
406 A::one(),
407 &a_slice.view(),
408 &x_slice.view(),
409 A::one(),
410 &mut y_slice,
411 );
412
413 for j in 0..a_rows {
415 let flat_idx = y_start + j;
416 let nd_idx: Vec<usize> = {
417 let mut idx = vec![0; output.ndim()];
418 let mut remaining = flat_idx;
419 for dim in (0..output.ndim()).rev() {
420 idx[dim] = remaining % outputshape[dim];
421 remaining /= outputshape[dim];
422 }
423 idx
424 };
425 output[nd_idx.as_slice()] = y_slice[j];
426 }
427 }
428
429 Ok(output)
430}
431
432use scirs2_core::ndarray::{Array1, Array2};
433
434#[cfg(test)]
435mod tests {
436 use super::*;
437 use scirs2_core::ndarray::array;
438
439 #[test]
440 fn test_broadcast_compatible() {
441 let a = array![[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]];
442 let b = array![[[1.0, 2.0], [3.0, 4.0]]];
443
444 assert!(a.broadcast_compatible(&b));
445
446 let c = array![[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]];
447 assert!(!a.broadcast_compatible(&c));
448 }
449
450 #[test]
451 fn test_broadcastshape() {
452 let a = array![[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]];
453 let b = array![[[1.0, 2.0], [3.0, 4.0]]];
454
455 let shape = a.broadcastshape(&b).unwrap();
456 assert_eq!(shape, vec![2, 2, 2]);
457 }
458
459 #[test]
460 fn test_broadcast_matmul_3d() {
461 let a = array![[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]];
463 let b = array![[[1.0, 0.0], [0.0, 1.0]], [[2.0, 0.0], [0.0, 2.0]]];
464
465 let c = broadcast_matmul_3d(&a, &b).unwrap();
466
467 assert_eq!(c[[0, 0, 0]], 1.0);
469 assert_eq!(c[[0, 0, 1]], 2.0);
470 assert_eq!(c[[0, 1, 0]], 3.0);
471 assert_eq!(c[[0, 1, 1]], 4.0);
472
473 assert_eq!(c[[1, 0, 0]], 10.0);
475 assert_eq!(c[[1, 0, 1]], 12.0);
476 assert_eq!(c[[1, 1, 0]], 14.0);
477 assert_eq!(c[[1, 1, 1]], 16.0);
478 }
479
480 #[test]
481 fn test_broadcast_matmul_dyn() {
482 let a = array![[[1.0_f64, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]].into_dyn();
484 let b = array![[[1.0, 0.0], [0.0, 1.0]], [[2.0, 0.0], [0.0, 2.0]]].into_dyn();
485
486 let c = broadcast_matmul(&a, &b).unwrap();
487
488 assert_eq!(c[[0, 0, 0]], 1.0);
490 assert_eq!(c[[0, 0, 1]], 2.0);
491 assert_eq!(c[[0, 1, 0]], 3.0);
492 assert_eq!(c[[0, 1, 1]], 4.0);
493
494 assert_eq!(c[[1, 0, 0]], 10.0);
496 assert_eq!(c[[1, 0, 1]], 12.0);
497 assert_eq!(c[[1, 1, 0]], 14.0);
498 assert_eq!(c[[1, 1, 1]], 16.0);
499 }
500
501 #[test]
502 fn test_broadcast_matvec_dyn() {
503 let a = array![[[1.0_f64, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]].into_dyn();
505 let x = array![[1.0, 1.0], [2.0, 1.0]].into_dyn();
506
507 let y = broadcast_matvec(&a, &x).unwrap();
508
509 assert_eq!(y[[0, 0]], 3.0);
511 assert_eq!(y[[0, 1]], 7.0);
512
513 assert_eq!(y[[1, 0]], 16.0);
515 assert_eq!(y[[1, 1]], 22.0);
516 }
517
518 #[test]
519 fn test_incompatible_dimensions() {
520 let a = array![[[1.0_f64, 2.0], [3.0, 4.0]]].into_dyn();
522 let b = array![[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]].into_dyn();
523
524 let result = broadcast_matmul(&a, &b);
525 assert!(result.is_err());
526 }
527
528 #[test]
529 fn test_broadcast_3d_with_different_batch() {
530 let a = array![[[1.0_f64, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]];
532 let b = array![[[1.0, 0.0], [0.0, 1.0]]];
533
534 let c = broadcast_matmul_3d(&a, &b).unwrap();
535
536 assert_eq!(c[[0, 0, 0]], 1.0);
538 assert_eq!(c[[0, 0, 1]], 2.0);
539 assert_eq!(c[[1, 0, 0]], 5.0);
540 assert_eq!(c[[1, 0, 1]], 6.0);
541 }
542}