scirs2_io/ml_framework/
datasets.rs

1//! Dataset utilities for ML frameworks
2#![allow(dead_code)]
3
4use crate::ml_framework::types::MLTensor;
5use std::collections::HashMap;
6
7/// ML dataset container
8#[derive(Clone)]
9pub struct MLDataset {
10    pub features: Vec<MLTensor>,
11    pub labels: Option<Vec<MLTensor>>,
12    pub metadata: HashMap<String, serde_json::Value>,
13}
14
15impl MLDataset {
16    /// Create new dataset
17    pub fn new(features: Vec<MLTensor>) -> Self {
18        Self {
19            features,
20            labels: None,
21            metadata: HashMap::new(),
22        }
23    }
24
25    /// Add labels
26    pub fn with_labels(mut self, labels: Vec<MLTensor>) -> Self {
27        self.labels = Some(labels);
28        self
29    }
30
31    /// Get number of samples
32    pub fn len(&self) -> usize {
33        self.features.len()
34    }
35
36    /// Check if empty
37    pub fn is_empty(&self) -> bool {
38        self.features.is_empty()
39    }
40
41    /// Split into train/test sets
42    pub fn train_test_split(&self, testratio: f32) -> (MLDataset, MLDataset) {
43        let n = self.len();
44        let test_size = (n as f32 * testratio) as usize;
45        let train_size = n - test_size;
46
47        let train_features = self.features[..train_size].to_vec();
48        let test_features = self.features[train_size..].to_vec();
49
50        let (train_labels, test_labels) = if let Some(labels) = &self.labels {
51            (
52                Some(labels[..train_size].to_vec()),
53                Some(labels[train_size..].to_vec()),
54            )
55        } else {
56            (None, None)
57        };
58
59        let train_dataset = MLDataset {
60            features: train_features,
61            labels: train_labels,
62            metadata: self.metadata.clone(),
63        };
64
65        let test_dataset = MLDataset {
66            features: test_features,
67            labels: test_labels,
68            metadata: self.metadata.clone(),
69        };
70
71        (train_dataset, test_dataset)
72    }
73}