1use crate::error::{CoreError, CoreResult, ErrorContext};
30use ::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
31use num_traits::Zero;
32use std::ops::{Add, Mul};
33
34pub fn dot<T>(a: &ArrayView2<T>, b: &ArrayView2<T>) -> Array2<T>
75where
76 T: Clone + Zero + Add<Output = T> + Mul<Output = T>,
77{
78 let (m, k) = (a.nrows(), a.ncols());
79 let (k2, n) = (b.nrows(), b.ncols());
80 debug_assert_eq!(k, k2, "dot: inner dimensions must match");
81
82 let mut result = Array2::<T>::zeros((m, n));
83 for i in 0..m {
84 for j in 0..n {
85 let mut sum = T::zero();
86 for l in 0..k {
87 sum = sum + a[[i, l]].clone() * b[[l, j]].clone();
88 }
89 result[[i, j]] = sum;
90 }
91 }
92 result
93}
94
95pub fn outer<T>(u: &ArrayView1<T>, v: &ArrayView1<T>) -> Array2<T>
118where
119 T: Clone + Zero + Mul<Output = T>,
120{
121 let m = u.len();
122 let n = v.len();
123 Array2::from_shape_fn((m, n), |(i, j)| u[i].clone() * v[j].clone())
124}
125
126pub fn kron<T>(a: &ArrayView2<T>, b: &ArrayView2<T>) -> Array2<T>
150where
151 T: Clone + Zero + Mul<Output = T>,
152{
153 let (p, q) = (a.nrows(), a.ncols());
154 let (r, s) = (b.nrows(), b.ncols());
155
156 Array2::from_shape_fn((p * r, q * s), |(i, j)| {
157 let ai = i / r;
158 let bi = i % r;
159 let aj = j / s;
160 let bj = j % s;
161 a[[ai, aj]].clone() * b[[bi, bj]].clone()
162 })
163}
164
165pub fn vstack<T>(arrays: &[ArrayView2<T>]) -> CoreResult<Array2<T>>
187where
188 T: Clone + Zero,
189{
190 if arrays.is_empty() {
191 return Err(CoreError::InvalidInput(ErrorContext::new(
192 "vstack: cannot stack an empty slice of arrays",
193 )));
194 }
195
196 let ncols = arrays[0].ncols();
197 for (idx, arr) in arrays.iter().enumerate().skip(1) {
198 if arr.ncols() != ncols {
199 return Err(CoreError::InvalidInput(ErrorContext::new(format!(
200 "vstack: array at index {idx} has {cols} columns, expected {ncols}",
201 cols = arr.ncols()
202 ))));
203 }
204 }
205
206 let total_rows: usize = arrays.iter().map(|a| a.nrows()).sum();
207 let mut result = Array2::<T>::zeros((total_rows, ncols));
208
209 let mut row_offset = 0;
210 for arr in arrays {
211 let nrows = arr.nrows();
212 for r in 0..nrows {
213 for c in 0..ncols {
214 result[[row_offset + r, c]] = arr[[r, c]].clone();
215 }
216 }
217 row_offset += nrows;
218 }
219
220 Ok(result)
221}
222
223pub fn hstack<T>(arrays: &[ArrayView2<T>]) -> CoreResult<Array2<T>>
246where
247 T: Clone + Zero,
248{
249 if arrays.is_empty() {
250 return Err(CoreError::InvalidInput(ErrorContext::new(
251 "hstack: cannot stack an empty slice of arrays",
252 )));
253 }
254
255 let nrows = arrays[0].nrows();
256 for (idx, arr) in arrays.iter().enumerate().skip(1) {
257 if arr.nrows() != nrows {
258 return Err(CoreError::InvalidInput(ErrorContext::new(format!(
259 "hstack: array at index {idx} has {r} rows, expected {nrows}",
260 r = arr.nrows()
261 ))));
262 }
263 }
264
265 let total_cols: usize = arrays.iter().map(|a| a.ncols()).sum();
266 let mut result = Array2::<T>::zeros((nrows, total_cols));
267
268 let mut col_offset = 0;
269 for arr in arrays {
270 let ncols = arr.ncols();
271 for r in 0..nrows {
272 for c in 0..ncols {
273 result[[r, col_offset + c]] = arr[[r, c]].clone();
274 }
275 }
276 col_offset += ncols;
277 }
278
279 Ok(result)
280}
281
282pub fn block_diag<T>(blocks: &[ArrayView2<T>]) -> Array2<T>
309where
310 T: Clone + Zero,
311{
312 if blocks.is_empty() {
313 return Array2::<T>::zeros((0, 0));
314 }
315
316 let total_rows: usize = blocks.iter().map(|b| b.nrows()).sum();
317 let total_cols: usize = blocks.iter().map(|b| b.ncols()).sum();
318
319 let mut result = Array2::<T>::zeros((total_rows, total_cols));
320
321 let mut row_off = 0;
322 let mut col_off = 0;
323 for block in blocks {
324 let (br, bc) = (block.nrows(), block.ncols());
325 for r in 0..br {
326 for c in 0..bc {
327 result[[row_off + r, col_off + c]] = block[[r, c]].clone();
328 }
329 }
330 row_off += br;
331 col_off += bc;
332 }
333
334 result
335}
336
337#[cfg(test)]
342mod tests {
343 use super::*;
344 use ::ndarray::array;
345 use approx::assert_abs_diff_eq;
346
347 #[test]
350 fn test_dot_identity() {
351 let eye = array![[1.0_f64, 0.0], [0.0, 1.0]];
352 let b = array![[3.0_f64, 4.0], [5.0, 6.0]];
353 let c = dot(&eye.view(), &b.view());
354 assert_abs_diff_eq!(c[[0, 0]], 3.0, epsilon = 1e-12);
355 assert_abs_diff_eq!(c[[1, 1]], 6.0, epsilon = 1e-12);
356 }
357
358 #[test]
359 fn test_dot_rectangular() {
360 let a = array![[1.0_f64, 2.0, 3.0], [4.0, 5.0, 6.0]];
362 let b = array![[7.0_f64, 8.0], [9.0, 10.0], [11.0, 12.0]];
363 let c = dot(&a.view(), &b.view());
364 assert_eq!(c.shape(), &[2, 2]);
365 assert_abs_diff_eq!(c[[0, 0]], 58.0, epsilon = 1e-12);
367 assert_abs_diff_eq!(c[[0, 1]], 64.0, epsilon = 1e-12);
368 assert_abs_diff_eq!(c[[1, 0]], 139.0, epsilon = 1e-12);
370 assert_abs_diff_eq!(c[[1, 1]], 154.0, epsilon = 1e-12);
371 }
372
373 #[test]
374 fn test_dot_integers() {
375 let a = array![[1_i32, 2], [3, 4]];
376 let b = array![[5_i32, 6], [7, 8]];
377 let c = dot(&a.view(), &b.view());
378 assert_eq!(c[[0, 0]], 19); assert_eq!(c[[1, 1]], 50); }
381
382 #[test]
385 fn test_outer_basic() {
386 let u = array![1.0_f64, 2.0, 3.0];
387 let v = array![4.0_f64, 5.0];
388 let m = outer(&u.view(), &v.view());
389 assert_eq!(m.shape(), &[3, 2]);
390 assert_abs_diff_eq!(m[[0, 0]], 4.0, epsilon = 1e-12);
391 assert_abs_diff_eq!(m[[1, 1]], 10.0, epsilon = 1e-12);
392 assert_abs_diff_eq!(m[[2, 0]], 12.0, epsilon = 1e-12);
393 assert_abs_diff_eq!(m[[2, 1]], 15.0, epsilon = 1e-12);
394 }
395
396 #[test]
397 fn test_outer_integers() {
398 let u = array![1_i32, 2];
399 let v = array![3_i32, 4, 5];
400 let m = outer(&u.view(), &v.view());
401 assert_eq!(m.shape(), &[2, 3]);
402 assert_eq!(m[[0, 0]], 3);
403 assert_eq!(m[[1, 2]], 10);
404 }
405
406 #[test]
409 fn test_kron_identity_identity() {
410 let eye2 = array![[1_i32, 0], [0, 1]];
411 let eye3 = array![[1_i32, 0, 0], [0, 1, 0], [0, 0, 1]];
412 let k = kron(&eye2.view(), &eye3.view());
413 assert_eq!(k.shape(), &[6, 6]);
414 for i in 0..6 {
416 for j in 0..6 {
417 assert_eq!(k[[i, j]], if i == j { 1 } else { 0 });
418 }
419 }
420 }
421
422 #[test]
423 fn test_kron_scalar() {
424 let two = array![[2_i32]];
425 let b = array![[1_i32, 2], [3, 4]];
426 let k = kron(&two.view(), &b.view());
427 assert_eq!(k.shape(), &[2, 2]);
428 assert_eq!(k[[0, 0]], 2);
429 assert_eq!(k[[1, 1]], 8);
430 }
431
432 #[test]
433 fn test_kron_matches_expected() {
434 let a = array![[1_i32, 2], [3, 4]];
438 let b = array![[0_i32, 5], [6, 7]];
439 let k = kron(&a.view(), &b.view());
440 assert_eq!(k.shape(), &[4, 4]);
441 assert_eq!(k[[0, 0]], 0);
443 assert_eq!(k[[0, 1]], 5);
444 assert_eq!(k[[0, 2]], 0);
445 assert_eq!(k[[0, 3]], 10);
446 assert_eq!(k[[3, 0]], 18);
448 assert_eq!(k[[3, 1]], 21);
449 assert_eq!(k[[3, 2]], 24);
450 assert_eq!(k[[3, 3]], 28);
451 }
452
453 #[test]
456 fn test_vstack_basic() {
457 let a = array![[1.0_f64, 2.0], [3.0, 4.0]];
458 let b = array![[5.0_f64, 6.0]];
459 let s = vstack(&[a.view(), b.view()]).expect("same cols");
460 assert_eq!(s.shape(), &[3, 2]);
461 assert_abs_diff_eq!(s[[2, 0]], 5.0, epsilon = 1e-12);
462 assert_abs_diff_eq!(s[[2, 1]], 6.0, epsilon = 1e-12);
463 }
464
465 #[test]
466 fn test_vstack_three_arrays() {
467 let a = array![[1_i32, 2]];
468 let b = array![[3_i32, 4]];
469 let c = array![[5_i32, 6], [7, 8]];
470 let s = vstack(&[a.view(), b.view(), c.view()]).expect("same cols");
471 assert_eq!(s.shape(), &[4, 2]);
472 assert_eq!(s[[0, 0]], 1);
473 assert_eq!(s[[1, 1]], 4);
474 assert_eq!(s[[2, 0]], 5);
475 assert_eq!(s[[3, 1]], 8);
476 }
477
478 #[test]
479 fn test_vstack_mismatch_error() {
480 let a = array![[1.0_f64, 2.0, 3.0]]; let b = array![[4.0_f64, 5.0]]; assert!(vstack(&[a.view(), b.view()]).is_err());
483 }
484
485 #[test]
486 fn test_vstack_empty_error() {
487 let empty: &[ArrayView2<f64>] = &[];
488 assert!(vstack(empty).is_err());
489 }
490
491 #[test]
494 fn test_hstack_basic() {
495 let a = array![[1.0_f64, 2.0], [3.0, 4.0]];
496 let b = array![[5.0_f64], [6.0]];
497 let s = hstack(&[a.view(), b.view()]).expect("same rows");
498 assert_eq!(s.shape(), &[2, 3]);
499 assert_abs_diff_eq!(s[[0, 2]], 5.0, epsilon = 1e-12);
500 assert_abs_diff_eq!(s[[1, 2]], 6.0, epsilon = 1e-12);
501 }
502
503 #[test]
504 fn test_hstack_three_arrays() {
505 let a = array![[1_i32], [2]];
506 let b = array![[3_i32], [4]];
507 let c = array![[5_i32, 6], [7, 8]];
508 let s = hstack(&[a.view(), b.view(), c.view()]).expect("same rows");
509 assert_eq!(s.shape(), &[2, 4]);
510 assert_eq!(s[[0, 0]], 1);
511 assert_eq!(s[[0, 1]], 3);
512 assert_eq!(s[[1, 3]], 8);
513 }
514
515 #[test]
516 fn test_hstack_mismatch_error() {
517 let a = array![[1.0_f64], [2.0], [3.0]]; let b = array![[4.0_f64], [5.0]]; assert!(hstack(&[a.view(), b.view()]).is_err());
520 }
521
522 #[test]
523 fn test_hstack_empty_error() {
524 let empty: &[ArrayView2<f64>] = &[];
525 assert!(hstack(empty).is_err());
526 }
527
528 #[test]
531 fn test_block_diag_square_blocks() {
532 let a = array![[1_i32, 2], [3, 4]];
533 let b = array![[5_i32, 6], [7, 8]];
534 let bd = block_diag(&[a.view(), b.view()]);
535 assert_eq!(bd.shape(), &[4, 4]);
536 assert_eq!(bd[[0, 0]], 1);
537 assert_eq!(bd[[1, 1]], 4);
538 assert_eq!(bd[[2, 2]], 5);
539 assert_eq!(bd[[3, 3]], 8);
540 assert_eq!(bd[[0, 2]], 0);
542 assert_eq!(bd[[3, 0]], 0);
543 }
544
545 #[test]
546 fn test_block_diag_rectangular_blocks() {
547 let a = array![[1_i32, 2, 3]]; let b = array![[4_i32], [5]]; let bd = block_diag(&[a.view(), b.view()]);
550 assert_eq!(bd.shape(), &[3, 4]);
551 assert_eq!(bd[[0, 2]], 3);
553 assert_eq!(bd[[1, 3]], 4);
555 assert_eq!(bd[[2, 3]], 5);
556 assert_eq!(bd[[1, 0]], 0);
558 }
559
560 #[test]
561 fn test_block_diag_single() {
562 let a = array![[9_i32]];
563 let bd = block_diag(&[a.view()]);
564 assert_eq!(bd.shape(), &[1, 1]);
565 assert_eq!(bd[[0, 0]], 9);
566 }
567
568 #[test]
569 fn test_block_diag_empty() {
570 let empty: &[ArrayView2<i32>] = &[];
571 let bd = block_diag(empty);
572 assert_eq!(bd.shape(), &[0, 0]);
573 }
574
575 #[test]
576 fn test_block_diag_three_blocks() {
577 let a = array![[1_i32, 2], [3, 4]];
578 let b = array![[5_i32]];
579 let c = array![[6_i32, 7, 8]];
580 let bd = block_diag(&[a.view(), b.view(), c.view()]);
581 assert_eq!(bd.shape(), &[4, 6]);
582 assert_eq!(bd[[0, 0]], 1);
584 assert_eq!(bd[[1, 1]], 4);
585 assert_eq!(bd[[2, 2]], 5);
586 assert_eq!(bd[[3, 3]], 6);
587 assert_eq!(bd[[3, 5]], 8);
588 assert_eq!(bd[[0, 3]], 0);
590 assert_eq!(bd[[3, 0]], 0);
591 }
592}