tensor_rs/tensor_impl/gen_tensor/
linalg.rs

1#![allow(clippy::comparison_chain)]
2use std::cmp;
3use super::GenTensor;
4use crate::tensor_trait::reduction::ReduceTensor;
5use crate::tensor_trait::elemwise::ElemwiseTensorOp;
6use crate::tensor_trait::index_slicing::IndexSlicing;
7use crate::tensor_trait::linalg::LinearAlgbra;
8
9impl<T> LinearAlgbra for GenTensor<T>
10where T: num_traits::Float {
11    type TensorType = GenTensor<T>;
12    type ElementType = T;
13
14    fn norm(&self) -> Self::TensorType {
15        // TODO: support 'fro', 'nuc', 'inf', '-inf'...
16        self.mul(self).sum(None, false).sqrt()
17    }
18
19    fn normalize_unit(&self) -> Self::TensorType {
20        let s = self.mul(self).sum(None, false);
21        self.div(&s.sqrt())
22    }
23
24    fn lu(&self) -> Option<[Self::TensorType; 2]> {
25        // lu is for square matrix only.
26        // TODO; handle the batched/3d case.
27        if self.size().len() != 2 {
28            return None;
29        }
30        if self.size()[0] != self.size()[1] {
31            return None;
32        }
33        let nr = self.size()[0];
34        let mut l = GenTensor::<T>::eye(nr, nr);
35        let mut u = self.clone();
36        for i in 0..nr-1 {
37            let leading = u.get(&[i, i]);
38            for j in i+1..nr {
39                let multiplier = u.get(&[j, i])/leading;
40                l.set(&[j, i], multiplier);
41                for k in i..nr {
42                    u.set(&[j, k], u.get(&[j, k]) - u.get(&[i, k])*multiplier);
43                }
44            }
45        }
46
47        Some([l, u])
48    }
49
50    fn lu_solve(&self, b: &Self::TensorType) -> Option<Self::TensorType> {
51        if self.size().len() != 2 {
52            return None;
53        }
54        if self.size()[0] != self.size()[1] {
55            return None;
56        }
57        let n = self.size()[0];
58        if b.size().len() != 2 || b.size()[0] != n || b.size()[1] != 1 {
59            return None;
60        }
61        
62        match self.lu() {
63            Some([l, u]) => {
64         
65                let mut y = GenTensor::<T>::zeros(&[n, 1]);
66                for i in 0..n {
67                    y.set(&[i, 0],
68                          (b.get(&[i, 0]) - y.dot(&l.get_row(i))) / l.get(&[i, i]));
69                }
70                let mut x = GenTensor::<T>::zeros(&[n, 1]);
71                for i in 0..n {
72                    x.set(&[n-i-1, 0],
73                          (y.get(&[n-i-1, 0]) - x.dot(&u.get_row(n-i-1))) / u.get(&[n-i-1, n-i-1]));
74                }
75                
76                Some(x)
77            },
78            None => {None}
79        }
80    }
81
82    fn qr(&self) -> Option<[Self::TensorType; 2]> {
83        // qr is for square matrix only.
84        // TODO; handle the batched/3d case.
85        if self.size().len() != 2 {
86            return None;
87        }
88        //if self.size()[0] != self.size()[1] {
89        //    return None;
90        //}
91        let m = self.size()[self.size().len()-2];
92        let n = self.size()[self.size().len()-1];
93
94        let mut q = GenTensor::<T>::zeros(&[m, cmp::min(m, n)]);
95        let mut r = GenTensor::<T>::zeros(&[n, n]);
96        for i in 0..n {
97            let a = self.get_column(i);
98            let mut u = a.clone();
99            for j in 0..i {
100                u = u.sub(&a.proj(&q.get_column(j)));
101            }
102            if i < cmp::min(m, n) {
103                let e = u.normalize_unit();
104                q.set_column(&e, i);
105            }
106            for j in 0..cmp::min(i+1, cmp::min(m, n)) {
107                if j <= m {
108                    r.set(&[j, i], a.dot(&q.get_column(j)));
109                }
110            }
111        }
112        
113        Some([q, r])
114    }
115
116    fn eigen(&self) -> Option<[Self::TensorType; 2]> {
117        // TODO; handle the batched/3d case.
118        if self.size().len() != 2 {
119            return None;
120        }
121        if self.size()[0] != self.size()[1] {
122            return None;
123        }
124        let n = self.size()[0];
125        let mut cap_a = self.clone();
126
127        let tolerance: f64 = 1e-9;
128        let iter_max = 100;
129
130        let mut evec = GenTensor::<T>::zeros(&[n, n]);
131        let mut eval = GenTensor::<T>::zeros(&[n, 1]);
132        for i in 0..n {
133            let mut x = GenTensor::<T>::fill(T::one(), &[n, 1]);
134            let mut iter_counter = 0;
135            loop {
136                if iter_counter > iter_max {
137                    break;
138                }
139                let x1 = x.clone();
140                x = cap_a.matmul(&x).normalize_unit();
141                if x1.sub(&x).norm().get_scale() < T::from(tolerance).unwrap() {
142                    break;
143                }
144                iter_counter += 1;
145            }
146            //println!("iter: {:?}", iter_counter);
147            let lambda = x.permute(&[1, 0]).matmul(self).matmul(&x).squeeze(None);
148
149            evec.set_column(&x, i);
150            eval.set(&[i, 0], lambda.get_scale());
151
152            cap_a = cap_a.sub(&GenTensor::<T>::eye(n, n).mul(&lambda));
153
154            //println!("index: {:?}", i);
155        }
156
157        Some([evec, eval])
158    }
159    fn cholesky(&self) -> Option<Self::TensorType> {
160        // TODO; handle the batched/3d case.
161        if self.size().len() != 2 {
162            return None;
163        }
164        if self.size()[0] != self.size()[1] {
165            return None;
166        }
167        let n = self.size()[0];
168
169        let mut ret = GenTensor::<T>::zeros(&[n, n]);
170        for i in 0..n {
171            for j in 0..i {
172                ret.set(&[j, i],
173                        (self.get(&[j, i]) -
174                         ret.get_column(j).dot(&ret.get_column(i)))/ret.get(&[j, j]))
175            }
176            ret.set(&[i, i],
177                    T::sqrt(self.get(&[i,i]) - ret.get_column(i).dot(&ret.get_column(i))));
178        }
179        Some(ret)
180    }
181    
182    fn det(&self) -> Option<Self::TensorType> {
183        if self.size().len() != 2 {
184            return None
185        }
186        if self.size()[0] != self.size()[1] {
187            return None
188        }
189        let n = self.size()[0];
190        let mut sign_pos = true;
191        let mut self_data = self.clone();
192
193        for i in 0..n {
194            if self_data.get(&[i, i]) == T::zero() {
195                let mut row_counter = 1;
196
197                loop {
198                    if i+row_counter == n {
199                        return Some(GenTensor::zeros(&[1])); // invalid
200                    }
201                    if self_data.get(&[i+row_counter, i]) == T::zero() {
202                        row_counter += 1;
203                    } else {
204                        sign_pos ^= true;
205                        let tmp_row = self.get_row(i);
206                        self_data.set_row(&self_data.get_row(i+row_counter), i);
207                        self_data.set_row(&tmp_row, i+row_counter);
208                        break;
209                    }
210                }
211            }
212        }
213        
214        if let Some(v) = self_data.lu() {
215            let [_l, u] = v;
216            let mut ret = u.get_diag().prod(None, false).get(&[0]);
217            if !sign_pos {
218                ret = ret.neg();
219            }
220            let ret = GenTensor::new_raw(&[ret], &[1]);
221            Some(ret)
222        } else {
223            None
224        }
225    }
226
227    fn svd(&self) -> Option<[Self::TensorType; 3]> {
228        // TODO; handle the batched/3d case.
229        // TODO: assume the input is thin matrix.
230        let m = self.size()[self.size().len()-2];
231        let n = self.size()[self.size().len()-1];
232
233        let cap_a: GenTensor<T>;
234        if m > n {
235            cap_a = self.permute(&[1, 0]).matmul(self);
236        } else if m < n {
237            cap_a = self.matmul(&self.permute(&[1, 0]));
238        } else {
239            cap_a = self.clone();
240        }
241
242        let tolerance: f64 = 1e-9;
243        let iter_max = 100;
244
245        let mut s: GenTensor<T>;
246        let mut v = GenTensor::<T>::eye(n, n);
247        let mut iter_counter = 0;
248        loop {
249
250            let v1 = v.clone();
251            let [qv, r] = cap_a.matmul(&v).qr().unwrap();
252            v = qv;
253            
254            if v1.sub(&v).norm().get_scale() < T::from(tolerance).unwrap() {
255                s = r;
256                break;
257            }
258
259            if iter_counter > iter_max {
260                s = r;
261                break;
262            }
263
264            iter_counter += 1;
265            //println!("iter_counter {:?}", iter_counter);
266        }
267
268        let u: GenTensor<T>;
269        if m > n {
270            s = s.sqrt();
271            v = v.permute(&[1, 0]);
272            let invs = GenTensor::<T>::ones(&[n]).div(&s.get_diag());
273            u = self.matmul(&v.permute(&[1, 0])).matmul(&invs);
274        } else if m < n {
275            s = s.sqrt();
276            u = v.permute(&[1, 0]);
277            let invs = GenTensor::<T>::ones(&[n]).div(&s.get_diag());
278            v = invs.matmul(&u.permute(&[1, 0])).matmul(self);
279        } else {
280            u = v.permute(&[1, 0]);
281        }
282        
283        Some([u, s, v])
284    }
285
286    fn inv(&self) -> Option<Self::TensorType> {
287        if self.size().len() != 2 {
288            return None;
289        }
290        if self.size()[self.size().len()-2] != self.size()[self.size().len()-1] {
291            return None;
292        }
293
294        let mut ret = GenTensor::zeros_like(self);
295        for i in 0..self.numel() {
296            let index = self.index2dimpos(i);
297            let minor = self.index_exclude(0, &GenTensor::new_raw(&[T::from(index[0]).unwrap()], &[1]))
298                .index_exclude(1, &GenTensor::new_raw(&[T::from(index[1]).unwrap()], &[1]));
299            let minor = minor.det().unwrap();
300            
301            if (index[0] + index[1]) %2 == 0 {
302                ret.set(&index, minor.get_scale());
303            } else {
304                ret.set(&index, minor.get_scale().neg());
305            }
306        }
307
308        let ret = ret.t();
309
310        let det = self.det()?;
311
312        Some(ret.div(&det))
313    }
314    
315    fn pinv(&self) -> Self::TensorType {
316        let [u, s, v] = self.svd().unwrap();
317        let m = s.size()[self.size().len()-2];
318        let n = s.size()[self.size().len()-1];
319        let mut diag_v = Vec::new();
320        for i in 0..cmp::min(m, n) {
321            if s.get(&[i, i]) != T::zero() {
322                diag_v.push(s.get(&[i, i]));
323            } else {
324                break;
325            }
326        }
327        let mut s = GenTensor::zeros(&[diag_v.len(), diag_v.len()]);
328        s.set_diag(&GenTensor::new_raw(&diag_v, &[diag_v.len()]));
329        v.matmul(&s).matmul(&u.t())
330    }
331
332    fn tr(&self) -> Self::TensorType {
333        self.get_diag().sum(None, false)
334    }
335}
336
337
338#[cfg(test)]
339mod tests {
340    use super::*;
341
342    #[test]
343    fn normalize_unit() {
344        let m = GenTensor::<f64>::new_raw(&[1., 1., 0., 1., 0., 1., 0., 1., 1.], &[3,3]);
345        let nm = m.normalize_unit();
346        assert_eq!(nm, GenTensor::<f64>::new_raw(&[0.4082482904638631, 0.4082482904638631, 0.,
347                                                   0.4082482904638631, 0., 0.4082482904638631,
348                                                   0., 0.4082482904638631, 0.4082482904638631, ],
349                                                 &[3,3]));
350    }
351
352    #[test]
353    fn lu() {
354        let m = GenTensor::<f64>::new_raw(&[1., 1., 1., 4., 3., -1., 3., 5., 3.], &[3,3]);
355        let [l, u] = m.lu().unwrap();
356        let el = GenTensor::<f64>::new_raw(&[1., 0., 0., 4., 1., 0., 3., -2., 1.], &[3,3]);
357        let eu = GenTensor::<f64>::new_raw(&[1., 1., 1., 0., -1., -5., 0., 0., -10.], &[3,3]);
358        assert_eq!(l, el);
359        assert_eq!(u, eu);
360    }
361
362    #[test]
363    fn lu_solve() {
364        let cap_a = GenTensor::<f64>::new_raw(&[7., -2., 1., 14., -7., -3., -7., 11., 18.], &[3,3]);
365        let b = GenTensor::<f64>::new_raw(&[12., 17., 5.], &[3,1]);
366        let x = cap_a.lu_solve(&b).unwrap();
367        let ex = GenTensor::<f64>::new_raw(&[3., 4., -1.,], &[3,1]);
368        assert_eq!(x, ex);
369    }
370
371    #[test]
372    fn det() {
373        let m = GenTensor::<f64>::new_raw(&[1., 1., 1., 4., 3., -1., 3., 5., 3.], &[3,3]);
374        let r = m.det().unwrap().get_scale();
375        assert_eq!(r, 10.);
376
377        let m = GenTensor::<f64>::new_raw(&[0., -2., 1., 1.], &[2,2]);
378        let r = m.det().unwrap().get_scale();
379        assert_eq!(r, 2.);
380    }
381
382    #[test]
383    fn qr() {
384        let m = GenTensor::<f64>::new_raw(&[1., 1., 0., 1., 0., 1., 0., 1., 1.], &[3,3]);
385        let [q, r] = m.qr().unwrap();
386        let eq = GenTensor::<f64>::new_raw(&[0.7071067811865475, 0.40824829046386313, -0.5773502691896257,
387                                             0.7071067811865475, -0.40824829046386296, 0.577350269189626,
388        0., 0.8164965809277261, 0.5773502691896256, ], &[3,3]);
389        let er = GenTensor::<f64>::new_raw(&[1.414213562373095, 0.7071067811865475, 0.7071067811865475, 0., 1.2247448713915894, 0.4082482904638632, 0., 0., 1.1547005383792515, ], &[3,3]);
390        assert_eq!(q, eq);
391        assert_eq!(r, er);
392    }
393
394    #[test]
395    fn cholesky() {
396        let m = GenTensor::<f64>::new_raw(&[4., 12., -16., 12., 37., -43., -16., -43., 98.], &[3,3]);
397        let c = m.cholesky().unwrap();
398        let ec = GenTensor::<f64>::new_raw(&[2., 6., -8., 0., 1., 5., 0., 0., 3.], &[3,3]);
399        assert_eq!(c, ec);
400    }
401
402    #[test]
403    fn eigen() {
404        let m = GenTensor::<f64>::new_raw(&[4., 3., -2., -3.], &[2,2]);
405        //let ec = GenTensor::<f64>::new_raw(&[4., 3., -2., -3.], &[2,2]);
406        let el = GenTensor::<f64>::new_raw(&[3., -2.], &[2,1]);
407        let [_evec, eval] = m.eigen().unwrap();
408        //println!("{:?}, {:?}", _evec, eval);
409        //println!("{:?}", eval.sub(&el).norm());
410        assert!(eval.sub(&el).norm().get_scale() < 1e-6);
411    }
412
413    #[test]
414    fn svd() {
415        let m = GenTensor::<f64>::new_raw(&[4., 12., -16., 12., 37., -43., -16., -43., 98.], &[3,3]);
416        let [_u, s, _v] = m.svd().unwrap();
417        println!("{:?}, {:?}, {:?}", _u, s, _v);
418        let es = GenTensor::<f64>::new_raw(&[123.47723179013161, 15.503963229407585, 0.018804980460810704], &[3]);
419        assert!(es.sub(&s.get_diag()).norm().get_scale() < 1e-6);
420
421        println!("{:?}", _u.matmul(&s).matmul(&_v.t()));
422        println!("{:?}", _u.matmul(&_u.t()));
423        println!("{:?}", _v.matmul(&_v.t()));
424    }
425
426    #[test]
427    fn inv() {
428        let m = GenTensor::<f64>::new_raw(&[3., 0., 2., 2., 0., -2., 0., 1., 1.], &[3,3]);
429        let inv_m = m.inv().unwrap();
430        let e_inv = GenTensor::<f64>::new_raw(&[0.2, 0.2, 0., -0.2, 0.3, 1., 0.2, -0.3, 0.], &[3,3]);
431        assert_eq!(inv_m, e_inv);
432    }
433
434    #[test]
435    fn pinv() {
436        let m = GenTensor::<f64>::new_raw(&[2., -1., 1., 4., 3., -2., 4., 5., -2.], &[3, 3]);
437        let pinv_m = m.pinv();
438        println!("{:?}", pinv_m.matmul(&m));
439    }
440}