1#![allow(clippy::comparison_chain)]
2use std::cmp;
3use super::GenTensor;
4use crate::tensor_trait::reduction::ReduceTensor;
5use crate::tensor_trait::elemwise::ElemwiseTensorOp;
6use crate::tensor_trait::index_slicing::IndexSlicing;
7use crate::tensor_trait::linalg::LinearAlgbra;
8
9impl<T> LinearAlgbra for GenTensor<T>
10where T: num_traits::Float {
11 type TensorType = GenTensor<T>;
12 type ElementType = T;
13
14 fn norm(&self) -> Self::TensorType {
15 self.mul(self).sum(None, false).sqrt()
17 }
18
19 fn normalize_unit(&self) -> Self::TensorType {
20 let s = self.mul(self).sum(None, false);
21 self.div(&s.sqrt())
22 }
23
24 fn lu(&self) -> Option<[Self::TensorType; 2]> {
25 if self.size().len() != 2 {
28 return None;
29 }
30 if self.size()[0] != self.size()[1] {
31 return None;
32 }
33 let nr = self.size()[0];
34 let mut l = GenTensor::<T>::eye(nr, nr);
35 let mut u = self.clone();
36 for i in 0..nr-1 {
37 let leading = u.get(&[i, i]);
38 for j in i+1..nr {
39 let multiplier = u.get(&[j, i])/leading;
40 l.set(&[j, i], multiplier);
41 for k in i..nr {
42 u.set(&[j, k], u.get(&[j, k]) - u.get(&[i, k])*multiplier);
43 }
44 }
45 }
46
47 Some([l, u])
48 }
49
50 fn lu_solve(&self, b: &Self::TensorType) -> Option<Self::TensorType> {
51 if self.size().len() != 2 {
52 return None;
53 }
54 if self.size()[0] != self.size()[1] {
55 return None;
56 }
57 let n = self.size()[0];
58 if b.size().len() != 2 || b.size()[0] != n || b.size()[1] != 1 {
59 return None;
60 }
61
62 match self.lu() {
63 Some([l, u]) => {
64
65 let mut y = GenTensor::<T>::zeros(&[n, 1]);
66 for i in 0..n {
67 y.set(&[i, 0],
68 (b.get(&[i, 0]) - y.dot(&l.get_row(i))) / l.get(&[i, i]));
69 }
70 let mut x = GenTensor::<T>::zeros(&[n, 1]);
71 for i in 0..n {
72 x.set(&[n-i-1, 0],
73 (y.get(&[n-i-1, 0]) - x.dot(&u.get_row(n-i-1))) / u.get(&[n-i-1, n-i-1]));
74 }
75
76 Some(x)
77 },
78 None => {None}
79 }
80 }
81
82 fn qr(&self) -> Option<[Self::TensorType; 2]> {
83 if self.size().len() != 2 {
86 return None;
87 }
88 let m = self.size()[self.size().len()-2];
92 let n = self.size()[self.size().len()-1];
93
94 let mut q = GenTensor::<T>::zeros(&[m, cmp::min(m, n)]);
95 let mut r = GenTensor::<T>::zeros(&[n, n]);
96 for i in 0..n {
97 let a = self.get_column(i);
98 let mut u = a.clone();
99 for j in 0..i {
100 u = u.sub(&a.proj(&q.get_column(j)));
101 }
102 if i < cmp::min(m, n) {
103 let e = u.normalize_unit();
104 q.set_column(&e, i);
105 }
106 for j in 0..cmp::min(i+1, cmp::min(m, n)) {
107 if j <= m {
108 r.set(&[j, i], a.dot(&q.get_column(j)));
109 }
110 }
111 }
112
113 Some([q, r])
114 }
115
116 fn eigen(&self) -> Option<[Self::TensorType; 2]> {
117 if self.size().len() != 2 {
119 return None;
120 }
121 if self.size()[0] != self.size()[1] {
122 return None;
123 }
124 let n = self.size()[0];
125 let mut cap_a = self.clone();
126
127 let tolerance: f64 = 1e-9;
128 let iter_max = 100;
129
130 let mut evec = GenTensor::<T>::zeros(&[n, n]);
131 let mut eval = GenTensor::<T>::zeros(&[n, 1]);
132 for i in 0..n {
133 let mut x = GenTensor::<T>::fill(T::one(), &[n, 1]);
134 let mut iter_counter = 0;
135 loop {
136 if iter_counter > iter_max {
137 break;
138 }
139 let x1 = x.clone();
140 x = cap_a.matmul(&x).normalize_unit();
141 if x1.sub(&x).norm().get_scale() < T::from(tolerance).unwrap() {
142 break;
143 }
144 iter_counter += 1;
145 }
146 let lambda = x.permute(&[1, 0]).matmul(self).matmul(&x).squeeze(None);
148
149 evec.set_column(&x, i);
150 eval.set(&[i, 0], lambda.get_scale());
151
152 cap_a = cap_a.sub(&GenTensor::<T>::eye(n, n).mul(&lambda));
153
154 }
156
157 Some([evec, eval])
158 }
159 fn cholesky(&self) -> Option<Self::TensorType> {
160 if self.size().len() != 2 {
162 return None;
163 }
164 if self.size()[0] != self.size()[1] {
165 return None;
166 }
167 let n = self.size()[0];
168
169 let mut ret = GenTensor::<T>::zeros(&[n, n]);
170 for i in 0..n {
171 for j in 0..i {
172 ret.set(&[j, i],
173 (self.get(&[j, i]) -
174 ret.get_column(j).dot(&ret.get_column(i)))/ret.get(&[j, j]))
175 }
176 ret.set(&[i, i],
177 T::sqrt(self.get(&[i,i]) - ret.get_column(i).dot(&ret.get_column(i))));
178 }
179 Some(ret)
180 }
181
182 fn det(&self) -> Option<Self::TensorType> {
183 if self.size().len() != 2 {
184 return None
185 }
186 if self.size()[0] != self.size()[1] {
187 return None
188 }
189 let n = self.size()[0];
190 let mut sign_pos = true;
191 let mut self_data = self.clone();
192
193 for i in 0..n {
194 if self_data.get(&[i, i]) == T::zero() {
195 let mut row_counter = 1;
196
197 loop {
198 if i+row_counter == n {
199 return Some(GenTensor::zeros(&[1])); }
201 if self_data.get(&[i+row_counter, i]) == T::zero() {
202 row_counter += 1;
203 } else {
204 sign_pos ^= true;
205 let tmp_row = self.get_row(i);
206 self_data.set_row(&self_data.get_row(i+row_counter), i);
207 self_data.set_row(&tmp_row, i+row_counter);
208 break;
209 }
210 }
211 }
212 }
213
214 if let Some(v) = self_data.lu() {
215 let [_l, u] = v;
216 let mut ret = u.get_diag().prod(None, false).get(&[0]);
217 if !sign_pos {
218 ret = ret.neg();
219 }
220 let ret = GenTensor::new_raw(&[ret], &[1]);
221 Some(ret)
222 } else {
223 None
224 }
225 }
226
227 fn svd(&self) -> Option<[Self::TensorType; 3]> {
228 let m = self.size()[self.size().len()-2];
231 let n = self.size()[self.size().len()-1];
232
233 let cap_a: GenTensor<T>;
234 if m > n {
235 cap_a = self.permute(&[1, 0]).matmul(self);
236 } else if m < n {
237 cap_a = self.matmul(&self.permute(&[1, 0]));
238 } else {
239 cap_a = self.clone();
240 }
241
242 let tolerance: f64 = 1e-9;
243 let iter_max = 100;
244
245 let mut s: GenTensor<T>;
246 let mut v = GenTensor::<T>::eye(n, n);
247 let mut iter_counter = 0;
248 loop {
249
250 let v1 = v.clone();
251 let [qv, r] = cap_a.matmul(&v).qr().unwrap();
252 v = qv;
253
254 if v1.sub(&v).norm().get_scale() < T::from(tolerance).unwrap() {
255 s = r;
256 break;
257 }
258
259 if iter_counter > iter_max {
260 s = r;
261 break;
262 }
263
264 iter_counter += 1;
265 }
267
268 let u: GenTensor<T>;
269 if m > n {
270 s = s.sqrt();
271 v = v.permute(&[1, 0]);
272 let invs = GenTensor::<T>::ones(&[n]).div(&s.get_diag());
273 u = self.matmul(&v.permute(&[1, 0])).matmul(&invs);
274 } else if m < n {
275 s = s.sqrt();
276 u = v.permute(&[1, 0]);
277 let invs = GenTensor::<T>::ones(&[n]).div(&s.get_diag());
278 v = invs.matmul(&u.permute(&[1, 0])).matmul(self);
279 } else {
280 u = v.permute(&[1, 0]);
281 }
282
283 Some([u, s, v])
284 }
285
286 fn inv(&self) -> Option<Self::TensorType> {
287 if self.size().len() != 2 {
288 return None;
289 }
290 if self.size()[self.size().len()-2] != self.size()[self.size().len()-1] {
291 return None;
292 }
293
294 let mut ret = GenTensor::zeros_like(self);
295 for i in 0..self.numel() {
296 let index = self.index2dimpos(i);
297 let minor = self.index_exclude(0, &GenTensor::new_raw(&[T::from(index[0]).unwrap()], &[1]))
298 .index_exclude(1, &GenTensor::new_raw(&[T::from(index[1]).unwrap()], &[1]));
299 let minor = minor.det().unwrap();
300
301 if (index[0] + index[1]) %2 == 0 {
302 ret.set(&index, minor.get_scale());
303 } else {
304 ret.set(&index, minor.get_scale().neg());
305 }
306 }
307
308 let ret = ret.t();
309
310 let det = self.det()?;
311
312 Some(ret.div(&det))
313 }
314
315 fn pinv(&self) -> Self::TensorType {
316 let [u, s, v] = self.svd().unwrap();
317 let m = s.size()[self.size().len()-2];
318 let n = s.size()[self.size().len()-1];
319 let mut diag_v = Vec::new();
320 for i in 0..cmp::min(m, n) {
321 if s.get(&[i, i]) != T::zero() {
322 diag_v.push(s.get(&[i, i]));
323 } else {
324 break;
325 }
326 }
327 let mut s = GenTensor::zeros(&[diag_v.len(), diag_v.len()]);
328 s.set_diag(&GenTensor::new_raw(&diag_v, &[diag_v.len()]));
329 v.matmul(&s).matmul(&u.t())
330 }
331
332 fn tr(&self) -> Self::TensorType {
333 self.get_diag().sum(None, false)
334 }
335}
336
337
338#[cfg(test)]
339mod tests {
340 use super::*;
341
342 #[test]
343 fn normalize_unit() {
344 let m = GenTensor::<f64>::new_raw(&[1., 1., 0., 1., 0., 1., 0., 1., 1.], &[3,3]);
345 let nm = m.normalize_unit();
346 assert_eq!(nm, GenTensor::<f64>::new_raw(&[0.4082482904638631, 0.4082482904638631, 0.,
347 0.4082482904638631, 0., 0.4082482904638631,
348 0., 0.4082482904638631, 0.4082482904638631, ],
349 &[3,3]));
350 }
351
352 #[test]
353 fn lu() {
354 let m = GenTensor::<f64>::new_raw(&[1., 1., 1., 4., 3., -1., 3., 5., 3.], &[3,3]);
355 let [l, u] = m.lu().unwrap();
356 let el = GenTensor::<f64>::new_raw(&[1., 0., 0., 4., 1., 0., 3., -2., 1.], &[3,3]);
357 let eu = GenTensor::<f64>::new_raw(&[1., 1., 1., 0., -1., -5., 0., 0., -10.], &[3,3]);
358 assert_eq!(l, el);
359 assert_eq!(u, eu);
360 }
361
362 #[test]
363 fn lu_solve() {
364 let cap_a = GenTensor::<f64>::new_raw(&[7., -2., 1., 14., -7., -3., -7., 11., 18.], &[3,3]);
365 let b = GenTensor::<f64>::new_raw(&[12., 17., 5.], &[3,1]);
366 let x = cap_a.lu_solve(&b).unwrap();
367 let ex = GenTensor::<f64>::new_raw(&[3., 4., -1.,], &[3,1]);
368 assert_eq!(x, ex);
369 }
370
371 #[test]
372 fn det() {
373 let m = GenTensor::<f64>::new_raw(&[1., 1., 1., 4., 3., -1., 3., 5., 3.], &[3,3]);
374 let r = m.det().unwrap().get_scale();
375 assert_eq!(r, 10.);
376
377 let m = GenTensor::<f64>::new_raw(&[0., -2., 1., 1.], &[2,2]);
378 let r = m.det().unwrap().get_scale();
379 assert_eq!(r, 2.);
380 }
381
382 #[test]
383 fn qr() {
384 let m = GenTensor::<f64>::new_raw(&[1., 1., 0., 1., 0., 1., 0., 1., 1.], &[3,3]);
385 let [q, r] = m.qr().unwrap();
386 let eq = GenTensor::<f64>::new_raw(&[0.7071067811865475, 0.40824829046386313, -0.5773502691896257,
387 0.7071067811865475, -0.40824829046386296, 0.577350269189626,
388 0., 0.8164965809277261, 0.5773502691896256, ], &[3,3]);
389 let er = GenTensor::<f64>::new_raw(&[1.414213562373095, 0.7071067811865475, 0.7071067811865475, 0., 1.2247448713915894, 0.4082482904638632, 0., 0., 1.1547005383792515, ], &[3,3]);
390 assert_eq!(q, eq);
391 assert_eq!(r, er);
392 }
393
394 #[test]
395 fn cholesky() {
396 let m = GenTensor::<f64>::new_raw(&[4., 12., -16., 12., 37., -43., -16., -43., 98.], &[3,3]);
397 let c = m.cholesky().unwrap();
398 let ec = GenTensor::<f64>::new_raw(&[2., 6., -8., 0., 1., 5., 0., 0., 3.], &[3,3]);
399 assert_eq!(c, ec);
400 }
401
402 #[test]
403 fn eigen() {
404 let m = GenTensor::<f64>::new_raw(&[4., 3., -2., -3.], &[2,2]);
405 let el = GenTensor::<f64>::new_raw(&[3., -2.], &[2,1]);
407 let [_evec, eval] = m.eigen().unwrap();
408 assert!(eval.sub(&el).norm().get_scale() < 1e-6);
411 }
412
413 #[test]
414 fn svd() {
415 let m = GenTensor::<f64>::new_raw(&[4., 12., -16., 12., 37., -43., -16., -43., 98.], &[3,3]);
416 let [_u, s, _v] = m.svd().unwrap();
417 println!("{:?}, {:?}, {:?}", _u, s, _v);
418 let es = GenTensor::<f64>::new_raw(&[123.47723179013161, 15.503963229407585, 0.018804980460810704], &[3]);
419 assert!(es.sub(&s.get_diag()).norm().get_scale() < 1e-6);
420
421 println!("{:?}", _u.matmul(&s).matmul(&_v.t()));
422 println!("{:?}", _u.matmul(&_u.t()));
423 println!("{:?}", _v.matmul(&_v.t()));
424 }
425
426 #[test]
427 fn inv() {
428 let m = GenTensor::<f64>::new_raw(&[3., 0., 2., 2., 0., -2., 0., 1., 1.], &[3,3]);
429 let inv_m = m.inv().unwrap();
430 let e_inv = GenTensor::<f64>::new_raw(&[0.2, 0.2, 0., -0.2, 0.3, 1., 0.2, -0.3, 0.], &[3,3]);
431 assert_eq!(inv_m, e_inv);
432 }
433
434 #[test]
435 fn pinv() {
436 let m = GenTensor::<f64>::new_raw(&[2., -1., 1., 4., 3., -2., 4., 5., -2.], &[3, 3]);
437 let pinv_m = m.pinv();
438 println!("{:?}", pinv_m.matmul(&m));
439 }
440}