Skip to main content

trustformers_core/evaluation/
datasets.rs

1// Dataset loading and management for evaluation
2use anyhow::Result;
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::path::Path;
6
7/// Dataset sample for evaluation
8#[derive(Debug, Clone, Serialize, Deserialize)]
9pub struct DatasetSample {
10    pub input: String,
11    pub target: String,
12    pub metadata: HashMap<String, serde_json::Value>,
13}
14
15/// Dataset collection for evaluation tasks
16#[derive(Debug, Clone)]
17pub struct EvaluationDataset {
18    pub name: String,
19    pub samples: Vec<DatasetSample>,
20    pub metadata: HashMap<String, serde_json::Value>,
21}
22
23impl EvaluationDataset {
24    pub fn new(name: String) -> Self {
25        Self {
26            name,
27            samples: Vec::new(),
28            metadata: HashMap::new(),
29        }
30    }
31
32    pub fn add_sample(&mut self, sample: DatasetSample) {
33        self.samples.push(sample);
34    }
35
36    pub fn add_samples(&mut self, samples: Vec<DatasetSample>) {
37        self.samples.extend(samples);
38    }
39
40    pub fn len(&self) -> usize {
41        self.samples.len()
42    }
43
44    pub fn is_empty(&self) -> bool {
45        self.samples.is_empty()
46    }
47
48    pub fn get_inputs(&self) -> Vec<String> {
49        self.samples.iter().map(|s| s.input.clone()).collect()
50    }
51
52    pub fn get_targets(&self) -> Vec<String> {
53        self.samples.iter().map(|s| s.target.clone()).collect()
54    }
55
56    pub fn sample(&self, n: usize) -> EvaluationDataset {
57        let mut sampled = self.clone();
58        if n < self.samples.len() {
59            sampled.samples.truncate(n);
60        }
61        sampled
62    }
63
64    pub fn shuffle(&mut self, seed: Option<u64>) {
65        use scirs2_core::random::*;
66
67        if let Some(seed) = seed {
68            let mut rng = StdRng::seed_from_u64(seed);
69            self.samples.shuffle(&mut rng);
70        } else {
71            let mut rng = thread_rng();
72            self.samples.shuffle(&mut rng);
73        }
74    }
75
76    pub fn split(&self, train_ratio: f64) -> (EvaluationDataset, EvaluationDataset) {
77        let split_idx = (self.samples.len() as f64 * train_ratio) as usize;
78
79        let mut train_dataset = EvaluationDataset::new(format!("{}_train", self.name));
80        train_dataset.samples = self.samples[..split_idx].to_vec();
81        train_dataset.metadata = self.metadata.clone();
82
83        let mut test_dataset = EvaluationDataset::new(format!("{}_test", self.name));
84        test_dataset.samples = self.samples[split_idx..].to_vec();
85        test_dataset.metadata = self.metadata.clone();
86
87        (train_dataset, test_dataset)
88    }
89}
90
91/// Dataset loader trait
92pub trait DatasetLoader {
93    fn load(&self, dataset_name: &str, split: &str) -> Result<EvaluationDataset>;
94    fn available_datasets(&self) -> Vec<String>;
95    fn available_splits(&self, dataset_name: &str) -> Vec<String>;
96}
97
98/// File-based dataset loader
99pub struct FileDatasetLoader {
100    data_dir: String,
101}
102
103impl FileDatasetLoader {
104    pub fn new<P: AsRef<Path>>(data_dir: P) -> Self {
105        Self {
106            data_dir: data_dir.as_ref().to_string_lossy().to_string(),
107        }
108    }
109
110    fn get_dataset_path(&self, dataset_name: &str, split: &str) -> String {
111        format!("{}/{}/{}.jsonl", self.data_dir, dataset_name, split)
112    }
113
114    fn load_jsonl(&self, path: &str) -> Result<Vec<DatasetSample>> {
115        let content = std::fs::read_to_string(path)?;
116        let mut samples = Vec::new();
117
118        for line in content.lines() {
119            if line.trim().is_empty() {
120                continue;
121            }
122
123            let json_value: serde_json::Value = serde_json::from_str(line)?;
124
125            let input = json_value
126                .get("input")
127                .or_else(|| json_value.get("text"))
128                .or_else(|| json_value.get("sentence"))
129                .and_then(|v| v.as_str())
130                .unwrap_or("")
131                .to_string();
132
133            let target = json_value
134                .get("target")
135                .or_else(|| json_value.get("label"))
136                .or_else(|| json_value.get("output"))
137                .and_then(|v| v.as_str())
138                .unwrap_or("")
139                .to_string();
140
141            let mut metadata = HashMap::new();
142            if let Some(obj) = json_value.as_object() {
143                for (key, value) in obj {
144                    if key != "input"
145                        && key != "text"
146                        && key != "sentence"
147                        && key != "target"
148                        && key != "label"
149                        && key != "output"
150                    {
151                        metadata.insert(key.clone(), value.clone());
152                    }
153                }
154            }
155
156            samples.push(DatasetSample {
157                input,
158                target,
159                metadata,
160            });
161        }
162
163        Ok(samples)
164    }
165}
166
167impl DatasetLoader for FileDatasetLoader {
168    fn load(&self, dataset_name: &str, split: &str) -> Result<EvaluationDataset> {
169        let path = self.get_dataset_path(dataset_name, split);
170        let samples = self.load_jsonl(&path)?;
171
172        let mut dataset = EvaluationDataset::new(format!("{}_{}", dataset_name, split));
173        dataset.add_samples(samples);
174
175        // Add dataset metadata
176        dataset.metadata.insert(
177            "source".to_string(),
178            serde_json::Value::String("file".to_string()),
179        );
180        dataset.metadata.insert("path".to_string(), serde_json::Value::String(path));
181        dataset.metadata.insert(
182            "dataset_name".to_string(),
183            serde_json::Value::String(dataset_name.to_string()),
184        );
185        dataset.metadata.insert(
186            "split".to_string(),
187            serde_json::Value::String(split.to_string()),
188        );
189
190        Ok(dataset)
191    }
192
193    fn available_datasets(&self) -> Vec<String> {
194        let data_path = Path::new(&self.data_dir);
195        if !data_path.exists() {
196            return Vec::new();
197        }
198
199        let mut datasets = Vec::new();
200        if let Ok(entries) = std::fs::read_dir(data_path) {
201            for entry in entries.flatten() {
202                if entry.file_type().map(|ft| ft.is_dir()).unwrap_or(false) {
203                    if let Some(name) = entry.file_name().to_str() {
204                        datasets.push(name.to_string());
205                    }
206                }
207            }
208        }
209
210        datasets.sort();
211        datasets
212    }
213
214    fn available_splits(&self, dataset_name: &str) -> Vec<String> {
215        let dataset_path = Path::new(&self.data_dir).join(dataset_name);
216        if !dataset_path.exists() {
217            return Vec::new();
218        }
219
220        let mut splits = Vec::new();
221        if let Ok(entries) = std::fs::read_dir(dataset_path) {
222            for entry in entries.flatten() {
223                if entry.file_type().map(|ft| ft.is_file()).unwrap_or(false) {
224                    if let Some(name) = entry.file_name().to_str() {
225                        if name.ends_with(".jsonl") {
226                            let split_name = name.strip_suffix(".jsonl").unwrap_or(name);
227                            splits.push(split_name.to_string());
228                        }
229                    }
230                }
231            }
232        }
233
234        splits.sort();
235        splits
236    }
237}
238
239/// In-memory dataset loader for testing
240pub struct MemoryDatasetLoader {
241    datasets: HashMap<String, HashMap<String, EvaluationDataset>>,
242}
243
244impl MemoryDatasetLoader {
245    pub fn new() -> Self {
246        Self {
247            datasets: HashMap::new(),
248        }
249    }
250
251    pub fn add_dataset(&mut self, dataset: EvaluationDataset, dataset_name: &str, split: &str) {
252        self.datasets
253            .entry(dataset_name.to_string())
254            .or_default()
255            .insert(split.to_string(), dataset);
256    }
257
258    pub fn create_dummy_glue_datasets(&mut self) {
259        // Create dummy GLUE datasets for testing
260        self.create_dummy_classification_dataset("cola", "train", 1000, vec!["0", "1"]);
261        self.create_dummy_classification_dataset("cola", "validation", 200, vec!["0", "1"]);
262
263        self.create_dummy_classification_dataset(
264            "sst2",
265            "train",
266            2000,
267            vec!["negative", "positive"],
268        );
269        self.create_dummy_classification_dataset(
270            "sst2",
271            "validation",
272            400,
273            vec!["negative", "positive"],
274        );
275
276        self.create_dummy_classification_dataset(
277            "mrpc",
278            "train",
279            1500,
280            vec!["not_equivalent", "equivalent"],
281        );
282        self.create_dummy_classification_dataset(
283            "mrpc",
284            "validation",
285            300,
286            vec!["not_equivalent", "equivalent"],
287        );
288
289        self.create_dummy_classification_dataset(
290            "mnli",
291            "train",
292            10000,
293            vec!["entailment", "neutral", "contradiction"],
294        );
295        self.create_dummy_classification_dataset(
296            "mnli",
297            "validation_matched",
298            2000,
299            vec!["entailment", "neutral", "contradiction"],
300        );
301        self.create_dummy_classification_dataset(
302            "mnli",
303            "validation_mismatched",
304            2000,
305            vec!["entailment", "neutral", "contradiction"],
306        );
307    }
308
309    fn create_dummy_classification_dataset(
310        &mut self,
311        name: &str,
312        split: &str,
313        size: usize,
314        labels: Vec<&str>,
315    ) {
316        let mut samples = Vec::new();
317
318        for i in 0..size {
319            let input = match name {
320                "cola" => format!("This is sentence number {} for acceptability.", i),
321                "sst2" => {
322                    if i % 2 == 0 {
323                        format!("This is a positive movie review {}.", i)
324                    } else {
325                        format!("This is a negative movie review {}.", i)
326                    }
327                },
328                "mrpc" => format!("Sentence A {}. [SEP] Sentence B {}.", i, i + 1),
329                "mnli" => format!("Premise sentence {}. [SEP] Hypothesis sentence {}.", i, i),
330                _ => format!("Input text {} for task {}.", i, name),
331            };
332
333            let target = labels[i % labels.len()].to_string();
334
335            let mut metadata = HashMap::new();
336            metadata.insert("idx".to_string(), serde_json::Value::Number(i.into()));
337            metadata.insert(
338                "task".to_string(),
339                serde_json::Value::String(name.to_string()),
340            );
341
342            samples.push(DatasetSample {
343                input,
344                target,
345                metadata,
346            });
347        }
348
349        let mut dataset = EvaluationDataset::new(format!("{}_{}", name, split));
350        dataset.add_samples(samples);
351        dataset.metadata.insert(
352            "source".to_string(),
353            serde_json::Value::String("memory".to_string()),
354        );
355        dataset.metadata.insert(
356            "task_type".to_string(),
357            serde_json::Value::String("classification".to_string()),
358        );
359        dataset.metadata.insert(
360            "num_labels".to_string(),
361            serde_json::Value::Number(labels.len().into()),
362        );
363
364        self.add_dataset(dataset, name, split);
365    }
366}
367
368impl DatasetLoader for MemoryDatasetLoader {
369    fn load(&self, dataset_name: &str, split: &str) -> Result<EvaluationDataset> {
370        self.datasets
371            .get(dataset_name)
372            .and_then(|splits| splits.get(split))
373            .cloned()
374            .ok_or_else(|| anyhow::anyhow!("Dataset {}:{} not found", dataset_name, split))
375    }
376
377    fn available_datasets(&self) -> Vec<String> {
378        let mut datasets: Vec<String> = self.datasets.keys().cloned().collect();
379        datasets.sort();
380        datasets
381    }
382
383    fn available_splits(&self, dataset_name: &str) -> Vec<String> {
384        self.datasets
385            .get(dataset_name)
386            .map(|splits| {
387                let mut split_names: Vec<String> = splits.keys().cloned().collect();
388                split_names.sort();
389                split_names
390            })
391            .unwrap_or_default()
392    }
393}
394
395impl Default for MemoryDatasetLoader {
396    fn default() -> Self {
397        Self::new()
398    }
399}
400
401/// Dataset manager for coordinating multiple loaders
402pub struct DatasetManager {
403    loaders: HashMap<String, Box<dyn DatasetLoader>>,
404    default_loader: String,
405}
406
407impl DatasetManager {
408    pub fn new() -> Self {
409        let mut manager = Self {
410            loaders: HashMap::new(),
411            default_loader: "memory".to_string(),
412        };
413
414        // Register default memory loader
415        manager.register_loader("memory".to_string(), Box::new(MemoryDatasetLoader::new()));
416
417        manager
418    }
419
420    pub fn register_loader(&mut self, name: String, loader: Box<dyn DatasetLoader>) {
421        self.loaders.insert(name, loader);
422    }
423
424    pub fn register_file_loader<P: AsRef<Path>>(&mut self, name: String, data_dir: P) {
425        let loader = FileDatasetLoader::new(data_dir);
426        self.loaders.insert(name, Box::new(loader));
427    }
428
429    pub fn set_default_loader(&mut self, name: String) {
430        if self.loaders.contains_key(&name) {
431            self.default_loader = name;
432        }
433    }
434
435    pub fn load_dataset(
436        &self,
437        dataset_name: &str,
438        split: &str,
439        loader_name: Option<&str>,
440    ) -> Result<EvaluationDataset> {
441        let loader_name = loader_name.unwrap_or(&self.default_loader);
442        let loader = self
443            .loaders
444            .get(loader_name)
445            .ok_or_else(|| anyhow::anyhow!("Unknown loader: {}", loader_name))?;
446
447        loader.load(dataset_name, split)
448    }
449
450    pub fn list_datasets(&self, loader_name: Option<&str>) -> Vec<String> {
451        let loader_name = loader_name.unwrap_or(&self.default_loader);
452        self.loaders
453            .get(loader_name)
454            .map(|loader| loader.available_datasets())
455            .unwrap_or_default()
456    }
457
458    pub fn list_splits(&self, dataset_name: &str, loader_name: Option<&str>) -> Vec<String> {
459        let loader_name = loader_name.unwrap_or(&self.default_loader);
460        self.loaders
461            .get(loader_name)
462            .map(|loader| loader.available_splits(dataset_name))
463            .unwrap_or_default()
464    }
465}
466
467impl Default for DatasetManager {
468    fn default() -> Self {
469        Self::new()
470    }
471}
472
473#[cfg(test)]
474mod tests {
475    use super::*;
476    use std::collections::HashMap;
477
478    #[test]
479    fn test_dataset_sample() {
480        let mut metadata = HashMap::new();
481        metadata.insert("idx".to_string(), serde_json::Value::Number(0.into()));
482
483        let sample = DatasetSample {
484            input: "Test input".to_string(),
485            target: "Test target".to_string(),
486            metadata,
487        };
488
489        assert_eq!(sample.input, "Test input");
490        assert_eq!(sample.target, "Test target");
491        assert_eq!(sample.metadata.len(), 1);
492    }
493
494    #[test]
495    fn test_evaluation_dataset() {
496        let mut dataset = EvaluationDataset::new("test_dataset".to_string());
497        assert_eq!(dataset.name, "test_dataset");
498        assert_eq!(dataset.len(), 0);
499        assert!(dataset.is_empty());
500
501        let sample = DatasetSample {
502            input: "Input 1".to_string(),
503            target: "Target 1".to_string(),
504            metadata: HashMap::new(),
505        };
506        dataset.add_sample(sample);
507
508        assert_eq!(dataset.len(), 1);
509        assert!(!dataset.is_empty());
510
511        let inputs = dataset.get_inputs();
512        let targets = dataset.get_targets();
513        assert_eq!(inputs, vec!["Input 1"]);
514        assert_eq!(targets, vec!["Target 1"]);
515    }
516
517    #[test]
518    fn test_dataset_sampling() {
519        let mut dataset = EvaluationDataset::new("test".to_string());
520
521        for i in 0..10 {
522            dataset.add_sample(DatasetSample {
523                input: format!("Input {}", i),
524                target: format!("Target {}", i),
525                metadata: HashMap::new(),
526            });
527        }
528
529        let sampled = dataset.sample(5);
530        assert_eq!(sampled.len(), 5);
531        assert_eq!(sampled.name, "test");
532    }
533
534    #[test]
535    fn test_dataset_split() {
536        let mut dataset = EvaluationDataset::new("test".to_string());
537
538        for i in 0..10 {
539            dataset.add_sample(DatasetSample {
540                input: format!("Input {}", i),
541                target: format!("Target {}", i),
542                metadata: HashMap::new(),
543            });
544        }
545
546        let (train, test) = dataset.split(0.7);
547        assert_eq!(train.len(), 7);
548        assert_eq!(test.len(), 3);
549        assert_eq!(train.name, "test_train");
550        assert_eq!(test.name, "test_test");
551    }
552
553    #[test]
554    fn test_memory_dataset_loader() {
555        let mut loader = MemoryDatasetLoader::new();
556
557        let mut dataset = EvaluationDataset::new("test_train".to_string());
558        dataset.add_sample(DatasetSample {
559            input: "Test input".to_string(),
560            target: "Test target".to_string(),
561            metadata: HashMap::new(),
562        });
563
564        loader.add_dataset(dataset, "test", "train");
565
566        let available_datasets = loader.available_datasets();
567        assert_eq!(available_datasets, vec!["test"]);
568
569        let available_splits = loader.available_splits("test");
570        assert_eq!(available_splits, vec!["train"]);
571
572        let loaded_dataset = loader.load("test", "train").expect("operation failed in test");
573        assert_eq!(loaded_dataset.len(), 1);
574        assert_eq!(loaded_dataset.name, "test_train");
575    }
576
577    #[test]
578    fn test_dummy_glue_datasets() {
579        let mut loader = MemoryDatasetLoader::new();
580        loader.create_dummy_glue_datasets();
581
582        let datasets = loader.available_datasets();
583        assert!(datasets.contains(&"cola".to_string()));
584        assert!(datasets.contains(&"sst2".to_string()));
585        assert!(datasets.contains(&"mrpc".to_string()));
586        assert!(datasets.contains(&"mnli".to_string()));
587
588        let cola_splits = loader.available_splits("cola");
589        assert!(cola_splits.contains(&"train".to_string()));
590        assert!(cola_splits.contains(&"validation".to_string()));
591
592        let cola_train = loader.load("cola", "train").expect("operation failed in test");
593        assert_eq!(cola_train.len(), 1000);
594    }
595
596    #[test]
597    fn test_dataset_manager() {
598        let mut manager = DatasetManager::new();
599
600        // Test with memory loader
601        let datasets = manager.list_datasets(None);
602        assert_eq!(datasets.len(), 0); // Empty by default
603
604        // Add a file loader
605        manager.register_file_loader("file".to_string(), "/tmp");
606
607        // Test loading (will fail since /tmp doesn't have datasets, but tests the interface)
608        let result = manager.load_dataset("nonexistent", "train", Some("file"));
609        assert!(result.is_err());
610    }
611}