1use super::super::core::Tensor;
5use crate::error::{RusTorchError, RusTorchResult};
6use ndarray::{ArrayD, IxDyn};
7use num_traits::Float;
8
9impl<
10 T: Float
11 + 'static
12 + ndarray::ScalarOperand
13 + num_traits::FromPrimitive
14 + Clone
15 + std::fmt::Debug,
16 > Tensor<T>
17{
18 pub fn norm(&self) -> T {
24 let sum_squares: T = self
25 .data
26 .iter()
27 .map(|&x| x * x)
28 .fold(T::zero(), |acc, x| acc + x);
29 sum_squares.sqrt()
30 }
31
32 pub fn norm_p(&self, p: T) -> T {
35 if p == T::from(2.0).unwrap() {
36 return self.norm();
37 }
38
39 if p == T::one() {
40 return self
42 .data
43 .iter()
44 .map(|&x| x.abs())
45 .fold(T::zero(), |acc, x| acc + x);
46 }
47
48 if p == T::infinity() {
49 return self
51 .data
52 .iter()
53 .map(|&x| x.abs())
54 .fold(T::zero(), |acc, x| if x > acc { x } else { acc });
55 }
56
57 let sum_powers: T = self
59 .data
60 .iter()
61 .map(|&x| x.abs().powf(p))
62 .fold(T::zero(), |acc, x| acc + x);
63
64 sum_powers.powf(T::one() / p)
65 }
66
67 pub fn frobenius_norm(&self) -> T {
70 self.norm()
71 }
72
73 pub fn nuclear_norm(&self) -> RusTorchResult<T> {
76 if self.shape().len() != 2 {
77 return Err(RusTorchError::InvalidOperation {
78 operation: "nuclear_norm".to_string(),
79 message: "Nuclear norm is only defined for 2D matrices".to_string(),
80 });
81 }
82
83 let (_, s, _) = self.svd()?;
84 Ok(s.data.iter().fold(T::zero(), |acc, &x| acc + x))
85 }
86
87 pub fn svd(&self) -> RusTorchResult<(Tensor<T>, Tensor<T>, Tensor<T>)> {
93 if self.shape().len() != 2 {
94 return Err(RusTorchError::InvalidOperation {
95 operation: "svd".to_string(),
96 message: "SVD is only defined for 2D matrices".to_string(),
97 });
98 }
99
100 let [m, n] = [self.shape()[0], self.shape()[1]];
101 let min_dim = std::cmp::min(m, n);
102
103 let mut a = self.clone();
106
107 let mut u_vecs = Vec::new();
110 let mut singular_values = Vec::new();
111
112 for _ in 0..min_dim {
113 let (u, s, v) = self.power_iteration_svd(&a)?;
115
116 singular_values.push(s);
117 u_vecs.push(u);
118
119 let outer_product = self.outer_product(&u_vecs.last().unwrap(), &v)?;
121 a = a.sub(&outer_product.mul_scalar(s))?;
122
123 if s < T::from(1e-10).unwrap() {
125 break;
126 }
127 }
128
129 let u = self.construct_orthogonal_matrix(&u_vecs, m, u_vecs.len())?;
131 let s = Tensor::from_vec(singular_values, vec![u_vecs.len()]);
132
133 let mut v_data = vec![T::zero(); n * n];
136 for i in 0..n {
137 v_data[i * n + i] = T::one();
138 }
139 let v = Tensor::from_vec(v_data, vec![n, n]);
140
141 Ok((u, s, v))
142 }
143
144 fn power_iteration_svd(&self, matrix: &Tensor<T>) -> RusTorchResult<(Tensor<T>, T, Tensor<T>)> {
145 let [m, n] = [matrix.shape()[0], matrix.shape()[1]];
146
147 let v: Vec<T> = (0..n)
149 .map(|i| T::from(i as f64 % 7.0 + 1.0).unwrap())
150 .collect();
151 let mut v_tensor = Tensor::from_vec(v, vec![n]);
152 v_tensor = v_tensor.normalize()?;
153
154 for _ in 0..100 {
156 let u_tensor = matrix.matmul(&v_tensor.unsqueeze(1)?)?.squeeze();
159 let u_norm = u_tensor.norm();
160 let u_normalized = u_tensor.div_scalar(u_norm);
161
162 let at = matrix.transpose()?;
164 let v_new = at.matmul(&u_normalized.unsqueeze(1)?)?.squeeze();
165 let v_norm = v_new.norm();
166 let v_normalized = v_new.div_scalar(v_norm);
167
168 let diff = v_normalized.sub(&v_tensor)?.norm();
170 if diff < T::from(1e-8).unwrap() {
171 break;
172 }
173
174 v_tensor = v_normalized;
175 }
176
177 let av = matrix.matmul(&v_tensor.unsqueeze(1)?)?.squeeze();
179 let sigma = av.norm();
180 let u = av.div_scalar(sigma);
181
182 Ok((u, sigma, v_tensor))
183 }
184
185 fn construct_orthogonal_matrix(
186 &self,
187 vectors: &[Tensor<T>],
188 rows: usize,
189 cols: usize,
190 ) -> RusTorchResult<Tensor<T>> {
191 let mut data = vec![T::zero(); rows * cols];
192
193 for (col, vec) in vectors.iter().enumerate() {
194 if col >= cols {
195 break;
196 }
197 for (row, &val) in vec.data.iter().enumerate() {
198 if row >= rows {
199 break;
200 }
201 data[row * cols + col] = val;
202 }
203 }
204
205 Ok(Tensor::from_vec(data, vec![rows, cols]))
206 }
207
208 pub fn eigh(&self) -> RusTorchResult<(Tensor<T>, Tensor<T>)> {
211 if self.shape().len() != 2 || self.shape()[0] != self.shape()[1] {
212 return Err(RusTorchError::InvalidOperation {
213 operation: "eigh".to_string(),
214 message: "Eigenvalue decomposition requires square matrices".to_string(),
215 });
216 }
217
218 let transpose = self.transpose()?;
220 let diff = self.sub(&transpose)?.norm();
221 if diff > T::from(1e-10).unwrap() {
222 return Err(RusTorchError::InvalidOperation {
223 operation: "eigh".to_string(),
224 message: "Matrix must be symmetric for eigh".to_string(),
225 });
226 }
227
228 let n = self.shape()[0];
229
230 let mut eigenvalues = Vec::new();
233 let mut eigenvectors = Vec::new();
234 let mut a = self.clone();
235
236 for _ in 0..std::cmp::min(n, 3) {
237 let (eigval, eigvec) = self.power_iteration_eigen(&a)?;
239 eigenvalues.push(eigval);
240 eigenvectors.push(eigvec);
241
242 let outer =
244 self.outer_product(&eigenvectors.last().unwrap(), &eigenvectors.last().unwrap())?;
245 a = a.sub(&outer.mul_scalar(eigval))?;
246 }
247
248 let eigenvalue_len = eigenvalues.len();
249 let eigenvalue_tensor = Tensor::from_vec(eigenvalues, vec![eigenvalue_len]);
250 let eigenvector_matrix =
251 self.construct_orthogonal_matrix(&eigenvectors, n, eigenvectors.len())?;
252
253 Ok((eigenvalue_tensor, eigenvector_matrix))
254 }
255
256 fn power_iteration_eigen(&self, matrix: &Tensor<T>) -> RusTorchResult<(T, Tensor<T>)> {
257 let n = matrix.shape()[0];
258
259 let v: Vec<T> = (0..n)
261 .map(|i| T::from((i * 3 + 1) as f64).unwrap())
262 .collect();
263 let mut v_tensor = Tensor::from_vec(v, vec![n]);
264 v_tensor = v_tensor.normalize()?;
265
266 let mut eigenvalue = T::zero();
267
268 for _ in 0..100 {
270 let v_new = matrix.matmul(&v_tensor.unsqueeze(1)?)?.squeeze();
272
273 let vt_av = v_tensor.dot(&v_new);
275 let vt_v = v_tensor.dot(&v_tensor);
276 let new_eigenvalue = vt_av / vt_v;
277
278 v_tensor = v_new.normalize()?;
280
281 if (new_eigenvalue - eigenvalue).abs() < T::from(1e-10).unwrap() {
283 eigenvalue = new_eigenvalue;
284 break;
285 }
286 eigenvalue = new_eigenvalue;
287 }
288
289 Ok((eigenvalue, v_tensor))
290 }
291
292 pub fn qr(&self) -> RusTorchResult<(Tensor<T>, Tensor<T>)> {
295 if self.shape().len() != 2 {
296 return Err(RusTorchError::InvalidOperation {
297 operation: "qr".to_string(),
298 message: "QR decomposition is only defined for 2D matrices".to_string(),
299 });
300 }
301
302 let [m, n] = [self.shape()[0], self.shape()[1]];
303
304 let mut q_vectors: Vec<Tensor<T>> = Vec::new();
306 let mut r_data = vec![T::zero(); n * n];
307
308 for j in 0..n {
309 let mut col_data = Vec::new();
311 for i in 0..m {
312 if let Some(val) = self.data.get(IxDyn(&[i * n + j])) {
313 col_data.push(*val);
314 } else {
315 col_data.push(T::zero());
316 }
317 }
318 let mut q_j = Tensor::from_vec(col_data, vec![m]);
319
320 for (k, q_k) in q_vectors.iter().enumerate() {
322 let r_kj = q_k.dot(&q_j);
323 r_data[k * n + j] = r_kj;
324 let proj = q_k.mul_scalar(r_kj);
325 q_j = q_j.sub(&proj)?;
326 }
327
328 let r_jj = q_j.norm();
330 r_data[j * n + j] = r_jj;
331
332 if r_jj > T::from(1e-10).unwrap() {
333 q_j = q_j.div_scalar(r_jj);
334 }
335
336 q_vectors.push(q_j);
337 }
338
339 let q = self.construct_orthogonal_matrix(&q_vectors, m, n)?;
341 let r = Tensor::from_vec(r_data, vec![n, n]);
342
343 Ok((q, r))
344 }
345
346 pub fn cholesky(&self) -> RusTorchResult<Tensor<T>> {
349 if self.shape().len() != 2 || self.shape()[0] != self.shape()[1] {
350 return Err(RusTorchError::InvalidOperation {
351 operation: "cholesky".to_string(),
352 message: "Cholesky decomposition requires square matrices".to_string(),
353 });
354 }
355
356 let n = self.shape()[0];
357 let mut l_data = vec![T::zero(); n * n];
358
359 for i in 0..n {
361 for j in 0..=i {
362 if i == j {
363 let mut sum = T::zero();
365 for k in 0..j {
366 let l_jk = l_data[j * n + k];
367 sum = sum + l_jk * l_jk;
368 }
369
370 let a_jj = self
371 .data
372 .get(IxDyn(&[j * n + j]))
373 .copied()
374 .unwrap_or(T::zero());
375 let l_jj_squared = a_jj - sum;
376
377 if l_jj_squared <= T::zero() {
378 return Err(RusTorchError::InvalidOperation {
379 operation: "cholesky".to_string(),
380 message: "Matrix is not positive definite".to_string(),
381 });
382 }
383
384 l_data[j * n + j] = l_jj_squared.sqrt();
385 } else {
386 let mut sum = T::zero();
388 for k in 0..j {
389 let l_ik = l_data[i * n + k];
390 let l_jk = l_data[j * n + k];
391 sum = sum + l_ik * l_jk;
392 }
393
394 let a_ij = self
395 .data
396 .get(IxDyn(&[i * n + j]))
397 .copied()
398 .unwrap_or(T::zero());
399 let l_jj = l_data[j * n + j];
400
401 if l_jj == T::zero() {
402 return Err(RusTorchError::InvalidOperation {
403 operation: "cholesky".to_string(),
404 message: "Division by zero in Cholesky decomposition".to_string(),
405 });
406 }
407
408 l_data[i * n + j] = (a_ij - sum) / l_jj;
409 }
410 }
411 }
412
413 Ok(Tensor::from_vec(l_data, vec![n, n]))
414 }
415
416 pub fn inverse(&self) -> RusTorchResult<Tensor<T>> {
422 if self.shape().len() != 2 || self.shape()[0] != self.shape()[1] {
423 return Err(RusTorchError::InvalidOperation {
424 operation: "inverse".to_string(),
425 message: "Matrix inverse requires square matrices".to_string(),
426 });
427 }
428
429 let n = self.shape()[0];
430
431 let det = self.det()?;
433 if det.abs() < T::from(1e-12).unwrap() {
434 return Err(RusTorchError::InvalidOperation {
435 operation: "inverse".to_string(),
436 message: "Matrix is singular and cannot be inverted".to_string(),
437 });
438 }
439
440 let mut augmented = vec![T::zero(); n * 2 * n];
442
443 for i in 0..n {
445 for j in 0..n {
446 let val = self
447 .data
448 .get(IxDyn(&[i * n + j]))
449 .copied()
450 .unwrap_or(T::zero());
451 augmented[i * 2 * n + j] = val;
452 }
453 augmented[i * 2 * n + n + i] = T::one();
455 }
456
457 for i in 0..n {
459 let mut max_row = i;
461 for k in (i + 1)..n {
462 if augmented[k * 2 * n + i].abs() > augmented[max_row * 2 * n + i].abs() {
463 max_row = k;
464 }
465 }
466
467 if max_row != i {
469 for j in 0..(2 * n) {
470 let temp = augmented[i * 2 * n + j];
471 augmented[i * 2 * n + j] = augmented[max_row * 2 * n + j];
472 augmented[max_row * 2 * n + j] = temp;
473 }
474 }
475
476 let pivot = augmented[i * 2 * n + i];
478 if pivot.abs() < T::from(1e-12).unwrap() {
479 return Err(RusTorchError::InvalidOperation {
480 operation: "inverse".to_string(),
481 message: "Matrix is singular".to_string(),
482 });
483 }
484
485 for j in 0..(2 * n) {
486 augmented[i * 2 * n + j] = augmented[i * 2 * n + j] / pivot;
487 }
488
489 for k in 0..n {
491 if k != i {
492 let factor = augmented[k * 2 * n + i];
493 for j in 0..(2 * n) {
494 augmented[k * 2 * n + j] =
495 augmented[k * 2 * n + j] - factor * augmented[i * 2 * n + j];
496 }
497 }
498 }
499 }
500
501 let mut inverse_data = vec![T::zero(); n * n];
503 for i in 0..n {
504 for j in 0..n {
505 inverse_data[i * n + j] = augmented[i * 2 * n + n + j];
506 }
507 }
508
509 Ok(Tensor::from_vec(inverse_data, vec![n, n]))
510 }
511
512 pub fn pinv(&self) -> RusTorchResult<Tensor<T>> {
515 if self.shape().len() != 2 {
516 return Err(RusTorchError::InvalidOperation {
517 operation: "pinv".to_string(),
518 message: "Pseudo-inverse is only defined for 2D matrices".to_string(),
519 });
520 }
521
522 let [m, n] = [self.shape()[0], self.shape()[1]];
523
524 if m >= n {
525 let at = self.transpose()?;
527 let ata = at.matmul(self)?;
528 let ata_inv = ata.inverse()?;
529 ata_inv.matmul(&at)
530 } else {
531 let at = self.transpose()?;
533 let aat = self.matmul(&at)?;
534 let aat_inv = aat.inverse()?;
535 at.matmul(&aat_inv)
536 }
537 }
538
539 fn normalize(&self) -> RusTorchResult<Self> {
543 let norm = self.norm();
544 if norm == T::zero() {
545 Ok(self.clone())
546 } else {
547 Ok(self.div_scalar(norm))
548 }
549 }
550
551 fn dot(&self, other: &Self) -> T {
552 self.data
553 .iter()
554 .zip(other.data.iter())
555 .map(|(&a, &b)| a * b)
556 .fold(T::zero(), |acc, x| acc + x)
557 }
558
559 fn outer_product(&self, u: &Self, v: &Self) -> RusTorchResult<Self> {
560 if u.numel() != self.shape()[0] || v.numel() != self.shape()[1] {
561 return Err(RusTorchError::InvalidOperation {
562 operation: "outer_product".to_string(),
563 message: "Vector dimensions don't match matrix dimensions".to_string(),
564 });
565 }
566
567 let m = u.numel();
568 let n = v.numel();
569 let mut result = vec![T::zero(); m * n];
570
571 for i in 0..m {
572 for j in 0..n {
573 let u_val = u.data.get(IxDyn(&[i])).copied().unwrap_or(T::zero());
574 let v_val = v.data.get(IxDyn(&[j])).copied().unwrap_or(T::zero());
575 result[i * n + j] = u_val * v_val;
576 }
577 }
578
579 Ok(Tensor::from_vec(result, vec![m, n]))
580 }
581
582 fn linalg_mul_scalar(&self, scalar: T) -> RusTorchResult<Self> {
586 let result_data: Vec<T> = self.data.iter().map(|&x| x * scalar).collect();
587 Ok(Tensor::from_vec(result_data, self.shape().to_vec()))
588 }
589
590 fn linalg_div_scalar(&self, scalar: T) -> RusTorchResult<Self> {
591 if scalar == T::zero() {
592 return Err(RusTorchError::InvalidOperation {
593 operation: "div_scalar".to_string(),
594 message: "Division by zero".to_string(),
595 });
596 }
597 let result_data: Vec<T> = self.data.iter().map(|&x| x / scalar).collect();
598 Ok(Tensor::from_vec(result_data, self.shape().to_vec()))
599 }
600}
601
602#[cfg(test)]
603mod tests {
604 use super::*;
605
606 #[test]
607 fn test_norms() {
608 let tensor = Tensor::from_vec(vec![3.0, 4.0], vec![2]);
609
610 let l2_norm = tensor.norm();
611 let l1_norm = tensor.norm_p(1.0);
612 let frobenius_norm = tensor.frobenius_norm();
613
614 assert!((l2_norm - 5.0).abs() < 1e-10); assert!((l1_norm - 7.0).abs() < 1e-10); assert!((frobenius_norm - 5.0).abs() < 1e-10);
617 }
618
619 #[test]
620 #[cfg(feature = "linalg")]
621 fn test_qr_decomposition() {
622 let matrix = Tensor::from_vec(vec![1.0, 1.0, 0.0, 1.0], vec![2, 2]);
623 let (q, r) = matrix.qr().unwrap();
624
625 assert_eq!(q.shape(), &[2, 2]);
626 assert_eq!(r.shape(), &[2, 2]);
627
628 let qt = q.transpose().unwrap();
630 let qtq = qt.matmul(&q).unwrap();
631
632 let identity = Tensor::from_vec(vec![1.0, 0.0, 0.0, 1.0], vec![2, 2]);
634 let diff = qtq.sub(&identity).unwrap().norm();
635 assert!(diff < 1e-10);
636 }
637
638 #[test]
639 #[cfg(feature = "linalg")]
640 fn test_cholesky_decomposition() {
641 let matrix = Tensor::from_vec(vec![4.0, 2.0, 2.0, 2.0], vec![2, 2]); let l = matrix.cholesky().unwrap();
645 assert_eq!(l.shape(), &[2, 2]);
646
647 let lt = l.transpose().unwrap();
649 let reconstructed = l.matmul(<).unwrap();
650 let diff = matrix.sub(&reconstructed).unwrap().norm();
651 assert!(diff < 1e-10);
652 }
653
654 #[test]
655 #[cfg(feature = "linalg")]
656 fn test_matrix_inverse() {
657 let matrix = Tensor::from_vec(vec![4.0, 2.0, 1.0, 3.0], vec![2, 2]);
658 let inverse = matrix.inverse().unwrap();
659
660 let product = matrix.matmul(&inverse).unwrap();
662 let identity = Tensor::from_vec(vec![1.0, 0.0, 0.0, 1.0], vec![2, 2]);
663 let diff = product.sub(&identity).unwrap().norm();
664 assert!(diff < 1e-10);
665 }
666
667 #[test]
668 #[cfg(feature = "linalg")]
669 fn test_pinv() {
670 let matrix = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]);
671 let pinv = matrix.pinv().unwrap();
672
673 assert_eq!(pinv.shape(), &[3, 2]);
674
675 let product1 = matrix.matmul(&pinv).unwrap();
677 let product2 = product1.matmul(&matrix).unwrap();
678 let diff = matrix.sub(&product2).unwrap().norm();
679 assert!(diff < 1e-8);
680 }
681
682 #[test]
683 fn test_svd_basic() {
684 let matrix = Tensor::from_vec(vec![3.0, 1.0, 1.0, 3.0], vec![2, 2]);
685 let (u, s, v) = matrix.svd().unwrap();
686
687 assert_eq!(u.shape()[0], 2);
688 assert_eq!(v.shape()[1], 2);
689 assert!(s.numel() > 0);
690
691 let s_data = s.as_slice().unwrap();
693 for i in 1..s_data.len() {
694 assert!(s_data[i - 1] >= s_data[i]);
695 assert!(s_data[i] >= 0.0);
696 }
697 }
698
699 #[test]
700 fn test_eigenvalue_decomposition() {
701 let matrix = Tensor::from_vec(vec![2.0, 1.0, 1.0, 2.0], vec![2, 2]);
703 let (eigenvalues, eigenvectors) = matrix.eigh().unwrap();
704
705 assert!(eigenvalues.numel() > 0);
706 assert_eq!(eigenvectors.shape()[0], 2);
707
708 let eig_data = eigenvalues.as_slice().unwrap();
710 for &val in eig_data {
711 assert!(val.is_finite());
712 }
713 }
714}