1use libc::{c_double, c_int};
2use ndarray::{Array1, Array2};
3use sprs::CsMat;
4use std::mem;
5use std::sync::{mpsc, Arc, Mutex};
6use std::thread;
7use std::time::Duration;
8use superlu_sys as ffi;
9
10use std::slice::from_raw_parts_mut;
11use superlu_sys::{Dtype_t, Mtype_t, Stype_t};
12
13mod tests;
14
15#[derive(Debug)]
16pub enum SolverError {
17 Conflict,
18 Unsolvable,
19 Timeout,
20}
21
22pub struct Options {
23 pub ffi: ffi::superlu_options_t,
24}
25
26impl Default for Options {
27 fn default() -> Self {
28 let mut options: ffi::superlu_options_t = unsafe { mem::zeroed() };
29 unsafe {
30 ffi::set_default_options(&mut options);
31 }
32 Self { ffi: options }
33 }
34}
35
36fn vec_of_array1_to_array2(columns: &Vec<Array1<f64>>) -> Array2<f64> {
37 let nrows = columns.first().map_or(0, |first_col| first_col.len());
38 let ncols = columns.len();
39 let mut result = Array2::zeros((nrows, ncols));
40
41 for (col_idx, col) in columns.iter().enumerate() {
42 result.column_mut(col_idx).assign(col);
43 }
44
45 result
46}
47
48pub fn solve_super_lu(
49 a: CsMat<f64>,
50 b: &Vec<Array1<f64>>,
51 timeout: Option<Duration>,
52 options: &mut Options,
53) -> Result<Vec<Array1<f64>>, SolverError> {
54 let m = a.rows();
55 let n = a.cols();
56 if m != n {
57 return Err(SolverError::Conflict);
58 }
59 if b.len() > 0 {
60 if m != b[0].len() {
61 return Err(SolverError::Conflict);
62 }
63 for rhs_col in b {
64 if rhs_col.len() != b[0].len() {
65 return Err(SolverError::Conflict);
66 }
67 }
68 }
69 if a.nnz() == 0 {
70 return Err(SolverError::Unsolvable);
71 }
72
73 let a_mat = Arc::new(Mutex::new(SuperMatrix::from_csc_mat(a)));
74 let b_mat = Arc::new(Mutex::new(SuperMatrix::from_ndarray(
75 vec_of_array1_to_array2(b),
76 )));
77 let options = Arc::new(Mutex::new(options.ffi));
78
79 let a_mat_clone = Arc::clone(&a_mat);
80 let b_mat_clone = Arc::clone(&b_mat);
81 let options_clone = Arc::clone(&options);
82
83 let (sender, receiver) = mpsc::channel();
84
85 thread::spawn(move || unsafe {
86 let perm_r = ffi::intMalloc(m as c_int);
87 assert!(!perm_r.is_null());
88
89 let perm_c = ffi::intMalloc(n as c_int);
90 assert!(!perm_c.is_null());
91
92 ffi::set_default_options(&mut *options_clone.lock().unwrap());
93
94 let mut stat: ffi::SuperLUStat_t = mem::zeroed();
95 ffi::StatInit(&mut stat);
96
97 let mut l_mat: ffi::SuperMatrix = mem::zeroed();
98 let mut u_mat: ffi::SuperMatrix = mem::zeroed();
99
100 let mut info = 0;
101 ffi::dgssv(
102 &mut *options_clone.lock().unwrap(),
103 a_mat_clone.lock().unwrap().raw_mut(),
104 perm_c,
105 perm_r,
106 &mut l_mat,
107 &mut u_mat,
108 b_mat_clone.lock().unwrap().raw_mut(),
109 &mut stat,
110 &mut info,
111 );
112
113 ffi::SUPERLU_FREE(perm_r as *mut _);
114 ffi::SUPERLU_FREE(perm_c as *mut _);
115 ffi::Destroy_SuperNode_Matrix(&mut l_mat);
116 ffi::Destroy_CompCol_Matrix(&mut u_mat);
117 ffi::StatFree(&mut stat);
118
119 if info != 0 {
120 let _ = sender.send(Err(SolverError::Unsolvable));
121 } else {
122 let _ = sender.send(Ok(()));
123 }
124 });
125 match timeout {
126 None => match receiver.recv() {
127 Ok(res) => match res {
128 Ok(_) => {
129 let res_data = b_mat.lock().unwrap().raw().data_to_vec();
130 match res_data {
131 None => Err(SolverError::Unsolvable),
132 Some(data) => Ok(data
133 .chunks(n)
134 .map(|chunk| Array1::from_iter(chunk.iter().cloned()))
135 .collect()),
136 }
137 }
138 Err(_) => Err(SolverError::Unsolvable),
139 },
140 Err(_) => {
141 panic!("Unknown internal SuperLU error");
142 }
143 },
144 Some(timeout_value) => match receiver.recv_timeout(timeout_value) {
145 Ok(res) => match res {
146 Ok(_) => {
147 let res_data = b_mat.lock().unwrap().raw().data_to_vec();
148 match res_data {
149 None => Err(SolverError::Unsolvable),
150 Some(data) => Ok(data
151 .chunks(n)
152 .map(|chunk| Array1::from_iter(chunk.iter().cloned()))
153 .collect()),
154 }
155 }
156 Err(_) => Err(SolverError::Unsolvable),
157 },
158 Err(mpsc::RecvTimeoutError::Timeout) => {
159 return Err(SolverError::Timeout);
160 }
161 Err(_) => {
162 panic!("Unknown internal SuperLU error");
163 }
164 },
165 }
166}
167
168pub struct SuperMatrix {
169 raw: ffi::SuperMatrix,
170 rust_managed: bool,
171}
172
173pub trait FromSuperMatrix: Sized {
174 fn from_super_matrix(_: &SuperMatrix) -> Option<Self>;
175}
176
177unsafe impl Send for SuperMatrix {}
178
179impl SuperMatrix {
180 pub unsafe fn from_raw(raw: ffi::SuperMatrix) -> SuperMatrix {
181 SuperMatrix {
182 raw,
183 rust_managed: false,
184 }
185 }
186
187 pub fn into_raw(self) -> ffi::SuperMatrix {
188 let raw = self.raw;
189 if self.rust_managed {
190 mem::forget(self);
191 }
192 raw
193 }
194
195 pub fn from_csc_mat(mat: CsMat<f64>) -> Self {
196 assert_eq!(mat.storage(), sprs::CompressedStorage::CSC);
197
198 let m = mat.rows() as c_int;
199 let n = mat.cols() as c_int;
200 let nnz = mat.nnz() as c_int;
201
202 let mut raw: ffi::SuperMatrix = unsafe { mem::zeroed() };
203
204 let nzval: Vec<c_double> = mat.data().iter().map(|&x| x as c_double).collect();
205 let rowind: Vec<c_int> = mat.indices().iter().map(|&x| x as c_int).collect();
206 let mut colptr = Vec::new();
207 let colptr_raw = mat.indptr();
208 for ptr in colptr_raw.as_slice().unwrap() {
209 colptr.push(ptr.clone() as c_int)
210 }
211
212 let nzval_boxed = nzval.into_boxed_slice();
213 let rowind_boxed = rowind.into_boxed_slice();
214 let colptr_boxed = colptr.into_boxed_slice();
215
216 let nzval_ptr = Box::leak(nzval_boxed).as_mut_ptr();
217 let rowind_ptr = Box::leak(rowind_boxed).as_mut_ptr();
218 let colptr_ptr = Box::leak(colptr_boxed).as_mut_ptr();
219
220 unsafe {
221 ffi::dCreate_CompCol_Matrix(
222 &mut raw,
223 m,
224 n,
225 nnz,
226 nzval_ptr as *mut c_double,
227 rowind_ptr as *mut c_int,
228 colptr_ptr as *mut c_int,
229 Stype_t::SLU_NC,
230 Dtype_t::SLU_D,
231 Mtype_t::SLU_GE,
232 );
233 }
234 unsafe { Self::from_raw(raw) }
235 }
236
237 pub fn from_ndarray(array: Array2<f64>) -> Self {
238 let nrows = array.nrows() as c_int;
239 let ncols = array.ncols() as c_int;
240
241 let col_major_data = unsafe { ffi::doubleMalloc(ncols * nrows) };
242 let mut index: usize = 0;
243 let col_major_data_ptr =
244 unsafe { from_raw_parts_mut(col_major_data, (ncols * nrows) as usize) };
245 for col in 0..ncols as usize {
246 for row in 0..nrows as usize {
247 col_major_data_ptr[index] = array[[row, col]];
248 index += 1;
249 }
250 }
251
252 let mut raw: ffi::SuperMatrix = unsafe { std::mem::zeroed() };
253
254 unsafe {
255 ffi::dCreate_Dense_Matrix(
256 &mut raw as *mut ffi::SuperMatrix,
257 nrows,
258 ncols,
259 col_major_data,
260 nrows,
261 Stype_t::SLU_DN,
262 Dtype_t::SLU_D,
263 Mtype_t::SLU_GE,
264 );
265
266 SuperMatrix {
267 raw,
268 rust_managed: true,
269 }
270 }
271 }
272
273 pub fn into_ndarray(self) -> Option<Array2<f64>> {
274 match self.raw.data_to_vec() {
275 None => None,
276 Some(data) => match Array2::from_shape_vec((self.nrows(), self.ncols()), data) {
277 Ok(arr) => Some(arr.t().to_owned()),
278 Err(_) => None,
279 },
280 }
281 }
282
283 pub fn nrows(&self) -> usize {
284 self.raw.nrow as usize
285 }
286
287 pub fn ncols(&self) -> usize {
288 self.raw.ncol as usize
289 }
290
291 pub fn raw(&self) -> &ffi::SuperMatrix {
292 &self.raw
293 }
294
295 pub fn raw_mut(&mut self) -> *mut ffi::SuperMatrix {
296 &mut self.raw
297 }
298}
299
300impl Drop for SuperMatrix {
301 fn drop(&mut self) {
302 unsafe {
303 let store = &*(self.raw().Store as *const ffi::NCformat);
304 if store.nnz == 0 {
305 return;
306 }
307 }
308 if self.rust_managed {
309 match self.raw.Stype {
310 Stype_t::SLU_NC => unsafe {
311 ffi::Destroy_CompCol_Matrix(&mut self.raw);
312 },
313 Stype_t::SLU_NCP => unsafe {
314 ffi::Destroy_CompCol_Permuted(&mut self.raw);
315 },
316 Stype_t::SLU_NR => unsafe {
317 ffi::Destroy_CompRow_Matrix(&mut self.raw);
318 },
319 Stype_t::SLU_SC | ffi::Stype_t::SLU_SCP | ffi::Stype_t::SLU_SR => unsafe {
320 ffi::Destroy_SuperNode_Matrix(&mut self.raw);
321 },
322 Stype_t::SLU_DN => unsafe {
323 ffi::Destroy_Dense_Matrix(&mut self.raw);
324 },
325 _ => {}
326 }
327 }
328 }
329}