1use libc::{c_float, c_uint};
2use log::info;
3use std::convert::TryInto;
4use std::os::unix::ffi::OsStrExt;
5use std::{ffi, path::Path, ptr, slice};
6
7use crate::{XGBError, XGBResult};
8
9static KEY_GROUP_PTR: &str = "group_ptr";
10static KEY_GROUP: &str = "group";
11static KEY_LABEL: &str = "label";
12static KEY_WEIGHT: &str = "weight";
13static KEY_BASE_MARGIN: &str = "base_margin";
14
15#[derive(Debug)]
20pub struct DMatrix {
21    pub(super) handle: xgboost_rs_sys::DMatrixHandle,
22    num_rows: usize,
23    num_cols: usize,
24}
25
26unsafe impl Send for DMatrix {}
27unsafe impl Sync for DMatrix {}
28
29impl DMatrix {
30    fn new(handle: xgboost_rs_sys::DMatrixHandle) -> XGBResult<Self> {
32        let mut out = 0;
35        xgb_call!(xgboost_rs_sys::XGDMatrixNumRow(handle, &mut out))?;
36        let num_rows = out as usize;
37
38        let mut out = 0;
39        xgb_call!(xgboost_rs_sys::XGDMatrixNumCol(handle, &mut out))?;
40        let num_cols = out as usize;
41        info!("Loaded DMatrix with shape: {}x{}", num_rows, num_cols);
42        Ok(DMatrix {
43            handle,
44            num_rows,
45            num_cols,
46        })
47    }
48
49    pub fn from_col_major_f32(
55        data: &[f32],
56        byte_size_ax_0: usize,
57        byte_size_ax_1: usize,
58        n_rows: usize,
59        n_cols: usize,
60        n_thread: i32,
61        nan: f32,
62    ) -> XGBResult<Self> {
63        let mut handle = ptr::null_mut();
64
65        let data_ptr_address = data.as_ptr() as usize;
67
68        let array_config = format!(
72            "{{
73            \"data\": [{data_ptr_address}, false], 
74            \"strides\": [{byte_size_ax_0}, {byte_size_ax_1}], 
75            \"descr\": [[\"\", \"<f4\"]], 
76            \"typestr\": \"<f4\", 
77            \"shape\": [{n_rows}, {n_cols}], 
78            \"version\": 3
79        }}"
80        );
81
82        let json_config = format!(
83            "
84                {{ \"missing\": {nan}, \"nthread\": {n_thread}}}
85                "
86        );
87
88        let array_config_cstr = ffi::CString::new(array_config).unwrap();
89        let json_config_cstr = ffi::CString::new(json_config).unwrap();
90
91        xgb_call!(xgboost_rs_sys::XGDMatrixCreateFromDense(
92            array_config_cstr.as_ptr(),
93            json_config_cstr.as_ptr(),
94            &mut handle
95        ))?;
96        Ok(DMatrix::new(handle).unwrap())
97    }
98
99    pub fn from_dense(data: &[f32], num_rows: usize) -> XGBResult<Self> {
106        let mut handle = ptr::null_mut();
107        xgb_call!(xgboost_rs_sys::XGDMatrixCreateFromMat(
108            data.as_ptr(),
109            num_rows as xgboost_rs_sys::bst_ulong,
110            (data.len() / num_rows) as xgboost_rs_sys::bst_ulong,
111            0.0, &mut handle
113        ))?;
114        Ok(DMatrix::new(handle).unwrap())
115    }
116
117    pub fn from_csr(
130        indptr: &[usize],
131        indices: &[usize],
132        data: &[f32],
133        num_cols: Option<usize>,
134    ) -> XGBResult<Self> {
135        assert_eq!(indices.len(), data.len());
136        let mut handle = ptr::null_mut();
137        let indptr: Vec<u64> = indptr.iter().map(|x| *x as u64).collect();
138        let indices: Vec<u32> = indices.iter().map(|x| *x as u32).collect();
139        let num_cols = num_cols.unwrap_or(0); xgb_call!(xgboost_rs_sys::XGDMatrixCreateFromCSREx(
141            indptr.as_ptr(),
142            indices.as_ptr(),
143            data.as_ptr(),
144            indptr.len().try_into().unwrap(),
145            data.len().try_into().unwrap(),
146            num_cols.try_into().unwrap(),
147            &mut handle
148        ))?;
149        Ok(DMatrix::new(handle).unwrap())
150    }
151
152    pub fn from_csc(
165        indptr: &[usize],
166        indices: &[usize],
167        data: &[f32],
168        num_rows: Option<usize>,
169    ) -> XGBResult<Self> {
170        assert_eq!(indices.len(), data.len());
171        let mut handle = ptr::null_mut();
172        let indptr: Vec<u64> = indptr.iter().map(|x| *x as u64).collect();
173        let indices: Vec<u32> = indices.iter().map(|x| *x as u32).collect();
174        let num_rows = num_rows.unwrap_or(0); xgb_call!(xgboost_rs_sys::XGDMatrixCreateFromCSCEx(
176            indptr.as_ptr(),
177            indices.as_ptr(),
178            data.as_ptr(),
179            indptr.len().try_into().unwrap(),
180            data.len().try_into().unwrap(),
181            num_rows.try_into().unwrap(),
182            &mut handle
183        ))?;
184        Ok(DMatrix::new(handle).unwrap())
185    }
186
187    pub fn load<P: AsRef<Path>>(path: P) -> XGBResult<Self> {
214        let path_as_string = path.as_ref().display().to_string();
215        let path_as_bytes = Path::new(&path_as_string).as_os_str().as_bytes();
216
217        let mut handle = ptr::null_mut();
218        let path_cstr = ffi::CString::new(path_as_bytes).unwrap();
219        let silent = true;
220        xgb_call!(xgboost_rs_sys::XGDMatrixCreateFromFile(
221            path_cstr.as_ptr(),
222            i32::from(silent),
223            &mut handle
224        ))?;
225        Ok(DMatrix::new(handle).unwrap())
226    }
227
228    pub fn save<P: AsRef<Path>>(&self, path: P) -> XGBResult<()> {
234        let fname = ffi::CString::new(path.as_ref().as_os_str().as_bytes()).unwrap();
235        let silent = true;
236        xgb_call!(xgboost_rs_sys::XGDMatrixSaveBinary(
237            self.handle,
238            fname.as_ptr(),
239            i32::from(silent)
240        ))
241    }
242
243    pub fn num_rows(&self) -> usize {
245        self.num_rows
246    }
247
248    pub fn num_cols(&self) -> usize {
250        self.num_cols
251    }
252
253    pub fn shape(&self) -> (usize, usize) {
255        (self.num_rows(), self.num_cols())
256    }
257
258    pub fn slice(&self, indices: &[usize]) -> XGBResult<DMatrix> {
264        let mut out_handle = ptr::null_mut();
265        let indices: Vec<i32> = indices.iter().map(|x| *x as i32).collect();
266        xgb_call!(xgboost_rs_sys::XGDMatrixSliceDMatrix(
267            self.handle,
268            indices.as_ptr(),
269            indices.len() as xgboost_rs_sys::bst_ulong,
270            &mut out_handle
271        ))?;
272        Ok(DMatrix::new(out_handle).unwrap())
273    }
274
275    pub fn get_labels(&self) -> XGBResult<&[f32]> {
277        self.get_float_info(KEY_LABEL)
278    }
279
280    pub fn set_labels(&mut self, array: &[f32]) -> XGBResult<()> {
282        self.set_float_info(KEY_LABEL, array)
283    }
284
285    pub fn get_weights(&self) -> XGBResult<&[f32]> {
287        self.get_float_info(KEY_WEIGHT)
288    }
289
290    pub fn set_weights(&mut self, array: &[f32]) -> XGBResult<()> {
292        self.set_float_info(KEY_WEIGHT, array)
293    }
294
295    pub fn get_base_margin(&self) -> XGBResult<&[f32]> {
297        self.get_float_info(KEY_BASE_MARGIN)
298    }
299
300    pub fn set_base_margin(&mut self, array: &[f32]) -> XGBResult<()> {
304        self.set_float_info(KEY_BASE_MARGIN, array)
305    }
306
307    pub fn set_group(&mut self, group: &[u32]) -> XGBResult<()> {
313        self.set_uint_info(KEY_GROUP, group)
315    }
316
317    pub fn get_group(&self) -> XGBResult<&[u32]> {
323        self.get_uint_info(KEY_GROUP_PTR)
324    }
325
326    fn get_float_info(&self, field: &str) -> XGBResult<&[f32]> {
327        let field = ffi::CString::new(field).unwrap();
328        let mut out_len = 0;
329        let mut out_dptr = ptr::null();
330        xgb_call!(xgboost_rs_sys::XGDMatrixGetFloatInfo(
331            self.handle,
332            field.as_ptr(),
333            &mut out_len,
334            &mut out_dptr
335        ))?;
336
337        Ok(unsafe { slice::from_raw_parts(out_dptr as *mut c_float, out_len as usize) })
338    }
339
340    fn set_float_info(&mut self, field: &str, array: &[f32]) -> XGBResult<()> {
341        let field = ffi::CString::new(field).unwrap();
342        xgb_call!(xgboost_rs_sys::XGDMatrixSetFloatInfo(
343            self.handle,
344            field.as_ptr(),
345            array.as_ptr(),
346            array.len() as u64
347        ))
348    }
349
350    fn get_uint_info(&self, field: &str) -> XGBResult<&[u32]> {
351        let field = ffi::CString::new(field).unwrap();
352        let mut out_len = 0;
353        let mut out_dptr = ptr::null();
354        xgb_call!(xgboost_rs_sys::XGDMatrixGetUIntInfo(
355            self.handle,
356            field.as_ptr(),
357            &mut out_len,
358            &mut out_dptr
359        ))?;
360        Ok(unsafe { slice::from_raw_parts(out_dptr as *mut c_uint, out_len as usize) })
361    }
362
363    fn set_uint_info(&mut self, field: &str, array: &[u32]) -> XGBResult<()> {
364        let field = ffi::CString::new(field).unwrap();
365        xgb_call!(xgboost_rs_sys::XGDMatrixSetUIntInfo(
366            self.handle,
367            field.as_ptr(),
368            array.as_ptr(),
369            array.len() as u64
370        ))
371    }
372}
373
374impl Drop for DMatrix {
375    fn drop(&mut self) {
376        xgb_call!(xgboost_rs_sys::XGDMatrixFree(self.handle)).unwrap();
377    }
378}
379
380#[cfg(test)]
381mod tests {
382    use super::*;
383
384    fn read_train_matrix() -> XGBResult<DMatrix> {
385        let data_path = concat!(env!("CARGO_MANIFEST_DIR"), "/src");
386        DMatrix::load(format!("{data_path}/data.csv?format=csv"))
387    }
388
389    #[test]
390    fn read_matrix() {
391        assert!(read_train_matrix().is_ok());
392    }
393
394    #[test]
395    fn read_num_rows() {
396        assert_eq!(read_train_matrix().unwrap().num_rows(), 23946);
397    }
398
399    #[test]
400    fn read_num_cols() {
401        assert_eq!(read_train_matrix().unwrap().num_cols(), 6);
402    }
403
404    #[test]
405    fn writing_and_reading() {
406        let dmat = read_train_matrix().unwrap();
407
408        let tmp_dir = tempfile::tempdir().expect("failed to create temp dir");
409        let out_path = tmp_dir.path().join("dmat.bin");
410        dmat.save(&out_path).unwrap();
411
412        let dmat2 = DMatrix::load(&out_path).unwrap();
413
414        assert_eq!(dmat.num_rows(), dmat2.num_rows());
415        assert_eq!(dmat.num_cols(), dmat2.num_cols());
416        }
418
419    #[test]
420    fn get_set_labels() {
421        let mut dmat = read_train_matrix().unwrap();
422        assert_eq!(dmat.get_labels().unwrap().len(), 23946);
423        let labels = vec![0.0; dmat.get_labels().unwrap().len()];
424        assert!(dmat.set_labels(&labels).is_ok());
425        assert_eq!(dmat.get_labels().unwrap(), labels);
426    }
427
428    #[test]
429    fn get_set_weights() {
430        let error_margin = f32::EPSILON;
431        let mut dmat = read_train_matrix().unwrap();
432        let empty_weights: Vec<f32> = vec![];
433        assert_eq!(dmat.get_weights().unwrap(), empty_weights.as_slice());
434
435        let weight = [1.0, 10.0, 44.9555];
436        assert!(dmat.set_weights(&weight).is_ok());
437        dmat.get_weights()
438            .unwrap()
439            .iter()
440            .zip(weight.iter())
441            .for_each(|(a, b)| {
442                assert!((a - b).abs() < error_margin);
443            });
444    }
445
446    #[test]
447    fn get_set_base_margin() {
448        let mut dmat = read_train_matrix().unwrap();
449        let empty_slice: Vec<f32> = vec![];
450        assert_eq!(dmat.get_base_margin().unwrap(), empty_slice.as_slice());
451        let base_margin = vec![1337.0; dmat.num_rows()];
452        assert!(dmat.set_base_margin(&base_margin).is_ok());
453        assert_eq!(dmat.get_base_margin().unwrap(), base_margin);
454    }
455
456    #[test]
457    fn get_set_group() {
458        let mut dmat = read_train_matrix().unwrap();
459        let empty_slice: Vec<u32> = vec![];
460        assert_eq!(dmat.get_group().unwrap(), empty_slice.as_slice());
461
462        let group = [1];
463        assert!(dmat.set_group(&group).is_ok());
464        assert_eq!(dmat.get_group().unwrap(), &[0, 1]);
465    }
466
467    #[test]
468    fn from_csr() {
469        let indptr = [0, 2, 3, 6, 8];
470        let indices = [0, 2, 2, 0, 1, 2, 1, 2];
471        let data = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
472
473        let dmat = DMatrix::from_csr(&indptr, &indices, &data, None).unwrap();
474        assert_eq!(dmat.num_rows(), 4);
475        assert_eq!(dmat.num_cols(), 0); let dmat = DMatrix::from_csr(&indptr, &indices, &data, Some(10)).unwrap();
478        assert_eq!(dmat.num_rows(), 4);
479        assert_eq!(dmat.num_cols(), 10);
480    }
481
482    #[test]
483    fn from_csc() {
484        let indptr = [0, 2, 3, 6, 8];
485        let indices = [0, 2, 2, 0, 1, 2, 1, 2];
486        let data = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
487
488        let dmat = DMatrix::from_csc(&indptr, &indices, &data, None).unwrap();
489        assert_eq!(dmat.num_rows(), 3);
490        assert_eq!(dmat.num_cols(), 4);
491
492        let dmat = DMatrix::from_csc(&indptr, &indices, &data, Some(10)).unwrap();
493        assert_eq!(dmat.num_rows(), 10);
494        assert_eq!(dmat.num_cols(), 4);
495    }
496
497    #[test]
498    fn from_dense() {
499        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
500        let num_rows = 2;
501
502        let dmat = DMatrix::from_dense(&data, num_rows).unwrap();
503        assert_eq!(dmat.num_rows(), 2);
504        assert_eq!(dmat.num_cols(), 3);
505
506        let data = vec![1.0, 2.0, 3.0];
507        let num_rows = 3;
508
509        let dmat = DMatrix::from_dense(&data, num_rows).unwrap();
510        assert_eq!(dmat.num_rows(), 3);
511        assert_eq!(dmat.num_cols(), 1);
512    }
513
514    #[test]
515    fn slice_from_indices() {
516        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
517        let num_rows = 4;
518
519        let dmat = DMatrix::from_dense(&data, num_rows).unwrap();
520
521        assert_eq!(dmat.shape(), (4, 2));
522
523        assert_eq!(dmat.slice(&[]).unwrap().shape(), (0, 2));
524        assert_eq!(dmat.slice(&[1]).unwrap().shape(), (1, 2));
525        assert_eq!(dmat.slice(&[0, 1]).unwrap().shape(), (2, 2));
526        assert_eq!(dmat.slice(&[3, 2, 1]).unwrap().shape(), (3, 2));
527    }
528
529    #[test]
530    fn slice() {
531        let data = vec![
532            1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
533        ];
534        let num_rows = 4;
535
536        let dmat = DMatrix::from_dense(&data, num_rows).unwrap();
537        assert_eq!(dmat.shape(), (4, 3));
538
539        assert_eq!(dmat.slice(&[0, 1, 2, 3]).unwrap().shape(), (4, 3));
540        assert_eq!(dmat.slice(&[0, 1]).unwrap().shape(), (2, 3));
541        assert_eq!(dmat.slice(&[1, 0]).unwrap().shape(), (2, 3));
542        assert_eq!(dmat.slice(&[0, 1, 2]).unwrap().shape(), (3, 3));
543        assert_eq!(dmat.slice(&[3, 2, 1]).unwrap().shape(), (3, 3));
544    }
545}