1use crate::math::Mat;
5use crate::matrix::BandMatrix;
6use crate::vector::ops::Copy;
7use crate::Matrix;
8use num::traits::NumCast;
9use std::cmp::{max, min};
10use std::fmt;
11use std::fmt::Debug;
12use std::iter::repeat;
13use std::mem::ManuallyDrop;
14use std::ops::Index;
15use std::slice;
16
17#[derive(Debug, PartialEq)]
18pub struct BandMat<T> {
23 rows: usize,
24 cols: usize,
25 sub_diagonals: u32,
26 sup_diagonals: u32,
27 data: Vec<T>,
28}
29
30impl<T> BandMat<T> {
31 pub fn new(n: usize, m: usize, sub: u32, sup: u32) -> BandMat<T> {
32 let len = n * m;
33 let mut data = Vec::with_capacity(len);
34 unsafe {
35 data.set_len(len);
36 }
37
38 BandMat {
39 rows: n,
40 cols: m,
41 data,
42 sub_diagonals: sub,
43 sup_diagonals: sup,
44 }
45 }
46
47 pub fn rows(&self) -> usize {
48 self.rows
49 }
50 pub fn cols(&self) -> usize {
51 self.cols
52 }
53 pub unsafe fn set_rows(&mut self, n: usize) {
58 self.rows = n;
59 }
60 pub unsafe fn set_cols(&mut self, n: usize) {
65 self.cols = n;
66 }
67 pub unsafe fn set_sub_diagonals(&mut self, n: u32) {
68 self.sub_diagonals = n;
69 }
70 pub unsafe fn set_sup_diagonals(&mut self, n: u32) {
71 self.sup_diagonals = n;
72 }
73
74 pub unsafe fn push(&mut self, val: T) {
75 self.data.push(val);
76 }
77}
78
79impl<T: std::marker::Copy> BandMat<T> {
80 pub fn from_matrix(mat: Mat<T>, sub_diagonals: u32, sup_diagonals: u32) -> BandMat<T> {
130 let mut mat = ManuallyDrop::new(mat);
131
132 let cols = mat.cols();
133 let rows = mat.rows();
134 let lda = (sub_diagonals + 1 + sup_diagonals) as usize;
135 let length = rows * cols;
136
137 if rows * lda > length {
139 panic!("BandMatrix conversion needed {} space, but only {} was provided. LDA was {}. Not enough space to safely convert to band matrix storage. Please consider expanding the size of the vector for the underlying Matrix", rows * lda, length, lda);
140 }
141
142 let mut v = unsafe { Vec::from_raw_parts(mat.as_mut_ptr(), length, length) };
143
144 for r in 0..rows {
155 let s = (r * cols) + max(0, r as isize - sub_diagonals as isize) as usize;
156 let e = (r * cols) + min(cols, r + sup_diagonals as usize + 1usize);
157
158 let bandmat_offset =
159 max(0, (lda as isize) - sup_diagonals as isize - r as isize - 1) as usize;
160
161 let i = (r * lda) + bandmat_offset;
162 let i = i as usize;
163 (&mut v).copy_within(s..e, i);
164 }
165
166 BandMat {
167 cols,
168 rows,
169 data: v,
170 sub_diagonals,
171 sup_diagonals,
172 }
173 }
174}
175
176impl<T: std::marker::Copy + Default> BandMat<T> {
177 pub fn to_matrix(bandmat: Self) -> Mat<T> {
193 let mut bandmat = ManuallyDrop::new(bandmat);
194
195 let ku = bandmat.sup_diagonals() as usize;
196 let kl = bandmat.sub_diagonals() as usize;
197 let lda = ku + kl + 1;
198 let rows = bandmat.rows();
199 let cols = bandmat.cols();
200 let length = rows * cols;
201
202 if length < lda * rows {
203 panic!("Could not convert BandMat to Mat. The specified length of the data vector is {}, which is less than the expected minimum {} x {} = {}", length, rows, lda, rows * lda);
204 }
205 let mut v = unsafe { Vec::from_raw_parts(bandmat.as_mut_ptr(), length, length) };
206
207 let num_of_last_row_terms = kl + 1 - (rows - min(rows, cols));
208
209 for r in (0..rows).rev() {
219 let offset = rows - r - 1;
220
221 let s = max(
222 0,
223 -(kl as isize + 1)
224 + (num_of_last_row_terms - (if rows > cols { 1 } else { 2 })) as isize
225 + offset as isize,
226 );
227 let s = (r * lda) as isize + s;
228 let s = s as usize;
229
230 let e = min(lda, num_of_last_row_terms + offset);
231 let e = (r * lda) + e;
232
233 let original_mat_offset =
234 cols as isize - num_of_last_row_terms as isize - offset as isize;
235 let i = (r * cols) + max(0, original_mat_offset) as usize;
236
237 v.copy_within(s..e, i);
238
239 let l = e - s;
241 let zero_range = (r * cols)..max(0, i);
242 let zero_range = zero_range.chain(min((r + 1) * cols, i + l)..((r + 1) * cols));
243 for i in zero_range {
244 v[i] = T::default();
245 }
246 }
247
248 Mat::new_from_data(rows, cols, v)
249 }
250}
251
252impl<T: Clone> BandMat<T> {
253 pub fn fill(value: T, n: usize, m: usize) -> BandMat<T> {
254 BandMat {
255 rows: n,
256 cols: m,
257 data: repeat(value).take(n * m).collect(),
258 sub_diagonals: n as u32,
259 sup_diagonals: m as u32,
260 }
261 }
262}
263
264impl<T> Index<usize> for BandMat<T> {
265 type Output = [T];
266
267 fn index(&self, index: usize) -> &[T] {
268 let offset = (index * self.cols) as isize;
269
270 unsafe {
271 let ptr = (&self.data[..]).as_ptr().offset(offset);
272 slice::from_raw_parts(ptr, self.cols)
273 }
274 }
275}
276
277impl<T: fmt::Display> fmt::Display for BandMat<T> {
278 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
279 for i in 0usize..self.rows {
280 for j in 0usize..self.cols {
281 match write!(f, "{}", self[i][j]) {
282 Ok(_) => (),
283 x => return x,
284 }
285 }
286
287 match writeln!(f) {
288 Ok(_) => (),
289 x => return x,
290 }
291 }
292
293 Ok(())
294 }
295}
296
297impl<T> Matrix<T> for BandMat<T> {
298 fn lead_dim(&self) -> u32 {
299 self.sub_diagonals + self.sup_diagonals + 1
300 }
301
302 fn rows(&self) -> u32 {
303 let n: Option<u32> = NumCast::from(self.rows);
304 n.unwrap()
305 }
306
307 fn cols(&self) -> u32 {
308 let n: Option<u32> = NumCast::from(self.cols);
309 n.unwrap()
310 }
311
312 fn as_ptr(&self) -> *const T {
313 self.data[..].as_ptr()
314 }
315
316 fn as_mut_ptr(&mut self) -> *mut T {
317 (&mut self.data[..]).as_mut_ptr()
318 }
319}
320
321impl<T> BandMatrix<T> for BandMat<T> {
322 fn sub_diagonals(&self) -> u32 {
323 self.sub_diagonals
324 }
325
326 fn sup_diagonals(&self) -> u32 {
327 self.sup_diagonals
328 }
329
330 fn as_matrix(&self) -> &dyn Matrix<T> {
331 self
332 }
333}
334
335impl<'a, T> From<&'a dyn BandMatrix<T>> for BandMat<T>
336where
337 T: Copy,
338{
339 fn from(a: &dyn BandMatrix<T>) -> BandMat<T> {
340 let n = a.rows() as usize;
341 let m = a.cols() as usize;
342 let len = n * m;
343
344 let sub = a.sub_diagonals() as u32;
345 let sup = a.sup_diagonals() as u32;
346
347 let mut result = BandMat {
348 rows: n,
349 cols: m,
350 data: Vec::with_capacity(len),
351 sub_diagonals: sub,
352 sup_diagonals: sup,
353 };
354 unsafe {
355 result.data.set_len(len);
356 }
357
358 Copy::copy_mat(a.as_matrix(), &mut result);
359 result
360 }
361}
362
363#[cfg(test)]
364mod tests {
365 use super::*;
366
367 fn write_to_memory<T: Clone>(dest: *mut T, source: &Vec<T>) -> () {
368 let mut v1;
369 unsafe {
370 v1 = Vec::from_raw_parts(dest, source.len(), source.capacity());
371 v1.clone_from(source);
372 }
373 let _ = ManuallyDrop::new(v1);
374 }
375
376 fn retrieve_memory<T: Clone>(t: &mut dyn Matrix<T>, l: usize) -> Vec<T> {
377 let mut v: Vec<T> = vec![];
378
379 unsafe {
380 let v1 = Vec::from_raw_parts(t.as_mut_ptr(), l, l);
381 v.clone_from(&v1);
382 let _ = ManuallyDrop::new(v1);
383 }
384
385 v
386 }
387
388 #[test]
389 fn basic_conversion_test() {
390 let v: Vec<f32> = vec![
391 0.5, 2.0, 0.0, 0.0, 1.0, 0.5, 2.0, 0.0, 0.0, 1.0, 0.5, 2.0, 0.0, 0.0, 1.0, 0.5,
392 ];
393
394 let mut m: Mat<f32> = Mat::new(4, 4);
395 let length = m.rows() * m.cols();
396
397 write_to_memory(m.as_mut_ptr(), &v);
398
399 let mut band_m = BandMat::from_matrix(m, 1, 1);
400
401 let result_vec = retrieve_memory(&mut band_m, length);
402
403 assert_eq!(result_vec[1], 0.5f32);
406 assert_eq!(result_vec[2], 2.0f32);
407 assert_eq!(result_vec[3], 1.0f32);
408 assert_eq!(result_vec[7], 0.5f32);
409 assert_eq!(result_vec[9], 1.0f32);
410 }
411
412 #[test]
413 fn nonsquare_conversion_test() {
414 let v: Vec<f32> = vec![
415 0.5, 1.0, 0.0, 0.0, 2.0, 0.5, 1.0, 0.0, 3.0, 2.0, 0.5, 1.0, 0.0, 3.0, 2.0, 0.5, 0.0,
416 0.0, 3.0, 2.0, 0.0, 0.0, 0.0, 3.0,
417 ];
418
419 let mut m: Mat<f32> = Mat::new(6, 4);
420 let length = m.rows() * m.cols();
421
422 write_to_memory(m.as_mut_ptr(), &v);
423
424 let mut band_m = BandMat::from_matrix(m, 2, 1);
425
426 let result_vec = retrieve_memory(&mut band_m, length);
427
428 assert_eq!(result_vec[2], 0.5);
429 assert_eq!(result_vec[5], 2.0);
430 assert_eq!(result_vec[7], 1.0);
431 assert_eq!(result_vec[8], 3.0);
432 assert_eq!(result_vec[16], 3.0);
433 assert_eq!(result_vec[20], 3.0);
434 }
435
436 #[test]
437 #[should_panic]
438 fn from_big_matrix_panic_test() {
439 let original: Vec<f32> = vec![
440 0.5, 2.0, 3.0, 4.0, 1.0, 0.5, 2.0, 3.0, 5.0, 1.0, 0.5, 2.0, 6.0, 5.0, 1.0, 0.5,
441 ];
442 let mut m: Mat<f32> = Mat::new(4, 4);
443
444 write_to_memory(m.as_mut_ptr(), &original);
445
446 let _ = BandMat::from_matrix(m, 3, 3);
447 }
448
449 #[test]
450 fn to_and_from_conversion_test() {
451 let original: Vec<f32> = vec![
452 0.5, 2.0, 0.0, 0.0, 1.0, 0.5, 2.0, 0.0, 0.0, 1.0, 0.5, 2.0, 0.0, 0.0, 1.0, 0.5,
453 ];
454 let v = original.clone();
455
456 let mut m: Mat<f32> = Mat::new(4, 4);
457 let length = m.rows() * m.cols();
458
459 write_to_memory(m.as_mut_ptr(), &v);
460
461 let band_m = BandMat::from_matrix(m, 1, 1);
462 let mut m = BandMat::to_matrix(band_m);
463
464 let result_vec = retrieve_memory(&mut m, length);
465
466 assert_eq!(result_vec, original);
467 }
468
469 #[test]
470 fn to_and_from_nonsquare_test() {
471 let original: Vec<f32> = vec![
472 0.5, 1.0, 0.0, 0.0, 2.0, 0.5, 1.0, 0.0, 3.0, 2.0, 0.5, 1.0, 0.0, 3.0, 2.0, 0.5, 0.0,
473 0.0, 3.0, 2.0,
474 ];
475 let v = original.clone();
476
477 let mut m: Mat<f32> = Mat::new(5, 4);
478 let length = m.rows() * m.cols();
479
480 write_to_memory(m.as_mut_ptr(), &v);
481
482 let band_m = BandMat::from_matrix(m, 2, 1);
483 let mut m = BandMat::to_matrix(band_m);
484
485 let result_vec = retrieve_memory(&mut m, length);
486
487 assert_eq!(result_vec, original);
488 }
489
490 #[test]
491 fn to_and_from_nonsquare2_test() {
492 let original: Vec<f32> = vec![
493 0.5, 1.0, 0.0, 0.0, 2.0, 0.5, 1.0, 0.0, 3.0, 2.0, 0.5, 1.0, 0.0, 3.0, 2.0, 0.5, 0.0,
494 0.0, 3.0, 2.0, 0.0, 0.0, 0.0, 3.0,
495 ];
496 let v = original.clone();
497
498 let mut m: Mat<f32> = Mat::new(6, 4);
499 let length = m.rows() * m.cols();
500
501 write_to_memory(m.as_mut_ptr(), &v);
502
503 let band_m = BandMat::from_matrix(m, 2, 1);
504 let mut m = BandMat::to_matrix(band_m);
505
506 let result_vec = retrieve_memory(&mut m, length);
507
508 assert_eq!(result_vec, original);
509 }
510}