scirs2_io/ml_framework/
datasets.rs1#![allow(dead_code)]
3
4use crate::ml_framework::types::MLTensor;
5use std::collections::HashMap;
6
7#[derive(Clone)]
9pub struct MLDataset {
10 pub features: Vec<MLTensor>,
11 pub labels: Option<Vec<MLTensor>>,
12 pub metadata: HashMap<String, serde_json::Value>,
13}
14
15impl MLDataset {
16 pub fn new(features: Vec<MLTensor>) -> Self {
18 Self {
19 features,
20 labels: None,
21 metadata: HashMap::new(),
22 }
23 }
24
25 pub fn with_labels(mut self, labels: Vec<MLTensor>) -> Self {
27 self.labels = Some(labels);
28 self
29 }
30
31 pub fn len(&self) -> usize {
33 self.features.len()
34 }
35
36 pub fn is_empty(&self) -> bool {
38 self.features.is_empty()
39 }
40
41 pub fn train_test_split(&self, testratio: f32) -> (MLDataset, MLDataset) {
43 let n = self.len();
44 let test_size = (n as f32 * testratio) as usize;
45 let train_size = n - test_size;
46
47 let train_features = self.features[..train_size].to_vec();
48 let test_features = self.features[train_size..].to_vec();
49
50 let (train_labels, test_labels) = if let Some(labels) = &self.labels {
51 (
52 Some(labels[..train_size].to_vec()),
53 Some(labels[train_size..].to_vec()),
54 )
55 } else {
56 (None, None)
57 };
58
59 let train_dataset = MLDataset {
60 features: train_features,
61 labels: train_labels,
62 metadata: self.metadata.clone(),
63 };
64
65 let test_dataset = MLDataset {
66 features: test_features,
67 labels: test_labels,
68 metadata: self.metadata.clone(),
69 };
70
71 (train_dataset, test_dataset)
72 }
73}