scirs2_linalg/
broadcast.rs

1//! NumPy-style broadcasting for linear algebra operations on higher-dimensional arrays
2//!
3//! This module provides broadcasting support for operations on arrays with
4//! more than 2 dimensions, following NumPy's broadcasting rules.
5
6use crate::error::{LinalgError, LinalgResult};
7use scirs2_core::ndarray::{Array, ArrayBase, Data, Dimension, Ix3, IxDyn};
8use scirs2_core::numeric::{Float, NumAssign};
9use std::fmt::Debug;
10use std::iter::Sum;
11
12/// Trait for broadcasting support
13pub trait BroadcastExt<A> {
14    /// Check if two arrays are compatible for broadcasting
15    fn broadcast_compatible<D2>(&self, other: &ArrayBase<D2, impl Dimension>) -> bool
16    where
17        D2: Data<Elem = A>;
18
19    /// Get the shape after broadcasting
20    fn broadcastshape<D2>(&self, other: &ArrayBase<D2, impl Dimension>) -> Option<Vec<usize>>
21    where
22        D2: Data<Elem = A>;
23}
24
25impl<A, S, D> BroadcastExt<A> for ArrayBase<S, D>
26where
27    S: Data<Elem = A>,
28    D: Dimension,
29{
30    fn broadcast_compatible<D2>(&self, other: &ArrayBase<D2, impl Dimension>) -> bool
31    where
32        D2: Data<Elem = A>,
33    {
34        let shape1 = self.shape();
35        let shape2 = other.shape();
36        let ndim1 = shape1.len();
37        let ndim2 = shape2.len();
38
39        // Start from the trailing dimensions
40        let mut i = ndim1;
41        let mut j = ndim2;
42
43        while i > 0 && j > 0 {
44            i -= 1;
45            j -= 1;
46
47            let dim1 = shape1[i];
48            let dim2 = shape2[j];
49
50            // Dimensions are compatible if they are equal or one of them is 1
51            if dim1 != dim2 && dim1 != 1 && dim2 != 1 {
52                return false;
53            }
54        }
55
56        true
57    }
58
59    fn broadcastshape<D2>(&self, other: &ArrayBase<D2, impl Dimension>) -> Option<Vec<usize>>
60    where
61        D2: Data<Elem = A>,
62    {
63        if !self.broadcast_compatible(other) {
64            return None;
65        }
66
67        let shape1 = self.shape();
68        let shape2 = other.shape();
69        let ndim1 = shape1.len();
70        let ndim2 = shape2.len();
71        let max_ndim = ndim1.max(ndim2);
72
73        let mut broadcastshape = vec![0; max_ndim];
74
75        // Fill from the trailing dimensions
76        let mut i = ndim1;
77        let mut j = ndim2;
78        let mut k = max_ndim;
79
80        while k > 0 {
81            k -= 1;
82
83            let dim1 = if i > 0 {
84                i -= 1;
85                shape1[i]
86            } else {
87                1
88            };
89
90            let dim2 = if j > 0 {
91                j -= 1;
92                shape2[j]
93            } else {
94                1
95            };
96
97            broadcastshape[k] = dim1.max(dim2);
98        }
99
100        Some(broadcastshape)
101    }
102}
103
104/// Broadcasting matrix multiplication for 3D arrays
105///
106/// This function implements NumPy-style broadcasting for matrix multiplication
107/// on 3D arrays. The last two dimensions are treated as matrices, and the
108/// first dimension is broadcast.
109#[allow(dead_code)]
110pub fn broadcast_matmul_3d<A>(
111    a: &ArrayBase<impl Data<Elem = A>, Ix3>,
112    b: &ArrayBase<impl Data<Elem = A>, Ix3>,
113) -> LinalgResult<Array<A, Ix3>>
114where
115    A: Float + NumAssign + Sum + Debug + 'static,
116{
117    let ashape = a.shape();
118    let bshape = b.shape();
119
120    // Check matrix dimensions are compatible
121    let a_cols = ashape[2];
122    let b_rows = bshape[1];
123
124    if a_cols != b_rows {
125        return Err(LinalgError::DimensionError(format!(
126            "Matrix dimensions don't match for multiplication: ({}, {}) x ({}, {})",
127            ashape[1], a_cols, b_rows, bshape[2]
128        )));
129    }
130
131    // Get the batch dimension
132    let batchsize = ashape[0].max(bshape[0]);
133
134    // Check if batch dimensions can be broadcast
135    if ashape[0] != bshape[0] && ashape[0] != 1 && bshape[0] != 1 {
136        return Err(LinalgError::DimensionError(
137            "Batch dimensions must be compatible for broadcasting".to_string(),
138        ));
139    }
140
141    // Compute output shape
142    let a_rows = ashape[1];
143    let b_cols = bshape[2];
144    let outputshape = [batchsize, a_rows, b_cols];
145
146    // Create output array
147    let mut output = Array::zeros(outputshape);
148
149    // Perform batched matrix multiplication
150    for i in 0..batchsize {
151        let a_idx = if ashape[0] == 1 { 0 } else { i };
152        let b_idx = if bshape[0] == 1 { 0 } else { i };
153
154        let a_mat = a.index_axis(scirs2_core::ndarray::Axis(0), a_idx);
155        let b_mat = b.index_axis(scirs2_core::ndarray::Axis(0), b_idx);
156        let mut out_mat = output.index_axis_mut(scirs2_core::ndarray::Axis(0), i);
157
158        // Standard matrix multiplication for this batch
159        scirs2_core::ndarray::linalg::general_mat_mul(
160            A::one(),
161            &a_mat,
162            &b_mat,
163            A::one(),
164            &mut out_mat,
165        );
166    }
167
168    Ok(output)
169}
170
171/// Broadcasting matrix multiplication for dynamic dimensional arrays
172///
173/// This function implements NumPy-style broadcasting for matrix multiplication
174/// on arrays with arbitrary dimensions. The last two dimensions are treated
175/// as matrices, and the leading dimensions are broadcast together.
176#[allow(dead_code)]
177pub fn broadcast_matmul<A>(
178    a: &ArrayBase<impl Data<Elem = A>, IxDyn>,
179    b: &ArrayBase<impl Data<Elem = A>, IxDyn>,
180) -> LinalgResult<Array<A, IxDyn>>
181where
182    A: Float + NumAssign + Sum + Debug + 'static,
183{
184    // Check that arrays have at least 2 dimensions
185    if a.ndim() < 2 || b.ndim() < 2 {
186        return Err(LinalgError::DimensionError(
187            "Arrays must have at least 2 dimensions for matrix multiplication".to_string(),
188        ));
189    }
190
191    let ashape = a.shape();
192    let bshape = b.shape();
193
194    // Check matrix dimensions are compatible
195    let a_cols = ashape[ashape.len() - 1];
196    let b_rows = bshape[bshape.len() - 2];
197
198    if a_cols != b_rows {
199        return Err(LinalgError::DimensionError(format!(
200            "Matrix dimensions don't match for multiplication: (..., {a_cols}) x ({b_rows}, ...)"
201        )));
202    }
203
204    // Get the batch dimensions (all but the last 2)
205    let a_batchshape = &ashape[..ashape.len() - 2];
206    let b_batchshape = &bshape[..bshape.len() - 2];
207
208    // Check if batch dimensions can be broadcast
209    let batchshape = if a_batchshape == b_batchshape {
210        a_batchshape.to_vec()
211    } else {
212        // For now, we don't support full broadcasting - require exact match
213        return Err(LinalgError::DimensionError(
214            "Batch dimensions must match exactly (full broadcasting not yet implemented)"
215                .to_string(),
216        ));
217    };
218
219    // Compute output shape
220    let a_rows = ashape[ashape.len() - 2];
221    let b_cols = bshape[bshape.len() - 1];
222    let mut outputshape = batchshape;
223    outputshape.push(a_rows);
224    outputshape.push(b_cols);
225
226    // Create output array
227    let mut output = Array::zeros(IxDyn(&outputshape));
228
229    // Extract the matrix dimensions
230    let n_batch = output.len() / (a_rows * b_cols);
231
232    // Perform batched matrix multiplication
233    // Need to reshape in steps to avoid borrowing issues
234    for i in 0..n_batch {
235        // Extract 2D slices for this batch
236        let mut a_slice = Array2::zeros((a_rows, a_cols));
237        let mut b_slice = Array2::zeros((b_rows, b_cols));
238        let mut out_slice = Array2::zeros((a_rows, b_cols));
239
240        // Copy data into slices
241        let a_start = i * a_rows * a_cols;
242        let b_start = i * b_rows * b_cols;
243        let out_start = i * a_rows * b_cols;
244
245        for r in 0..a_rows {
246            for c in 0..a_cols {
247                let flat_idx = a_start + r * a_cols + c;
248                let nd_idx: Vec<usize> = {
249                    let mut idx = vec![0; a.ndim()];
250                    let mut remaining = flat_idx;
251                    for dim in (0..a.ndim()).rev() {
252                        idx[dim] = remaining % ashape[dim];
253                        remaining /= ashape[dim];
254                    }
255                    idx
256                };
257                a_slice[[r, c]] = a[nd_idx.as_slice()];
258            }
259        }
260
261        for r in 0..b_rows {
262            for c in 0..b_cols {
263                let flat_idx = b_start + r * b_cols + c;
264                let nd_idx: Vec<usize> = {
265                    let mut idx = vec![0; b.ndim()];
266                    let mut remaining = flat_idx;
267                    for dim in (0..b.ndim()).rev() {
268                        idx[dim] = remaining % bshape[dim];
269                        remaining /= bshape[dim];
270                    }
271                    idx
272                };
273                b_slice[[r, c]] = b[nd_idx.as_slice()];
274            }
275        }
276
277        // Perform matrix multiplication
278        scirs2_core::ndarray::linalg::general_mat_mul(
279            A::one(),
280            &a_slice.view(),
281            &b_slice.view(),
282            A::one(),
283            &mut out_slice,
284        );
285
286        // Copy result back
287        for r in 0..a_rows {
288            for c in 0..b_cols {
289                let flat_idx = out_start + r * b_cols + c;
290                let nd_idx: Vec<usize> = {
291                    let mut idx = vec![0; output.ndim()];
292                    let mut remaining = flat_idx;
293                    for dim in (0..output.ndim()).rev() {
294                        idx[dim] = remaining % outputshape[dim];
295                        remaining /= outputshape[dim];
296                    }
297                    idx
298                };
299                output[nd_idx.as_slice()] = out_slice[[r, c]];
300            }
301        }
302    }
303
304    Ok(output)
305}
306
307/// Broadcasting matrix-vector multiplication for dynamic dimensional arrays
308#[allow(dead_code)]
309pub fn broadcast_matvec<A>(
310    a: &ArrayBase<impl Data<Elem = A>, IxDyn>,
311    x: &ArrayBase<impl Data<Elem = A>, IxDyn>,
312) -> LinalgResult<Array<A, IxDyn>>
313where
314    A: Float + NumAssign + Sum + Debug + 'static,
315{
316    // Check that matrix has at least 2 dimensions and vector has at least 1
317    if a.ndim() < 2 || x.ndim() < 1 {
318        return Err(LinalgError::DimensionError(
319            "Matrix must have at least 2 dimensions and vector at least 1".to_string(),
320        ));
321    }
322
323    let ashape = a.shape();
324    let xshape = x.shape();
325
326    // Check dimensions are compatible
327    let a_cols = ashape[ashape.len() - 1];
328    let x_len = xshape[xshape.len() - 1];
329
330    if a_cols != x_len {
331        return Err(LinalgError::DimensionError(format!(
332            "Matrix and vector dimensions don't match: (..., {a_cols}) x ({x_len})"
333        )));
334    }
335
336    // Get the batch dimensions
337    let a_batchshape = &ashape[..ashape.len() - 2];
338    let x_batchshape = &xshape[..xshape.len() - 1];
339
340    // Check if batch dimensions can be broadcast
341    let batchshape = if a_batchshape == x_batchshape {
342        a_batchshape.to_vec()
343    } else {
344        // For now, we don't support full broadcasting
345        return Err(LinalgError::DimensionError(
346            "Batch dimensions must match exactly (full broadcasting not yet implemented)"
347                .to_string(),
348        ));
349    };
350
351    // Compute output shape
352    let a_rows = ashape[ashape.len() - 2];
353    let mut outputshape = batchshape;
354    outputshape.push(a_rows);
355
356    // Create output array
357    let mut output = Array::zeros(IxDyn(&outputshape));
358
359    // Extract dimensions
360    let n_batch = output.len() / a_rows;
361
362    // Perform batched matrix-vector multiplication
363    for i in 0..n_batch {
364        // Extract slices for this batch
365        let mut a_slice = Array2::zeros((a_rows, a_cols));
366        let mut x_slice = Array1::zeros(x_len);
367        let mut y_slice = Array1::zeros(a_rows);
368
369        // Copy data into slices
370        let a_start = i * a_rows * a_cols;
371        let x_start = i * x_len;
372        let y_start = i * a_rows;
373
374        for r in 0..a_rows {
375            for c in 0..a_cols {
376                let flat_idx = a_start + r * a_cols + c;
377                let nd_idx: Vec<usize> = {
378                    let mut idx = vec![0; a.ndim()];
379                    let mut remaining = flat_idx;
380                    for dim in (0..a.ndim()).rev() {
381                        idx[dim] = remaining % ashape[dim];
382                        remaining /= ashape[dim];
383                    }
384                    idx
385                };
386                a_slice[[r, c]] = a[nd_idx.as_slice()];
387            }
388        }
389
390        for j in 0..x_len {
391            let flat_idx = x_start + j;
392            let nd_idx: Vec<usize> = {
393                let mut idx = vec![0; x.ndim()];
394                let mut remaining = flat_idx;
395                for dim in (0..x.ndim()).rev() {
396                    idx[dim] = remaining % xshape[dim];
397                    remaining /= xshape[dim];
398                }
399                idx
400            };
401            x_slice[j] = x[nd_idx.as_slice()];
402        }
403
404        // Perform matrix-vector multiplication
405        scirs2_core::ndarray::linalg::general_mat_vec_mul(
406            A::one(),
407            &a_slice.view(),
408            &x_slice.view(),
409            A::one(),
410            &mut y_slice,
411        );
412
413        // Copy result back
414        for j in 0..a_rows {
415            let flat_idx = y_start + j;
416            let nd_idx: Vec<usize> = {
417                let mut idx = vec![0; output.ndim()];
418                let mut remaining = flat_idx;
419                for dim in (0..output.ndim()).rev() {
420                    idx[dim] = remaining % outputshape[dim];
421                    remaining /= outputshape[dim];
422                }
423                idx
424            };
425            output[nd_idx.as_slice()] = y_slice[j];
426        }
427    }
428
429    Ok(output)
430}
431
432use scirs2_core::ndarray::{Array1, Array2};
433
434#[cfg(test)]
435mod tests {
436    use super::*;
437    use scirs2_core::ndarray::array;
438
439    #[test]
440    fn test_broadcast_compatible() {
441        let a = array![[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]];
442        let b = array![[[1.0, 2.0], [3.0, 4.0]]];
443
444        assert!(a.broadcast_compatible(&b));
445
446        let c = array![[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]];
447        assert!(!a.broadcast_compatible(&c));
448    }
449
450    #[test]
451    fn test_broadcastshape() {
452        let a = array![[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]];
453        let b = array![[[1.0, 2.0], [3.0, 4.0]]];
454
455        let shape = a.broadcastshape(&b).unwrap();
456        assert_eq!(shape, vec![2, 2, 2]);
457    }
458
459    #[test]
460    fn test_broadcast_matmul_3d() {
461        // Test 3D arrays (batch of 2x2 matrices)
462        let a = array![[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]];
463        let b = array![[[1.0, 0.0], [0.0, 1.0]], [[2.0, 0.0], [0.0, 2.0]]];
464
465        let c = broadcast_matmul_3d(&a, &b).unwrap();
466
467        // First batch: identity matrix multiplication
468        assert_eq!(c[[0, 0, 0]], 1.0);
469        assert_eq!(c[[0, 0, 1]], 2.0);
470        assert_eq!(c[[0, 1, 0]], 3.0);
471        assert_eq!(c[[0, 1, 1]], 4.0);
472
473        // Second batch: multiplication by 2*I
474        assert_eq!(c[[1, 0, 0]], 10.0);
475        assert_eq!(c[[1, 0, 1]], 12.0);
476        assert_eq!(c[[1, 1, 0]], 14.0);
477        assert_eq!(c[[1, 1, 1]], 16.0);
478    }
479
480    #[test]
481    fn test_broadcast_matmul_dyn() {
482        // Test dynamic arrays (batch of 2x2 matrices)
483        let a = array![[[1.0_f64, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]].into_dyn();
484        let b = array![[[1.0, 0.0], [0.0, 1.0]], [[2.0, 0.0], [0.0, 2.0]]].into_dyn();
485
486        let c = broadcast_matmul(&a, &b).unwrap();
487
488        // First batch: identity matrix multiplication
489        assert_eq!(c[[0, 0, 0]], 1.0);
490        assert_eq!(c[[0, 0, 1]], 2.0);
491        assert_eq!(c[[0, 1, 0]], 3.0);
492        assert_eq!(c[[0, 1, 1]], 4.0);
493
494        // Second batch: multiplication by 2*I
495        assert_eq!(c[[1, 0, 0]], 10.0);
496        assert_eq!(c[[1, 0, 1]], 12.0);
497        assert_eq!(c[[1, 1, 0]], 14.0);
498        assert_eq!(c[[1, 1, 1]], 16.0);
499    }
500
501    #[test]
502    fn test_broadcast_matvec_dyn() {
503        // Test dynamic array (batch of 2x2 matrices) with dynamic vector (batch of vectors)
504        let a = array![[[1.0_f64, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]].into_dyn();
505        let x = array![[1.0, 1.0], [2.0, 1.0]].into_dyn();
506
507        let y = broadcast_matvec(&a, &x).unwrap();
508
509        // First batch: [1,2;3,4] * [1,1] = [3,7]
510        assert_eq!(y[[0, 0]], 3.0);
511        assert_eq!(y[[0, 1]], 7.0);
512
513        // Second batch: [5,6;7,8] * [2,1] = [16,22]
514        assert_eq!(y[[1, 0]], 16.0);
515        assert_eq!(y[[1, 1]], 22.0);
516    }
517
518    #[test]
519    fn test_incompatible_dimensions() {
520        // These matrices have incompatible dimensions: (2, 2) x (3, 2)
521        let a = array![[[1.0_f64, 2.0], [3.0, 4.0]]].into_dyn();
522        let b = array![[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]].into_dyn();
523
524        let result = broadcast_matmul(&a, &b);
525        assert!(result.is_err());
526    }
527
528    #[test]
529    fn test_broadcast_3d_with_different_batch() {
530        // Test broadcasting with different batch sizes (1 and 2)
531        let a = array![[[1.0_f64, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]];
532        let b = array![[[1.0, 0.0], [0.0, 1.0]]];
533
534        let c = broadcast_matmul_3d(&a, &b).unwrap();
535
536        // Both batches use the same B matrix (identity)
537        assert_eq!(c[[0, 0, 0]], 1.0);
538        assert_eq!(c[[0, 0, 1]], 2.0);
539        assert_eq!(c[[1, 0, 0]], 5.0);
540        assert_eq!(c[[1, 0, 1]], 6.0);
541    }
542}