xgboost/
dmatrix.rs

1use std::{slice, ffi, ptr, path::Path};
2use libc::{c_uint, c_float};
3use std::os::unix::ffi::OsStrExt;
4
5use xgboost_sys;
6
7use super::{XGBResult, XGBError};
8
9static KEY_ROOT_INDEX: &'static str = "root_index";
10static KEY_LABEL: &'static str = "label";
11static KEY_WEIGHT: &'static str = "weight";
12static KEY_BASE_MARGIN: &'static str = "base_margin";
13
14/// Data matrix used throughout XGBoost for training/predicting [`Booster`](struct.Booster.html) models.
15///
16/// It's used as a container for both features (i.e. a row for every instance), and an optional true label for that
17/// instance (as an `f32` value).
18///
19/// Can be created files, or from dense or sparse
20/// ([CSR](https://en.wikipedia.org/wiki/Sparse_matrix#Compressed_sparse_row_(CSR,_CRS_or_Yale_format))
21/// or [CSC](https://en.wikipedia.org/wiki/Sparse_matrix#Compressed_sparse_column_(CSC_or_CCS))) matrices.
22///
23/// # Examples
24///
25/// ## Load from file
26///
27/// Load matrix from file in [LIBSVM](https://www.csie.ntu.edu.tw/~cjlin/libsvm/) or binary format.
28///
29/// ```should_panic
30/// use xgboost::DMatrix;
31///
32/// let dmat = DMatrix::load("somefile.txt").unwrap();
33/// ```
34///
35/// ## Create from dense array
36///
37/// ```
38/// use xgboost::DMatrix;
39///
40/// let data = &[1.0, 0.5, 0.2, 0.2,
41///              0.7, 1.0, 0.1, 0.1,
42///              0.2, 0.0, 0.0, 1.0];
43/// let num_rows = 3;
44/// let mut dmat = DMatrix::from_dense(data, num_rows).unwrap();
45/// assert_eq!(dmat.shape(), (3, 4));
46///
47/// // set true labels for each row
48/// dmat.set_labels(&[1.0, 0.0, 1.0]);
49/// ```
50///
51/// ## Create from sparse CSR matrix
52///
53/// Create from sparse representation of
54/// ```text
55/// [[1.0, 0.0, 2.0],
56///  [0.0, 0.0, 3.0],
57///  [4.0, 5.0, 6.0]]
58/// ```
59///
60/// ```
61/// use xgboost::DMatrix;
62///
63/// let indptr = &[0, 2, 3, 6];
64/// let indices = &[0, 2, 2, 0, 1, 2];
65/// let data = &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
66/// let dmat = DMatrix::from_csr(indptr, indices, data, None).unwrap();
67/// assert_eq!(dmat.shape(), (3, 3));
68/// ```
69pub struct DMatrix {
70    pub(super) handle: xgboost_sys::DMatrixHandle,
71    num_rows: usize,
72    num_cols: usize,
73}
74
75impl DMatrix {
76    /// Construct a new instance from a DMatrixHandle created by the XGBoost C API.
77    fn new(handle: xgboost_sys::DMatrixHandle) -> XGBResult<Self> {
78        // number of rows/cols are frequently read throughout applications, so more convenient to pull them out once
79        // when the matrix is created, instead of having to check errors each time XGDMatrixNum* is called
80        let mut out = 0;
81        xgb_call!(xgboost_sys::XGDMatrixNumRow(handle, &mut out))?;
82        let num_rows = out as usize;
83
84        let mut out = 0;
85        xgb_call!(xgboost_sys::XGDMatrixNumCol(handle, &mut out))?;
86        let num_cols = out as usize;
87
88        info!("Loaded DMatrix with shape: {}x{}", num_rows, num_cols);
89        Ok(DMatrix { handle, num_rows, num_cols })
90    }
91
92    /// Create a new `DMatrix` from dense array in row-major order.
93    ///
94    /// E.g. the matrix
95    /// ```text
96    /// [[1.0, 2.0],
97    ///  [3.0, 4.0],
98    ///  [5.0, 6.0]]
99    /// ```
100    /// would be represented converted into a `DMatrix` with
101    /// ```
102    /// use xgboost::DMatrix;
103    ///
104    /// let data = &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
105    /// let num_rows = 3;
106    /// let dmat = DMatrix::from_dense(data, num_rows).unwrap();
107    /// ```
108    pub fn from_dense(data: &[f32], num_rows: usize) -> XGBResult<Self> {
109        let mut handle = ptr::null_mut();
110        xgb_call!(xgboost_sys::XGDMatrixCreateFromMat(data.as_ptr(),
111                                                      num_rows as xgboost_sys::bst_ulong,
112                                                      (data.len() / num_rows) as xgboost_sys::bst_ulong,
113                                                      0.0, // TODO: can values be missing here?
114                                                      &mut handle))?;
115        Ok(DMatrix::new(handle)?)
116    }
117
118    /// Create a new `DMatrix` from a sparse
119    /// [CSR](https://en.wikipedia.org/wiki/Sparse_matrix#Compressed_sparse_row_(CSR,_CRS_or_Yale_format)) matrix.
120    ///
121    /// Uses standard CSR representation where the column indices for row _i_ are stored in
122    /// `indices[indptr[i]:indptr[i+1]]` and their corresponding values are stored in
123    /// `data[indptr[i]:indptr[i+1]`.
124    ///
125    /// If `num_cols` is set to None, number of columns will be inferred from given data.
126    pub fn from_csr(indptr: &[usize], indices: &[usize], data: &[f32], num_cols: Option<usize>) -> XGBResult<Self> {
127        assert_eq!(indices.len(), data.len());
128        let mut handle = ptr::null_mut();
129        let indices: Vec<u32> = indices.iter().map(|x| *x as u32).collect();
130        let num_cols = num_cols.unwrap_or(0); // infer from data if 0
131        xgb_call!(xgboost_sys::XGDMatrixCreateFromCSREx(indptr.as_ptr(),
132                                                        indices.as_ptr(),
133                                                        data.as_ptr(),
134                                                        indptr.len(),
135                                                        data.len(),
136                                                        num_cols,
137                                                        &mut handle))?;
138        Ok(DMatrix::new(handle)?)
139    }
140
141    /// Create a new `DMatrix` from a sparse
142    /// [CSC](https://en.wikipedia.org/wiki/Sparse_matrix#Compressed_sparse_column_(CSC_or_CCS))) matrix.
143    ///
144    /// Uses standard CSC representation where the row indices for column _i_ are stored in
145    /// `indices[indptr[i]:indptr[i+1]]` and their corresponding values are stored in
146    /// `data[indptr[i]:indptr[i+1]`.
147    ///
148    /// If `num_rows` is set to None, number of rows will be inferred from given data.
149    pub fn from_csc(indptr: &[usize], indices: &[usize], data: &[f32], num_rows: Option<usize>) -> XGBResult<Self> {
150        assert_eq!(indices.len(), data.len());
151        let mut handle = ptr::null_mut();
152        let indices: Vec<u32> = indices.iter().map(|x| *x as u32).collect();
153        let num_rows = num_rows.unwrap_or(0); // infer from data if 0
154        xgb_call!(xgboost_sys::XGDMatrixCreateFromCSCEx(indptr.as_ptr(),
155                                                        indices.as_ptr(),
156                                                        data.as_ptr(),
157                                                        indptr.len(),
158                                                        data.len(),
159                                                        num_rows,
160                                                        &mut handle))?;
161        Ok(DMatrix::new(handle)?)
162    }
163
164    /// Create a new `DMatrix` from given file.
165    ///
166    /// Supports text files in [LIBSVM](https://www.csie.ntu.edu.tw/~cjlin/libsvm/) format, CSV,
167    /// binary files written either by `save`, or from another XGBoost library.
168    ///
169    /// For more details on accepted formats, seem the
170    /// [XGBoost input format](https://xgboost.readthedocs.io/en/latest/tutorials/input_format.html)
171    /// documentation.
172    ///
173    /// # LIBSVM format
174    ///
175    /// Specified data in a sparse format as:
176    /// ```text
177    /// <label> <index>:<value> [<index>:<value> ...]
178    /// ```
179    ///
180    /// E.g.
181    /// ```text
182    /// 0 1:1 9:0 11:0
183    /// 1 9:1 11:0.375 15:1
184    /// 0 1:0 8:0.22 11:1
185    /// ```
186    pub fn load<P: AsRef<Path>>(path: P) -> XGBResult<Self> {
187        debug!("Loading DMatrix from: {}", path.as_ref().display());
188        let mut handle = ptr::null_mut();
189        let fname = ffi::CString::new(path.as_ref().as_os_str().as_bytes()).unwrap();
190        let silent = true;
191        xgb_call!(xgboost_sys::XGDMatrixCreateFromFile(fname.as_ptr(), silent as i32, &mut handle))?;
192        Ok(DMatrix::new(handle)?)
193    }
194
195    /// Serialise this `DMatrix` as a binary file to given path.
196    pub fn save<P: AsRef<Path>>(&self, path: P) -> XGBResult<()> {
197        debug!("Writing DMatrix to: {}", path.as_ref().display());
198        let fname = ffi::CString::new(path.as_ref().as_os_str().as_bytes()).unwrap();
199        let silent = true;
200        xgb_call!(xgboost_sys::XGDMatrixSaveBinary(self.handle, fname.as_ptr(), silent as i32))
201    }
202
203    /// Get the number of rows in this matrix.
204    pub fn num_rows(&self) -> usize {
205        self.num_rows
206    }
207
208    /// Get the number of columns in this matrix.
209    pub fn num_cols(&self) -> usize {
210        self.num_cols
211    }
212
213    /// Get the shape (rows x columns) of this matrix.
214    pub fn shape(&self) -> (usize, usize) {
215        (self.num_rows(), self.num_cols())
216    }
217
218    /// Get a new DMatrix as a containing only given indices.
219    pub fn slice(&self, indices: &[usize]) -> XGBResult<DMatrix> {
220        debug!("Slicing {} rows from DMatrix", indices.len());
221        let mut out_handle = ptr::null_mut();
222        let indices: Vec<i32> = indices.iter().map(|x| *x as i32).collect();
223        xgb_call!(xgboost_sys::XGDMatrixSliceDMatrix(self.handle,
224                                                     indices.as_ptr(),
225                                                     indices.len() as xgboost_sys::bst_ulong,
226                                                     &mut out_handle))?;
227        Ok(DMatrix::new(out_handle)?)
228    }
229
230    /// Gets the specified root index of each instance, can be used for multi task setting.
231    ///
232    /// See the XGBoost documentation for more information.
233    pub fn get_root_index(&self) -> XGBResult<&[u32]> {
234        self.get_uint_info(KEY_ROOT_INDEX)
235    }
236
237    /// Sets the specified root index of each instance, can be used for multi task setting.
238    ///
239    /// See the XGBoost documentation for more information.
240    pub fn set_root_index(&mut self, array: &[u32]) -> XGBResult<()> {
241        self.set_uint_info(KEY_ROOT_INDEX, array)
242    }
243
244    /// Get ground truth labels for each row of this matrix.
245    pub fn get_labels(&self) -> XGBResult<&[f32]> {
246        self.get_float_info(KEY_LABEL)
247    }
248
249    /// Set ground truth labels for each row of this matrix.
250    pub fn set_labels(&mut self, array: &[f32]) -> XGBResult<()> {
251        self.set_float_info(KEY_LABEL, array)
252    }
253
254    /// Get weights of each instance.
255    pub fn get_weights(&self) -> XGBResult<&[f32]> {
256        self.get_float_info(KEY_WEIGHT)
257    }
258
259    /// Set weights of each instance.
260    pub fn set_weights(&mut self, array: &[f32]) -> XGBResult<()> {
261        self.set_float_info(KEY_WEIGHT, array)
262    }
263
264    /// Get base margin.
265    pub fn get_base_margin(&self) -> XGBResult<&[f32]> {
266        self.get_float_info(KEY_BASE_MARGIN)
267    }
268
269    /// Set base margin.
270    ///
271    /// If specified, xgboost will start from this margin, can be used to specify initial prediction to boost from.
272    pub fn set_base_margin(&mut self, array: &[f32]) -> XGBResult<()> {
273        self.set_float_info(KEY_BASE_MARGIN, array)
274    }
275
276    /// Set the index for the beginning and end of a group.
277    ///
278    /// Needed when the learning task is ranking.
279    ///
280    /// See the XGBoost documentation for more information.
281    pub fn set_group(&mut self, group: &[u32]) -> XGBResult<()> {
282        xgb_call!(xgboost_sys::XGDMatrixSetGroup(self.handle, group.as_ptr(), group.len() as u64))
283    }
284
285    fn get_float_info(&self, field: &str) -> XGBResult<&[f32]> {
286        let field = ffi::CString::new(field).unwrap();
287        let mut out_len = 0;
288        let mut out_dptr = ptr::null();
289        xgb_call!(xgboost_sys::XGDMatrixGetFloatInfo(self.handle,
290                                                     field.as_ptr(),
291                                                     &mut out_len,
292                                                     &mut out_dptr))?;
293
294        Ok(unsafe { slice::from_raw_parts(out_dptr as *mut c_float, out_len as usize) })
295    }
296
297    fn set_float_info(&mut self, field: &str, array: &[f32]) -> XGBResult<()> {
298        let field = ffi::CString::new(field).unwrap();
299        xgb_call!(xgboost_sys::XGDMatrixSetFloatInfo(self.handle,
300                                                     field.as_ptr(),
301                                                     array.as_ptr(),
302                                                     array.len() as u64))
303    }
304
305    fn get_uint_info(&self, field: &str) -> XGBResult<&[u32]> {
306        let field = ffi::CString::new(field).unwrap();
307        let mut out_len = 0;
308        let mut out_dptr = ptr::null();
309        xgb_call!(xgboost_sys::XGDMatrixGetUIntInfo(self.handle,
310                                                    field.as_ptr(),
311                                                    &mut out_len,
312                                                    &mut out_dptr))?;
313
314        Ok(unsafe { slice::from_raw_parts(out_dptr as *mut c_uint, out_len as usize) })
315    }
316
317    fn set_uint_info(&mut self, field: &str, array: &[u32]) -> XGBResult<()> {
318        let field = ffi::CString::new(field).unwrap();
319        xgb_call!(xgboost_sys::XGDMatrixSetUIntInfo(self.handle,
320                                                    field.as_ptr(),
321                                                    array.as_ptr(),
322                                                    array.len() as u64))
323    }
324}
325
326impl Drop for DMatrix {
327    fn drop(&mut self) {
328        xgb_call!(xgboost_sys::XGDMatrixFree(self.handle)).unwrap();
329    }
330}
331
332#[cfg(test)]
333mod tests {
334    use tempfile;
335    use super::*;
336    fn read_train_matrix() -> XGBResult<DMatrix> {
337        DMatrix::load("xgboost-sys/xgboost/demo/data/agaricus.txt.train")
338    }
339
340    #[test]
341    fn read_matrix() {
342        assert!(read_train_matrix().is_ok());
343    }
344
345    #[test]
346    fn read_num_rows() {
347        assert_eq!(read_train_matrix().unwrap().num_rows(), 6513);
348    }
349
350    #[test]
351    fn read_num_cols() {
352        assert_eq!(read_train_matrix().unwrap().num_cols(), 127);
353    }
354
355    #[test]
356    fn writing_and_reading() {
357        let dmat = read_train_matrix().unwrap();
358
359        let tmp_dir = tempfile::tempdir().expect("failed to create temp dir");
360        let out_path = tmp_dir.path().join("dmat.bin");
361        dmat.save(&out_path).unwrap();
362
363        let dmat2 = DMatrix::load(&out_path).unwrap();
364
365        assert_eq!(dmat.num_rows(), dmat2.num_rows());
366        assert_eq!(dmat.num_cols(), dmat2.num_cols());
367        // TODO: check contents as well, if possible
368    }
369
370    #[test]
371    fn get_set_root_index() {
372        let mut dmat = read_train_matrix().unwrap();
373        assert_eq!(dmat.get_root_index().unwrap(), &[]);
374
375        let root_index = [3, 22, 1];
376        assert!(dmat.set_root_index(&root_index).is_ok());
377        assert_eq!(dmat.get_root_index().unwrap(), &[3, 22, 1]);
378    }
379
380    #[test]
381    fn get_set_labels() {
382        let mut dmat = read_train_matrix().unwrap();
383        assert_eq!(dmat.get_labels().unwrap().len(), 6513);
384
385        let label = [0.1, 0.0 -4.5, 11.29842, 333333.33];
386        assert!(dmat.set_labels(&label).is_ok());
387        assert_eq!(dmat.get_labels().unwrap(), label);
388    }
389
390    #[test]
391    fn get_set_weights() {
392        let mut dmat = read_train_matrix().unwrap();
393        assert_eq!(dmat.get_weights().unwrap(), &[]);
394
395        let weight = [1.0, 10.0, -123.456789, 44.9555];
396        assert!(dmat.set_weights(&weight).is_ok());
397        assert_eq!(dmat.get_weights().unwrap(), weight);
398    }
399
400    #[test]
401    fn get_set_base_margin() {
402        let mut dmat = read_train_matrix().unwrap();
403        assert_eq!(dmat.get_base_margin().unwrap(), &[]);
404
405        let base_margin = [0.00001, 0.000002, 1.23];
406        assert!(dmat.set_base_margin(&base_margin).is_ok());
407        assert_eq!(dmat.get_base_margin().unwrap(), base_margin);
408    }
409
410    #[test]
411    fn set_group() {
412        let mut dmat = read_train_matrix().unwrap();
413
414        let group = [1, 2, 3];
415        assert!(dmat.set_group(&group).is_ok());
416    }
417
418    #[test]
419    fn from_csr() {
420        let indptr = [0, 2, 3, 6, 8];
421        let indices = [0, 2, 2, 0, 1, 2, 1, 2];
422        let data = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
423
424        let dmat = DMatrix::from_csr(&indptr, &indices, &data, None).unwrap();
425        assert_eq!(dmat.num_rows(), 4);
426        assert_eq!(dmat.num_cols(), 3);
427
428        let dmat = DMatrix::from_csr(&indptr, &indices, &data, Some(10)).unwrap();
429        assert_eq!(dmat.num_rows(), 4);
430        assert_eq!(dmat.num_cols(), 10);
431    }
432
433    #[test]
434    fn from_csc() {
435        let indptr = [0, 2, 3, 6, 8];
436        let indices = [0, 2, 2, 0, 1, 2, 1, 2];
437        let data = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
438
439        let dmat = DMatrix::from_csc(&indptr, &indices, &data, None).unwrap();
440        assert_eq!(dmat.num_rows(), 3);
441        assert_eq!(dmat.num_cols(), 4);
442
443        let dmat = DMatrix::from_csc(&indptr, &indices, &data, Some(10)).unwrap();
444        assert_eq!(dmat.num_rows(), 10);
445        assert_eq!(dmat.num_cols(), 4);
446    }
447
448    #[test]
449    fn from_dense() {
450        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
451        let num_rows = 2;
452
453        let dmat = DMatrix::from_dense(&data, num_rows).unwrap();
454        assert_eq!(dmat.num_rows(), 2);
455        assert_eq!(dmat.num_cols(), 3);
456
457        let data = vec![1.0, 2.0, 3.0];
458        let num_rows = 3;
459
460        let dmat = DMatrix::from_dense(&data, num_rows).unwrap();
461        assert_eq!(dmat.num_rows(), 3);
462        assert_eq!(dmat.num_cols(), 1);
463    }
464
465    #[test]
466    fn slice_from_indices() {
467        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
468        let num_rows = 4;
469
470        let dmat = DMatrix::from_dense(&data, num_rows).unwrap();
471        assert_eq!(dmat.shape(), (4, 2));
472
473        assert_eq!(dmat.slice(&[]).unwrap().shape(), (0, 2));
474        assert_eq!(dmat.slice(&[1]).unwrap().shape(), (1, 2));
475        assert_eq!(dmat.slice(&[0, 1]).unwrap().shape(), (2, 2));
476        assert_eq!(dmat.slice(&[3, 2, 1]).unwrap().shape(), (3, 2));
477        assert!(dmat.slice(&[10, 11, 12]).is_err());
478    }
479
480    #[test]
481    fn slice() {
482        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0];
483        let num_rows = 4;
484
485        let dmat = DMatrix::from_dense(&data, num_rows).unwrap();
486        assert_eq!(dmat.shape(), (4, 3));
487
488        assert_eq!(dmat.slice(&[0, 1, 2, 3]).unwrap().shape(), (4, 3));
489        assert_eq!(dmat.slice(&[0, 1]).unwrap().shape(), (2, 3));
490        assert_eq!(dmat.slice(&[1, 0]).unwrap().shape(), (2, 3));
491        assert_eq!(dmat.slice(&[0, 1, 2]).unwrap().shape(), (3, 3));
492        assert_eq!(dmat.slice(&[3, 2, 1]).unwrap().shape(), (3, 3));
493    }
494}