1use std::slice::Iter;
10use std::cmp::Ordering;
11use std::collections::VecDeque;
12use ahash::AHashSet;
13use itertools::Itertools;
14use log::debug;
15use sprs::PermOwned;
16
17use yui_core::{Ring, RingOps};
18use yui_core::algo::top_sort;
19use super::*;
20use super::util::perm_for_indices;
21
22cfg_if::cfg_if! {
23 if #[cfg(feature = "multithread")] {
24 use std::cell::RefCell;
25 use std::sync::RwLock;
26 use thread_local::ThreadLocal;
27 use rayon::prelude::*;
28 }
29}
30
31const LOG_THRESHOLD: usize = 10_000;
32
33#[derive(Clone, Copy, PartialEq, Eq, Debug)]
34pub enum PivotType {
35 Rows, Cols
36}
37
38#[derive(Clone, Copy, Debug)]
39pub enum PivotCondition {
40 One, Weight(f64), AnyUnit
41}
42
43impl PivotCondition {
44 fn is_cand<R>(&self, r: &R) -> bool
45 where R: Ring, for<'x> &'x R: RingOps<R> {
46 match self {
47 PivotCondition::One => r.is_pm_one(),
48 PivotCondition::Weight(w) => r.is_unit() && r.c_weight() <= *w,
49 PivotCondition::AnyUnit => r.is_unit(),
50 }
51 }
52}
53
54pub fn find_pivots<R>(a: &SpMat<R>, piv_type: PivotType, pivot_cond: PivotCondition) -> Vec<(usize, usize)>
55where R: Ring, for<'x> &'x R: RingOps<R> {
56 if a.is_zero() {
57 return vec![];
58 }
59
60 let mut pf = PivotFinder::new(a, piv_type, pivot_cond);
61 pf.find_pivots();
62 pf.result()
63}
64
65pub fn perms_by_pivots<R>(a: &SpMat<R>, pivs: &[(usize, usize)]) -> (PermOwned, PermOwned)
66where R: Ring, for<'x> &'x R: RingOps<R> {
67 let (m, n) = a.shape();
68 (
69 perm_for_indices(m, pivs.iter().map(|(i, _)| i)),
70 perm_for_indices(n, pivs.iter().map(|(_, j)| j))
71 )
72}
73
74type Row = usize;
75type Col = usize;
76
77pub struct PivotFinder {
78 str: MatrixStr,
79 pivots: PivotData,
80 piv_type: PivotType
81}
82
83impl PivotFinder {
84 pub fn new<R>(a: &SpMat<R>, piv_type: PivotType, pivot_cond: PivotCondition) -> Self
85 where R: Ring, for<'x> &'x R: RingOps<R> {
86 let str = MatrixStr::new(a, piv_type, pivot_cond);
87 let pivots = PivotData::new(a, piv_type);
88 PivotFinder{ str, pivots, piv_type }
89 }
90
91 pub fn find_pivots(&mut self) {
92 debug!("pivots: {:?} ..", self.str.shape());
93
94 self.find_fl_pivots();
95 self.find_fl_col_pivots();
96 self.find_cycle_free_pivots();
97
98 debug!("pivots: {:?} => {}.", self.str.shape(), self.pivots.count());
99 }
100
101 pub fn result(&self) -> Vec<(usize, usize)> {
102 let tree = self.pivots.iter().map(|(i, j)| {
103 let list = self.str.cols_in(i).filter(|&&j2|
104 j != j2 && self.pivots.has_col(j2)
105 ).copied().collect_vec();
106 (j, list)
107 });
108
109 let sorted = top_sort(tree).unwrap();
110 let is_row_type = self.piv_type == PivotType::Rows;
111
112 sorted.into_iter().map(|j| {
113 let i = self.pivots.row_for(j).unwrap();
114 if is_row_type { (i, j) } else { (j, i) }
115 }).collect_vec()
116 }
117
118 fn rows(&self) -> Row {
119 self.str.shape.0
120 }
121
122 fn cols(&self) -> Col {
123 self.str.shape.1
124 }
125
126 fn remain_rows(&self) -> impl Iterator<Item = Row> {
127 let piv_rows: AHashSet<_> = self.pivots.iter().map(|(i, _)| i).collect();
128 let m = self.rows();
129
130 (0 .. m).filter(|&i|
131 !piv_rows.contains(&i) && !self.str.is_empty_row(i)
132 ).sorted_by(|&i1, &i2|
133 self.str.cmp_rows(i1, i2)
134 )
135 }
136
137 fn occupied_cols(&self) -> AHashSet<Col> {
138 self.pivots.iter().fold(AHashSet::new(), |mut res, (i, _)| {
139 for &j in self.str.cols_in(i) {
140 res.insert(j);
141 }
142 res
143 })
144 }
145
146 fn find_fl_pivots(&mut self) {
147 let remain_rows: Vec<_> = self.remain_rows().collect();
148
149 for i in remain_rows {
150 let Some(j) = self.str.head_col_in(i) else { continue };
151
152 if !self.pivots.has_col(j) && self.str.is_candidate(i, j) {
153 self.pivots.set(i, j);
154 }
155 }
156
157 let piv_count = self.pivots.count();
158
159 debug!(" fl-pivots: +{}.", piv_count);
160 }
161
162 fn find_fl_col_pivots(&mut self) {
163 let before_piv_count = self.pivots.count();
164
165 let remain_rows: Vec<_> = self.remain_rows().collect();
166 let mut occ_cols = self.occupied_cols();
167
168 for i in remain_rows {
169 let mut cands = vec![];
170
171 for &j in self.str.cols_in(i) {
172 if !occ_cols.contains(&j) && self.str.is_candidate(i, j) {
173 cands.push(j);
174 }
175 }
176
177 let Some(j) = cands.into_iter().sorted_by(|&j1, &j2|
178 self.str.cmp_cols(j1, j2)
179 ).next() else { continue };
180
181 self.pivots.set(i, j);
182
183 for &j in self.str.cols_in(i) {
184 occ_cols.insert(j);
185 }
186 }
187
188 let piv_count = self.pivots.count();
189
190 debug!(" fl-col-pivots: +{}, total: {}.", piv_count - before_piv_count, piv_count);
191 }
192
193 fn find_cycle_free_pivots(&mut self) {
194 let before_piv_count = self.pivots.count();
195
196 cfg_if::cfg_if! {
197 if #[cfg(feature = "multithread")] {
198 self.find_cycle_free_pivots_m();
199 } else {
200 self.find_cycle_free_pivots_s();
201 }
202 }
203
204 let piv_count = self.pivots.count();
205
206 debug!(" cycle-free-pivots: +{}, total: {}.", piv_count - before_piv_count, piv_count);
207 }
208
209 #[allow(unused)]
210 fn find_cycle_free_pivots_s(&mut self) {
211 let remain_rows: Vec<_> = self.remain_rows().collect();
212 let total_rows = remain_rows.len();
213
214 debug!(" start find-cycle-free-pivots: {total_rows} rows");
215
216 let n = self.cols();
217 let mut w = RowWorker::new(n);
218 let mut row_count = 0;
219
220 for i in remain_rows {
221 if let Some(j) = w.find_cycle_free_pivots(i, &self.str, &self.pivots) {
222 self.pivots.set(i, j);
223 }
224
225 if self.should_report() {
226 row_count += 1;
227 if row_count % LOG_THRESHOLD == 0 {
228 let c = self.pivots.count();
229 debug!(" [{row_count}/{total_rows}], {c} pivots.");
230 }
231 }
232 }
233 }
234
235 #[cfg(feature = "multithread")]
236 fn find_cycle_free_pivots_m(&mut self) {
237 use yui_core::util::sync::SyncCounter;
238
239 let remain_rows = self.remain_rows().collect_vec();
240 let total_rows = remain_rows.len();
241
242 debug!(" start find-cycle-free-pivots: {total_rows} rows");
243
244 let n = self.cols();
245 let pivots = RwLock::new(
246 std::mem::take(&mut self.pivots)
247 );
248 let loc_pivots_tls = ThreadLocal::new();
249 let loc_worker_tls = ThreadLocal::new();
250
251 let report = self.should_report();
252 let row_counter = SyncCounter::new();
253
254 remain_rows.par_iter().for_each(|&i| {
255 let mut loc_pivots = init_tls(&loc_pivots_tls, ||
256 pivots.read().unwrap().clone()
257 ).borrow_mut();
258
259 let mut w = init_tls(&loc_worker_tls, ||
260 RowWorker::new(n)
261 ).borrow_mut();
262
263 loc_pivots.update_from(&pivots.read().unwrap());
264 w.init(i, &self.str, &loc_pivots);
265
266 self.find_cycle_free_pivots_in(&pivots, &mut loc_pivots, &mut w);
267
268 if report {
269 let row_count = row_counter.incr();
270 if row_count % LOG_THRESHOLD == 0 {
271 let c = loc_pivots.count();
272 debug!(" [{row_count}/{total_rows}], {c} pivots.");
273 }
274 }
275 });
276
277 self.pivots = pivots.into_inner().unwrap();
278 }
279
280 #[cfg(feature = "multithread")]
281 fn find_cycle_free_pivots_in(&self, pivots: &RwLock<PivotData>, loc_pivots: &mut PivotData, w: &mut RowWorker) {
282 loop {
283 w.traverse(&self.str, loc_pivots);
284
285 let Some(j) = w.choose_candidate(&self.str) else {
286 break
287 };
288
289 let mut pivots = pivots.write().unwrap();
293 w.update_diff(&loc_pivots, &pivots);
294
295 if w.should_retry() {
296 loc_pivots.update_from(&pivots);
297 continue
298 } else {
299 pivots.set(w.row, j);
300 break
301 }
302 }
303 }
304
305 fn should_report(&self) -> bool {
306 self.rows() > LOG_THRESHOLD && log::max_level() >= log::LevelFilter::Debug
307 }
308}
309
310#[cfg(feature = "multithread")]
311fn init_tls<T, F>(tl: &ThreadLocal<RefCell<T>>, f: F) -> &RefCell<T>
312where T: Send, F: FnOnce() -> T {
313 tl.get_or(|| RefCell::new( f() ) )
314}
315
316struct MatrixStr {
317 shape: (usize, usize),
318 entries: Vec<Vec<Col>>, cands: Vec<AHashSet<Col>>, row_wght: Vec<f64>, col_wght: Vec<f64>, }
323
324impl MatrixStr {
325 fn new<R>(a: &SpMat<R>, piv_type: PivotType, pivot_cond: PivotCondition) -> Self
326 where R: Ring, for<'x> &'x R: RingOps<R> {
327 let shape = match piv_type {
328 PivotType::Rows => a.shape(),
329 PivotType::Cols => (a.ncols(), a.nrows())
330 };
331 let t = match piv_type {
332 PivotType::Rows => |i: usize, j: usize| (i, j),
333 PivotType::Cols => |i, j| (j, i)
334 };
335
336 let (m, n) = shape;
337 let mut entries = vec![vec![]; m];
338 let mut row_wght = vec![0.0; m];
339 let mut col_wght = vec![0.0; n];
340 let mut cands = vec![AHashSet::new(); m];
341
342 for (i, j, r) in a.iter() {
343 if r.is_zero() { continue }
344
345 let (i, j) = t(i, j);
346 entries[i].push(j);
347
348 let w = r.c_weight();
349 row_wght[i] += w;
350 col_wght[j] += w;
351
352 if pivot_cond.is_cand(r) {
353 cands[i].insert(j);
354 }
355 }
356
357 Self { shape, entries, cands, row_wght, col_wght }
358 }
359
360 fn shape(&self) -> (usize, usize) {
361 self.shape
362 }
363
364 fn is_empty_row(&self, i: Row) -> bool {
365 self.entries[i].is_empty()
366 }
367
368 fn head_col_in(&self, i: Row) -> Option<Col> {
369 self.entries[i].first().copied()
370 }
371
372 fn cols_in(&self, i: Row) -> Iter<Col> {
373 self.entries[i].iter()
374 }
375
376 fn cmp_rows(&self, i1: Row, i2: Row) -> Ordering {
377 if let Some(o) = self.row_wght[i1].partial_cmp(&self.row_wght[i2]) {
378 o.then(Ord::cmp(&i1, &i2))
379 } else {
380 Ordering::Equal
381 }
382 }
383
384 fn cmp_cols(&self, j1: Col, j2: Col) -> Ordering {
385 if let Some(o) = self.col_wght[j1].partial_cmp(&self.col_wght[j2]) {
386 o.then(Ord::cmp(&j1, &j2))
387 } else {
388 Ordering::Equal
389 }
390 }
391
392 fn is_candidate(&self, i: Row, j: Col) -> bool {
393 self.cands[i].contains(&j)
394 }
395}
396
397#[derive(Clone, Default)]
398struct PivotData {
399 data: Vec<Option<Row>>, indices: Vec<Col>
401}
402
403impl PivotData {
404 fn new<R>(a: &SpMat<R>, piv_type: PivotType) -> Self
405 where R: Ring, for<'x> &'x R: RingOps<R> {
406 let n = if piv_type == PivotType::Rows {
407 a.ncols()
408 } else {
409 a.nrows()
410 };
411 let data = vec![None; n];
412 let indices = vec![];
413 Self { data, indices }
414 }
415
416 fn count(&self) -> usize {
417 self.indices.len()
418 }
419
420 fn has_col(&self, j: Col) -> bool {
421 self.data[j].is_some()
422 }
423
424 fn row_for(&self, j: Col) -> Option<Row> {
425 self.data[j]
426 }
427
428 fn set(&mut self, i: Row, j: Col) {
429 assert!(!self.has_col(j));
430 self.data[j] = Some(i);
431 self.indices.push(j);
432 }
433
434 fn iter(&self) -> impl Iterator<Item = (Row, Col)> + '_ {
435 self.indices.iter().map(|&j| {
436 let i = self.data[j].unwrap();
437 (i, j)
438 })
439 }
440
441 #[allow(unused)]
442 fn pivot_at(&self, k: usize) -> (Row, Col) {
443 let j = self.indices[k];
444 let i = self.data[j].unwrap();
445 (i, j)
446 }
447
448 fn update_from(&mut self, from: &Self) {
449 debug_assert!(self.count() <= from.count());
450 for k in self.count() .. from.count() {
451 let (i, j) = from.pivot_at(k);
452 self.set(i, j);
453 }
454 }
455}
456
457#[repr(u8)]
458#[derive(Clone, Copy, PartialEq, Eq, Debug)]
459enum EntryStatus {
460 None, Candidate, Occupied
461}
462
463struct RowWorker {
464 row: usize,
465 status: Vec<EntryStatus>,
466 ncand: usize,
467 queue: VecDeque<Col>,
468 queued: AHashSet<Col>
469}
470
471impl RowWorker {
472 fn new(size: usize) -> Self {
473 let status = vec![EntryStatus::None; size];
474 let queue = VecDeque::new();
475 let queued = AHashSet::new();
476 RowWorker {row: 0, status, ncand: 0, queue, queued }
477 }
478
479 fn clear(&mut self) {
480 self.row = 0;
481 self.status.fill(EntryStatus::None);
482 self.ncand = 0;
483 self.queue.clear();
484 self.queued.clear();
485 }
486
487 fn find_cycle_free_pivots(&mut self, i: usize, str: &MatrixStr, pivots: &PivotData) -> Option<Col> {
496 self.init(i, str, pivots);
497 self.traverse(str, pivots);
498 self.choose_candidate(str)
499 }
500
501 fn init(&mut self, i: usize, str: &MatrixStr, pivots: &PivotData) {
502 self.clear();
503 self.row = i;
504
505 for &j in str.cols_in(i) {
506 if pivots.has_col(j) {
507 self.enqueue(j);
508 self.set_occupied(j);
509 } else if str.is_candidate(i, j) {
510 self.set_candidate(j);
511 } else {
512 self.set_occupied(j);
513 }
514 }
515 }
516
517 fn traverse(&mut self, str: &MatrixStr, pivots: &PivotData) {
518 if !self.has_candidate() {
519 return
520 }
521
522 while let Some(j) = self.dequeue() {
523 let i2 = pivots.row_for(j).unwrap();
524
525 for &j2 in str.cols_in(i2) {
526 if pivots.has_col(j2) && !self.is_queued(j2) {
527 self.enqueue(j2);
528 }
529
530 self.set_occupied(j2);
531
532 if !self.has_candidate() {
533 break
534 }
535 }
536 }
537 }
538
539 fn choose_candidate(&self, str: &MatrixStr) -> Option<Col> {
540 let n = self.status.len();
541 (0 .. n)
542 .filter(|&j| self.is_candidate(j))
543 .sorted_by(|&j1, &j2|
544 str.cmp_cols(j1, j2)
545 ).next()
546 }
547
548 #[allow(dead_code)]
549 fn update_diff(&mut self, loc_pivots: &PivotData, pivots: &PivotData) {
550 debug_assert!(loc_pivots.count() <= pivots.count());
551 for k in loc_pivots.count()..pivots.count() {
552 let j = pivots.indices[k];
553 if self.is_candidate(j) || self.is_occupied(j) {
554 self.enqueue(j);
555 self.set_occupied(j);
556 }
557 }
558 }
559
560 fn should_retry(&self) -> bool {
561 !self.queue.is_empty()
562 }
563
564 fn has_candidate(&self) -> bool {
565 self.ncand > 0
566 }
567
568 fn is_candidate(&self, i: usize) -> bool {
569 self.status[i] == EntryStatus::Candidate
570 }
571
572 fn set_candidate(&mut self, i: usize) {
573 assert_eq!(self.status[i], EntryStatus::None);
574 self.status[i] = EntryStatus::Candidate;
575 self.ncand += 1;
576 }
577
578 fn is_occupied(&self, i: usize) -> bool {
579 self.status[i] == EntryStatus::Occupied
580 }
581
582 fn set_occupied(&mut self, i: usize) {
583 if self.is_candidate(i) {
584 self.ncand -= 1;
585 }
586 self.status[i] = EntryStatus::Occupied;
587 }
588
589 fn enqueue(&mut self, i: Col) {
590 self.queue.push_back(i);
591 self.queued.insert(i);
592 }
593
594 fn dequeue(&mut self) -> Option<Col> {
595 self.queue.pop_front()
596 }
597
598 fn is_queued(&self, i: Col) -> bool {
599 self.queued.contains(&i)
600 }
601}
602
603#[cfg(test)]
604mod tests {
605 use super::*;
606 use num_traits::{Zero, One};
607
608 #[test]
609 fn str_init() {
610 let a = SpMat::from_dense_data((6, 9), [
611 1, 0, 1, 0, 0, 1, 1, 0, 1,
612 0, 1, 1, 1, 0, 1, 0, 2, 0,
613 0, 0, 1, 1, 0, 0, 0, 1, 1,
614 0, 1, 1, 0, 3, 0, 0, 0, 0,
615 0, 1, 0, 1, 0, 0, 1, 0, 1,
616 1, 0, 1, 0, 1, 1, 0, 1, 1
617 ]);
618 let str = MatrixStr::new(&a, PivotType::Rows, PivotCondition::One);
619
620 assert_eq!(str.entries, vec![
621 vec![0,2,5,6,8],
622 vec![1,2,3,5,7],
623 vec![2,3,7,8],
624 vec![1,2,4],
625 vec![1,3,6,8],
626 vec![0,2,4,5,7,8]]
627 );
628 assert_eq!(str.row_wght, vec![5.0, 6.0, 4.0, 5.0, 4.0, 6.0]);
629 assert_eq!(str.col_wght, vec![2.0, 3.0, 5.0, 3.0, 4.0, 3.0, 2.0, 4.0, 4.0]);
630 assert_eq!(str.cands, vec![
631 AHashSet::from_iter([0,2,5,6,8]),
632 AHashSet::from_iter([1,2,3,5]),
633 AHashSet::from_iter([2,3,7,8]),
634 AHashSet::from_iter([1,2]),
635 AHashSet::from_iter([1,3,6,8]),
636 AHashSet::from_iter([0,2,4,5,7,8])
637 ]);
638 }
639
640 #[test]
641 fn str_row_head() {
642 let a = SpMat::from_dense_data((4, 4), [
643 1, 0, 1, 0,
644 0, 1, 1, 1,
645 0, 0, 0, 0,
646 0, 0, 1, 1,
647 ]);
648 let str = MatrixStr::new(&a, PivotType::Rows, PivotCondition::One);
649
650 assert_eq!(str.head_col_in(0), Some(0));
651 assert_eq!(str.head_col_in(1), Some(1));
652 assert_eq!(str.head_col_in(2), None);
653 assert_eq!(str.head_col_in(3), Some(2));
654 }
655
656 #[test]
657 fn rows_cols() {
658 let a = SpMat::<i32>::from_dense_data((4, 3), []);
659 let pf = PivotFinder::new(&a, PivotType::Rows, PivotCondition::One);
660 assert_eq!(pf.rows(), 4);
661 assert_eq!(pf.cols(), 3);
662 }
663
664 #[test]
665 fn pivot_data() {
666 let a = SpMat::from_dense_data((2, 4), [
667 1, 0, 1, 0,
668 0, 0, 1, 1,
669 ]);
670 let mut piv = PivotData::new(&a, PivotType::Rows);
671
672 assert_eq!(piv.count(), 0);
673 assert!(!piv.has_col(0));
674 assert_eq!(piv.row_for(0), None);
675
676 piv.set(1, 2);
677
678 assert_eq!(piv.count(), 1);
679 assert!(piv.has_col(2));
680 assert_eq!(piv.row_for(2), Some(1));
681 }
682
683 #[test]
684 fn remain_rows() {
685 let a = SpMat::from_dense_data((4, 4), [
686 1, 0, 1, 0,
687 0, 1, 1, 1,
688 0, 0, 0, 0,
689 0, 0, 1, 1,
690 ]);
691 let mut pf = PivotFinder::new(&a, PivotType::Rows, PivotCondition::One);
692
693 assert_eq!(pf.remain_rows().collect_vec(), vec![0,3,1]);
694
695 pf.pivots.set(0, 0);
696
697 assert_eq!(pf.remain_rows().collect_vec(), vec![3,1]);
698
699 pf.pivots.set(1, 1);
700
701 assert_eq!(pf.remain_rows().collect_vec(), vec![3]);
702 }
703
704 #[test]
705 fn pivots() {
706 let a = SpMat::from_dense_data((4, 4), [
707 1, 0, 1, 0,
708 0, 1, 1, 1,
709 0, 0, 0, 0,
710 0, 0, 1, 1,
711 ]);
712 let mut pf = PivotFinder::new(&a, PivotType::Rows, PivotCondition::One);
713
714 assert_eq!(pf.occupied_cols(), AHashSet::new());
715
716 pf.pivots.set(0, 0);
717
718 assert_eq!(pf.occupied_cols(), AHashSet::from_iter([0,2]));
719
720 pf.pivots.set(1, 1);
721
722 assert_eq!(pf.occupied_cols(), AHashSet::from_iter([0,1,2,3]));
723 }
724
725 #[test]
726 fn find_fl_pivots() {
727 let a = SpMat::from_dense_data((6, 9), [
728 1, 0, 1, 0, 0, 1, 1, 0, 1,
729 0, 1, 1, 1, 0, 1, 0, 1, 0,
730 0, 0, 1, 1, 0, 0, 0, 1, 1,
731 0, 1, 1, 0, 1, 0, 0, 0, 0,
732 0, 0, 1, 1, 0, 0, 0, 0, 0,
733 0, 0, 0, 0, 0, 1, 0, 1, 1
734 ]);
735 let mut pf = PivotFinder::new(&a, PivotType::Rows, PivotCondition::One);
736
737 pf.find_fl_pivots();
738
739 assert_eq!(pf.pivots.iter().collect_vec(), vec![(4, 2), (3, 1), (5, 5), (0, 0)]);
740 }
741
742 #[test]
743 fn find_fl_col_pivots() {
744 let a = SpMat::from_dense_data((6, 9), [
745 1, 0, 0, 0, 0, 1, 0, 0, 1,
746 0, 1, 1, 1, 0, 1, 0, 1, 0,
747 0, 0, 1, 1, 0, 0, 0, 1, 1,
748 0, 1, 0, 0, 1, 0, 0, 0, 0,
749 0, 0, 1, 0, 0, 0, 0, 0, 0,
750 0, 1, 0, 0, 0, 1, 0, 1, 0
751 ]);
752 let mut pf = PivotFinder::new(&a, PivotType::Rows, PivotCondition::One);
753
754 pf.find_fl_col_pivots();
755
756 assert_eq!(pf.pivots.iter().collect_vec(), vec![(4, 2), (3, 4), (0, 0), (5, 7), (2, 3)]);
757 }
758
759 #[test]
760 fn find_fl_row_col_pivots() {
761 let a = SpMat::from_dense_data((6, 9), [
762 1, 0, 0, 0, 0, 1, 0, 0, 1,
763 0, 1, 1, 1, 0, 1, 0, 1, 0,
764 0, 0, 1, 1, 0, 0, 0, 1, 1,
765 0, 1, 0, 0, 1, 0, 0, 0, 0,
766 0, 0, 1, 0, 0, 0, 0, 0, 0,
767 0, 1, 0, 0, 0, 1, 0, 1, 0
768 ]);
769 let mut pf = PivotFinder::new(&a, PivotType::Rows, PivotCondition::One);
770
771 pf.find_fl_pivots();
772
773 assert_eq!(pf.pivots.iter().collect_vec(), vec![(4, 2), (3, 1), (0, 0)]);
774
775 pf.find_fl_col_pivots();
776
777 assert_eq!(pf.pivots.iter().collect_vec(), vec![(4, 2), (3, 1), (0, 0), (5, 7), (2, 3)]);
778 }
779
780 #[test]
781 fn find_cycle_free_pivots_s() {
782 let a = SpMat::from_dense_data((6, 9), [
783 1, 0, 0, 0, 0, 1, 0, 0, 1,
784 0, 1, 1, 1, 0, 1, 0, 1, 0,
785 0, 0, 1, 1, 0, 0, 0, 1, 1,
786 0, 1, 0, 0, 1, 0, 0, 0, 0,
787 0, 0, 1, 0, 0, 0, 0, 0, 0,
788 0, 1, 0, 0, 0, 1, 0, 1, 0
789 ]);
790 let mut pf = PivotFinder::new(&a, PivotType::Rows, PivotCondition::One);
791
792 pf.find_cycle_free_pivots_s();
793
794 assert_eq!(pf.pivots.iter().collect_vec(), vec![(4, 2), (3, 4), (0, 0), (5, 1), (2, 3)]);
795 }
796
797 #[cfg(feature = "multithread")]
798 #[test]
799 fn find_cycle_free_pivots_m() {
800 let a = SpMat::from_dense_data((6, 9), [
801 1, 0, 0, 0, 0, 1, 0, 0, 1,
802 0, 1, 1, 1, 0, 1, 0, 1, 0,
803 0, 0, 1, 1, 0, 0, 0, 1, 1,
804 0, 1, 0, 0, 1, 0, 0, 0, 0,
805 0, 0, 1, 0, 0, 0, 0, 0, 0,
806 0, 1, 0, 0, 0, 1, 0, 1, 0
807 ]);
808 let mut pf = PivotFinder::new(&a, PivotType::Rows, PivotCondition::One);
809
810 pf.find_cycle_free_pivots_m();
811
812 assert!(pf.pivots.count() >= 5);
813 }
814
815 #[test]
816 fn zero() {
817 let a = SpMat::from_dense_data((1, 1), [0]);
818 let pivs = find_pivots(&a, PivotType::Rows, PivotCondition::One);
819 let r = pivs.len();
820 assert_eq!(r, 0);
821 }
822
823 #[test]
824 fn id_1() {
825 let a = SpMat::from_dense_data((1, 1), [1]);
826 let pivs = find_pivots(&a, PivotType::Rows, PivotCondition::One);
827 let r = pivs.len();
828 assert_eq!(r, 1);
829 }
830
831 #[test]
832 fn id_2() {
833 let a = SpMat::from_dense_data((2, 2), [
834 1, 0, 0, 1
835 ]);
836 let pivs = find_pivots(&a, PivotType::Rows, PivotCondition::One);
837 let r = pivs.len();
838 assert_eq!(r, 2);
839 }
840
841 #[test]
842 fn result() {
843 let a = SpMat::from_dense_data((6, 9), [
844 1, 0, 0, 0, 0, 1, 0, 0, 1,
845 0, 1, 1, 1, 0, 1, 0, 1, 0,
846 0, 0, 1, 1, 0, 0, 0, 1, 1,
847 0, 1, 0, 0, 1, 0, 0, 0, 0,
848 0, 0, 1, 0, 0, 0, 0, 0, 0,
849 0, 1, 0, 0, 0, 1, 0, 1, 0
850 ]);
851 let pivs = find_pivots(&a, PivotType::Rows, PivotCondition::One);
852 let r = pivs.len();
853 assert_eq!(r, 5);
854
855 let (p, q) = perms_by_pivots(&a, &pivs);
856 let b = a.permute(p.view(), q.view()).into_dense();
857
858 assert!((0..r).all(|i| b[(i, i)].is_one()));
859 assert!((0..r).all(|j| {
860 (j+1..r).all(|i| b[(i, j)].is_zero())
861 }));
862 }
863
864 #[test]
865 fn result_cols() {
866 let a = SpMat::from_dense_data((6, 9), [
867 1, 0, 0, 0, 0, 1, 0, 0, 1,
868 0, 1, 1, 1, 0, 1, 0, 1, 0,
869 0, 0, 1, 1, 0, 0, 0, 1, 1,
870 0, 1, 0, 0, 1, 0, 0, 0, 0,
871 0, 0, 1, 0, 0, 0, 0, 0, 0,
872 0, 1, 0, 0, 0, 1, 0, 1, 0
873 ]);
874 let pivs = find_pivots(&a, PivotType::Cols, PivotCondition::One);
875 let r = pivs.len();
876 assert_eq!(r, 6);
877
878 let (p, q) = perms_by_pivots(&a, &pivs);
879 let b = a.permute(p.view(), q.view()).into_dense();
880
881 assert!((0..r).all(|i| b[(i, i)].is_one()));
882 assert!((0..r).all(|i| {
883 (i+1..r).all(|j| b[(i, j)].is_zero())
884 }));
885 }
886
887 #[test]
888 fn rand() {
889 let d = 0.1;
890 let shape = (60, 80);
891 let a = SpMat::<i32>::rand(shape, d);
892
893 let pivs = find_pivots(&a, PivotType::Rows, PivotCondition::One);
894 let r = pivs.len();
895 assert!(r > 10);
896
897 let (p, q) = perms_by_pivots(&a, &pivs);
898 let b = a.permute(p.view(), q.view()).into_dense();
899
900 assert!((0..r).all(|i| b[(i, i)].is_one()));
901 assert!((0..r).all(|j| {
902 (j+1..r).all(|i| b[(i, j)].is_zero())
903 }))
904 }
905}