Skip to main content

yui_matrix/sparse/
pivot.rs

1// Implementation based on:
2// 
3// "Parallel Sparse PLUQ Factorization modulo p", Charles Bouillaguet, Claire Delaplace, Marie-Emilie Voge.
4// https://hal.inria.fr/hal-01646133/document
5// 
6// see also: SpaSM (Sparse direct Solver Modulo p)
7// https://github.com/cbouilla/spasm
8
9use std::slice::Iter;
10use std::cmp::Ordering;
11use std::collections::VecDeque;
12use ahash::AHashSet;
13use itertools::Itertools;
14use log::debug;
15use sprs::PermOwned;
16
17use yui_core::{Ring, RingOps};
18use yui_core::algo::top_sort;
19use super::*;
20use super::util::perm_for_indices;
21
22cfg_if::cfg_if! {
23    if #[cfg(feature = "multithread")] {
24        use std::cell::RefCell;
25        use std::sync::RwLock;
26        use thread_local::ThreadLocal;
27        use rayon::prelude::*;
28    }
29}
30
31const LOG_THRESHOLD: usize = 10_000;
32
33#[derive(Clone, Copy, PartialEq, Eq, Debug)]
34pub enum PivotType { 
35    Rows, Cols
36}
37
38#[derive(Clone, Copy, Debug)]
39pub enum PivotCondition { 
40    One, Weight(f64), AnyUnit
41}
42
43impl PivotCondition { 
44    fn is_cand<R>(&self, r: &R) -> bool 
45    where R: Ring, for<'x> &'x R: RingOps<R> {
46        match self {
47            PivotCondition::One       => r.is_pm_one(),
48            PivotCondition::Weight(w) => r.is_unit() && r.c_weight() <= *w,
49            PivotCondition::AnyUnit   => r.is_unit(),
50        }
51    }
52}
53
54pub fn find_pivots<R>(a: &SpMat<R>, piv_type: PivotType, pivot_cond: PivotCondition) -> Vec<(usize, usize)>
55where R: Ring, for<'x> &'x R: RingOps<R> {
56    if a.is_zero() { 
57        return vec![];
58    }
59    
60    let mut pf = PivotFinder::new(a, piv_type, pivot_cond);
61    pf.find_pivots();
62    pf.result()
63}
64
65pub fn perms_by_pivots<R>(a: &SpMat<R>, pivs: &[(usize, usize)]) -> (PermOwned, PermOwned)
66where R: Ring, for<'x> &'x R: RingOps<R> {
67    let (m, n) = a.shape();
68    (
69        perm_for_indices(m, pivs.iter().map(|(i, _)| i)), 
70        perm_for_indices(n, pivs.iter().map(|(_, j)| j))
71    )
72}
73
74type Row = usize;
75type Col = usize;
76
77pub struct PivotFinder {
78    str: MatrixStr,
79    pivots: PivotData,
80    piv_type: PivotType
81}
82
83impl PivotFinder { 
84    pub fn new<R>(a: &SpMat<R>, piv_type: PivotType, pivot_cond: PivotCondition) -> Self
85    where R: Ring, for<'x> &'x R: RingOps<R> {
86        let str = MatrixStr::new(a, piv_type, pivot_cond);
87        let pivots = PivotData::new(a, piv_type);
88        PivotFinder{ str, pivots, piv_type }
89    }
90
91    pub fn find_pivots(&mut self) {
92        debug!("pivots: {:?} ..", self.str.shape());
93
94        self.find_fl_pivots();
95        self.find_fl_col_pivots();
96        self.find_cycle_free_pivots();
97
98        debug!("pivots: {:?} => {}.", self.str.shape(), self.pivots.count());
99    }
100
101    pub fn result(&self) -> Vec<(usize, usize)> { 
102        let tree = self.pivots.iter().map(|(i, j)| { 
103            let list = self.str.cols_in(i).filter(|&&j2|
104                j != j2 && self.pivots.has_col(j2)
105            ).copied().collect_vec();
106            (j, list)
107        });
108        
109        let sorted = top_sort(tree).unwrap();
110        let is_row_type = self.piv_type == PivotType::Rows;
111        
112        sorted.into_iter().map(|j| {
113            let i = self.pivots.row_for(j).unwrap();
114            if is_row_type { (i, j) } else { (j, i) }
115        }).collect_vec()
116    }
117
118    fn rows(&self) -> Row { 
119        self.str.shape.0
120    }
121
122    fn cols(&self) -> Col { 
123        self.str.shape.1
124    }
125
126    fn remain_rows(&self) -> impl Iterator<Item = Row> { 
127        let piv_rows: AHashSet<_> = self.pivots.iter().map(|(i, _)| i).collect();
128        let m = self.rows();
129
130        (0 .. m).filter(|&i| 
131            !piv_rows.contains(&i) && !self.str.is_empty_row(i)
132        ).sorted_by(|&i1, &i2| 
133            self.str.cmp_rows(i1, i2)
134        )
135    }
136
137    fn occupied_cols(&self) -> AHashSet<Col> {
138        self.pivots.iter().fold(AHashSet::new(), |mut res, (i, _)| {
139            for &j in self.str.cols_in(i) { 
140                res.insert(j);
141            }
142            res
143        })
144    }
145
146    fn find_fl_pivots(&mut self) {
147        let remain_rows: Vec<_> = self.remain_rows().collect();
148
149        for i in remain_rows {
150            let Some(j) = self.str.head_col_in(i) else { continue };
151
152            if !self.pivots.has_col(j) && self.str.is_candidate(i, j) {
153                self.pivots.set(i, j);
154            }
155        }
156
157        let piv_count = self.pivots.count();
158
159        debug!("  fl-pivots: +{}.", piv_count);
160    }
161
162    fn find_fl_col_pivots(&mut self) {
163        let before_piv_count = self.pivots.count();
164
165        let remain_rows: Vec<_> = self.remain_rows().collect();
166        let mut occ_cols = self.occupied_cols();
167
168        for i in remain_rows { 
169            let mut cands = vec![];
170
171            for &j in self.str.cols_in(i) { 
172                if !occ_cols.contains(&j) && self.str.is_candidate(i, j) {
173                    cands.push(j);
174                }
175            }
176
177            let Some(j) = cands.into_iter().sorted_by(|&j1, &j2| 
178                self.str.cmp_cols(j1, j2)
179            ).next() else { continue };
180
181            self.pivots.set(i, j);
182
183            for &j in self.str.cols_in(i) { 
184                occ_cols.insert(j);
185            }
186        }
187
188        let piv_count = self.pivots.count();
189        
190        debug!("  fl-col-pivots: +{}, total: {}.", piv_count - before_piv_count, piv_count);
191    }
192
193    fn find_cycle_free_pivots(&mut self) {
194        let before_piv_count = self.pivots.count();
195
196        cfg_if::cfg_if! { 
197            if #[cfg(feature = "multithread")] { 
198                self.find_cycle_free_pivots_m();
199            } else { 
200                self.find_cycle_free_pivots_s();
201            }
202        }
203
204        let piv_count = self.pivots.count();
205
206        debug!("  cycle-free-pivots: +{}, total: {}.", piv_count - before_piv_count, piv_count);
207    }
208
209    #[allow(unused)]
210    fn find_cycle_free_pivots_s(&mut self) {
211        let remain_rows: Vec<_> = self.remain_rows().collect();
212        let total_rows = remain_rows.len();
213
214        debug!("  start find-cycle-free-pivots: {total_rows} rows");
215
216        let n = self.cols();
217        let mut w = RowWorker::new(n);
218        let mut row_count = 0;
219
220        for i in remain_rows { 
221            if let Some(j) = w.find_cycle_free_pivots(i, &self.str, &self.pivots) { 
222                self.pivots.set(i, j);
223            }
224
225            if self.should_report() {
226                row_count += 1;
227                if row_count % LOG_THRESHOLD == 0 { 
228                    let c = self.pivots.count();
229                    debug!("    [{row_count}/{total_rows}], {c} pivots.");
230                }
231            }
232        }
233     }
234
235     #[cfg(feature = "multithread")]
236     fn find_cycle_free_pivots_m(&mut self) {
237        use yui_core::util::sync::SyncCounter;
238
239        let remain_rows = self.remain_rows().collect_vec();
240        let total_rows = remain_rows.len();
241
242        debug!("  start find-cycle-free-pivots: {total_rows} rows");
243        
244        let n = self.cols();
245        let pivots = RwLock::new(
246            std::mem::take(&mut self.pivots)
247        );
248        let loc_pivots_tls = ThreadLocal::new();
249        let loc_worker_tls = ThreadLocal::new();
250
251        let report = self.should_report();
252        let row_counter = SyncCounter::new();
253
254        remain_rows.par_iter().for_each(|&i| { 
255            let mut loc_pivots = init_tls(&loc_pivots_tls, || 
256                pivots.read().unwrap().clone()
257            ).borrow_mut();
258
259            let mut w = init_tls(&loc_worker_tls, || 
260                RowWorker::new(n)
261            ).borrow_mut();
262
263            loc_pivots.update_from(&pivots.read().unwrap());
264            w.init(i, &self.str, &loc_pivots);
265
266            self.find_cycle_free_pivots_in(&pivots, &mut loc_pivots, &mut w);
267
268            if report { 
269                let row_count = row_counter.incr();            
270                if row_count % LOG_THRESHOLD == 0 { 
271                    let c = loc_pivots.count();
272                    debug!("    [{row_count}/{total_rows}], {c} pivots.");
273                }
274            }
275        });
276
277        self.pivots = pivots.into_inner().unwrap();
278     }
279
280     #[cfg(feature = "multithread")]
281     fn find_cycle_free_pivots_in(&self, pivots: &RwLock<PivotData>, loc_pivots: &mut PivotData, w: &mut RowWorker) {
282        loop { 
283            w.traverse(&self.str, loc_pivots);
284    
285            let Some(j) = w.choose_candidate(&self.str) else {
286                break
287            };
288            
289            // If changes are made in other threads, update `loc_pivots` and retry.
290            // Otherwise, modify `pivots` and exit.
291        
292            let mut pivots = pivots.write().unwrap();
293            w.update_diff(&loc_pivots, &pivots);
294            
295            if w.should_retry() { 
296                loc_pivots.update_from(&pivots);
297                continue
298            } else { 
299                pivots.set(w.row, j);
300                break
301            }    
302        }
303     }
304
305     fn should_report(&self) -> bool { 
306        self.rows() > LOG_THRESHOLD && log::max_level() >= log::LevelFilter::Debug
307     }
308}
309
310#[cfg(feature = "multithread")]
311fn init_tls<T, F>(tl: &ThreadLocal<RefCell<T>>, f: F) -> &RefCell<T>
312where T: Send, F: FnOnce() -> T {
313    tl.get_or(|| RefCell::new( f() ) )
314}
315
316struct MatrixStr { 
317    shape: (usize, usize),
318    entries: Vec<Vec<Col>>,     // [row -> [col]]
319    cands: Vec<AHashSet<Col>>,  // [row -> [col]]
320    row_wght: Vec<f64>,         // [row -> weight]
321    col_wght: Vec<f64>,         // [col -> weight]
322}
323
324impl MatrixStr { 
325    fn new<R>(a: &SpMat<R>, piv_type: PivotType, pivot_cond: PivotCondition) -> Self
326    where R: Ring, for<'x> &'x R: RingOps<R> { 
327        let shape = match piv_type {
328            PivotType::Rows => a.shape(),
329            PivotType::Cols => (a.ncols(), a.nrows())
330        };
331        let t = match piv_type {
332            PivotType::Rows => |i: usize, j: usize| (i, j),
333            PivotType::Cols => |i, j| (j, i)
334        };
335
336        let (m, n) = shape;
337        let mut entries = vec![vec![]; m];
338        let mut row_wght = vec![0.0; m];
339        let mut col_wght = vec![0.0; n];
340        let mut cands = vec![AHashSet::new(); m];
341
342        for (i, j, r) in a.iter() { 
343            if r.is_zero() { continue }
344            
345            let (i, j) = t(i, j);
346            entries[i].push(j);
347
348            let w = r.c_weight();
349            row_wght[i] += w;
350            col_wght[j] += w;
351
352            if pivot_cond.is_cand(r) { 
353                cands[i].insert(j);
354            }
355        }
356
357        Self { shape, entries, cands, row_wght, col_wght }
358    }
359
360    fn shape(&self) -> (usize, usize) {
361        self.shape
362    }
363
364    fn is_empty_row(&self, i: Row) -> bool { 
365        self.entries[i].is_empty()
366    }
367
368    fn head_col_in(&self, i: Row) -> Option<Col> {
369        self.entries[i].first().copied()
370    }
371
372    fn cols_in(&self, i: Row) -> Iter<Col> {
373        self.entries[i].iter()
374    }
375
376    fn cmp_rows(&self, i1: Row, i2: Row) -> Ordering {
377        if let Some(o) = self.row_wght[i1].partial_cmp(&self.row_wght[i2]) { 
378            o.then(Ord::cmp(&i1, &i2))
379        } else { 
380            Ordering::Equal
381        }
382    }
383
384    fn cmp_cols(&self, j1: Col, j2: Col) -> Ordering {
385        if let Some(o) = self.col_wght[j1].partial_cmp(&self.col_wght[j2]) { 
386            o.then(Ord::cmp(&j1, &j2))
387        } else {
388            Ordering::Equal
389        }
390    }
391
392    fn is_candidate(&self, i: Row, j: Col) -> bool { 
393        self.cands[i].contains(&j)
394    }
395}
396
397#[derive(Clone, Default)]
398struct PivotData { 
399    data: Vec<Option<Row>>,   // col -> row
400    indices: Vec<Col>
401}
402
403impl PivotData { 
404    fn new<R>(a: &SpMat<R>, piv_type: PivotType) -> Self
405    where R: Ring, for<'x> &'x R: RingOps<R> { 
406        let n = if piv_type == PivotType::Rows { 
407            a.ncols()
408        } else { 
409            a.nrows()
410        };
411        let data = vec![None; n];
412        let indices = vec![];
413        Self { data, indices }
414    }
415
416    fn count(&self) -> usize { 
417        self.indices.len()
418    }
419
420    fn has_col(&self, j: Col) -> bool { 
421        self.data[j].is_some()
422    }
423
424    fn row_for(&self, j: Col) -> Option<Row> { 
425        self.data[j]
426    }
427
428    fn set(&mut self, i: Row, j: Col) {
429        assert!(!self.has_col(j));
430        self.data[j] = Some(i);
431        self.indices.push(j);
432    }
433
434    fn iter(&self) -> impl Iterator<Item = (Row, Col)> + '_ { 
435        self.indices.iter().map(|&j| {
436            let i = self.data[j].unwrap();
437            (i, j)
438        })
439    }
440
441    #[allow(unused)]
442    fn pivot_at(&self, k: usize) -> (Row, Col) { 
443        let j = self.indices[k];
444        let i = self.data[j].unwrap();
445        (i, j)
446    }
447
448    fn update_from(&mut self, from: &Self) { 
449        debug_assert!(self.count() <= from.count());
450        for k in self.count() .. from.count() { 
451            let (i, j) = from.pivot_at(k);
452            self.set(i, j);
453        }
454    }
455}
456
457#[repr(u8)]
458#[derive(Clone, Copy, PartialEq, Eq, Debug)]
459enum EntryStatus { 
460    None, Candidate, Occupied
461}
462
463struct RowWorker { 
464    row: usize,
465    status: Vec<EntryStatus>,
466    ncand: usize,
467    queue: VecDeque<Col>,
468    queued: AHashSet<Col>
469}
470
471impl RowWorker {
472    fn new(size: usize) -> Self { 
473        let status = vec![EntryStatus::None; size];
474        let queue = VecDeque::new();
475        let queued = AHashSet::new();
476        RowWorker {row: 0, status, ncand: 0, queue, queued }
477    }
478
479    fn clear(&mut self) {
480        self.row = 0;
481        self.status.fill(EntryStatus::None);
482        self.ncand = 0;
483        self.queue.clear();
484        self.queued.clear();
485    }
486
487    //  i [  o       #     # ]     [  o   x   x      # ]     [  o   x   x   x  # ]
488    //    [  |               ]     [  |   :   :        ]     [  |   :   :   :    ]
489    //    [  *   .   .       ] ~~> [  * - o - .        ] ~~> [  * - o - .   :    ]
490    //    [                  ]     [      |            ]     [      |       :    ]
491    //    [      *       .   ]     [      *       .    ]     [      *-------.    ]
492    //
493    //  o: queued, #: candidate, x: occupied
494
495    fn find_cycle_free_pivots(&mut self, i: usize, str: &MatrixStr, pivots: &PivotData) -> Option<Col> {
496        self.init(i, str, pivots);
497        self.traverse(str, pivots);
498        self.choose_candidate(str)
499    }
500
501    fn init(&mut self, i: usize, str: &MatrixStr, pivots: &PivotData) { 
502        self.clear();
503        self.row = i;
504
505        for &j in str.cols_in(i) {
506            if pivots.has_col(j) {
507                self.enqueue(j);
508                self.set_occupied(j);
509            } else if str.is_candidate(i, j) {
510                self.set_candidate(j);
511            } else { 
512                self.set_occupied(j);
513            }
514        }
515    }
516
517    fn traverse(&mut self, str: &MatrixStr, pivots: &PivotData) {
518        if !self.has_candidate() { 
519            return
520        }
521
522        while let Some(j) = self.dequeue() { 
523            let i2 = pivots.row_for(j).unwrap();
524
525            for &j2 in str.cols_in(i2) { 
526                if pivots.has_col(j2) && !self.is_queued(j2) { 
527                    self.enqueue(j2);
528                }
529
530                self.set_occupied(j2);
531
532                if !self.has_candidate() { 
533                    break 
534                }
535            }
536        }
537    }
538
539    fn choose_candidate(&self, str: &MatrixStr) -> Option<Col> { 
540        let n = self.status.len();
541        (0 .. n)
542            .filter(|&j| self.is_candidate(j))
543            .sorted_by(|&j1, &j2| 
544                str.cmp_cols(j1, j2)
545            ).next()
546    }
547
548    #[allow(dead_code)]
549    fn update_diff(&mut self, loc_pivots: &PivotData, pivots: &PivotData) {
550        debug_assert!(loc_pivots.count() <= pivots.count());
551        for k in loc_pivots.count()..pivots.count() { 
552            let j = pivots.indices[k];
553            if self.is_candidate(j) || self.is_occupied(j) {
554                self.enqueue(j);
555                self.set_occupied(j);
556            }
557        }
558    }
559
560    fn should_retry(&self) -> bool { 
561        !self.queue.is_empty()
562    }
563
564    fn has_candidate(&self) -> bool { 
565        self.ncand > 0
566    }
567
568    fn is_candidate(&self, i: usize) -> bool { 
569        self.status[i] == EntryStatus::Candidate
570    }
571
572    fn set_candidate(&mut self, i: usize) { 
573        assert_eq!(self.status[i], EntryStatus::None);
574        self.status[i] = EntryStatus::Candidate;
575        self.ncand += 1;
576    }
577
578    fn is_occupied(&self, i: usize) -> bool { 
579        self.status[i] == EntryStatus::Occupied
580    }
581
582    fn set_occupied(&mut self, i: usize) { 
583        if self.is_candidate(i) { 
584            self.ncand -= 1;
585        }
586        self.status[i] = EntryStatus::Occupied;
587    }
588
589    fn enqueue(&mut self, i: Col) { 
590        self.queue.push_back(i);
591        self.queued.insert(i);
592    }
593
594    fn dequeue(&mut self) -> Option<Col> { 
595        self.queue.pop_front()
596    }
597
598    fn is_queued(&self, i: Col) -> bool { 
599        self.queued.contains(&i)
600    }
601}
602
603#[cfg(test)]
604mod tests {
605    use super::*;
606    use num_traits::{Zero, One};
607 
608    #[test]
609    fn str_init() {
610        let a = SpMat::from_dense_data((6, 9), [
611            1, 0, 1, 0, 0, 1, 1, 0, 1,
612            0, 1, 1, 1, 0, 1, 0, 2, 0,
613            0, 0, 1, 1, 0, 0, 0, 1, 1,
614            0, 1, 1, 0, 3, 0, 0, 0, 0,
615            0, 1, 0, 1, 0, 0, 1, 0, 1,
616            1, 0, 1, 0, 1, 1, 0, 1, 1
617        ]);
618        let str = MatrixStr::new(&a, PivotType::Rows, PivotCondition::One);
619
620        assert_eq!(str.entries, vec![
621            vec![0,2,5,6,8], 
622            vec![1,2,3,5,7], 
623            vec![2,3,7,8], 
624            vec![1,2,4], 
625            vec![1,3,6,8], 
626            vec![0,2,4,5,7,8]]
627        );
628        assert_eq!(str.row_wght, vec![5.0, 6.0, 4.0, 5.0, 4.0, 6.0]);
629        assert_eq!(str.col_wght, vec![2.0, 3.0, 5.0, 3.0, 4.0, 3.0, 2.0, 4.0, 4.0]);
630        assert_eq!(str.cands, vec![
631            AHashSet::from_iter([0,2,5,6,8]),
632            AHashSet::from_iter([1,2,3,5]),
633            AHashSet::from_iter([2,3,7,8]),
634            AHashSet::from_iter([1,2]),
635            AHashSet::from_iter([1,3,6,8]),
636            AHashSet::from_iter([0,2,4,5,7,8])
637        ]);
638    }
639
640    #[test]
641    fn str_row_head() {
642        let a = SpMat::from_dense_data((4, 4), [
643            1, 0, 1, 0,
644            0, 1, 1, 1,
645            0, 0, 0, 0,
646            0, 0, 1, 1,
647        ]);
648        let str = MatrixStr::new(&a, PivotType::Rows, PivotCondition::One);
649
650        assert_eq!(str.head_col_in(0), Some(0));
651        assert_eq!(str.head_col_in(1), Some(1));
652        assert_eq!(str.head_col_in(2), None);
653        assert_eq!(str.head_col_in(3), Some(2));
654    }
655
656    #[test]
657    fn rows_cols() {
658        let a = SpMat::<i32>::from_dense_data((4, 3), []);
659        let pf = PivotFinder::new(&a, PivotType::Rows, PivotCondition::One);
660        assert_eq!(pf.rows(), 4);
661        assert_eq!(pf.cols(), 3);
662    }
663
664    #[test]
665    fn pivot_data() { 
666        let a = SpMat::from_dense_data((2, 4), [
667            1, 0, 1, 0,
668            0, 0, 1, 1,
669        ]);
670        let mut piv = PivotData::new(&a, PivotType::Rows);
671
672        assert_eq!(piv.count(), 0);
673        assert!(!piv.has_col(0));
674        assert_eq!(piv.row_for(0), None);
675
676        piv.set(1, 2);
677
678        assert_eq!(piv.count(), 1);
679        assert!(piv.has_col(2));
680        assert_eq!(piv.row_for(2), Some(1));
681    }
682
683    #[test]
684    fn remain_rows() {
685        let a = SpMat::from_dense_data((4, 4), [
686            1, 0, 1, 0,
687            0, 1, 1, 1,
688            0, 0, 0, 0,
689            0, 0, 1, 1,
690        ]);
691        let mut pf = PivotFinder::new(&a, PivotType::Rows, PivotCondition::One);
692
693        assert_eq!(pf.remain_rows().collect_vec(), vec![0,3,1]);
694
695        pf.pivots.set(0, 0);
696
697        assert_eq!(pf.remain_rows().collect_vec(), vec![3,1]);
698
699        pf.pivots.set(1, 1);
700
701        assert_eq!(pf.remain_rows().collect_vec(), vec![3]);
702    }
703
704    #[test]
705    fn pivots() {
706        let a = SpMat::from_dense_data((4, 4), [
707            1, 0, 1, 0,
708            0, 1, 1, 1,
709            0, 0, 0, 0,
710            0, 0, 1, 1,
711        ]);
712        let mut pf = PivotFinder::new(&a, PivotType::Rows, PivotCondition::One);
713
714        assert_eq!(pf.occupied_cols(), AHashSet::new());
715
716        pf.pivots.set(0, 0);
717
718        assert_eq!(pf.occupied_cols(), AHashSet::from_iter([0,2]));
719
720        pf.pivots.set(1, 1);
721
722        assert_eq!(pf.occupied_cols(), AHashSet::from_iter([0,1,2,3]));
723    }
724
725    #[test]
726    fn find_fl_pivots() {
727        let a = SpMat::from_dense_data((6, 9), [
728            1, 0, 1, 0, 0, 1, 1, 0, 1,
729            0, 1, 1, 1, 0, 1, 0, 1, 0,
730            0, 0, 1, 1, 0, 0, 0, 1, 1,
731            0, 1, 1, 0, 1, 0, 0, 0, 0,
732            0, 0, 1, 1, 0, 0, 0, 0, 0,
733            0, 0, 0, 0, 0, 1, 0, 1, 1
734        ]);
735        let mut pf = PivotFinder::new(&a, PivotType::Rows, PivotCondition::One);
736
737        pf.find_fl_pivots();
738
739        assert_eq!(pf.pivots.iter().collect_vec(), vec![(4, 2), (3, 1), (5, 5), (0, 0)]);
740    }
741
742    #[test]
743    fn find_fl_col_pivots() { 
744        let a = SpMat::from_dense_data((6, 9), [
745            1, 0, 0, 0, 0, 1, 0, 0, 1,
746            0, 1, 1, 1, 0, 1, 0, 1, 0,
747            0, 0, 1, 1, 0, 0, 0, 1, 1,
748            0, 1, 0, 0, 1, 0, 0, 0, 0,
749            0, 0, 1, 0, 0, 0, 0, 0, 0,
750            0, 1, 0, 0, 0, 1, 0, 1, 0
751        ]);
752        let mut pf = PivotFinder::new(&a, PivotType::Rows, PivotCondition::One);
753
754        pf.find_fl_col_pivots();
755
756        assert_eq!(pf.pivots.iter().collect_vec(), vec![(4, 2), (3, 4), (0, 0), (5, 7), (2, 3)]);
757    }
758
759    #[test]
760    fn find_fl_row_col_pivots() { 
761        let a = SpMat::from_dense_data((6, 9), [
762            1, 0, 0, 0, 0, 1, 0, 0, 1,
763            0, 1, 1, 1, 0, 1, 0, 1, 0,
764            0, 0, 1, 1, 0, 0, 0, 1, 1,
765            0, 1, 0, 0, 1, 0, 0, 0, 0,
766            0, 0, 1, 0, 0, 0, 0, 0, 0,
767            0, 1, 0, 0, 0, 1, 0, 1, 0
768        ]);
769        let mut pf = PivotFinder::new(&a, PivotType::Rows, PivotCondition::One);
770
771        pf.find_fl_pivots();
772
773        assert_eq!(pf.pivots.iter().collect_vec(), vec![(4, 2), (3, 1), (0, 0)]);
774
775        pf.find_fl_col_pivots();
776
777        assert_eq!(pf.pivots.iter().collect_vec(), vec![(4, 2), (3, 1), (0, 0), (5, 7), (2, 3)]);
778    }
779
780    #[test]
781    fn find_cycle_free_pivots_s() {
782        let a = SpMat::from_dense_data((6, 9), [
783            1, 0, 0, 0, 0, 1, 0, 0, 1,
784            0, 1, 1, 1, 0, 1, 0, 1, 0,
785            0, 0, 1, 1, 0, 0, 0, 1, 1,
786            0, 1, 0, 0, 1, 0, 0, 0, 0,
787            0, 0, 1, 0, 0, 0, 0, 0, 0,
788            0, 1, 0, 0, 0, 1, 0, 1, 0
789        ]);
790        let mut pf = PivotFinder::new(&a, PivotType::Rows, PivotCondition::One);
791
792        pf.find_cycle_free_pivots_s();
793
794        assert_eq!(pf.pivots.iter().collect_vec(), vec![(4, 2), (3, 4), (0, 0), (5, 1), (2, 3)]);
795    }
796
797    #[cfg(feature = "multithread")]
798    #[test]
799    fn find_cycle_free_pivots_m() {
800        let a = SpMat::from_dense_data((6, 9), [
801            1, 0, 0, 0, 0, 1, 0, 0, 1,
802            0, 1, 1, 1, 0, 1, 0, 1, 0,
803            0, 0, 1, 1, 0, 0, 0, 1, 1,
804            0, 1, 0, 0, 1, 0, 0, 0, 0,
805            0, 0, 1, 0, 0, 0, 0, 0, 0,
806            0, 1, 0, 0, 0, 1, 0, 1, 0
807        ]);
808        let mut pf = PivotFinder::new(&a, PivotType::Rows, PivotCondition::One);
809
810        pf.find_cycle_free_pivots_m();
811
812        assert!(pf.pivots.count() >= 5);
813    }
814
815    #[test]
816    fn zero() { 
817        let a = SpMat::from_dense_data((1, 1), [0]);
818        let pivs = find_pivots(&a, PivotType::Rows, PivotCondition::One);
819        let r = pivs.len();
820        assert_eq!(r, 0);
821    }
822
823    #[test]
824    fn id_1() { 
825        let a = SpMat::from_dense_data((1, 1), [1]);
826        let pivs = find_pivots(&a, PivotType::Rows, PivotCondition::One);
827        let r = pivs.len();
828        assert_eq!(r, 1);
829    }
830
831    #[test]
832    fn id_2() { 
833        let a = SpMat::from_dense_data((2, 2), [
834            1, 0, 0, 1
835        ]);
836        let pivs = find_pivots(&a, PivotType::Rows, PivotCondition::One);
837        let r = pivs.len();
838        assert_eq!(r, 2);
839    }
840
841    #[test]
842    fn result() { 
843        let a = SpMat::from_dense_data((6, 9), [
844            1, 0, 0, 0, 0, 1, 0, 0, 1,
845            0, 1, 1, 1, 0, 1, 0, 1, 0,
846            0, 0, 1, 1, 0, 0, 0, 1, 1,
847            0, 1, 0, 0, 1, 0, 0, 0, 0,
848            0, 0, 1, 0, 0, 0, 0, 0, 0,
849            0, 1, 0, 0, 0, 1, 0, 1, 0
850        ]);
851        let pivs = find_pivots(&a, PivotType::Rows, PivotCondition::One);
852        let r = pivs.len();
853        assert_eq!(r, 5);
854        
855        let (p, q) = perms_by_pivots(&a, &pivs);
856        let b = a.permute(p.view(), q.view()).into_dense();
857
858        assert!((0..r).all(|i| b[(i, i)].is_one()));
859        assert!((0..r).all(|j| {
860            (j+1..r).all(|i| b[(i, j)].is_zero())
861        }));
862    }
863
864    #[test]
865    fn result_cols() { 
866        let a = SpMat::from_dense_data((6, 9), [
867            1, 0, 0, 0, 0, 1, 0, 0, 1,
868            0, 1, 1, 1, 0, 1, 0, 1, 0,
869            0, 0, 1, 1, 0, 0, 0, 1, 1,
870            0, 1, 0, 0, 1, 0, 0, 0, 0,
871            0, 0, 1, 0, 0, 0, 0, 0, 0,
872            0, 1, 0, 0, 0, 1, 0, 1, 0
873        ]);
874        let pivs = find_pivots(&a, PivotType::Cols, PivotCondition::One);
875        let r = pivs.len();
876        assert_eq!(r, 6);
877        
878        let (p, q) = perms_by_pivots(&a, &pivs);
879        let b = a.permute(p.view(), q.view()).into_dense();
880
881        assert!((0..r).all(|i| b[(i, i)].is_one()));
882        assert!((0..r).all(|i| {
883            (i+1..r).all(|j| b[(i, j)].is_zero())
884        }));
885    }
886
887    #[test]
888    fn rand() {
889        let d = 0.1;
890        let shape = (60, 80);
891        let a = SpMat::<i32>::rand(shape, d);
892
893        let pivs = find_pivots(&a, PivotType::Rows, PivotCondition::One);
894        let r = pivs.len();
895        assert!(r > 10);
896        
897        let (p, q) = perms_by_pivots(&a, &pivs);
898        let b = a.permute(p.view(), q.view()).into_dense();
899
900        assert!((0..r).all(|i| b[(i, i)].is_one()));
901        assert!((0..r).all(|j| {
902            (j+1..r).all(|i| b[(i, j)].is_zero())
903        }))
904    }
905}