1use std::sync::OnceLock;
8
9use crate::error::{Result, ScryLearnError};
10
11use crate::matrix::DenseMatrix;
12use crate::sparse::CscMatrix;
13
14#[derive(Clone, Debug, Default)]
16pub(crate) enum Storage {
17 #[default]
19 Dense,
20 Sparse(CscMatrix),
22}
23
24#[derive(Clone, Debug)]
26pub struct ColumnStats {
27 pub name: String,
29 pub count: usize,
31 pub mean: f64,
33 pub std: f64,
35 pub min: f64,
37 pub q25: f64,
39 pub median: f64,
41 pub q75: f64,
43 pub max: f64,
45}
46
47#[derive(Clone, Debug)]
52#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
53#[non_exhaustive]
54pub struct Dataset {
55 pub features: Vec<Vec<f64>>,
57 pub target: Vec<f64>,
59 pub feature_names: Vec<String>,
61 pub target_name: String,
63 pub class_labels: Option<Vec<String>>,
65 #[cfg_attr(feature = "serde", serde(skip))]
70 matrix: OnceLock<DenseMatrix>,
71 #[cfg_attr(feature = "serde", serde(skip))]
76 row_major_cache: Option<Vec<f64>>,
77 #[cfg_attr(feature = "serde", serde(skip))]
79 storage: Storage,
80}
81
82impl Dataset {
83 pub fn new(
90 features: Vec<Vec<f64>>,
91 target: Vec<f64>,
92 feature_names: Vec<String>,
93 target_name: impl Into<String>,
94 ) -> Self {
95 assert!(
96 feature_names.len() == features.len(),
97 "feature_names.len()={} but features.len()={}",
98 feature_names.len(),
99 features.len(),
100 );
101 if let Some(first) = features.first() {
102 for (i, col) in features.iter().enumerate().skip(1) {
103 assert!(
104 col.len() == first.len(),
105 "feature column {i} has {} rows but column 0 has {}",
106 col.len(),
107 first.len(),
108 );
109 }
110 }
111 Self {
112 features,
113 target,
114 feature_names,
115 target_name: target_name.into(),
116 class_labels: None,
117 matrix: OnceLock::new(),
118 row_major_cache: None,
119 storage: Storage::Dense,
120 }
121 }
122
123 pub fn from_matrix(
127 matrix: DenseMatrix,
128 target: Vec<f64>,
129 feature_names: Vec<String>,
130 target_name: impl Into<String>,
131 ) -> Self {
132 let features = matrix.to_col_vecs();
133 let cell = OnceLock::new();
134 let _ = cell.set(matrix);
135 Self {
136 features,
137 target,
138 feature_names,
139 target_name: target_name.into(),
140 class_labels: None,
141 matrix: cell,
142 row_major_cache: None,
143 storage: Storage::Dense,
144 }
145 }
146
147 #[inline]
152 pub fn matrix(&self) -> &DenseMatrix {
153 self.matrix.get_or_init(|| {
154 DenseMatrix::from_col_major_ref(&self.features)
155 .expect("DenseMatrix build from features failed")
156 })
157 }
158
159 #[cfg(feature = "csv")]
166 pub fn from_csv(path: &str, target_column: &str) -> Result<Self> {
167 let file = std::fs::File::open(path).map_err(ScryLearnError::Io)?;
168 Self::from_csv_reader(file, target_column)
169 }
170
171 #[cfg(feature = "csv")]
175 pub fn from_csv_reader(rdr: impl std::io::Read, target_column: &str) -> Result<Self> {
176 let mut csv_rdr = csv::ReaderBuilder::new()
177 .has_headers(true)
178 .flexible(true)
179 .from_reader(rdr);
180
181 let headers: Vec<String> = csv_rdr
182 .headers()
183 .map_err(|e| ScryLearnError::Csv(e.to_string()))?
184 .iter()
185 .map(std::string::ToString::to_string)
186 .collect();
187
188 let target_idx = headers
189 .iter()
190 .position(|h| h.eq_ignore_ascii_case(target_column))
191 .ok_or_else(|| ScryLearnError::InvalidColumn(target_column.to_string()))?;
192
193 let mut rows: Vec<Vec<String>> = Vec::new();
195 for result in csv_rdr.records() {
196 let record = result.map_err(|e| ScryLearnError::Csv(e.to_string()))?;
197 rows.push(
198 record
199 .iter()
200 .map(std::string::ToString::to_string)
201 .collect(),
202 );
203 }
204
205 if rows.is_empty() {
206 return Err(ScryLearnError::EmptyDataset);
207 }
208
209 let feature_indices: Vec<usize> = (0..headers.len()).filter(|&i| i != target_idx).collect();
211
212 let n_samples = rows.len();
213 let n_features = feature_indices.len();
214
215 let (target, class_labels) = parse_target_column(&rows, target_idx);
217
218 let mut features = vec![vec![0.0; n_samples]; n_features];
220 let mut feature_names = Vec::with_capacity(n_features);
221
222 for (feat_col, &col_idx) in feature_indices.iter().enumerate() {
223 feature_names.push(headers[col_idx].clone());
224 for (row_idx, row) in rows.iter().enumerate() {
225 let val = row.get(col_idx).map_or("", std::string::String::as_str);
226 features[feat_col][row_idx] = val.parse::<f64>().unwrap_or(f64::NAN);
227 }
228 }
229
230 Ok(Self {
231 features,
232 target,
233 feature_names,
234 target_name: headers[target_idx].clone(),
235 class_labels,
236 matrix: OnceLock::new(),
237 row_major_cache: None,
238 storage: Storage::Dense,
239 })
240 }
241
242 #[inline]
244 pub fn n_samples(&self) -> usize {
245 self.target.len()
246 }
247
248 #[inline]
250 pub fn n_features(&self) -> usize {
251 match &self.storage {
252 Storage::Sparse(csc) => csc.n_cols(),
253 Storage::Dense => self.features.len(),
254 }
255 }
256
257 pub fn n_classes(&self) -> usize {
259 self.class_labels.as_ref().map_or_else(
260 || {
261 let mut vals: Vec<i64> = self.target.iter().map(|&v| v as i64).collect();
262 vals.sort_unstable();
263 vals.dedup();
264 vals.len()
265 },
266 Vec::len,
267 )
268 }
269
270 pub fn feature(&self, idx: usize) -> &[f64] {
272 &self.features[idx]
273 }
274
275 pub fn sample(&self, idx: usize) -> Vec<f64> {
277 self.features.iter().map(|col| col[idx]).collect()
278 }
279
280 pub fn feature_matrix(&self) -> Vec<Vec<f64>> {
282 let n = self.n_samples();
283 let m = self.n_features();
284 let mut matrix = vec![vec![0.0; m]; n];
285 for (j, feat_col) in self.features.iter().enumerate() {
286 for (i, &val) in feat_col.iter().enumerate() {
287 matrix[i][j] = val;
288 }
289 }
290 matrix
291 }
292
293 pub fn flat_feature_matrix(&mut self) -> &[f64] {
298 if self.row_major_cache.is_none() {
299 let n = self.n_samples();
300 let m = self.n_features();
301 let mut buf = vec![0.0; n * m];
302 if let Some(mat) = self.matrix.get() {
303 let src = mat.as_slice();
304 for j in 0..m {
305 let col_off = j * n;
306 for i in 0..n {
307 buf[i * m + j] = src[col_off + i];
308 }
309 }
310 } else {
311 for j in 0..m {
312 for i in 0..n {
313 buf[i * m + j] = self.features[j][i];
314 }
315 }
316 }
317 self.row_major_cache = Some(buf);
318 }
319 self.row_major_cache
320 .as_ref()
321 .expect("row_major_cache populated above")
322 }
323
324 #[inline]
328 pub fn sample_row<'a>(&self, cache: &'a [f64], idx: usize) -> &'a [f64] {
329 let m = self.n_features();
330 &cache[idx * m..(idx + 1) * m]
331 }
332
333 pub fn subset(&self, indices: &[usize]) -> Self {
335 let target: Vec<f64> = indices.iter().map(|&i| self.target[i]).collect();
336
337 if let Storage::Sparse(csc) = &self.storage {
338 let new_csc = subset_csc(csc, indices);
339 return Self {
340 features: Vec::new(),
341 target,
342 feature_names: self.feature_names.clone(),
343 target_name: self.target_name.clone(),
344 class_labels: self.class_labels.clone(),
345 matrix: OnceLock::new(),
346 row_major_cache: None,
347 storage: Storage::Sparse(new_csc),
348 };
349 }
350
351 let features: Vec<Vec<f64>> = self
352 .features
353 .iter()
354 .map(|col| indices.iter().map(|&i| col[i]).collect())
355 .collect();
356 Self {
357 features,
358 target,
359 feature_names: self.feature_names.clone(),
360 target_name: self.target_name.clone(),
361 class_labels: self.class_labels.clone(),
362 matrix: OnceLock::new(),
363 row_major_cache: None,
364 storage: Storage::Dense,
365 }
366 }
367
368 pub fn sync_matrix(&mut self) {
374 self.matrix = OnceLock::new();
375 self.row_major_cache = None;
376 }
377
378 #[inline]
382 pub fn invalidate_matrix(&mut self) {
383 self.matrix = OnceLock::new();
384 self.row_major_cache = None;
385 }
386
387 pub fn validate_finite(&self) -> Result<()> {
389 if let Storage::Sparse(csc) = &self.storage {
391 for j in 0..csc.n_cols() {
392 for (i, v) in csc.col(j).iter() {
393 if !v.is_finite() {
394 let name = self
395 .feature_names
396 .get(j)
397 .map_or_else(|| format!("feature[{j}]"), std::clone::Clone::clone);
398 return Err(ScryLearnError::InvalidData(format!(
399 "non-finite value ({v}) in {name} at sample {i}"
400 )));
401 }
402 }
403 }
404 } else {
405 for (j, col) in self.features.iter().enumerate() {
406 for (i, &v) in col.iter().enumerate() {
407 if !v.is_finite() {
408 let name = self
409 .feature_names
410 .get(j)
411 .map_or_else(|| format!("feature[{j}]"), std::clone::Clone::clone);
412 return Err(ScryLearnError::InvalidData(format!(
413 "non-finite value ({v}) in {name} at sample {i}"
414 )));
415 }
416 }
417 }
418 }
419 for (i, &v) in self.target.iter().enumerate() {
420 if !v.is_finite() {
421 return Err(ScryLearnError::InvalidData(format!(
422 "non-finite value ({v}) in target at sample {i}"
423 )));
424 }
425 }
426 Ok(())
427 }
428
429 pub fn validate_no_inf(&self) -> Result<()> {
434 if let Storage::Sparse(csc) = &self.storage {
435 for j in 0..csc.n_cols() {
436 for (i, v) in csc.col(j).iter() {
437 if v.is_infinite() {
438 let name = self
439 .feature_names
440 .get(j)
441 .map_or_else(|| format!("feature[{j}]"), std::clone::Clone::clone);
442 return Err(ScryLearnError::InvalidData(format!(
443 "infinite value ({v}) in {name} at sample {i}"
444 )));
445 }
446 }
447 }
448 } else {
449 for (j, col) in self.features.iter().enumerate() {
450 for (i, &v) in col.iter().enumerate() {
451 if v.is_infinite() {
452 let name = self
453 .feature_names
454 .get(j)
455 .map_or_else(|| format!("feature[{j}]"), std::clone::Clone::clone);
456 return Err(ScryLearnError::InvalidData(format!(
457 "infinite value ({v}) in {name} at sample {i}"
458 )));
459 }
460 }
461 }
462 }
463 for (i, &v) in self.target.iter().enumerate() {
464 if v.is_infinite() {
465 return Err(ScryLearnError::InvalidData(format!(
466 "infinite value ({v}) in target at sample {i}"
467 )));
468 }
469 }
470 Ok(())
471 }
472
473 pub fn with_class_labels(mut self, labels: Vec<String>) -> Self {
475 self.class_labels = Some(labels);
476 self
477 }
478
479 pub fn from_sparse(
484 csc: CscMatrix,
485 target: Vec<f64>,
486 feature_names: Vec<String>,
487 target_name: impl Into<String>,
488 ) -> Self {
489 Self {
490 features: Vec::new(),
491 target,
492 feature_names,
493 target_name: target_name.into(),
494 class_labels: None,
495 matrix: OnceLock::new(),
496 row_major_cache: None,
497 storage: Storage::Sparse(csc),
498 }
499 }
500
501 #[inline]
503 pub fn is_sparse(&self) -> bool {
504 matches!(self.storage, Storage::Sparse(_))
505 }
506
507 pub fn sparse_csc(&self) -> Option<&CscMatrix> {
509 match &self.storage {
510 Storage::Sparse(m) => Some(m),
511 Storage::Dense => None,
512 }
513 }
514
515 pub fn sparse_csr(&self) -> Option<crate::sparse::CsrMatrix> {
517 self.sparse_csc().map(CscMatrix::to_csr)
518 }
519
520 pub fn summary(&self) -> Vec<ColumnStats> {
526 let n_feat = self.n_features();
527 let mut stats = Vec::with_capacity(n_feat + 1);
528
529 for j in 0..n_feat {
530 let name = self
531 .feature_names
532 .get(j)
533 .cloned()
534 .unwrap_or_else(|| format!("feature[{j}]"));
535
536 let col: Vec<f64> = if let Some(csc) = self.sparse_csc() {
537 let n_rows = csc.n_rows();
539 let mut dense = vec![0.0_f64; n_rows];
540 for (i, v) in csc.col(j).iter() {
541 dense[i] = v;
542 }
543 dense
544 } else {
545 self.features[j].clone()
546 };
547
548 stats.push(compute_column_stats(&name, &col));
549 }
550
551 stats.push(compute_column_stats(&self.target_name, &self.target));
552 stats
553 }
554
555 pub fn describe(&self) {
559 let stats = self.summary();
560 if stats.is_empty() {
561 return;
562 }
563
564 let labels = ["count", "mean", "std", "min", "25%", "50%", "75%", "max"];
565 let label_width = labels.iter().map(|l| l.len()).max().unwrap_or(0);
566
567 let col_widths: Vec<usize> = stats.iter().map(|s| s.name.len().max(12)).collect();
568
569 print!("{:>width$}", "", width = label_width);
571 for (i, s) in stats.iter().enumerate() {
572 print!(" {:>width$}", s.name, width = col_widths[i]);
573 }
574 println!();
575
576 for (row_idx, label) in labels.iter().enumerate() {
578 print!("{:>width$}", label, width = label_width);
579 for (i, s) in stats.iter().enumerate() {
580 let val = match row_idx {
581 0 => s.count as f64,
582 1 => s.mean,
583 2 => s.std,
584 3 => s.min,
585 4 => s.q25,
586 5 => s.median,
587 6 => s.q75,
588 7 => s.max,
589 _ => unreachable!(),
590 };
591 print!(" {:>width$.6}", val, width = col_widths[i]);
592 }
593 println!();
594 }
595 }
596
597 pub fn ensure_dense(&mut self) {
602 if let Storage::Sparse(csc) = &self.storage {
603 let n_cols = csc.n_cols();
604 let n_rows = csc.n_rows();
605 let mut features = vec![vec![0.0; n_rows]; n_cols];
606 for (j, feat_col) in features.iter_mut().enumerate() {
607 for (i, v) in csc.col(j).iter() {
608 feat_col[i] = v;
609 }
610 }
611 self.features = features;
612 self.matrix = OnceLock::new();
613 }
614 }
615}
616
617fn subset_csc(csc: &CscMatrix, indices: &[usize]) -> CscMatrix {
625 let n_new_rows = indices.len();
626 let n_cols = csc.n_cols();
627
628 let mut row_map = std::collections::HashMap::with_capacity(n_new_rows);
630 for (new_idx, &old_idx) in indices.iter().enumerate() {
631 row_map.insert(old_idx, new_idx);
632 }
633
634 let mut cols: Vec<Vec<f64>> = vec![vec![0.0; n_new_rows]; n_cols];
636 for (j, col) in cols.iter_mut().enumerate() {
637 for (old_row, val) in csc.col(j).iter() {
638 if let Some(&new_row) = row_map.get(&old_row) {
639 col[new_row] = val;
640 }
641 }
642 }
643
644 CscMatrix::from_dense(&cols)
645}
646
647fn compute_column_stats(name: &str, values: &[f64]) -> ColumnStats {
649 let mut sorted: Vec<f64> = values.iter().copied().filter(|v| v.is_finite()).collect();
650 sorted.sort_unstable_by(|a, b| a.total_cmp(b));
651
652 let count = sorted.len();
653 if count == 0 {
654 return ColumnStats {
655 name: name.to_string(),
656 count: 0,
657 mean: f64::NAN,
658 std: f64::NAN,
659 min: f64::NAN,
660 q25: f64::NAN,
661 median: f64::NAN,
662 q75: f64::NAN,
663 max: f64::NAN,
664 };
665 }
666
667 let sum: f64 = sorted.iter().sum();
668 let mean = sum / count as f64;
669
670 let std = if count <= 1 {
671 0.0
672 } else {
673 let var = sorted.iter().map(|v| (v - mean).powi(2)).sum::<f64>() / (count - 1) as f64;
674 var.sqrt()
675 };
676
677 let min = sorted[0];
678 let max = sorted[count - 1];
679 let q25 = percentile(&sorted, 0.25);
680 let median = percentile(&sorted, 0.50);
681 let q75 = percentile(&sorted, 0.75);
682
683 ColumnStats {
684 name: name.to_string(),
685 count,
686 mean,
687 std,
688 min,
689 q25,
690 median,
691 q75,
692 max,
693 }
694}
695
696fn percentile(sorted: &[f64], p: f64) -> f64 {
698 let n = sorted.len();
699 if n == 1 {
700 return sorted[0];
701 }
702 let idx = p * (n - 1) as f64;
703 let lo = idx.floor() as usize;
704 let hi = lo + 1;
705 let frac = idx - lo as f64;
706 if hi >= n {
707 sorted[lo]
708 } else {
709 sorted[lo] * (1.0 - frac) + sorted[hi] * frac
710 }
711}
712
713#[cfg(feature = "csv")]
714fn parse_target_column(rows: &[Vec<String>], col_idx: usize) -> (Vec<f64>, Option<Vec<String>>) {
718 let numeric: Vec<Option<f64>> = rows
720 .iter()
721 .map(|row| row.get(col_idx).and_then(|s| s.parse::<f64>().ok()))
722 .collect();
723
724 let all_numeric = numeric.iter().all(std::option::Option::is_some);
725 if all_numeric {
726 return (numeric.into_iter().flatten().collect(), None);
727 }
728
729 let mut labels: Vec<String> = Vec::new();
731 let mut encoded = Vec::with_capacity(rows.len());
732
733 for row in rows {
734 let val = row.get(col_idx).map_or("", std::string::String::as_str);
735 let idx = labels.iter().position(|l| l == val).unwrap_or_else(|| {
736 labels.push(val.to_string());
737 labels.len() - 1
738 });
739 encoded.push(idx as f64);
740 }
741
742 (encoded, Some(labels))
743}
744
745#[cfg(test)]
746#[allow(clippy::float_cmp)]
747mod tests {
748 use super::*;
749
750 #[test]
751 fn test_dataset_new() {
752 let features = vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]];
753 let target = vec![0.0, 1.0, 0.0];
754 let ds = Dataset::new(features, target, vec!["f1".into(), "f2".into()], "label");
755 assert_eq!(ds.n_samples(), 3);
756 assert_eq!(ds.n_features(), 2);
757 assert_eq!(ds.feature(0), &[1.0, 2.0, 3.0]);
758 assert_eq!(ds.sample(1), vec![2.0, 5.0]);
759 }
760
761 #[cfg(feature = "csv")]
762 #[test]
763 fn test_dataset_from_csv_reader() {
764 let csv = "f1,f2,target\n1.0,4.0,a\n2.0,5.0,b\n3.0,6.0,a\n";
765 let ds = Dataset::from_csv_reader(csv.as_bytes(), "target").unwrap();
766 assert_eq!(ds.n_samples(), 3);
767 assert_eq!(ds.n_features(), 2);
768 assert_eq!(ds.target, vec![0.0, 1.0, 0.0]);
769 assert_eq!(
770 ds.class_labels,
771 Some(vec!["a".to_string(), "b".to_string()])
772 );
773 }
774
775 #[test]
776 fn test_dataset_subset() {
777 let features = vec![vec![1.0, 2.0, 3.0, 4.0], vec![10.0, 20.0, 30.0, 40.0]];
778 let target = vec![0.0, 1.0, 0.0, 1.0];
779 let ds = Dataset::new(features, target, vec!["a".into(), "b".into()], "t");
780 let sub = ds.subset(&[0, 2]);
781 assert_eq!(sub.n_samples(), 2);
782 assert_eq!(sub.feature(0), &[1.0, 3.0]);
783 assert_eq!(sub.target, vec![0.0, 0.0]);
784 }
785
786 #[cfg(feature = "csv")]
787 #[test]
788 fn test_empty_csv() {
789 let csv = "f1,target\n";
790 let err = Dataset::from_csv_reader(csv.as_bytes(), "target");
791 assert!(err.is_err());
792 }
793
794 #[test]
795 fn test_n_classes() {
796 let ds = Dataset::new(
797 vec![vec![1.0, 2.0, 3.0]],
798 vec![0.0, 1.0, 2.0],
799 vec!["f".into()],
800 "t",
801 );
802 assert_eq!(ds.n_classes(), 3);
803 }
804
805 #[test]
806 fn test_matrix_accessor() {
807 let features = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
808 let ds = Dataset::new(features, vec![0.0, 1.0], vec!["a".into(), "b".into()], "t");
809 let mat = ds.matrix();
810 assert_eq!(mat.n_rows(), 2);
811 assert_eq!(mat.n_cols(), 2);
812 assert_eq!(mat.col(0), &[1.0, 2.0]);
813 assert_eq!(mat.col(1), &[3.0, 4.0]);
814 }
815
816 #[test]
817 fn test_from_matrix() {
818 let mat = DenseMatrix::from_col_major(vec![vec![1.0, 2.0], vec![3.0, 4.0]]).unwrap();
819 let ds = Dataset::from_matrix(mat, vec![0.0, 1.0], vec!["a".into(), "b".into()], "t");
820 assert_eq!(ds.n_samples(), 2);
821 assert_eq!(ds.n_features(), 2);
822 assert_eq!(ds.feature(0), &[1.0, 2.0]);
823 assert_eq!(ds.matrix().col(1), &[3.0, 4.0]);
824 }
825
826 fn sample_csc() -> CscMatrix {
831 CscMatrix::from_dense(&[vec![1.0, 0.0, 3.0], vec![0.0, 2.0, 0.0]])
835 }
836
837 #[test]
838 fn test_from_sparse_basic() {
839 let csc = sample_csc();
840 let ds = Dataset::from_sparse(csc, vec![0.0, 1.0, 0.0], vec!["a".into(), "b".into()], "t");
841 assert!(ds.is_sparse());
842 assert_eq!(ds.n_samples(), 3);
843 assert_eq!(ds.n_features(), 2);
844 }
845
846 #[test]
847 fn test_sparse_csc_accessor() {
848 let csc = sample_csc();
849 let ds = Dataset::from_sparse(csc, vec![0.0, 1.0, 0.0], vec!["a".into(), "b".into()], "t");
850 let csc_ref = ds.sparse_csc().expect("should have CSC");
851 assert_eq!(csc_ref.n_rows(), 3);
852 assert_eq!(csc_ref.n_cols(), 2);
853 assert_eq!(csc_ref.get(0, 0), 1.0);
854 assert_eq!(csc_ref.get(1, 1), 2.0);
855 assert_eq!(csc_ref.get(1, 0), 0.0);
856 }
857
858 #[test]
859 fn test_sparse_csr_conversion() {
860 let csc = sample_csc();
861 let ds = Dataset::from_sparse(csc, vec![0.0, 1.0, 0.0], vec!["a".into(), "b".into()], "t");
862 let csr = ds.sparse_csr().expect("should convert to CSR");
863 assert_eq!(csr.n_rows(), 3);
864 assert_eq!(csr.n_cols(), 2);
865 assert_eq!(csr.get(0, 0), 1.0);
866 assert_eq!(csr.get(2, 0), 3.0);
867 assert_eq!(csr.get(1, 1), 2.0);
868 }
869
870 #[test]
871 fn test_sparse_subset() {
872 let csc = sample_csc();
873 let ds = Dataset::from_sparse(csc, vec![0.0, 1.0, 2.0], vec!["a".into(), "b".into()], "t");
874 let sub = ds.subset(&[0, 2]);
875 assert!(sub.is_sparse());
876 assert_eq!(sub.n_samples(), 2);
877 assert_eq!(sub.n_features(), 2);
878 assert_eq!(sub.target, vec![0.0, 2.0]);
879 let csc_ref = sub.sparse_csc().unwrap();
880 assert_eq!(csc_ref.get(0, 0), 1.0); assert_eq!(csc_ref.get(1, 0), 3.0); }
883
884 #[test]
885 fn test_sparse_with_class_labels() {
886 let csc = sample_csc();
887 let ds = Dataset::from_sparse(csc, vec![0.0, 1.0, 0.0], vec!["a".into(), "b".into()], "t")
888 .with_class_labels(vec!["cat".into(), "dog".into()]);
889 assert!(ds.is_sparse());
890 assert_eq!(
891 ds.class_labels,
892 Some(vec!["cat".to_string(), "dog".to_string()])
893 );
894 }
895
896 #[test]
897 fn test_n_features_consistency() {
898 let dense_ds = Dataset::new(
900 vec![vec![1.0, 0.0, 3.0], vec![0.0, 2.0, 0.0]],
901 vec![0.0, 1.0, 0.0],
902 vec!["a".into(), "b".into()],
903 "t",
904 );
905 let csc = sample_csc();
906 let sparse_ds =
907 Dataset::from_sparse(csc, vec![0.0, 1.0, 0.0], vec!["a".into(), "b".into()], "t");
908 assert_eq!(dense_ds.n_features(), sparse_ds.n_features());
909 }
910
911 #[test]
912 fn test_ensure_dense() {
913 let csc = sample_csc();
914 let mut ds =
915 Dataset::from_sparse(csc, vec![0.0, 1.0, 0.0], vec!["a".into(), "b".into()], "t");
916 assert!(ds.features.is_empty());
917 ds.ensure_dense();
918 assert_eq!(ds.features.len(), 2);
919 assert_eq!(ds.features[0], vec![1.0, 0.0, 3.0]);
920 assert_eq!(ds.features[1], vec![0.0, 2.0, 0.0]);
921 }
922
923 #[test]
924 fn test_dense_not_sparse() {
925 let ds = Dataset::new(vec![vec![1.0, 2.0]], vec![0.0, 1.0], vec!["x".into()], "y");
926 assert!(!ds.is_sparse());
927 assert!(ds.sparse_csc().is_none());
928 assert!(ds.sparse_csr().is_none());
929 }
930
931 #[test]
932 fn test_matrix_lazy_rebuild_after_invalidate() {
933 let features = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
934 let mut ds = Dataset::new(features, vec![0.0, 1.0], vec!["a".into(), "b".into()], "t");
935
936 assert_eq!(ds.matrix().col(0), &[1.0, 2.0]);
938
939 ds.invalidate_matrix();
941
942 assert_eq!(ds.matrix().col(0), &[1.0, 2.0]);
944 assert_eq!(ds.matrix().col(1), &[3.0, 4.0]);
945 }
946
947 #[test]
948 fn test_describe_summary() {
949 let features = vec![vec![1.0, 2.0, 3.0, 4.0], vec![10.0, 20.0, 30.0, 40.0]];
950 let target = vec![0.0, 1.0, 0.0, 1.0];
951 let ds = Dataset::new(features, target, vec!["a".into(), "b".into()], "t");
952
953 let stats = ds.summary();
954 assert_eq!(stats.len(), 3); assert_eq!(stats[0].name, "a");
958 assert_eq!(stats[0].count, 4);
959 assert!((stats[0].mean - 2.5).abs() < 1e-10);
960 assert!((stats[0].min - 1.0).abs() < 1e-10);
961 assert!((stats[0].max - 4.0).abs() < 1e-10);
962
963 assert_eq!(stats[1].name, "b");
965 assert_eq!(stats[1].count, 4);
966 assert!((stats[1].mean - 25.0).abs() < 1e-10);
967 assert!((stats[1].min - 10.0).abs() < 1e-10);
968 assert!((stats[1].max - 40.0).abs() < 1e-10);
969
970 assert_eq!(stats[2].name, "t");
972 assert_eq!(stats[2].count, 4);
973 assert!((stats[2].mean - 0.5).abs() < 1e-10);
974
975 ds.describe();
977 }
978
979 #[test]
980 fn test_matrix_lazy_rebuild_reflects_feature_mutation() {
981 let features = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
982 let mut ds = Dataset::new(features, vec![0.0, 1.0], vec!["a".into(), "b".into()], "t");
983
984 ds.features[0][0] = 99.0;
986 ds.invalidate_matrix();
987
988 assert_eq!(ds.matrix().col(0), &[99.0, 2.0]);
990 }
991}