1use shape_value::aligned_vec::AlignedVec;
7use shape_value::heap_value::MatrixData;
8use wide::f64x4;
9
10const SIMD_THRESHOLD: usize = 16;
11
12pub fn matrix_add(a: &MatrixData, b: &MatrixData) -> Result<MatrixData, String> {
14 if a.rows != b.rows || a.cols != b.cols {
15 return Err(format!(
16 "Matrix dimension mismatch for add: {}x{} vs {}x{}",
17 a.rows, a.cols, b.rows, b.cols
18 ));
19 }
20 let len = a.data.len();
21 let mut result = AlignedVec::with_capacity(len);
22
23 if len >= SIMD_THRESHOLD {
24 let chunks = len / 4;
25 let a_ptr = a.data.as_ptr();
26 let b_ptr = b.data.as_ptr();
27 for i in 0..chunks {
28 let offset = i * 4;
29 let va = f64x4::from(unsafe { *(a_ptr.add(offset) as *const [f64; 4]) });
30 let vb = f64x4::from(unsafe { *(b_ptr.add(offset) as *const [f64; 4]) });
31 let vc = va + vb;
32 let arr: [f64; 4] = vc.into();
33 for v in arr {
34 result.push(v);
35 }
36 }
37 for i in (chunks * 4)..len {
38 result.push(a.data[i] + b.data[i]);
39 }
40 } else {
41 for i in 0..len {
42 result.push(a.data[i] + b.data[i]);
43 }
44 }
45
46 Ok(MatrixData::from_flat(result, a.rows, a.cols))
47}
48
49pub fn matrix_sub(a: &MatrixData, b: &MatrixData) -> Result<MatrixData, String> {
51 if a.rows != b.rows || a.cols != b.cols {
52 return Err(format!(
53 "Matrix dimension mismatch for sub: {}x{} vs {}x{}",
54 a.rows, a.cols, b.rows, b.cols
55 ));
56 }
57 let len = a.data.len();
58 let mut result = AlignedVec::with_capacity(len);
59
60 if len >= SIMD_THRESHOLD {
61 let chunks = len / 4;
62 let a_ptr = a.data.as_ptr();
63 let b_ptr = b.data.as_ptr();
64 for i in 0..chunks {
65 let offset = i * 4;
66 let va = f64x4::from(unsafe { *(a_ptr.add(offset) as *const [f64; 4]) });
67 let vb = f64x4::from(unsafe { *(b_ptr.add(offset) as *const [f64; 4]) });
68 let vc = va - vb;
69 let arr: [f64; 4] = vc.into();
70 for v in arr {
71 result.push(v);
72 }
73 }
74 for i in (chunks * 4)..len {
75 result.push(a.data[i] - b.data[i]);
76 }
77 } else {
78 for i in 0..len {
79 result.push(a.data[i] - b.data[i]);
80 }
81 }
82
83 Ok(MatrixData::from_flat(result, a.rows, a.cols))
84}
85
86pub fn matrix_scale(a: &MatrixData, scalar: f64) -> MatrixData {
88 let len = a.data.len();
89 let mut result = AlignedVec::with_capacity(len);
90
91 if len >= SIMD_THRESHOLD {
92 let chunks = len / 4;
93 let s = f64x4::splat(scalar);
94 let a_ptr = a.data.as_ptr();
95 for i in 0..chunks {
96 let offset = i * 4;
97 let va = f64x4::from(unsafe { *(a_ptr.add(offset) as *const [f64; 4]) });
98 let vc = va * s;
99 let arr: [f64; 4] = vc.into();
100 for v in arr {
101 result.push(v);
102 }
103 }
104 for i in (chunks * 4)..len {
105 result.push(a.data[i] * scalar);
106 }
107 } else {
108 for i in 0..len {
109 result.push(a.data[i] * scalar);
110 }
111 }
112
113 MatrixData::from_flat(result, a.rows, a.cols)
114}
115
116pub fn matrix_element_mul(a: &MatrixData, b: &MatrixData) -> Result<MatrixData, String> {
118 if a.rows != b.rows || a.cols != b.cols {
119 return Err(format!(
120 "Matrix dimension mismatch for element-wise mul: {}x{} vs {}x{}",
121 a.rows, a.cols, b.rows, b.cols
122 ));
123 }
124 let len = a.data.len();
125 let mut result = AlignedVec::with_capacity(len);
126
127 if len >= SIMD_THRESHOLD {
128 let chunks = len / 4;
129 let a_ptr = a.data.as_ptr();
130 let b_ptr = b.data.as_ptr();
131 for i in 0..chunks {
132 let offset = i * 4;
133 let va = f64x4::from(unsafe { *(a_ptr.add(offset) as *const [f64; 4]) });
134 let vb = f64x4::from(unsafe { *(b_ptr.add(offset) as *const [f64; 4]) });
135 let vc = va * vb;
136 let arr: [f64; 4] = vc.into();
137 for v in arr {
138 result.push(v);
139 }
140 }
141 for i in (chunks * 4)..len {
142 result.push(a.data[i] * b.data[i]);
143 }
144 } else {
145 for i in 0..len {
146 result.push(a.data[i] * b.data[i]);
147 }
148 }
149
150 Ok(MatrixData::from_flat(result, a.rows, a.cols))
151}
152
153pub fn matrix_matmul(a: &MatrixData, b: &MatrixData) -> Result<MatrixData, String> {
156 if a.cols != b.rows {
157 return Err(format!(
158 "Matrix dimension mismatch for matmul: {}x{} * {}x{}",
159 a.rows, a.cols, b.rows, b.cols
160 ));
161 }
162 let m = a.rows as usize;
163 let k = a.cols as usize;
164 let n = b.cols as usize;
165 let mut result = AlignedVec::with_capacity(m * n);
166 for _ in 0..(m * n) {
167 result.push(0.0);
168 }
169
170 for i in 0..m {
172 let a_row_base = i * k;
173 let out_row_base = i * n;
174 for kk in 0..k {
175 let a_ik = a.data[a_row_base + kk];
176 let b_row_base = kk * n;
177 if n >= SIMD_THRESHOLD {
178 let chunks = n / 4;
179 let sa = f64x4::splat(a_ik);
180 for j in 0..chunks {
181 let offset = j * 4;
182 let vb = f64x4::from(unsafe {
183 *(b.data.as_ptr().add(b_row_base + offset) as *const [f64; 4])
184 });
185 let vc = f64x4::from(unsafe {
186 *(result.as_ptr().add(out_row_base + offset) as *const [f64; 4])
187 });
188 let vr = vc + sa * vb;
189 let arr: [f64; 4] = vr.into();
190 for (idx, v) in arr.iter().enumerate() {
191 result[out_row_base + offset + idx] = *v;
192 }
193 }
194 for j in (chunks * 4)..n {
195 result[out_row_base + j] += a_ik * b.data[b_row_base + j];
196 }
197 } else {
198 for j in 0..n {
199 result[out_row_base + j] += a_ik * b.data[b_row_base + j];
200 }
201 }
202 }
203 }
204
205 Ok(MatrixData::from_flat(result, a.rows as u32, b.cols as u32))
206}
207
208pub fn matrix_matvec(a: &MatrixData, v: &[f64]) -> Result<AlignedVec<f64>, String> {
211 let n = a.cols as usize;
212 if n != v.len() {
213 return Err(format!(
214 "Matrix-vector dimension mismatch: {}x{} * vec({})",
215 a.rows,
216 a.cols,
217 v.len()
218 ));
219 }
220 let m = a.rows as usize;
221 let mut result = AlignedVec::with_capacity(m);
222
223 for i in 0..m {
224 let row_base = i * n;
225 let mut acc = 0.0;
226 if n >= SIMD_THRESHOLD {
227 let chunks = n / 4;
228 let mut vacc = f64x4::splat(0.0);
229 for j in 0..chunks {
230 let offset = j * 4;
231 let va = f64x4::from(unsafe {
232 *(a.data.as_ptr().add(row_base + offset) as *const [f64; 4])
233 });
234 let vv = f64x4::from(unsafe { *(v.as_ptr().add(offset) as *const [f64; 4]) });
235 vacc = vacc + va * vv;
236 }
237 let arr: [f64; 4] = vacc.into();
238 acc = arr[0] + arr[1] + arr[2] + arr[3];
239 for j in (chunks * 4)..n {
240 acc += a.data[row_base + j] * v[j];
241 }
242 } else {
243 for j in 0..n {
244 acc += a.data[row_base + j] * v[j];
245 }
246 }
247 result.push(acc);
248 }
249
250 Ok(result)
251}
252
253pub fn matrix_transpose(m: &MatrixData) -> MatrixData {
255 let rows = m.rows as usize;
256 let cols = m.cols as usize;
257 let mut result = AlignedVec::with_capacity(rows * cols);
258 for _ in 0..(rows * cols) {
259 result.push(0.0);
260 }
261
262 for i in 0..rows {
263 for j in 0..cols {
264 result[j * rows + i] = m.data[i * cols + j];
265 }
266 }
267
268 MatrixData::from_flat(result, m.cols, m.rows)
269}
270
271pub fn matrix_inverse(m: &MatrixData) -> Result<MatrixData, String> {
274 if m.rows != m.cols {
275 return Err(format!(
276 "Cannot invert non-square matrix: {}x{}",
277 m.rows, m.cols
278 ));
279 }
280 let n = m.rows as usize;
281 if n == 0 {
282 return Ok(MatrixData::new(0, 0));
283 }
284
285 let mut aug = vec![0.0f64; n * 2 * n];
287 for i in 0..n {
288 for j in 0..n {
289 aug[i * 2 * n + j] = m.data[i * n + j];
290 }
291 aug[i * 2 * n + n + i] = 1.0;
292 }
293
294 for col in 0..n {
296 let mut max_val = aug[col * 2 * n + col].abs();
298 let mut max_row = col;
299 for row in (col + 1)..n {
300 let val = aug[row * 2 * n + col].abs();
301 if val > max_val {
302 max_val = val;
303 max_row = row;
304 }
305 }
306
307 if max_val < 1e-14 {
308 return Err("Matrix is singular and cannot be inverted".to_string());
309 }
310
311 if max_row != col {
313 for j in 0..(2 * n) {
314 aug.swap(col * 2 * n + j, max_row * 2 * n + j);
315 }
316 }
317
318 let pivot = aug[col * 2 * n + col];
320 for j in 0..(2 * n) {
321 aug[col * 2 * n + j] /= pivot;
322 }
323
324 for row in 0..n {
326 if row != col {
327 let factor = aug[row * 2 * n + col];
328 for j in 0..(2 * n) {
329 aug[row * 2 * n + j] -= factor * aug[col * 2 * n + j];
330 }
331 }
332 }
333 }
334
335 let mut result = AlignedVec::with_capacity(n * n);
337 for i in 0..n {
338 for j in 0..n {
339 result.push(aug[i * 2 * n + n + j]);
340 }
341 }
342
343 Ok(MatrixData::from_flat(result, m.rows, m.cols))
344}
345
346pub fn matrix_determinant(m: &MatrixData) -> Result<f64, String> {
348 if m.rows != m.cols {
349 return Err(format!(
350 "Cannot compute determinant of non-square matrix: {}x{}",
351 m.rows, m.cols
352 ));
353 }
354 let n = m.rows as usize;
355 if n == 0 {
356 return Ok(1.0);
357 }
358 if n == 1 {
359 return Ok(m.data[0]);
360 }
361 if n == 2 {
362 return Ok(m.data[0] * m.data[3] - m.data[1] * m.data[2]);
363 }
364
365 let mut a: Vec<f64> = m.data.iter().copied().collect();
367 let mut det = 1.0f64;
368
369 for col in 0..n {
370 let mut max_val = a[col * n + col].abs();
372 let mut max_row = col;
373 for row in (col + 1)..n {
374 let val = a[row * n + col].abs();
375 if val > max_val {
376 max_val = val;
377 max_row = row;
378 }
379 }
380
381 if max_val < 1e-14 {
382 return Ok(0.0);
383 }
384
385 if max_row != col {
386 for j in 0..n {
387 a.swap(col * n + j, max_row * n + j);
388 }
389 det = -det;
390 }
391
392 det *= a[col * n + col];
393
394 let pivot = a[col * n + col];
395 for row in (col + 1)..n {
396 let factor = a[row * n + col] / pivot;
397 for j in (col + 1)..n {
398 a[row * n + j] -= factor * a[col * n + j];
399 }
400 }
401 }
402
403 Ok(det)
404}
405
406pub fn matrix_trace(m: &MatrixData) -> Result<f64, String> {
408 if m.rows != m.cols {
409 return Err(format!(
410 "Cannot compute trace of non-square matrix: {}x{}",
411 m.rows, m.cols
412 ));
413 }
414 let n = m.rows as usize;
415 let mut sum = 0.0;
416 for i in 0..n {
417 sum += m.data[i * n + i];
418 }
419 Ok(sum)
420}
421
422#[cfg(test)]
423mod tests {
424 use super::*;
425
426 fn mat(data: &[f64], rows: u32, cols: u32) -> MatrixData {
427 let mut aligned = AlignedVec::with_capacity(data.len());
428 for &v in data {
429 aligned.push(v);
430 }
431 MatrixData::from_flat(aligned, rows, cols)
432 }
433
434 fn approx_eq(a: f64, b: f64) -> bool {
435 (a - b).abs() < 1e-10
436 }
437
438 fn mat_approx_eq(a: &MatrixData, b: &MatrixData) -> bool {
439 a.rows == b.rows
440 && a.cols == b.cols
441 && a.data
442 .iter()
443 .zip(b.data.iter())
444 .all(|(x, y)| approx_eq(*x, *y))
445 }
446
447 #[test]
448 fn test_matrix_add_2x2() {
449 let a = mat(&[1.0, 2.0, 3.0, 4.0], 2, 2);
450 let b = mat(&[5.0, 6.0, 7.0, 8.0], 2, 2);
451 let c = matrix_add(&a, &b).unwrap();
452 assert_eq!(c.data.as_slice(), &[6.0, 8.0, 10.0, 12.0]);
453 }
454
455 #[test]
456 fn test_matrix_sub_2x2() {
457 let a = mat(&[5.0, 6.0, 7.0, 8.0], 2, 2);
458 let b = mat(&[1.0, 2.0, 3.0, 4.0], 2, 2);
459 let c = matrix_sub(&a, &b).unwrap();
460 assert_eq!(c.data.as_slice(), &[4.0, 4.0, 4.0, 4.0]);
461 }
462
463 #[test]
464 fn test_matrix_scale() {
465 let a = mat(&[1.0, 2.0, 3.0, 4.0], 2, 2);
466 let c = matrix_scale(&a, 3.0);
467 assert_eq!(c.data.as_slice(), &[3.0, 6.0, 9.0, 12.0]);
468 }
469
470 #[test]
471 fn test_matrix_element_mul() {
472 let a = mat(&[1.0, 2.0, 3.0, 4.0], 2, 2);
473 let b = mat(&[5.0, 6.0, 7.0, 8.0], 2, 2);
474 let c = matrix_element_mul(&a, &b).unwrap();
475 assert_eq!(c.data.as_slice(), &[5.0, 12.0, 21.0, 32.0]);
476 }
477
478 #[test]
479 fn test_matrix_matmul_2x2() {
480 let a = mat(&[1.0, 2.0, 3.0, 4.0], 2, 2);
481 let b = mat(&[5.0, 6.0, 7.0, 8.0], 2, 2);
482 let c = matrix_matmul(&a, &b).unwrap();
483 assert_eq!(c.data.as_slice(), &[19.0, 22.0, 43.0, 50.0]);
484 }
485
486 #[test]
487 fn test_matrix_matmul_3x3() {
488 let a = mat(&[1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0], 3, 3);
489 let b = mat(&[2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], 3, 3);
490 let c = matrix_matmul(&a, &b).unwrap();
491 assert_eq!(c.data.as_slice(), b.data.as_slice());
492 }
493
494 #[test]
495 fn test_matrix_matmul_2x3_3x2() {
496 let a = mat(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 2, 3);
497 let b = mat(&[7.0, 8.0, 9.0, 10.0, 11.0, 12.0], 3, 2);
498 let c = matrix_matmul(&a, &b).unwrap();
499 assert_eq!(c.rows, 2);
500 assert_eq!(c.cols, 2);
501 assert_eq!(c.data.as_slice(), &[58.0, 64.0, 139.0, 154.0]);
504 }
505
506 #[test]
507 fn test_matrix_matvec() {
508 let a = mat(&[1.0, 2.0, 3.0, 4.0], 2, 2);
509 let v = [5.0, 6.0];
510 let result = matrix_matvec(&a, &v).unwrap();
511 assert_eq!(result.as_slice(), &[17.0, 39.0]);
512 }
513
514 #[test]
515 fn test_matrix_transpose() {
516 let a = mat(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 2, 3);
517 let t = matrix_transpose(&a);
518 assert_eq!(t.rows, 3);
519 assert_eq!(t.cols, 2);
520 assert_eq!(t.data.as_slice(), &[1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
521 }
522
523 #[test]
524 fn test_matrix_inverse_2x2() {
525 let a = mat(&[4.0, 7.0, 2.0, 6.0], 2, 2);
526 let inv = matrix_inverse(&a).unwrap();
527 let identity = matrix_matmul(&a, &inv).unwrap();
529 assert!(approx_eq(identity.get(0, 0), 1.0));
530 assert!(approx_eq(identity.get(0, 1), 0.0));
531 assert!(approx_eq(identity.get(1, 0), 0.0));
532 assert!(approx_eq(identity.get(1, 1), 1.0));
533 }
534
535 #[test]
536 fn test_matrix_inverse_3x3() {
537 let a = mat(&[1.0, 2.0, 3.0, 0.0, 1.0, 4.0, 5.0, 6.0, 0.0], 3, 3);
538 let inv = matrix_inverse(&a).unwrap();
539 let identity = matrix_matmul(&a, &inv).unwrap();
540 for i in 0..3u32 {
541 for j in 0..3u32 {
542 let expected = if i == j { 1.0 } else { 0.0 };
543 assert!(
544 approx_eq(identity.get(i, j), expected),
545 "identity[{},{}] = {} (expected {})",
546 i,
547 j,
548 identity.get(i, j),
549 expected
550 );
551 }
552 }
553 }
554
555 #[test]
556 fn test_matrix_inverse_singular() {
557 let a = mat(&[1.0, 2.0, 2.0, 4.0], 2, 2);
558 assert!(matrix_inverse(&a).is_err());
559 }
560
561 #[test]
562 fn test_matrix_determinant_2x2() {
563 let a = mat(&[3.0, 8.0, 4.0, 6.0], 2, 2);
564 let det = matrix_determinant(&a).unwrap();
565 assert!(approx_eq(det, -14.0));
566 }
567
568 #[test]
569 fn test_matrix_determinant_3x3() {
570 let a = mat(&[6.0, 1.0, 1.0, 4.0, -2.0, 5.0, 2.0, 8.0, 7.0], 3, 3);
571 let det = matrix_determinant(&a).unwrap();
572 assert!(approx_eq(det, -306.0));
573 }
574
575 #[test]
576 fn test_matrix_determinant_singular() {
577 let a = mat(&[1.0, 2.0, 2.0, 4.0], 2, 2);
578 let det = matrix_determinant(&a).unwrap();
579 assert!(approx_eq(det, 0.0));
580 }
581
582 #[test]
583 fn test_matrix_trace() {
584 let a = mat(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0], 3, 3);
585 let tr = matrix_trace(&a).unwrap();
586 assert!(approx_eq(tr, 15.0));
587 }
588
589 #[test]
590 fn test_matrix_add_dimension_mismatch() {
591 let a = mat(&[1.0, 2.0, 3.0, 4.0], 2, 2);
592 let b = mat(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 2, 3);
593 assert!(matrix_add(&a, &b).is_err());
594 }
595
596 #[test]
597 fn test_matrix_matmul_dimension_mismatch() {
598 let a = mat(&[1.0, 2.0, 3.0, 4.0], 2, 2);
599 let b = mat(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 3, 2);
600 assert!(matrix_matmul(&a, &b).is_err());
601 }
602
603 #[test]
604 fn test_matrix_add_large_simd() {
605 let n = 20;
607 let data_a: Vec<f64> = (0..n).map(|i| i as f64).collect();
608 let data_b: Vec<f64> = (0..n).map(|i| (i * 2) as f64).collect();
609 let a = mat(&data_a, 4, 5);
610 let b = mat(&data_b, 4, 5);
611 let c = matrix_add(&a, &b).unwrap();
612 for i in 0..n {
613 assert!(approx_eq(c.data[i], data_a[i] + data_b[i]));
614 }
615 }
616
617 #[test]
618 fn test_matrix_matmul_4x4() {
619 let a = mat(
620 &[
621 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0,
622 16.0,
623 ],
624 4,
625 4,
626 );
627 let identity = mat(
628 &[
629 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0,
630 ],
631 4,
632 4,
633 );
634 let c = matrix_matmul(&a, &identity).unwrap();
635 assert!(mat_approx_eq(&c, &a));
636 }
637}