sprs_superlu/
lib.rs

1use libc::{c_double, c_int};
2use ndarray::{Array1, Array2};
3use sprs::CsMat;
4use std::mem;
5use std::sync::{mpsc, Arc, Mutex};
6use std::thread;
7use std::time::Duration;
8use superlu_sys as ffi;
9
10use std::slice::from_raw_parts_mut;
11use superlu_sys::{Dtype_t, Mtype_t, Stype_t};
12
13mod tests;
14
15#[derive(Debug)]
16pub enum SolverError {
17    Conflict,
18    Unsolvable,
19    Timeout,
20}
21
22pub struct Options {
23    pub ffi: ffi::superlu_options_t,
24}
25
26impl Default for Options {
27    fn default() -> Self {
28        let mut options: ffi::superlu_options_t = unsafe { mem::zeroed() };
29        unsafe {
30            ffi::set_default_options(&mut options);
31        }
32        Self { ffi: options }
33    }
34}
35
36fn vec_of_array1_to_array2(columns: &Vec<Array1<f64>>) -> Array2<f64> {
37    let nrows = columns.first().map_or(0, |first_col| first_col.len());
38    let ncols = columns.len();
39    let mut result = Array2::zeros((nrows, ncols));
40
41    for (col_idx, col) in columns.iter().enumerate() {
42        result.column_mut(col_idx).assign(col);
43    }
44
45    result
46}
47
48pub fn solve_super_lu(
49    a: CsMat<f64>,
50    b: &Vec<Array1<f64>>,
51    timeout: Option<Duration>,
52    options: &mut Options,
53) -> Result<Vec<Array1<f64>>, SolverError> {
54    let m = a.rows();
55    let n = a.cols();
56    if m != n {
57        return Err(SolverError::Conflict);
58    }
59    if b.len() > 0 {
60        if m != b[0].len() {
61            return Err(SolverError::Conflict);
62        }
63        for rhs_col in b {
64            if rhs_col.len() != b[0].len() {
65                return Err(SolverError::Conflict);
66            }
67        }
68    }
69    if a.nnz() == 0 {
70        return Err(SolverError::Unsolvable);
71    }
72
73    let a_mat = Arc::new(Mutex::new(SuperMatrix::from_csc_mat(a)));
74    let b_mat = Arc::new(Mutex::new(SuperMatrix::from_ndarray(
75        vec_of_array1_to_array2(b),
76    )));
77    let options = Arc::new(Mutex::new(options.ffi));
78
79    let a_mat_clone = Arc::clone(&a_mat);
80    let b_mat_clone = Arc::clone(&b_mat);
81    let options_clone = Arc::clone(&options);
82
83    let (sender, receiver) = mpsc::channel();
84
85    thread::spawn(move || unsafe {
86        let perm_r = ffi::intMalloc(m as c_int);
87        assert!(!perm_r.is_null());
88
89        let perm_c = ffi::intMalloc(n as c_int);
90        assert!(!perm_c.is_null());
91
92        ffi::set_default_options(&mut *options_clone.lock().unwrap());
93
94        let mut stat: ffi::SuperLUStat_t = mem::zeroed();
95        ffi::StatInit(&mut stat);
96
97        let mut l_mat: ffi::SuperMatrix = mem::zeroed();
98        let mut u_mat: ffi::SuperMatrix = mem::zeroed();
99
100        let mut info = 0;
101        ffi::dgssv(
102            &mut *options_clone.lock().unwrap(),
103            a_mat_clone.lock().unwrap().raw_mut(),
104            perm_c,
105            perm_r,
106            &mut l_mat,
107            &mut u_mat,
108            b_mat_clone.lock().unwrap().raw_mut(),
109            &mut stat,
110            &mut info,
111        );
112
113        ffi::SUPERLU_FREE(perm_r as *mut _);
114        ffi::SUPERLU_FREE(perm_c as *mut _);
115        ffi::Destroy_SuperNode_Matrix(&mut l_mat);
116        ffi::Destroy_CompCol_Matrix(&mut u_mat);
117        ffi::StatFree(&mut stat);
118
119        if info != 0 {
120            let _ = sender.send(Err(SolverError::Unsolvable));
121        } else {
122            let _ = sender.send(Ok(()));
123        }
124    });
125    match timeout {
126        None => match receiver.recv() {
127            Ok(res) => match res {
128                Ok(_) => {
129                    let res_data = b_mat.lock().unwrap().raw().data_to_vec();
130                    match res_data {
131                        None => Err(SolverError::Unsolvable),
132                        Some(data) => Ok(data
133                            .chunks(n)
134                            .map(|chunk| Array1::from_iter(chunk.iter().cloned()))
135                            .collect()),
136                    }
137                }
138                Err(_) => Err(SolverError::Unsolvable),
139            },
140            Err(_) => {
141                panic!("Unknown internal SuperLU error");
142            }
143        },
144        Some(timeout_value) => match receiver.recv_timeout(timeout_value) {
145            Ok(res) => match res {
146                Ok(_) => {
147                    let res_data = b_mat.lock().unwrap().raw().data_to_vec();
148                    match res_data {
149                        None => Err(SolverError::Unsolvable),
150                        Some(data) => Ok(data
151                            .chunks(n)
152                            .map(|chunk| Array1::from_iter(chunk.iter().cloned()))
153                            .collect()),
154                    }
155                }
156                Err(_) => Err(SolverError::Unsolvable),
157            },
158            Err(mpsc::RecvTimeoutError::Timeout) => {
159                return Err(SolverError::Timeout);
160            }
161            Err(_) => {
162                panic!("Unknown internal SuperLU error");
163            }
164        },
165    }
166}
167
168pub struct SuperMatrix {
169    raw: ffi::SuperMatrix,
170    rust_managed: bool,
171}
172
173pub trait FromSuperMatrix: Sized {
174    fn from_super_matrix(_: &SuperMatrix) -> Option<Self>;
175}
176
177unsafe impl Send for SuperMatrix {}
178
179impl SuperMatrix {
180    pub unsafe fn from_raw(raw: ffi::SuperMatrix) -> SuperMatrix {
181        SuperMatrix {
182            raw,
183            rust_managed: false,
184        }
185    }
186
187    pub fn into_raw(self) -> ffi::SuperMatrix {
188        let raw = self.raw;
189        if self.rust_managed {
190            mem::forget(self);
191        }
192        raw
193    }
194
195    pub fn from_csc_mat(mat: CsMat<f64>) -> Self {
196        assert_eq!(mat.storage(), sprs::CompressedStorage::CSC);
197
198        let m = mat.rows() as c_int;
199        let n = mat.cols() as c_int;
200        let nnz = mat.nnz() as c_int;
201
202        let mut raw: ffi::SuperMatrix = unsafe { mem::zeroed() };
203
204        let nzval: Vec<c_double> = mat.data().iter().map(|&x| x as c_double).collect();
205        let rowind: Vec<c_int> = mat.indices().iter().map(|&x| x as c_int).collect();
206        let mut colptr = Vec::new();
207        let colptr_raw = mat.indptr();
208        for ptr in colptr_raw.as_slice().unwrap() {
209            colptr.push(ptr.clone() as c_int)
210        }
211
212        let nzval_boxed = nzval.into_boxed_slice();
213        let rowind_boxed = rowind.into_boxed_slice();
214        let colptr_boxed = colptr.into_boxed_slice();
215
216        let nzval_ptr = Box::leak(nzval_boxed).as_mut_ptr();
217        let rowind_ptr = Box::leak(rowind_boxed).as_mut_ptr();
218        let colptr_ptr = Box::leak(colptr_boxed).as_mut_ptr();
219
220        unsafe {
221            ffi::dCreate_CompCol_Matrix(
222                &mut raw,
223                m,
224                n,
225                nnz,
226                nzval_ptr as *mut c_double,
227                rowind_ptr as *mut c_int,
228                colptr_ptr as *mut c_int,
229                Stype_t::SLU_NC,
230                Dtype_t::SLU_D,
231                Mtype_t::SLU_GE,
232            );
233        }
234        unsafe { Self::from_raw(raw) }
235    }
236
237    pub fn from_ndarray(array: Array2<f64>) -> Self {
238        let nrows = array.nrows() as c_int;
239        let ncols = array.ncols() as c_int;
240
241        let col_major_data = unsafe { ffi::doubleMalloc(ncols * nrows) };
242        let mut index: usize = 0;
243        let col_major_data_ptr =
244            unsafe { from_raw_parts_mut(col_major_data, (ncols * nrows) as usize) };
245        for col in 0..ncols as usize {
246            for row in 0..nrows as usize {
247                col_major_data_ptr[index] = array[[row, col]];
248                index += 1;
249            }
250        }
251
252        let mut raw: ffi::SuperMatrix = unsafe { std::mem::zeroed() };
253
254        unsafe {
255            ffi::dCreate_Dense_Matrix(
256                &mut raw as *mut ffi::SuperMatrix,
257                nrows,
258                ncols,
259                col_major_data,
260                nrows,
261                Stype_t::SLU_DN,
262                Dtype_t::SLU_D,
263                Mtype_t::SLU_GE,
264            );
265
266            SuperMatrix {
267                raw,
268                rust_managed: true,
269            }
270        }
271    }
272
273    pub fn into_ndarray(self) -> Option<Array2<f64>> {
274        match self.raw.data_to_vec() {
275            None => None,
276            Some(data) => match Array2::from_shape_vec((self.nrows(), self.ncols()), data) {
277                Ok(arr) => Some(arr.t().to_owned()),
278                Err(_) => None,
279            },
280        }
281    }
282
283    pub fn nrows(&self) -> usize {
284        self.raw.nrow as usize
285    }
286
287    pub fn ncols(&self) -> usize {
288        self.raw.ncol as usize
289    }
290
291    pub fn raw(&self) -> &ffi::SuperMatrix {
292        &self.raw
293    }
294
295    pub fn raw_mut(&mut self) -> *mut ffi::SuperMatrix {
296        &mut self.raw
297    }
298}
299
300impl Drop for SuperMatrix {
301    fn drop(&mut self) {
302        unsafe {
303            let store = &*(self.raw().Store as *const ffi::NCformat);
304            if store.nnz == 0 {
305                return;
306            }
307        }
308        if self.rust_managed {
309            match self.raw.Stype {
310                Stype_t::SLU_NC => unsafe {
311                    ffi::Destroy_CompCol_Matrix(&mut self.raw);
312                },
313                Stype_t::SLU_NCP => unsafe {
314                    ffi::Destroy_CompCol_Permuted(&mut self.raw);
315                },
316                Stype_t::SLU_NR => unsafe {
317                    ffi::Destroy_CompRow_Matrix(&mut self.raw);
318                },
319                Stype_t::SLU_SC | ffi::Stype_t::SLU_SCP | ffi::Stype_t::SLU_SR => unsafe {
320                    ffi::Destroy_SuperNode_Matrix(&mut self.raw);
321                },
322                Stype_t::SLU_DN => unsafe {
323                    ffi::Destroy_Dense_Matrix(&mut self.raw);
324                },
325                _ => {}
326            }
327        }
328    }
329}