Skip to main content

yui_matrix/sparse/
trans.rs

1use sprs::PermView;
2use yui_core::{CloneAnd, Ring, RingOps};
3use crate::sparse::{SpMat, MatTrait, SpVec};
4
5#[derive(Clone, Debug)]
6#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
7pub struct Trans<R> 
8where R: Ring, for <'x> &'x R: RingOps<R> {
9    src_dim: usize, 
10    tgt_dim: usize,
11    f_mats: Vec<SpMat<R>>,
12    b_mats: Vec<SpMat<R>>,
13}
14
15impl<R> Trans<R> 
16where R: Ring, for <'x> &'x R: RingOps<R> { 
17    pub fn id(n: usize) -> Self { 
18        Self { 
19            src_dim: n, 
20            tgt_dim: n, 
21            f_mats: vec![], 
22            b_mats: vec![] 
23        }
24    }
25
26    pub fn zero() -> Self { 
27        Self::id(0)
28    }
29
30    pub fn new(f: SpMat<R>, b: SpMat<R>) -> Self {
31        let mut t = Self::id(f.ncols());
32        t.append(f, b);
33        t
34    }
35
36    pub fn src_dim(&self) -> usize { 
37        self.src_dim
38    }
39
40    pub fn tgt_dim(&self) -> usize { 
41        self.tgt_dim
42    }
43
44    pub fn is_id(&self) -> bool { 
45        self.f_mats.is_empty()
46    }
47
48    pub fn forward(&self, v: &SpVec<R>) -> SpVec<R> {
49        assert_eq!(v.dim(), self.src_dim);
50        self.f_mats.iter().fold(v.clone(), |v, f| f * v)
51    }
52
53    pub fn backward(&self, v: &SpVec<R>) -> SpVec<R> {
54        assert_eq!(v.dim(), self.tgt_dim);
55        self.b_mats.iter().rev().fold(v.clone(), |v, f| f * v)
56    }
57
58    pub fn append(&mut self, f: SpMat<R>, b: SpMat<R>) { 
59        assert_eq!(f.ncols(), b.nrows());
60        assert_eq!(f.nrows(), b.ncols());
61        assert_eq!(f.ncols(), self.tgt_dim);
62
63        self.tgt_dim = f.nrows();
64        self.f_mats.push(f);
65        self.b_mats.push(b);
66    }
67
68    pub fn append_perm(&mut self, p: PermView) { 
69        assert_eq!(p.dim(), self.tgt_dim);
70        let f = SpMat::from_row_perm(p.clone());
71        let b = SpMat::from_col_perm(p);
72        self.append(f, b)
73    }
74
75    pub fn merge(&mut self, mut other: Trans<R>) { 
76        assert_eq!(self.tgt_dim, other.src_dim);
77
78        self.tgt_dim = other.tgt_dim;
79        self.f_mats.append(&mut other.f_mats);
80        self.b_mats.append(&mut other.b_mats);
81    }
82
83    pub fn merged(&self, other: &Trans<R>) -> Self { 
84        self.clone_and(|t| 
85            t.merge(other.clone())
86        )
87    }
88
89    pub fn forward_mat(&self) -> SpMat<R> {
90        // f = fn * ... f1 * f0
91        if self.f_mats.len() == 1 { 
92            self.f_mats[0].clone()
93        } else { 
94            self.f_mats.iter().rev().fold(
95                SpMat::id(self.tgt_dim), 
96                |res, f| res * f
97            )
98        }
99    }
100
101    pub fn backward_mat(&self) -> SpMat<R> {
102        // b = b0 * b1 * ... * bn
103        if self.b_mats.len() == 1 { 
104            self.b_mats[0].clone()
105        } else { 
106            self.b_mats.iter().rev().fold(
107                SpMat::id(self.tgt_dim), 
108                |res, b| b * res
109            )
110        }
111    }
112
113    pub fn reduce(&mut self) {
114        if self.f_mats.len() > 1 { 
115            let f = self.forward_mat();
116            self.f_mats = vec![f];
117        }
118
119        if self.b_mats.len() > 1 { 
120            let b = self.backward_mat();
121            self.b_mats = vec![b];
122        }
123    }
124
125    pub fn sub(&self, indices: &[usize]) -> Self { 
126        let n = self.tgt_dim();
127        let p = indices.len();
128        let f = SpMat::from_entries(
129            (p, n), 
130            indices.iter().enumerate().map(|(i, &j)|
131                (i, j, R::one())
132            )
133        );
134        let b = SpMat::from_entries(
135            (n, p), 
136            indices.iter().enumerate().map(|(i, &j)|
137                (j, i, R::one())
138            )
139        );
140        self.clone_and(|sub|
141            sub.append(f, b)
142        )
143    }
144}
145
146impl<R> Default for Trans<R>
147where R: Ring, for <'x> &'x R: RingOps<R> {
148    fn default() -> Self {
149        Self::zero()
150    }
151}
152
153#[cfg(test)]
154mod tests {
155    use sprs::PermOwned;
156
157    use super::*;
158    use crate::sparse::*;
159
160    #[test]
161    fn id() {
162        let t = Trans::<i32>::id(5);
163
164        let v = SpVec::from(vec![0,1,2,3,4]);
165        let w = t.forward(&v);
166        let x = t.backward(&v);
167
168        assert_eq!(w, SpVec::from(vec![0,1,2,3,4]));
169        assert_eq!(x, SpVec::from(vec![0,1,2,3,4]));
170    }
171    
172    #[test]
173    fn trans() {
174        let t = Trans::<i32>::new(
175            SpMat::id(5).submat_rows(0..3),
176            SpMat::id(5).submat_cols(0..3),
177        );
178
179        let v = SpVec::from(vec![0,1,2,3,4]);
180        let w = t.forward(&v);
181        let x = t.backward(&w);
182
183        assert_eq!(w, SpVec::from(vec![0,1,2]));
184        assert_eq!(x, SpVec::from(vec![0,1,2,0,0]));
185    }
186
187    #[test]
188    fn append_perm() {
189        let mut t = Trans::<i32>::new(
190            SpMat::id(5).submat_rows(0..3),
191            SpMat::id(5).submat_cols(0..3),
192        );
193        t.append_perm(
194            PermOwned::new(vec![1,2,0]).view()
195        );
196
197        let v = SpVec::from(vec![0,1,2,3,4]);
198        let w = t.forward(&v);
199        let x = t.backward(&w);
200
201        assert_eq!(w.into_vec(), vec![2,0,1]);
202        assert_eq!(x.into_vec(), vec![0,1,2,0,0]);
203    }
204
205    #[test]
206    fn is_id() { 
207        let t = Trans::<i64>::id(10);
208        assert!(t.is_id());
209    }
210}