Skip to main content

shape_runtime/intrinsics/
matrix_kernels.rs

1//! SIMD-accelerated matrix kernels operating on `MatrixData`.
2//!
3//! All functions operate directly on `MatrixData` (row-major, SIMD-aligned)
4//! to avoid the overhead of nested-array extraction in the old `matrix.rs`.
5
6use shape_value::aligned_vec::AlignedVec;
7use shape_value::heap_value::MatrixData;
8use wide::f64x4;
9
10const SIMD_THRESHOLD: usize = 16;
11
12/// Element-wise matrix addition: C = A + B (same dimensions).
13pub fn matrix_add(a: &MatrixData, b: &MatrixData) -> Result<MatrixData, String> {
14    if a.rows != b.rows || a.cols != b.cols {
15        return Err(format!(
16            "Matrix dimension mismatch for add: {}x{} vs {}x{}",
17            a.rows, a.cols, b.rows, b.cols
18        ));
19    }
20    let len = a.data.len();
21    let mut result = AlignedVec::with_capacity(len);
22
23    if len >= SIMD_THRESHOLD {
24        let chunks = len / 4;
25        let a_ptr = a.data.as_ptr();
26        let b_ptr = b.data.as_ptr();
27        for i in 0..chunks {
28            let offset = i * 4;
29            let va = f64x4::from(unsafe { *(a_ptr.add(offset) as *const [f64; 4]) });
30            let vb = f64x4::from(unsafe { *(b_ptr.add(offset) as *const [f64; 4]) });
31            let vc = va + vb;
32            let arr: [f64; 4] = vc.into();
33            for v in arr {
34                result.push(v);
35            }
36        }
37        for i in (chunks * 4)..len {
38            result.push(a.data[i] + b.data[i]);
39        }
40    } else {
41        for i in 0..len {
42            result.push(a.data[i] + b.data[i]);
43        }
44    }
45
46    Ok(MatrixData::from_flat(result, a.rows, a.cols))
47}
48
49/// Element-wise matrix subtraction: C = A - B (same dimensions).
50pub fn matrix_sub(a: &MatrixData, b: &MatrixData) -> Result<MatrixData, String> {
51    if a.rows != b.rows || a.cols != b.cols {
52        return Err(format!(
53            "Matrix dimension mismatch for sub: {}x{} vs {}x{}",
54            a.rows, a.cols, b.rows, b.cols
55        ));
56    }
57    let len = a.data.len();
58    let mut result = AlignedVec::with_capacity(len);
59
60    if len >= SIMD_THRESHOLD {
61        let chunks = len / 4;
62        let a_ptr = a.data.as_ptr();
63        let b_ptr = b.data.as_ptr();
64        for i in 0..chunks {
65            let offset = i * 4;
66            let va = f64x4::from(unsafe { *(a_ptr.add(offset) as *const [f64; 4]) });
67            let vb = f64x4::from(unsafe { *(b_ptr.add(offset) as *const [f64; 4]) });
68            let vc = va - vb;
69            let arr: [f64; 4] = vc.into();
70            for v in arr {
71                result.push(v);
72            }
73        }
74        for i in (chunks * 4)..len {
75            result.push(a.data[i] - b.data[i]);
76        }
77    } else {
78        for i in 0..len {
79            result.push(a.data[i] - b.data[i]);
80        }
81    }
82
83    Ok(MatrixData::from_flat(result, a.rows, a.cols))
84}
85
86/// Scalar multiplication: C = A * scalar.
87pub fn matrix_scale(a: &MatrixData, scalar: f64) -> MatrixData {
88    let len = a.data.len();
89    let mut result = AlignedVec::with_capacity(len);
90
91    if len >= SIMD_THRESHOLD {
92        let chunks = len / 4;
93        let s = f64x4::splat(scalar);
94        let a_ptr = a.data.as_ptr();
95        for i in 0..chunks {
96            let offset = i * 4;
97            let va = f64x4::from(unsafe { *(a_ptr.add(offset) as *const [f64; 4]) });
98            let vc = va * s;
99            let arr: [f64; 4] = vc.into();
100            for v in arr {
101                result.push(v);
102            }
103        }
104        for i in (chunks * 4)..len {
105            result.push(a.data[i] * scalar);
106        }
107    } else {
108        for i in 0..len {
109            result.push(a.data[i] * scalar);
110        }
111    }
112
113    MatrixData::from_flat(result, a.rows, a.cols)
114}
115
116/// Element-wise (Hadamard) multiplication: C[i,j] = A[i,j] * B[i,j].
117pub fn matrix_element_mul(a: &MatrixData, b: &MatrixData) -> Result<MatrixData, String> {
118    if a.rows != b.rows || a.cols != b.cols {
119        return Err(format!(
120            "Matrix dimension mismatch for element-wise mul: {}x{} vs {}x{}",
121            a.rows, a.cols, b.rows, b.cols
122        ));
123    }
124    let len = a.data.len();
125    let mut result = AlignedVec::with_capacity(len);
126
127    if len >= SIMD_THRESHOLD {
128        let chunks = len / 4;
129        let a_ptr = a.data.as_ptr();
130        let b_ptr = b.data.as_ptr();
131        for i in 0..chunks {
132            let offset = i * 4;
133            let va = f64x4::from(unsafe { *(a_ptr.add(offset) as *const [f64; 4]) });
134            let vb = f64x4::from(unsafe { *(b_ptr.add(offset) as *const [f64; 4]) });
135            let vc = va * vb;
136            let arr: [f64; 4] = vc.into();
137            for v in arr {
138                result.push(v);
139            }
140        }
141        for i in (chunks * 4)..len {
142            result.push(a.data[i] * b.data[i]);
143        }
144    } else {
145        for i in 0..len {
146            result.push(a.data[i] * b.data[i]);
147        }
148    }
149
150    Ok(MatrixData::from_flat(result, a.rows, a.cols))
151}
152
153/// Matrix multiplication: C = A * B.
154/// A is (m x k), B is (k x n), result is (m x n).
155pub fn matrix_matmul(a: &MatrixData, b: &MatrixData) -> Result<MatrixData, String> {
156    if a.cols != b.rows {
157        return Err(format!(
158            "Matrix dimension mismatch for matmul: {}x{} * {}x{}",
159            a.rows, a.cols, b.rows, b.cols
160        ));
161    }
162    let m = a.rows as usize;
163    let k = a.cols as usize;
164    let n = b.cols as usize;
165    let mut result = AlignedVec::with_capacity(m * n);
166    for _ in 0..(m * n) {
167        result.push(0.0);
168    }
169
170    // ikj loop order for better cache behavior
171    for i in 0..m {
172        let a_row_base = i * k;
173        let out_row_base = i * n;
174        for kk in 0..k {
175            let a_ik = a.data[a_row_base + kk];
176            let b_row_base = kk * n;
177            if n >= SIMD_THRESHOLD {
178                let chunks = n / 4;
179                let sa = f64x4::splat(a_ik);
180                for j in 0..chunks {
181                    let offset = j * 4;
182                    let vb = f64x4::from(unsafe {
183                        *(b.data.as_ptr().add(b_row_base + offset) as *const [f64; 4])
184                    });
185                    let vc = f64x4::from(unsafe {
186                        *(result.as_ptr().add(out_row_base + offset) as *const [f64; 4])
187                    });
188                    let vr = vc + sa * vb;
189                    let arr: [f64; 4] = vr.into();
190                    for (idx, v) in arr.iter().enumerate() {
191                        result[out_row_base + offset + idx] = *v;
192                    }
193                }
194                for j in (chunks * 4)..n {
195                    result[out_row_base + j] += a_ik * b.data[b_row_base + j];
196                }
197            } else {
198                for j in 0..n {
199                    result[out_row_base + j] += a_ik * b.data[b_row_base + j];
200                }
201            }
202        }
203    }
204
205    Ok(MatrixData::from_flat(result, a.rows as u32, b.cols as u32))
206}
207
208/// Matrix-vector multiplication: y = A * v.
209/// A is (m x n), v has length n, result has length m.
210pub fn matrix_matvec(a: &MatrixData, v: &[f64]) -> Result<AlignedVec<f64>, String> {
211    let n = a.cols as usize;
212    if n != v.len() {
213        return Err(format!(
214            "Matrix-vector dimension mismatch: {}x{} * vec({})",
215            a.rows,
216            a.cols,
217            v.len()
218        ));
219    }
220    let m = a.rows as usize;
221    let mut result = AlignedVec::with_capacity(m);
222
223    for i in 0..m {
224        let row_base = i * n;
225        let mut acc = 0.0;
226        if n >= SIMD_THRESHOLD {
227            let chunks = n / 4;
228            let mut vacc = f64x4::splat(0.0);
229            for j in 0..chunks {
230                let offset = j * 4;
231                let va = f64x4::from(unsafe {
232                    *(a.data.as_ptr().add(row_base + offset) as *const [f64; 4])
233                });
234                let vv = f64x4::from(unsafe { *(v.as_ptr().add(offset) as *const [f64; 4]) });
235                vacc = vacc + va * vv;
236            }
237            let arr: [f64; 4] = vacc.into();
238            acc = arr[0] + arr[1] + arr[2] + arr[3];
239            for j in (chunks * 4)..n {
240                acc += a.data[row_base + j] * v[j];
241            }
242        } else {
243            for j in 0..n {
244                acc += a.data[row_base + j] * v[j];
245            }
246        }
247        result.push(acc);
248    }
249
250    Ok(result)
251}
252
253/// Matrix transpose: B = A^T.
254pub fn matrix_transpose(m: &MatrixData) -> MatrixData {
255    let rows = m.rows as usize;
256    let cols = m.cols as usize;
257    let mut result = AlignedVec::with_capacity(rows * cols);
258    for _ in 0..(rows * cols) {
259        result.push(0.0);
260    }
261
262    for i in 0..rows {
263        for j in 0..cols {
264            result[j * rows + i] = m.data[i * cols + j];
265        }
266    }
267
268    MatrixData::from_flat(result, m.cols, m.rows)
269}
270
271/// Matrix inverse via Gauss-Jordan elimination.
272/// Only works for square matrices.
273pub fn matrix_inverse(m: &MatrixData) -> Result<MatrixData, String> {
274    if m.rows != m.cols {
275        return Err(format!(
276            "Cannot invert non-square matrix: {}x{}",
277            m.rows, m.cols
278        ));
279    }
280    let n = m.rows as usize;
281    if n == 0 {
282        return Ok(MatrixData::new(0, 0));
283    }
284
285    // Build augmented matrix [A | I]
286    let mut aug = vec![0.0f64; n * 2 * n];
287    for i in 0..n {
288        for j in 0..n {
289            aug[i * 2 * n + j] = m.data[i * n + j];
290        }
291        aug[i * 2 * n + n + i] = 1.0;
292    }
293
294    // Forward elimination with partial pivoting
295    for col in 0..n {
296        // Find pivot
297        let mut max_val = aug[col * 2 * n + col].abs();
298        let mut max_row = col;
299        for row in (col + 1)..n {
300            let val = aug[row * 2 * n + col].abs();
301            if val > max_val {
302                max_val = val;
303                max_row = row;
304            }
305        }
306
307        if max_val < 1e-14 {
308            return Err("Matrix is singular and cannot be inverted".to_string());
309        }
310
311        // Swap rows
312        if max_row != col {
313            for j in 0..(2 * n) {
314                aug.swap(col * 2 * n + j, max_row * 2 * n + j);
315            }
316        }
317
318        // Scale pivot row
319        let pivot = aug[col * 2 * n + col];
320        for j in 0..(2 * n) {
321            aug[col * 2 * n + j] /= pivot;
322        }
323
324        // Eliminate column
325        for row in 0..n {
326            if row != col {
327                let factor = aug[row * 2 * n + col];
328                for j in 0..(2 * n) {
329                    aug[row * 2 * n + j] -= factor * aug[col * 2 * n + j];
330                }
331            }
332        }
333    }
334
335    // Extract inverse from right half
336    let mut result = AlignedVec::with_capacity(n * n);
337    for i in 0..n {
338        for j in 0..n {
339            result.push(aug[i * 2 * n + n + j]);
340        }
341    }
342
343    Ok(MatrixData::from_flat(result, m.rows, m.cols))
344}
345
346/// Matrix determinant via LU decomposition (partial pivoting).
347pub fn matrix_determinant(m: &MatrixData) -> Result<f64, String> {
348    if m.rows != m.cols {
349        return Err(format!(
350            "Cannot compute determinant of non-square matrix: {}x{}",
351            m.rows, m.cols
352        ));
353    }
354    let n = m.rows as usize;
355    if n == 0 {
356        return Ok(1.0);
357    }
358    if n == 1 {
359        return Ok(m.data[0]);
360    }
361    if n == 2 {
362        return Ok(m.data[0] * m.data[3] - m.data[1] * m.data[2]);
363    }
364
365    // Work on a copy
366    let mut a: Vec<f64> = m.data.iter().copied().collect();
367    let mut det = 1.0f64;
368
369    for col in 0..n {
370        // Partial pivoting
371        let mut max_val = a[col * n + col].abs();
372        let mut max_row = col;
373        for row in (col + 1)..n {
374            let val = a[row * n + col].abs();
375            if val > max_val {
376                max_val = val;
377                max_row = row;
378            }
379        }
380
381        if max_val < 1e-14 {
382            return Ok(0.0);
383        }
384
385        if max_row != col {
386            for j in 0..n {
387                a.swap(col * n + j, max_row * n + j);
388            }
389            det = -det;
390        }
391
392        det *= a[col * n + col];
393
394        let pivot = a[col * n + col];
395        for row in (col + 1)..n {
396            let factor = a[row * n + col] / pivot;
397            for j in (col + 1)..n {
398                a[row * n + j] -= factor * a[col * n + j];
399            }
400        }
401    }
402
403    Ok(det)
404}
405
406/// Matrix trace: sum of diagonal elements.
407pub fn matrix_trace(m: &MatrixData) -> Result<f64, String> {
408    if m.rows != m.cols {
409        return Err(format!(
410            "Cannot compute trace of non-square matrix: {}x{}",
411            m.rows, m.cols
412        ));
413    }
414    let n = m.rows as usize;
415    let mut sum = 0.0;
416    for i in 0..n {
417        sum += m.data[i * n + i];
418    }
419    Ok(sum)
420}
421
422#[cfg(test)]
423mod tests {
424    use super::*;
425
426    fn mat(data: &[f64], rows: u32, cols: u32) -> MatrixData {
427        let mut aligned = AlignedVec::with_capacity(data.len());
428        for &v in data {
429            aligned.push(v);
430        }
431        MatrixData::from_flat(aligned, rows, cols)
432    }
433
434    fn approx_eq(a: f64, b: f64) -> bool {
435        (a - b).abs() < 1e-10
436    }
437
438    fn mat_approx_eq(a: &MatrixData, b: &MatrixData) -> bool {
439        a.rows == b.rows
440            && a.cols == b.cols
441            && a.data
442                .iter()
443                .zip(b.data.iter())
444                .all(|(x, y)| approx_eq(*x, *y))
445    }
446
447    #[test]
448    fn test_matrix_add_2x2() {
449        let a = mat(&[1.0, 2.0, 3.0, 4.0], 2, 2);
450        let b = mat(&[5.0, 6.0, 7.0, 8.0], 2, 2);
451        let c = matrix_add(&a, &b).unwrap();
452        assert_eq!(c.data.as_slice(), &[6.0, 8.0, 10.0, 12.0]);
453    }
454
455    #[test]
456    fn test_matrix_sub_2x2() {
457        let a = mat(&[5.0, 6.0, 7.0, 8.0], 2, 2);
458        let b = mat(&[1.0, 2.0, 3.0, 4.0], 2, 2);
459        let c = matrix_sub(&a, &b).unwrap();
460        assert_eq!(c.data.as_slice(), &[4.0, 4.0, 4.0, 4.0]);
461    }
462
463    #[test]
464    fn test_matrix_scale() {
465        let a = mat(&[1.0, 2.0, 3.0, 4.0], 2, 2);
466        let c = matrix_scale(&a, 3.0);
467        assert_eq!(c.data.as_slice(), &[3.0, 6.0, 9.0, 12.0]);
468    }
469
470    #[test]
471    fn test_matrix_element_mul() {
472        let a = mat(&[1.0, 2.0, 3.0, 4.0], 2, 2);
473        let b = mat(&[5.0, 6.0, 7.0, 8.0], 2, 2);
474        let c = matrix_element_mul(&a, &b).unwrap();
475        assert_eq!(c.data.as_slice(), &[5.0, 12.0, 21.0, 32.0]);
476    }
477
478    #[test]
479    fn test_matrix_matmul_2x2() {
480        let a = mat(&[1.0, 2.0, 3.0, 4.0], 2, 2);
481        let b = mat(&[5.0, 6.0, 7.0, 8.0], 2, 2);
482        let c = matrix_matmul(&a, &b).unwrap();
483        assert_eq!(c.data.as_slice(), &[19.0, 22.0, 43.0, 50.0]);
484    }
485
486    #[test]
487    fn test_matrix_matmul_3x3() {
488        let a = mat(&[1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0], 3, 3);
489        let b = mat(&[2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], 3, 3);
490        let c = matrix_matmul(&a, &b).unwrap();
491        assert_eq!(c.data.as_slice(), b.data.as_slice());
492    }
493
494    #[test]
495    fn test_matrix_matmul_2x3_3x2() {
496        let a = mat(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 2, 3);
497        let b = mat(&[7.0, 8.0, 9.0, 10.0, 11.0, 12.0], 3, 2);
498        let c = matrix_matmul(&a, &b).unwrap();
499        assert_eq!(c.rows, 2);
500        assert_eq!(c.cols, 2);
501        // [1*7+2*9+3*11, 1*8+2*10+3*12] = [58, 64]
502        // [4*7+5*9+6*11, 4*8+5*10+6*12] = [139, 154]
503        assert_eq!(c.data.as_slice(), &[58.0, 64.0, 139.0, 154.0]);
504    }
505
506    #[test]
507    fn test_matrix_matvec() {
508        let a = mat(&[1.0, 2.0, 3.0, 4.0], 2, 2);
509        let v = [5.0, 6.0];
510        let result = matrix_matvec(&a, &v).unwrap();
511        assert_eq!(result.as_slice(), &[17.0, 39.0]);
512    }
513
514    #[test]
515    fn test_matrix_transpose() {
516        let a = mat(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 2, 3);
517        let t = matrix_transpose(&a);
518        assert_eq!(t.rows, 3);
519        assert_eq!(t.cols, 2);
520        assert_eq!(t.data.as_slice(), &[1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
521    }
522
523    #[test]
524    fn test_matrix_inverse_2x2() {
525        let a = mat(&[4.0, 7.0, 2.0, 6.0], 2, 2);
526        let inv = matrix_inverse(&a).unwrap();
527        // Verify A * A^-1 = I
528        let identity = matrix_matmul(&a, &inv).unwrap();
529        assert!(approx_eq(identity.get(0, 0), 1.0));
530        assert!(approx_eq(identity.get(0, 1), 0.0));
531        assert!(approx_eq(identity.get(1, 0), 0.0));
532        assert!(approx_eq(identity.get(1, 1), 1.0));
533    }
534
535    #[test]
536    fn test_matrix_inverse_3x3() {
537        let a = mat(&[1.0, 2.0, 3.0, 0.0, 1.0, 4.0, 5.0, 6.0, 0.0], 3, 3);
538        let inv = matrix_inverse(&a).unwrap();
539        let identity = matrix_matmul(&a, &inv).unwrap();
540        for i in 0..3u32 {
541            for j in 0..3u32 {
542                let expected = if i == j { 1.0 } else { 0.0 };
543                assert!(
544                    approx_eq(identity.get(i, j), expected),
545                    "identity[{},{}] = {} (expected {})",
546                    i,
547                    j,
548                    identity.get(i, j),
549                    expected
550                );
551            }
552        }
553    }
554
555    #[test]
556    fn test_matrix_inverse_singular() {
557        let a = mat(&[1.0, 2.0, 2.0, 4.0], 2, 2);
558        assert!(matrix_inverse(&a).is_err());
559    }
560
561    #[test]
562    fn test_matrix_determinant_2x2() {
563        let a = mat(&[3.0, 8.0, 4.0, 6.0], 2, 2);
564        let det = matrix_determinant(&a).unwrap();
565        assert!(approx_eq(det, -14.0));
566    }
567
568    #[test]
569    fn test_matrix_determinant_3x3() {
570        let a = mat(&[6.0, 1.0, 1.0, 4.0, -2.0, 5.0, 2.0, 8.0, 7.0], 3, 3);
571        let det = matrix_determinant(&a).unwrap();
572        assert!(approx_eq(det, -306.0));
573    }
574
575    #[test]
576    fn test_matrix_determinant_singular() {
577        let a = mat(&[1.0, 2.0, 2.0, 4.0], 2, 2);
578        let det = matrix_determinant(&a).unwrap();
579        assert!(approx_eq(det, 0.0));
580    }
581
582    #[test]
583    fn test_matrix_trace() {
584        let a = mat(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0], 3, 3);
585        let tr = matrix_trace(&a).unwrap();
586        assert!(approx_eq(tr, 15.0));
587    }
588
589    #[test]
590    fn test_matrix_add_dimension_mismatch() {
591        let a = mat(&[1.0, 2.0, 3.0, 4.0], 2, 2);
592        let b = mat(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 2, 3);
593        assert!(matrix_add(&a, &b).is_err());
594    }
595
596    #[test]
597    fn test_matrix_matmul_dimension_mismatch() {
598        let a = mat(&[1.0, 2.0, 3.0, 4.0], 2, 2);
599        let b = mat(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 3, 2);
600        assert!(matrix_matmul(&a, &b).is_err());
601    }
602
603    #[test]
604    fn test_matrix_add_large_simd() {
605        // Test SIMD path (>= 16 elements)
606        let n = 20;
607        let data_a: Vec<f64> = (0..n).map(|i| i as f64).collect();
608        let data_b: Vec<f64> = (0..n).map(|i| (i * 2) as f64).collect();
609        let a = mat(&data_a, 4, 5);
610        let b = mat(&data_b, 4, 5);
611        let c = matrix_add(&a, &b).unwrap();
612        for i in 0..n {
613            assert!(approx_eq(c.data[i], data_a[i] + data_b[i]));
614        }
615    }
616
617    #[test]
618    fn test_matrix_matmul_4x4() {
619        let a = mat(
620            &[
621                1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0,
622                16.0,
623            ],
624            4,
625            4,
626        );
627        let identity = mat(
628            &[
629                1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0,
630            ],
631            4,
632            4,
633        );
634        let c = matrix_matmul(&a, &identity).unwrap();
635        assert!(mat_approx_eq(&c, &a));
636    }
637}