1use anyhow::Result;
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::path::Path;
6
7#[derive(Debug, Clone, Serialize, Deserialize)]
9pub struct DatasetSample {
10 pub input: String,
11 pub target: String,
12 pub metadata: HashMap<String, serde_json::Value>,
13}
14
15#[derive(Debug, Clone)]
17pub struct EvaluationDataset {
18 pub name: String,
19 pub samples: Vec<DatasetSample>,
20 pub metadata: HashMap<String, serde_json::Value>,
21}
22
23impl EvaluationDataset {
24 pub fn new(name: String) -> Self {
25 Self {
26 name,
27 samples: Vec::new(),
28 metadata: HashMap::new(),
29 }
30 }
31
32 pub fn add_sample(&mut self, sample: DatasetSample) {
33 self.samples.push(sample);
34 }
35
36 pub fn add_samples(&mut self, samples: Vec<DatasetSample>) {
37 self.samples.extend(samples);
38 }
39
40 pub fn len(&self) -> usize {
41 self.samples.len()
42 }
43
44 pub fn is_empty(&self) -> bool {
45 self.samples.is_empty()
46 }
47
48 pub fn get_inputs(&self) -> Vec<String> {
49 self.samples.iter().map(|s| s.input.clone()).collect()
50 }
51
52 pub fn get_targets(&self) -> Vec<String> {
53 self.samples.iter().map(|s| s.target.clone()).collect()
54 }
55
56 pub fn sample(&self, n: usize) -> EvaluationDataset {
57 let mut sampled = self.clone();
58 if n < self.samples.len() {
59 sampled.samples.truncate(n);
60 }
61 sampled
62 }
63
64 pub fn shuffle(&mut self, seed: Option<u64>) {
65 use scirs2_core::random::*;
66
67 if let Some(seed) = seed {
68 let mut rng = StdRng::seed_from_u64(seed);
69 self.samples.shuffle(&mut rng);
70 } else {
71 let mut rng = thread_rng();
72 self.samples.shuffle(&mut rng);
73 }
74 }
75
76 pub fn split(&self, train_ratio: f64) -> (EvaluationDataset, EvaluationDataset) {
77 let split_idx = (self.samples.len() as f64 * train_ratio) as usize;
78
79 let mut train_dataset = EvaluationDataset::new(format!("{}_train", self.name));
80 train_dataset.samples = self.samples[..split_idx].to_vec();
81 train_dataset.metadata = self.metadata.clone();
82
83 let mut test_dataset = EvaluationDataset::new(format!("{}_test", self.name));
84 test_dataset.samples = self.samples[split_idx..].to_vec();
85 test_dataset.metadata = self.metadata.clone();
86
87 (train_dataset, test_dataset)
88 }
89}
90
91pub trait DatasetLoader {
93 fn load(&self, dataset_name: &str, split: &str) -> Result<EvaluationDataset>;
94 fn available_datasets(&self) -> Vec<String>;
95 fn available_splits(&self, dataset_name: &str) -> Vec<String>;
96}
97
98pub struct FileDatasetLoader {
100 data_dir: String,
101}
102
103impl FileDatasetLoader {
104 pub fn new<P: AsRef<Path>>(data_dir: P) -> Self {
105 Self {
106 data_dir: data_dir.as_ref().to_string_lossy().to_string(),
107 }
108 }
109
110 fn get_dataset_path(&self, dataset_name: &str, split: &str) -> String {
111 format!("{}/{}/{}.jsonl", self.data_dir, dataset_name, split)
112 }
113
114 fn load_jsonl(&self, path: &str) -> Result<Vec<DatasetSample>> {
115 let content = std::fs::read_to_string(path)?;
116 let mut samples = Vec::new();
117
118 for line in content.lines() {
119 if line.trim().is_empty() {
120 continue;
121 }
122
123 let json_value: serde_json::Value = serde_json::from_str(line)?;
124
125 let input = json_value
126 .get("input")
127 .or_else(|| json_value.get("text"))
128 .or_else(|| json_value.get("sentence"))
129 .and_then(|v| v.as_str())
130 .unwrap_or("")
131 .to_string();
132
133 let target = json_value
134 .get("target")
135 .or_else(|| json_value.get("label"))
136 .or_else(|| json_value.get("output"))
137 .and_then(|v| v.as_str())
138 .unwrap_or("")
139 .to_string();
140
141 let mut metadata = HashMap::new();
142 if let Some(obj) = json_value.as_object() {
143 for (key, value) in obj {
144 if key != "input"
145 && key != "text"
146 && key != "sentence"
147 && key != "target"
148 && key != "label"
149 && key != "output"
150 {
151 metadata.insert(key.clone(), value.clone());
152 }
153 }
154 }
155
156 samples.push(DatasetSample {
157 input,
158 target,
159 metadata,
160 });
161 }
162
163 Ok(samples)
164 }
165}
166
167impl DatasetLoader for FileDatasetLoader {
168 fn load(&self, dataset_name: &str, split: &str) -> Result<EvaluationDataset> {
169 let path = self.get_dataset_path(dataset_name, split);
170 let samples = self.load_jsonl(&path)?;
171
172 let mut dataset = EvaluationDataset::new(format!("{}_{}", dataset_name, split));
173 dataset.add_samples(samples);
174
175 dataset.metadata.insert(
177 "source".to_string(),
178 serde_json::Value::String("file".to_string()),
179 );
180 dataset.metadata.insert("path".to_string(), serde_json::Value::String(path));
181 dataset.metadata.insert(
182 "dataset_name".to_string(),
183 serde_json::Value::String(dataset_name.to_string()),
184 );
185 dataset.metadata.insert(
186 "split".to_string(),
187 serde_json::Value::String(split.to_string()),
188 );
189
190 Ok(dataset)
191 }
192
193 fn available_datasets(&self) -> Vec<String> {
194 let data_path = Path::new(&self.data_dir);
195 if !data_path.exists() {
196 return Vec::new();
197 }
198
199 let mut datasets = Vec::new();
200 if let Ok(entries) = std::fs::read_dir(data_path) {
201 for entry in entries.flatten() {
202 if entry.file_type().map(|ft| ft.is_dir()).unwrap_or(false) {
203 if let Some(name) = entry.file_name().to_str() {
204 datasets.push(name.to_string());
205 }
206 }
207 }
208 }
209
210 datasets.sort();
211 datasets
212 }
213
214 fn available_splits(&self, dataset_name: &str) -> Vec<String> {
215 let dataset_path = Path::new(&self.data_dir).join(dataset_name);
216 if !dataset_path.exists() {
217 return Vec::new();
218 }
219
220 let mut splits = Vec::new();
221 if let Ok(entries) = std::fs::read_dir(dataset_path) {
222 for entry in entries.flatten() {
223 if entry.file_type().map(|ft| ft.is_file()).unwrap_or(false) {
224 if let Some(name) = entry.file_name().to_str() {
225 if name.ends_with(".jsonl") {
226 let split_name = name.strip_suffix(".jsonl").unwrap_or(name);
227 splits.push(split_name.to_string());
228 }
229 }
230 }
231 }
232 }
233
234 splits.sort();
235 splits
236 }
237}
238
239pub struct MemoryDatasetLoader {
241 datasets: HashMap<String, HashMap<String, EvaluationDataset>>,
242}
243
244impl MemoryDatasetLoader {
245 pub fn new() -> Self {
246 Self {
247 datasets: HashMap::new(),
248 }
249 }
250
251 pub fn add_dataset(&mut self, dataset: EvaluationDataset, dataset_name: &str, split: &str) {
252 self.datasets
253 .entry(dataset_name.to_string())
254 .or_default()
255 .insert(split.to_string(), dataset);
256 }
257
258 pub fn create_dummy_glue_datasets(&mut self) {
259 self.create_dummy_classification_dataset("cola", "train", 1000, vec!["0", "1"]);
261 self.create_dummy_classification_dataset("cola", "validation", 200, vec!["0", "1"]);
262
263 self.create_dummy_classification_dataset(
264 "sst2",
265 "train",
266 2000,
267 vec!["negative", "positive"],
268 );
269 self.create_dummy_classification_dataset(
270 "sst2",
271 "validation",
272 400,
273 vec!["negative", "positive"],
274 );
275
276 self.create_dummy_classification_dataset(
277 "mrpc",
278 "train",
279 1500,
280 vec!["not_equivalent", "equivalent"],
281 );
282 self.create_dummy_classification_dataset(
283 "mrpc",
284 "validation",
285 300,
286 vec!["not_equivalent", "equivalent"],
287 );
288
289 self.create_dummy_classification_dataset(
290 "mnli",
291 "train",
292 10000,
293 vec!["entailment", "neutral", "contradiction"],
294 );
295 self.create_dummy_classification_dataset(
296 "mnli",
297 "validation_matched",
298 2000,
299 vec!["entailment", "neutral", "contradiction"],
300 );
301 self.create_dummy_classification_dataset(
302 "mnli",
303 "validation_mismatched",
304 2000,
305 vec!["entailment", "neutral", "contradiction"],
306 );
307 }
308
309 fn create_dummy_classification_dataset(
310 &mut self,
311 name: &str,
312 split: &str,
313 size: usize,
314 labels: Vec<&str>,
315 ) {
316 let mut samples = Vec::new();
317
318 for i in 0..size {
319 let input = match name {
320 "cola" => format!("This is sentence number {} for acceptability.", i),
321 "sst2" => {
322 if i % 2 == 0 {
323 format!("This is a positive movie review {}.", i)
324 } else {
325 format!("This is a negative movie review {}.", i)
326 }
327 },
328 "mrpc" => format!("Sentence A {}. [SEP] Sentence B {}.", i, i + 1),
329 "mnli" => format!("Premise sentence {}. [SEP] Hypothesis sentence {}.", i, i),
330 _ => format!("Input text {} for task {}.", i, name),
331 };
332
333 let target = labels[i % labels.len()].to_string();
334
335 let mut metadata = HashMap::new();
336 metadata.insert("idx".to_string(), serde_json::Value::Number(i.into()));
337 metadata.insert(
338 "task".to_string(),
339 serde_json::Value::String(name.to_string()),
340 );
341
342 samples.push(DatasetSample {
343 input,
344 target,
345 metadata,
346 });
347 }
348
349 let mut dataset = EvaluationDataset::new(format!("{}_{}", name, split));
350 dataset.add_samples(samples);
351 dataset.metadata.insert(
352 "source".to_string(),
353 serde_json::Value::String("memory".to_string()),
354 );
355 dataset.metadata.insert(
356 "task_type".to_string(),
357 serde_json::Value::String("classification".to_string()),
358 );
359 dataset.metadata.insert(
360 "num_labels".to_string(),
361 serde_json::Value::Number(labels.len().into()),
362 );
363
364 self.add_dataset(dataset, name, split);
365 }
366}
367
368impl DatasetLoader for MemoryDatasetLoader {
369 fn load(&self, dataset_name: &str, split: &str) -> Result<EvaluationDataset> {
370 self.datasets
371 .get(dataset_name)
372 .and_then(|splits| splits.get(split))
373 .cloned()
374 .ok_or_else(|| anyhow::anyhow!("Dataset {}:{} not found", dataset_name, split))
375 }
376
377 fn available_datasets(&self) -> Vec<String> {
378 let mut datasets: Vec<String> = self.datasets.keys().cloned().collect();
379 datasets.sort();
380 datasets
381 }
382
383 fn available_splits(&self, dataset_name: &str) -> Vec<String> {
384 self.datasets
385 .get(dataset_name)
386 .map(|splits| {
387 let mut split_names: Vec<String> = splits.keys().cloned().collect();
388 split_names.sort();
389 split_names
390 })
391 .unwrap_or_default()
392 }
393}
394
395impl Default for MemoryDatasetLoader {
396 fn default() -> Self {
397 Self::new()
398 }
399}
400
401pub struct DatasetManager {
403 loaders: HashMap<String, Box<dyn DatasetLoader>>,
404 default_loader: String,
405}
406
407impl DatasetManager {
408 pub fn new() -> Self {
409 let mut manager = Self {
410 loaders: HashMap::new(),
411 default_loader: "memory".to_string(),
412 };
413
414 manager.register_loader("memory".to_string(), Box::new(MemoryDatasetLoader::new()));
416
417 manager
418 }
419
420 pub fn register_loader(&mut self, name: String, loader: Box<dyn DatasetLoader>) {
421 self.loaders.insert(name, loader);
422 }
423
424 pub fn register_file_loader<P: AsRef<Path>>(&mut self, name: String, data_dir: P) {
425 let loader = FileDatasetLoader::new(data_dir);
426 self.loaders.insert(name, Box::new(loader));
427 }
428
429 pub fn set_default_loader(&mut self, name: String) {
430 if self.loaders.contains_key(&name) {
431 self.default_loader = name;
432 }
433 }
434
435 pub fn load_dataset(
436 &self,
437 dataset_name: &str,
438 split: &str,
439 loader_name: Option<&str>,
440 ) -> Result<EvaluationDataset> {
441 let loader_name = loader_name.unwrap_or(&self.default_loader);
442 let loader = self
443 .loaders
444 .get(loader_name)
445 .ok_or_else(|| anyhow::anyhow!("Unknown loader: {}", loader_name))?;
446
447 loader.load(dataset_name, split)
448 }
449
450 pub fn list_datasets(&self, loader_name: Option<&str>) -> Vec<String> {
451 let loader_name = loader_name.unwrap_or(&self.default_loader);
452 self.loaders
453 .get(loader_name)
454 .map(|loader| loader.available_datasets())
455 .unwrap_or_default()
456 }
457
458 pub fn list_splits(&self, dataset_name: &str, loader_name: Option<&str>) -> Vec<String> {
459 let loader_name = loader_name.unwrap_or(&self.default_loader);
460 self.loaders
461 .get(loader_name)
462 .map(|loader| loader.available_splits(dataset_name))
463 .unwrap_or_default()
464 }
465}
466
467impl Default for DatasetManager {
468 fn default() -> Self {
469 Self::new()
470 }
471}
472
473#[cfg(test)]
474mod tests {
475 use super::*;
476 use std::collections::HashMap;
477
478 #[test]
479 fn test_dataset_sample() {
480 let mut metadata = HashMap::new();
481 metadata.insert("idx".to_string(), serde_json::Value::Number(0.into()));
482
483 let sample = DatasetSample {
484 input: "Test input".to_string(),
485 target: "Test target".to_string(),
486 metadata,
487 };
488
489 assert_eq!(sample.input, "Test input");
490 assert_eq!(sample.target, "Test target");
491 assert_eq!(sample.metadata.len(), 1);
492 }
493
494 #[test]
495 fn test_evaluation_dataset() {
496 let mut dataset = EvaluationDataset::new("test_dataset".to_string());
497 assert_eq!(dataset.name, "test_dataset");
498 assert_eq!(dataset.len(), 0);
499 assert!(dataset.is_empty());
500
501 let sample = DatasetSample {
502 input: "Input 1".to_string(),
503 target: "Target 1".to_string(),
504 metadata: HashMap::new(),
505 };
506 dataset.add_sample(sample);
507
508 assert_eq!(dataset.len(), 1);
509 assert!(!dataset.is_empty());
510
511 let inputs = dataset.get_inputs();
512 let targets = dataset.get_targets();
513 assert_eq!(inputs, vec!["Input 1"]);
514 assert_eq!(targets, vec!["Target 1"]);
515 }
516
517 #[test]
518 fn test_dataset_sampling() {
519 let mut dataset = EvaluationDataset::new("test".to_string());
520
521 for i in 0..10 {
522 dataset.add_sample(DatasetSample {
523 input: format!("Input {}", i),
524 target: format!("Target {}", i),
525 metadata: HashMap::new(),
526 });
527 }
528
529 let sampled = dataset.sample(5);
530 assert_eq!(sampled.len(), 5);
531 assert_eq!(sampled.name, "test");
532 }
533
534 #[test]
535 fn test_dataset_split() {
536 let mut dataset = EvaluationDataset::new("test".to_string());
537
538 for i in 0..10 {
539 dataset.add_sample(DatasetSample {
540 input: format!("Input {}", i),
541 target: format!("Target {}", i),
542 metadata: HashMap::new(),
543 });
544 }
545
546 let (train, test) = dataset.split(0.7);
547 assert_eq!(train.len(), 7);
548 assert_eq!(test.len(), 3);
549 assert_eq!(train.name, "test_train");
550 assert_eq!(test.name, "test_test");
551 }
552
553 #[test]
554 fn test_memory_dataset_loader() {
555 let mut loader = MemoryDatasetLoader::new();
556
557 let mut dataset = EvaluationDataset::new("test_train".to_string());
558 dataset.add_sample(DatasetSample {
559 input: "Test input".to_string(),
560 target: "Test target".to_string(),
561 metadata: HashMap::new(),
562 });
563
564 loader.add_dataset(dataset, "test", "train");
565
566 let available_datasets = loader.available_datasets();
567 assert_eq!(available_datasets, vec!["test"]);
568
569 let available_splits = loader.available_splits("test");
570 assert_eq!(available_splits, vec!["train"]);
571
572 let loaded_dataset = loader.load("test", "train").expect("operation failed in test");
573 assert_eq!(loaded_dataset.len(), 1);
574 assert_eq!(loaded_dataset.name, "test_train");
575 }
576
577 #[test]
578 fn test_dummy_glue_datasets() {
579 let mut loader = MemoryDatasetLoader::new();
580 loader.create_dummy_glue_datasets();
581
582 let datasets = loader.available_datasets();
583 assert!(datasets.contains(&"cola".to_string()));
584 assert!(datasets.contains(&"sst2".to_string()));
585 assert!(datasets.contains(&"mrpc".to_string()));
586 assert!(datasets.contains(&"mnli".to_string()));
587
588 let cola_splits = loader.available_splits("cola");
589 assert!(cola_splits.contains(&"train".to_string()));
590 assert!(cola_splits.contains(&"validation".to_string()));
591
592 let cola_train = loader.load("cola", "train").expect("operation failed in test");
593 assert_eq!(cola_train.len(), 1000);
594 }
595
596 #[test]
597 fn test_dataset_manager() {
598 let mut manager = DatasetManager::new();
599
600 let datasets = manager.list_datasets(None);
602 assert_eq!(datasets.len(), 0); manager.register_file_loader("file".to_string(), "/tmp");
606
607 let result = manager.load_dataset("nonexistent", "train", Some("file"));
609 assert!(result.is_err());
610 }
611}