1use std::{slice, ffi, ptr, path::Path};
2use libc::{c_uint, c_float};
3use std::os::unix::ffi::OsStrExt;
4
5use xgboost_sys;
6
7use super::{XGBResult, XGBError};
8
9static KEY_ROOT_INDEX: &'static str = "root_index";
10static KEY_LABEL: &'static str = "label";
11static KEY_WEIGHT: &'static str = "weight";
12static KEY_BASE_MARGIN: &'static str = "base_margin";
13
14pub struct DMatrix {
70 pub(super) handle: xgboost_sys::DMatrixHandle,
71 num_rows: usize,
72 num_cols: usize,
73}
74
75impl DMatrix {
76 fn new(handle: xgboost_sys::DMatrixHandle) -> XGBResult<Self> {
78 let mut out = 0;
81 xgb_call!(xgboost_sys::XGDMatrixNumRow(handle, &mut out))?;
82 let num_rows = out as usize;
83
84 let mut out = 0;
85 xgb_call!(xgboost_sys::XGDMatrixNumCol(handle, &mut out))?;
86 let num_cols = out as usize;
87
88 info!("Loaded DMatrix with shape: {}x{}", num_rows, num_cols);
89 Ok(DMatrix { handle, num_rows, num_cols })
90 }
91
92 pub fn from_dense(data: &[f32], num_rows: usize) -> XGBResult<Self> {
109 let mut handle = ptr::null_mut();
110 xgb_call!(xgboost_sys::XGDMatrixCreateFromMat(data.as_ptr(),
111 num_rows as xgboost_sys::bst_ulong,
112 (data.len() / num_rows) as xgboost_sys::bst_ulong,
113 0.0, &mut handle))?;
115 Ok(DMatrix::new(handle)?)
116 }
117
118 pub fn from_csr(indptr: &[usize], indices: &[usize], data: &[f32], num_cols: Option<usize>) -> XGBResult<Self> {
127 assert_eq!(indices.len(), data.len());
128 let mut handle = ptr::null_mut();
129 let indices: Vec<u32> = indices.iter().map(|x| *x as u32).collect();
130 let num_cols = num_cols.unwrap_or(0); xgb_call!(xgboost_sys::XGDMatrixCreateFromCSREx(indptr.as_ptr(),
132 indices.as_ptr(),
133 data.as_ptr(),
134 indptr.len(),
135 data.len(),
136 num_cols,
137 &mut handle))?;
138 Ok(DMatrix::new(handle)?)
139 }
140
141 pub fn from_csc(indptr: &[usize], indices: &[usize], data: &[f32], num_rows: Option<usize>) -> XGBResult<Self> {
150 assert_eq!(indices.len(), data.len());
151 let mut handle = ptr::null_mut();
152 let indices: Vec<u32> = indices.iter().map(|x| *x as u32).collect();
153 let num_rows = num_rows.unwrap_or(0); xgb_call!(xgboost_sys::XGDMatrixCreateFromCSCEx(indptr.as_ptr(),
155 indices.as_ptr(),
156 data.as_ptr(),
157 indptr.len(),
158 data.len(),
159 num_rows,
160 &mut handle))?;
161 Ok(DMatrix::new(handle)?)
162 }
163
164 pub fn load<P: AsRef<Path>>(path: P) -> XGBResult<Self> {
187 debug!("Loading DMatrix from: {}", path.as_ref().display());
188 let mut handle = ptr::null_mut();
189 let fname = ffi::CString::new(path.as_ref().as_os_str().as_bytes()).unwrap();
190 let silent = true;
191 xgb_call!(xgboost_sys::XGDMatrixCreateFromFile(fname.as_ptr(), silent as i32, &mut handle))?;
192 Ok(DMatrix::new(handle)?)
193 }
194
195 pub fn save<P: AsRef<Path>>(&self, path: P) -> XGBResult<()> {
197 debug!("Writing DMatrix to: {}", path.as_ref().display());
198 let fname = ffi::CString::new(path.as_ref().as_os_str().as_bytes()).unwrap();
199 let silent = true;
200 xgb_call!(xgboost_sys::XGDMatrixSaveBinary(self.handle, fname.as_ptr(), silent as i32))
201 }
202
203 pub fn num_rows(&self) -> usize {
205 self.num_rows
206 }
207
208 pub fn num_cols(&self) -> usize {
210 self.num_cols
211 }
212
213 pub fn shape(&self) -> (usize, usize) {
215 (self.num_rows(), self.num_cols())
216 }
217
218 pub fn slice(&self, indices: &[usize]) -> XGBResult<DMatrix> {
220 debug!("Slicing {} rows from DMatrix", indices.len());
221 let mut out_handle = ptr::null_mut();
222 let indices: Vec<i32> = indices.iter().map(|x| *x as i32).collect();
223 xgb_call!(xgboost_sys::XGDMatrixSliceDMatrix(self.handle,
224 indices.as_ptr(),
225 indices.len() as xgboost_sys::bst_ulong,
226 &mut out_handle))?;
227 Ok(DMatrix::new(out_handle)?)
228 }
229
230 pub fn get_root_index(&self) -> XGBResult<&[u32]> {
234 self.get_uint_info(KEY_ROOT_INDEX)
235 }
236
237 pub fn set_root_index(&mut self, array: &[u32]) -> XGBResult<()> {
241 self.set_uint_info(KEY_ROOT_INDEX, array)
242 }
243
244 pub fn get_labels(&self) -> XGBResult<&[f32]> {
246 self.get_float_info(KEY_LABEL)
247 }
248
249 pub fn set_labels(&mut self, array: &[f32]) -> XGBResult<()> {
251 self.set_float_info(KEY_LABEL, array)
252 }
253
254 pub fn get_weights(&self) -> XGBResult<&[f32]> {
256 self.get_float_info(KEY_WEIGHT)
257 }
258
259 pub fn set_weights(&mut self, array: &[f32]) -> XGBResult<()> {
261 self.set_float_info(KEY_WEIGHT, array)
262 }
263
264 pub fn get_base_margin(&self) -> XGBResult<&[f32]> {
266 self.get_float_info(KEY_BASE_MARGIN)
267 }
268
269 pub fn set_base_margin(&mut self, array: &[f32]) -> XGBResult<()> {
273 self.set_float_info(KEY_BASE_MARGIN, array)
274 }
275
276 pub fn set_group(&mut self, group: &[u32]) -> XGBResult<()> {
282 xgb_call!(xgboost_sys::XGDMatrixSetGroup(self.handle, group.as_ptr(), group.len() as u64))
283 }
284
285 fn get_float_info(&self, field: &str) -> XGBResult<&[f32]> {
286 let field = ffi::CString::new(field).unwrap();
287 let mut out_len = 0;
288 let mut out_dptr = ptr::null();
289 xgb_call!(xgboost_sys::XGDMatrixGetFloatInfo(self.handle,
290 field.as_ptr(),
291 &mut out_len,
292 &mut out_dptr))?;
293
294 Ok(unsafe { slice::from_raw_parts(out_dptr as *mut c_float, out_len as usize) })
295 }
296
297 fn set_float_info(&mut self, field: &str, array: &[f32]) -> XGBResult<()> {
298 let field = ffi::CString::new(field).unwrap();
299 xgb_call!(xgboost_sys::XGDMatrixSetFloatInfo(self.handle,
300 field.as_ptr(),
301 array.as_ptr(),
302 array.len() as u64))
303 }
304
305 fn get_uint_info(&self, field: &str) -> XGBResult<&[u32]> {
306 let field = ffi::CString::new(field).unwrap();
307 let mut out_len = 0;
308 let mut out_dptr = ptr::null();
309 xgb_call!(xgboost_sys::XGDMatrixGetUIntInfo(self.handle,
310 field.as_ptr(),
311 &mut out_len,
312 &mut out_dptr))?;
313
314 Ok(unsafe { slice::from_raw_parts(out_dptr as *mut c_uint, out_len as usize) })
315 }
316
317 fn set_uint_info(&mut self, field: &str, array: &[u32]) -> XGBResult<()> {
318 let field = ffi::CString::new(field).unwrap();
319 xgb_call!(xgboost_sys::XGDMatrixSetUIntInfo(self.handle,
320 field.as_ptr(),
321 array.as_ptr(),
322 array.len() as u64))
323 }
324}
325
326impl Drop for DMatrix {
327 fn drop(&mut self) {
328 xgb_call!(xgboost_sys::XGDMatrixFree(self.handle)).unwrap();
329 }
330}
331
332#[cfg(test)]
333mod tests {
334 use tempfile;
335 use super::*;
336 fn read_train_matrix() -> XGBResult<DMatrix> {
337 DMatrix::load("xgboost-sys/xgboost/demo/data/agaricus.txt.train")
338 }
339
340 #[test]
341 fn read_matrix() {
342 assert!(read_train_matrix().is_ok());
343 }
344
345 #[test]
346 fn read_num_rows() {
347 assert_eq!(read_train_matrix().unwrap().num_rows(), 6513);
348 }
349
350 #[test]
351 fn read_num_cols() {
352 assert_eq!(read_train_matrix().unwrap().num_cols(), 127);
353 }
354
355 #[test]
356 fn writing_and_reading() {
357 let dmat = read_train_matrix().unwrap();
358
359 let tmp_dir = tempfile::tempdir().expect("failed to create temp dir");
360 let out_path = tmp_dir.path().join("dmat.bin");
361 dmat.save(&out_path).unwrap();
362
363 let dmat2 = DMatrix::load(&out_path).unwrap();
364
365 assert_eq!(dmat.num_rows(), dmat2.num_rows());
366 assert_eq!(dmat.num_cols(), dmat2.num_cols());
367 }
369
370 #[test]
371 fn get_set_root_index() {
372 let mut dmat = read_train_matrix().unwrap();
373 assert_eq!(dmat.get_root_index().unwrap(), &[]);
374
375 let root_index = [3, 22, 1];
376 assert!(dmat.set_root_index(&root_index).is_ok());
377 assert_eq!(dmat.get_root_index().unwrap(), &[3, 22, 1]);
378 }
379
380 #[test]
381 fn get_set_labels() {
382 let mut dmat = read_train_matrix().unwrap();
383 assert_eq!(dmat.get_labels().unwrap().len(), 6513);
384
385 let label = [0.1, 0.0 -4.5, 11.29842, 333333.33];
386 assert!(dmat.set_labels(&label).is_ok());
387 assert_eq!(dmat.get_labels().unwrap(), label);
388 }
389
390 #[test]
391 fn get_set_weights() {
392 let mut dmat = read_train_matrix().unwrap();
393 assert_eq!(dmat.get_weights().unwrap(), &[]);
394
395 let weight = [1.0, 10.0, -123.456789, 44.9555];
396 assert!(dmat.set_weights(&weight).is_ok());
397 assert_eq!(dmat.get_weights().unwrap(), weight);
398 }
399
400 #[test]
401 fn get_set_base_margin() {
402 let mut dmat = read_train_matrix().unwrap();
403 assert_eq!(dmat.get_base_margin().unwrap(), &[]);
404
405 let base_margin = [0.00001, 0.000002, 1.23];
406 assert!(dmat.set_base_margin(&base_margin).is_ok());
407 assert_eq!(dmat.get_base_margin().unwrap(), base_margin);
408 }
409
410 #[test]
411 fn set_group() {
412 let mut dmat = read_train_matrix().unwrap();
413
414 let group = [1, 2, 3];
415 assert!(dmat.set_group(&group).is_ok());
416 }
417
418 #[test]
419 fn from_csr() {
420 let indptr = [0, 2, 3, 6, 8];
421 let indices = [0, 2, 2, 0, 1, 2, 1, 2];
422 let data = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
423
424 let dmat = DMatrix::from_csr(&indptr, &indices, &data, None).unwrap();
425 assert_eq!(dmat.num_rows(), 4);
426 assert_eq!(dmat.num_cols(), 3);
427
428 let dmat = DMatrix::from_csr(&indptr, &indices, &data, Some(10)).unwrap();
429 assert_eq!(dmat.num_rows(), 4);
430 assert_eq!(dmat.num_cols(), 10);
431 }
432
433 #[test]
434 fn from_csc() {
435 let indptr = [0, 2, 3, 6, 8];
436 let indices = [0, 2, 2, 0, 1, 2, 1, 2];
437 let data = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
438
439 let dmat = DMatrix::from_csc(&indptr, &indices, &data, None).unwrap();
440 assert_eq!(dmat.num_rows(), 3);
441 assert_eq!(dmat.num_cols(), 4);
442
443 let dmat = DMatrix::from_csc(&indptr, &indices, &data, Some(10)).unwrap();
444 assert_eq!(dmat.num_rows(), 10);
445 assert_eq!(dmat.num_cols(), 4);
446 }
447
448 #[test]
449 fn from_dense() {
450 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
451 let num_rows = 2;
452
453 let dmat = DMatrix::from_dense(&data, num_rows).unwrap();
454 assert_eq!(dmat.num_rows(), 2);
455 assert_eq!(dmat.num_cols(), 3);
456
457 let data = vec![1.0, 2.0, 3.0];
458 let num_rows = 3;
459
460 let dmat = DMatrix::from_dense(&data, num_rows).unwrap();
461 assert_eq!(dmat.num_rows(), 3);
462 assert_eq!(dmat.num_cols(), 1);
463 }
464
465 #[test]
466 fn slice_from_indices() {
467 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
468 let num_rows = 4;
469
470 let dmat = DMatrix::from_dense(&data, num_rows).unwrap();
471 assert_eq!(dmat.shape(), (4, 2));
472
473 assert_eq!(dmat.slice(&[]).unwrap().shape(), (0, 2));
474 assert_eq!(dmat.slice(&[1]).unwrap().shape(), (1, 2));
475 assert_eq!(dmat.slice(&[0, 1]).unwrap().shape(), (2, 2));
476 assert_eq!(dmat.slice(&[3, 2, 1]).unwrap().shape(), (3, 2));
477 assert!(dmat.slice(&[10, 11, 12]).is_err());
478 }
479
480 #[test]
481 fn slice() {
482 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0];
483 let num_rows = 4;
484
485 let dmat = DMatrix::from_dense(&data, num_rows).unwrap();
486 assert_eq!(dmat.shape(), (4, 3));
487
488 assert_eq!(dmat.slice(&[0, 1, 2, 3]).unwrap().shape(), (4, 3));
489 assert_eq!(dmat.slice(&[0, 1]).unwrap().shape(), (2, 3));
490 assert_eq!(dmat.slice(&[1, 0]).unwrap().shape(), (2, 3));
491 assert_eq!(dmat.slice(&[0, 1, 2]).unwrap().shape(), (3, 3));
492 assert_eq!(dmat.slice(&[3, 2, 1]).unwrap().shape(), (3, 3));
493 }
494}