rust_blas/math/
bandmat.rs

1// Copyright 2015 Michael Yang. All rights reserved.
2// Use of this source code is governed by a MIT-style
3// license that can be found in the LICENSE file.
4use crate::math::Mat;
5use crate::matrix::BandMatrix;
6use crate::vector::ops::Copy;
7use crate::Matrix;
8use num::traits::NumCast;
9use std::cmp::{max, min};
10use std::fmt;
11use std::fmt::Debug;
12use std::iter::repeat;
13use std::mem::ManuallyDrop;
14use std::ops::Index;
15use std::slice;
16
17#[derive(Debug, PartialEq)]
18/// Banded Matrix
19/// A banded matrix is a matrix where only the diagonal, a number of super-diagonals and a number of
20/// sub-diagonals are non-zero.
21/// https://en.wikipedia.org/wiki/Band_matrix
22pub struct BandMat<T> {
23    rows: usize,
24    cols: usize,
25    sub_diagonals: u32,
26    sup_diagonals: u32,
27    data: Vec<T>,
28}
29
30impl<T> BandMat<T> {
31    pub fn new(n: usize, m: usize, sub: u32, sup: u32) -> BandMat<T> {
32        let len = n * m;
33        let mut data = Vec::with_capacity(len);
34        unsafe {
35            data.set_len(len);
36        }
37
38        BandMat {
39            rows: n,
40            cols: m,
41            data,
42            sub_diagonals: sub,
43            sup_diagonals: sup,
44        }
45    }
46
47    pub fn rows(&self) -> usize {
48        self.rows
49    }
50    pub fn cols(&self) -> usize {
51        self.cols
52    }
53    /// Set Rows Manually
54    /// # Safety
55    /// No guarantees are made about rows x columns being equivalent to data length after this
56    /// operation
57    pub unsafe fn set_rows(&mut self, n: usize) {
58        self.rows = n;
59    }
60    /// Set Columns Manually
61    /// # Safety
62    /// No guarantees are made about rows x columns being equivalent to data length after this
63    /// operation
64    pub unsafe fn set_cols(&mut self, n: usize) {
65        self.cols = n;
66    }
67    pub unsafe fn set_sub_diagonals(&mut self, n: u32) {
68        self.sub_diagonals = n;
69    }
70    pub unsafe fn set_sup_diagonals(&mut self, n: u32) {
71        self.sup_diagonals = n;
72    }
73
74    pub unsafe fn push(&mut self, val: T) {
75        self.data.push(val);
76    }
77}
78
79impl<T: std::marker::Copy> BandMat<T> {
80    /// Converts a [`Mat`] into a [`BandMat`].
81    ///
82    /// The idea is to compress the the band matrix by compressing it to a form that is as legible
83    /// as possible but without many of the extraneous zeros. You can read more about the process
84    /// here: [Wikipedia](https://en.wikipedia.org/wiki/Band_matrix#Band_storage) and [Official
85    /// BLAS
86    /// Docs](http://www.netlib.org/lapack/explore-html/d7/d15/group__double__blas__level2_ga0dc187c15a47772440defe879d034888.html#ga0dc187c15a47772440defe879d034888),
87    /// but the best demonstration is probably by example.
88    ///
89    /// Say you have a matrix:
90    ///
91    /// ```
92    /// let m =
93    /// [
94    ///   0.5, 2.0, 0.0, 0.0,
95    ///   1.0, 0.5, 2.0, 0.0,
96    ///   0.0, 1.0, 0.5, 2.0,
97    ///   0.0, 0.0, 1.0, 0.5,
98    /// ];
99    /// ```
100    ///
101    /// This method will transform it into:
102    ///
103    /// ```
104    /// let x = 0.0;
105    /// let m =
106    /// [
107    ///   x,   0.5, 2.0,
108    ///   1.0, 0.5, 2.0,
109    ///   1.0, 0.5, 2.0,
110    ///   1.0, 0.5,   x,
111    /// ];
112    /// ```
113    ///
114    /// The `x`'s represent the values that will not be read by the blas operation, and therefore
115    /// can remain unchanged. Notice that the dimensions of the new matrix are `(rows, LDA)`, where
116    /// `LDA = <sub diagonals> + <sup diagonals> + 1`. This matrix will be stored in the original
117    /// memory of the matrix that is consumed by this method.
118    ///
119    ///  For details about how the conversion actually happens, consult the code comments.
120    ///
121    /// # Panics
122    ///
123    /// Panics if the size of the vector representing the input matrix is too small, that is
124    /// `rows * LDA > rows * cols`. In this case there is not enough space to perform a safe
125    /// conversion to the Band Storage format.
126    ///
127    /// [`BandMat`]: struct.BandMat.html
128    /// [`Mat`]: ../mat/struct.Mat.html
129    pub fn from_matrix(mat: Mat<T>, sub_diagonals: u32, sup_diagonals: u32) -> BandMat<T> {
130        let mut mat = ManuallyDrop::new(mat);
131
132        let cols = mat.cols();
133        let rows = mat.rows();
134        let lda = (sub_diagonals + 1 + sup_diagonals) as usize;
135        let length = rows * cols;
136
137        // Not enough space to represent the matrix in BandMatrix storage
138        if rows * lda > length {
139            panic!("BandMatrix conversion needed {} space, but only {} was provided. LDA was {}. Not enough space to safely convert to band matrix storage. Please consider expanding the size of the vector for the underlying Matrix", rows * lda, length, lda);
140        }
141
142        let mut v = unsafe { Vec::from_raw_parts(mat.as_mut_ptr(), length, length) };
143
144        /*
145         * For each row in the original matrix we do the following:
146         *
147         *    1. We identify where the numbers start. Represented by the s variable.
148         *    2. We identify where the numbers end. Represented by the e variable.
149         *    3. We identify at which index in the resulting matrix they should be placed. That is
150         *       represented by i.
151         *    4. We call copy_within to move all of those values to their positions in the new
152         *       matrix.
153         */
154        for r in 0..rows {
155            let s = (r * cols) + max(0, r as isize - sub_diagonals as isize) as usize;
156            let e = (r * cols) + min(cols, r + sup_diagonals as usize + 1usize);
157
158            let bandmat_offset =
159                max(0, (lda as isize) - sup_diagonals as isize - r as isize - 1) as usize;
160
161            let i = (r * lda) + bandmat_offset;
162            let i = i as usize;
163            (&mut v).copy_within(s..e, i);
164        }
165
166        BandMat {
167            cols,
168            rows,
169            data: v,
170            sub_diagonals,
171            sup_diagonals,
172        }
173    }
174}
175
176impl<T: std::marker::Copy + Default> BandMat<T> {
177    /// Converts a [`BandMat`] back into a [`Mat`].
178    ///
179    /// This method creates a [`Mat`] instance by reversing the steps from
180    /// the [`from_matrix`] method. It will also fill in all the values that are "zero" to the
181    /// default value of `T`.
182    ///
183    /// For more information about the implementation, please consult the code comments.
184    ///
185    /// # Panics
186    ///
187    /// Panics if the values of `rows * cols` doesn't correspond to the length of the data vector.
188    ///
189    /// [`BandMat`]: struct.BandMat.html
190    /// [`Mat`]: ../mat/struct.Mat.html
191    /// [`from_matrix`]: #method.from_matrix
192    pub fn to_matrix(bandmat: Self) -> Mat<T> {
193        let mut bandmat = ManuallyDrop::new(bandmat);
194
195        let ku = bandmat.sup_diagonals() as usize;
196        let kl = bandmat.sub_diagonals() as usize;
197        let lda = ku + kl + 1;
198        let rows = bandmat.rows();
199        let cols = bandmat.cols();
200        let length = rows * cols;
201
202        if length < lda * rows {
203            panic!("Could not convert BandMat to Mat. The specified length of the data vector is {}, which is less than the expected minimum {} x {} = {}", length, rows, lda, rows * lda);
204        }
205        let mut v = unsafe { Vec::from_raw_parts(bandmat.as_mut_ptr(), length, length) };
206
207        let num_of_last_row_terms = kl + 1 - (rows - min(rows, cols));
208
209        /*
210         * Refer to the `from_matrix` method for explanations of the meanings of the variables, but
211         * now with respect to the band matrix. That is, s now represents the start point in the
212         * band matrix for a particular row. The offset variable just inverts the index of the row
213         * (if we have a total of 10 rows, row 7 will have offset 3).
214         *
215         * We have to iterate on the rows in reverse order, because we need to be careful not to
216         * overwrite anything from the space of the original band matrix and lose values.
217         */
218        for r in (0..rows).rev() {
219            let offset = rows - r - 1;
220
221            let s = max(
222                0,
223                -(kl as isize + 1)
224                    + (num_of_last_row_terms - (if rows > cols { 1 } else { 2 })) as isize
225                    + offset as isize,
226            );
227            let s = (r * lda) as isize + s;
228            let s = s as usize;
229
230            let e = min(lda, num_of_last_row_terms + offset);
231            let e = (r * lda) + e;
232
233            let original_mat_offset =
234                cols as isize - num_of_last_row_terms as isize - offset as isize;
235            let i = (r * cols) + max(0, original_mat_offset) as usize;
236
237            v.copy_within(s..e, i);
238
239            // Fill the rest of the values for that row with "0"
240            let l = e - s;
241            let zero_range = (r * cols)..max(0, i);
242            let zero_range = zero_range.chain(min((r + 1) * cols, i + l)..((r + 1) * cols));
243            for i in zero_range {
244                v[i] = T::default();
245            }
246        }
247
248        Mat::new_from_data(rows, cols, v)
249    }
250}
251
252impl<T: Clone> BandMat<T> {
253    pub fn fill(value: T, n: usize, m: usize) -> BandMat<T> {
254        BandMat {
255            rows: n,
256            cols: m,
257            data: repeat(value).take(n * m).collect(),
258            sub_diagonals: n as u32,
259            sup_diagonals: m as u32,
260        }
261    }
262}
263
264impl<T> Index<usize> for BandMat<T> {
265    type Output = [T];
266
267    fn index(&self, index: usize) -> &[T] {
268        let offset = (index * self.cols) as isize;
269
270        unsafe {
271            let ptr = (&self.data[..]).as_ptr().offset(offset);
272            slice::from_raw_parts(ptr, self.cols)
273        }
274    }
275}
276
277impl<T: fmt::Display> fmt::Display for BandMat<T> {
278    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
279        for i in 0usize..self.rows {
280            for j in 0usize..self.cols {
281                match write!(f, "{}", self[i][j]) {
282                    Ok(_) => (),
283                    x => return x,
284                }
285            }
286
287            match writeln!(f) {
288                Ok(_) => (),
289                x => return x,
290            }
291        }
292
293        Ok(())
294    }
295}
296
297impl<T> Matrix<T> for BandMat<T> {
298    fn lead_dim(&self) -> u32 {
299        self.sub_diagonals + self.sup_diagonals + 1
300    }
301
302    fn rows(&self) -> u32 {
303        let n: Option<u32> = NumCast::from(self.rows);
304        n.unwrap()
305    }
306
307    fn cols(&self) -> u32 {
308        let n: Option<u32> = NumCast::from(self.cols);
309        n.unwrap()
310    }
311
312    fn as_ptr(&self) -> *const T {
313        self.data[..].as_ptr()
314    }
315
316    fn as_mut_ptr(&mut self) -> *mut T {
317        (&mut self.data[..]).as_mut_ptr()
318    }
319}
320
321impl<T> BandMatrix<T> for BandMat<T> {
322    fn sub_diagonals(&self) -> u32 {
323        self.sub_diagonals
324    }
325
326    fn sup_diagonals(&self) -> u32 {
327        self.sup_diagonals
328    }
329
330    fn as_matrix(&self) -> &dyn Matrix<T> {
331        self
332    }
333}
334
335impl<'a, T> From<&'a dyn BandMatrix<T>> for BandMat<T>
336where
337    T: Copy,
338{
339    fn from(a: &dyn BandMatrix<T>) -> BandMat<T> {
340        let n = a.rows() as usize;
341        let m = a.cols() as usize;
342        let len = n * m;
343
344        let sub = a.sub_diagonals() as u32;
345        let sup = a.sup_diagonals() as u32;
346
347        let mut result = BandMat {
348            rows: n,
349            cols: m,
350            data: Vec::with_capacity(len),
351            sub_diagonals: sub,
352            sup_diagonals: sup,
353        };
354        unsafe {
355            result.data.set_len(len);
356        }
357
358        Copy::copy_mat(a.as_matrix(), &mut result);
359        result
360    }
361}
362
363#[cfg(test)]
364mod tests {
365    use super::*;
366
367    fn write_to_memory<T: Clone>(dest: *mut T, source: &Vec<T>) -> () {
368        let mut v1;
369        unsafe {
370            v1 = Vec::from_raw_parts(dest, source.len(), source.capacity());
371            v1.clone_from(source);
372        }
373        let _ = ManuallyDrop::new(v1);
374    }
375
376    fn retrieve_memory<T: Clone>(t: &mut dyn Matrix<T>, l: usize) -> Vec<T> {
377        let mut v: Vec<T> = vec![];
378
379        unsafe {
380            let v1 = Vec::from_raw_parts(t.as_mut_ptr(), l, l);
381            v.clone_from(&v1);
382            let _ = ManuallyDrop::new(v1);
383        }
384
385        v
386    }
387
388    #[test]
389    fn basic_conversion_test() {
390        let v: Vec<f32> = vec![
391            0.5, 2.0, 0.0, 0.0, 1.0, 0.5, 2.0, 0.0, 0.0, 1.0, 0.5, 2.0, 0.0, 0.0, 1.0, 0.5,
392        ];
393
394        let mut m: Mat<f32> = Mat::new(4, 4);
395        let length = m.rows() * m.cols();
396
397        write_to_memory(m.as_mut_ptr(), &v);
398
399        let mut band_m = BandMat::from_matrix(m, 1, 1);
400
401        let result_vec = retrieve_memory(&mut band_m, length);
402
403        // Check random values in position to make sure that they're correct, since it's hard to
404        // actualy predict the real vector values
405        assert_eq!(result_vec[1], 0.5f32);
406        assert_eq!(result_vec[2], 2.0f32);
407        assert_eq!(result_vec[3], 1.0f32);
408        assert_eq!(result_vec[7], 0.5f32);
409        assert_eq!(result_vec[9], 1.0f32);
410    }
411
412    #[test]
413    fn nonsquare_conversion_test() {
414        let v: Vec<f32> = vec![
415            0.5, 1.0, 0.0, 0.0, 2.0, 0.5, 1.0, 0.0, 3.0, 2.0, 0.5, 1.0, 0.0, 3.0, 2.0, 0.5, 0.0,
416            0.0, 3.0, 2.0, 0.0, 0.0, 0.0, 3.0,
417        ];
418
419        let mut m: Mat<f32> = Mat::new(6, 4);
420        let length = m.rows() * m.cols();
421
422        write_to_memory(m.as_mut_ptr(), &v);
423
424        let mut band_m = BandMat::from_matrix(m, 2, 1);
425
426        let result_vec = retrieve_memory(&mut band_m, length);
427
428        assert_eq!(result_vec[2], 0.5);
429        assert_eq!(result_vec[5], 2.0);
430        assert_eq!(result_vec[7], 1.0);
431        assert_eq!(result_vec[8], 3.0);
432        assert_eq!(result_vec[16], 3.0);
433        assert_eq!(result_vec[20], 3.0);
434    }
435
436    #[test]
437    #[should_panic]
438    fn from_big_matrix_panic_test() {
439        let original: Vec<f32> = vec![
440            0.5, 2.0, 3.0, 4.0, 1.0, 0.5, 2.0, 3.0, 5.0, 1.0, 0.5, 2.0, 6.0, 5.0, 1.0, 0.5,
441        ];
442        let mut m: Mat<f32> = Mat::new(4, 4);
443
444        write_to_memory(m.as_mut_ptr(), &original);
445
446        let _ = BandMat::from_matrix(m, 3, 3);
447    }
448
449    #[test]
450    fn to_and_from_conversion_test() {
451        let original: Vec<f32> = vec![
452            0.5, 2.0, 0.0, 0.0, 1.0, 0.5, 2.0, 0.0, 0.0, 1.0, 0.5, 2.0, 0.0, 0.0, 1.0, 0.5,
453        ];
454        let v = original.clone();
455
456        let mut m: Mat<f32> = Mat::new(4, 4);
457        let length = m.rows() * m.cols();
458
459        write_to_memory(m.as_mut_ptr(), &v);
460
461        let band_m = BandMat::from_matrix(m, 1, 1);
462        let mut m = BandMat::to_matrix(band_m);
463
464        let result_vec = retrieve_memory(&mut m, length);
465
466        assert_eq!(result_vec, original);
467    }
468
469    #[test]
470    fn to_and_from_nonsquare_test() {
471        let original: Vec<f32> = vec![
472            0.5, 1.0, 0.0, 0.0, 2.0, 0.5, 1.0, 0.0, 3.0, 2.0, 0.5, 1.0, 0.0, 3.0, 2.0, 0.5, 0.0,
473            0.0, 3.0, 2.0,
474        ];
475        let v = original.clone();
476
477        let mut m: Mat<f32> = Mat::new(5, 4);
478        let length = m.rows() * m.cols();
479
480        write_to_memory(m.as_mut_ptr(), &v);
481
482        let band_m = BandMat::from_matrix(m, 2, 1);
483        let mut m = BandMat::to_matrix(band_m);
484
485        let result_vec = retrieve_memory(&mut m, length);
486
487        assert_eq!(result_vec, original);
488    }
489
490    #[test]
491    fn to_and_from_nonsquare2_test() {
492        let original: Vec<f32> = vec![
493            0.5, 1.0, 0.0, 0.0, 2.0, 0.5, 1.0, 0.0, 3.0, 2.0, 0.5, 1.0, 0.0, 3.0, 2.0, 0.5, 0.0,
494            0.0, 3.0, 2.0, 0.0, 0.0, 0.0, 3.0,
495        ];
496        let v = original.clone();
497
498        let mut m: Mat<f32> = Mat::new(6, 4);
499        let length = m.rows() * m.cols();
500
501        write_to_memory(m.as_mut_ptr(), &v);
502
503        let band_m = BandMat::from_matrix(m, 2, 1);
504        let mut m = BandMat::to_matrix(band_m);
505
506        let result_vec = retrieve_memory(&mut m, length);
507
508        assert_eq!(result_vec, original);
509    }
510}