1use crate::case::Case;
4use crate::error::{EvalError, EvalResult};
5use serde::{Deserialize, Serialize};
6use std::path::Path;
7
8#[derive(Debug, Clone)]
10pub struct Dataset<Inputs, Output = (), Metadata = ()> {
11 pub name: Option<String>,
13 pub description: Option<String>,
15 pub cases: Vec<Case<Inputs, Output, Metadata>>,
17}
18
19impl<Inputs, Output, Metadata> Dataset<Inputs, Output, Metadata> {
20 pub fn new() -> Self {
22 Self {
23 name: None,
24 description: None,
25 cases: Vec::new(),
26 }
27 }
28
29 pub fn with_name(mut self, name: impl Into<String>) -> Self {
31 self.name = Some(name.into());
32 self
33 }
34
35 pub fn with_description(mut self, desc: impl Into<String>) -> Self {
37 self.description = Some(desc.into());
38 self
39 }
40
41 pub fn case(mut self, case: Case<Inputs, Output, Metadata>) -> Self {
43 self.cases.push(case);
44 self
45 }
46
47 pub fn cases(
49 mut self,
50 cases: impl IntoIterator<Item = Case<Inputs, Output, Metadata>>,
51 ) -> Self {
52 self.cases.extend(cases);
53 self
54 }
55
56 pub fn len(&self) -> usize {
58 self.cases.len()
59 }
60
61 pub fn is_empty(&self) -> bool {
63 self.cases.is_empty()
64 }
65
66 pub fn filter_by_tag(&self, tag: &str) -> Vec<&Case<Inputs, Output, Metadata>> {
68 self.cases.iter().filter(|c| c.has_tag(tag)).collect()
69 }
70
71 pub fn filter<F>(&self, predicate: F) -> Vec<&Case<Inputs, Output, Metadata>>
73 where
74 F: Fn(&Case<Inputs, Output, Metadata>) -> bool,
75 {
76 self.cases.iter().filter(|c| predicate(c)).collect()
77 }
78
79 pub fn subset(&self, indices: &[usize]) -> Self
81 where
82 Inputs: Clone,
83 Output: Clone,
84 Metadata: Clone,
85 {
86 let cases = indices
87 .iter()
88 .filter_map(|&i| self.cases.get(i).cloned())
89 .collect();
90 Dataset {
91 name: self.name.clone(),
92 description: self.description.clone(),
93 cases,
94 }
95 }
96
97 pub fn take(&self, n: usize) -> Self
99 where
100 Inputs: Clone,
101 Output: Clone,
102 Metadata: Clone,
103 {
104 Dataset {
105 name: self.name.clone(),
106 description: self.description.clone(),
107 cases: self.cases.iter().take(n).cloned().collect(),
108 }
109 }
110
111 pub fn shuffle(&self, seed: u64) -> Self
113 where
114 Inputs: Clone,
115 Output: Clone,
116 Metadata: Clone,
117 {
118 use std::collections::hash_map::DefaultHasher;
119 use std::hash::{Hash, Hasher};
120
121 let mut cases = self.cases.clone();
122 let n = cases.len();
123
124 for i in 0..n {
125 let mut hasher = DefaultHasher::new();
126 seed.hash(&mut hasher);
127 i.hash(&mut hasher);
128 let j = (hasher.finish() as usize) % n;
129 cases.swap(i, j);
130 }
131
132 Dataset {
133 name: self.name.clone(),
134 description: self.description.clone(),
135 cases,
136 }
137 }
138}
139
140impl<Inputs, Output, Metadata> Default for Dataset<Inputs, Output, Metadata> {
141 fn default() -> Self {
142 Self::new()
143 }
144}
145
146#[derive(Debug, Clone, Serialize, Deserialize)]
148pub struct StringDataset {
149 #[serde(default)]
151 pub name: Option<String>,
152 #[serde(default)]
154 pub description: Option<String>,
155 pub cases: Vec<StringCase>,
157}
158
159#[derive(Debug, Clone, Serialize, Deserialize)]
161pub struct StringCase {
162 #[serde(default)]
164 pub name: Option<String>,
165 pub input: String,
167 #[serde(default)]
169 pub expected: Option<String>,
170 #[serde(default)]
172 pub tags: Vec<String>,
173}
174
175impl StringDataset {
176 pub fn from_json(path: impl AsRef<Path>) -> EvalResult<Self> {
178 let content = std::fs::read_to_string(path.as_ref())?;
179 Self::from_json_str(&content)
180 }
181
182 pub fn from_json_str(content: &str) -> EvalResult<Self> {
184 serde_json::from_str(content).map_err(|e| EvalError::Serialization(e.to_string()))
185 }
186
187 pub fn from_yaml(path: impl AsRef<Path>) -> EvalResult<Self> {
189 let content = std::fs::read_to_string(path.as_ref())?;
190 Self::from_yaml_str(&content)
191 }
192
193 pub fn from_yaml_str(content: &str) -> EvalResult<Self> {
195 serde_yaml::from_str(content).map_err(|e| EvalError::Yaml(e.to_string()))
196 }
197
198 pub fn to_json(&self, path: impl AsRef<Path>) -> EvalResult<()> {
200 let content = self.to_json_string()?;
201 std::fs::write(path.as_ref(), content)?;
202 Ok(())
203 }
204
205 pub fn to_json_string(&self) -> EvalResult<String> {
207 serde_json::to_string_pretty(self).map_err(|e| EvalError::Serialization(e.to_string()))
208 }
209
210 pub fn to_yaml(&self, path: impl AsRef<Path>) -> EvalResult<()> {
212 let content = self.to_yaml_string()?;
213 std::fs::write(path.as_ref(), content)?;
214 Ok(())
215 }
216
217 pub fn to_yaml_string(&self) -> EvalResult<String> {
219 serde_yaml::to_string(self).map_err(|e| EvalError::Yaml(e.to_string()))
220 }
221
222 pub fn to_dataset(&self) -> Dataset<String, String> {
224 let cases = self
225 .cases
226 .iter()
227 .map(|c| {
228 Case::new(c.input.clone())
229 .with_name(c.name.clone().unwrap_or_default())
230 .with_tags(c.tags.clone())
231 .with_expected_output(c.expected.clone().unwrap_or_default())
232 })
233 .collect();
234
235 Dataset {
236 name: self.name.clone(),
237 description: self.description.clone(),
238 cases,
239 }
240 }
241}
242
243#[derive(Debug)]
245pub struct DatasetBuilder<Inputs, Output = (), Metadata = ()> {
246 name: Option<String>,
247 description: Option<String>,
248 cases: Vec<Case<Inputs, Output, Metadata>>,
249}
250
251impl<Inputs, Output, Metadata> DatasetBuilder<Inputs, Output, Metadata> {
252 pub fn new() -> Self {
254 Self {
255 name: None,
256 description: None,
257 cases: Vec::new(),
258 }
259 }
260
261 pub fn name(mut self, name: impl Into<String>) -> Self {
263 self.name = Some(name.into());
264 self
265 }
266
267 pub fn description(mut self, desc: impl Into<String>) -> Self {
269 self.description = Some(desc.into());
270 self
271 }
272
273 pub fn case(mut self, case: Case<Inputs, Output, Metadata>) -> Self {
275 self.cases.push(case);
276 self
277 }
278
279 pub fn build(self) -> Dataset<Inputs, Output, Metadata> {
281 Dataset {
282 name: self.name,
283 description: self.description,
284 cases: self.cases,
285 }
286 }
287}
288
289impl<Inputs, Output, Metadata> Default for DatasetBuilder<Inputs, Output, Metadata> {
290 fn default() -> Self {
291 Self::new()
292 }
293}
294
295#[cfg(test)]
296mod tests {
297 use super::*;
298
299 #[test]
300 fn test_dataset_new() {
301 let dataset: Dataset<String> = Dataset::new()
302 .with_name("test")
303 .with_description("Test dataset");
304
305 assert_eq!(dataset.name, Some("test".to_string()));
306 assert!(dataset.is_empty());
307 }
308
309 #[test]
310 fn test_dataset_add_cases() {
311 let dataset: Dataset<String> = Dataset::new()
312 .case(Case::new("input1".to_string()))
313 .case(Case::new("input2".to_string()));
314
315 assert_eq!(dataset.len(), 2);
316 }
317
318 #[test]
319 fn test_dataset_filter_by_tag() {
320 let dataset: Dataset<String> = Dataset::new()
321 .case(Case::new("a".to_string()).with_tag("unit"))
322 .case(Case::new("b".to_string()).with_tag("integration"))
323 .case(Case::new("c".to_string()).with_tag("unit"));
324
325 let unit_tests = dataset.filter_by_tag("unit");
326 assert_eq!(unit_tests.len(), 2);
327 }
328
329 #[test]
330 fn test_dataset_take() {
331 let dataset: Dataset<String> = Dataset::new()
332 .case(Case::new("a".to_string()))
333 .case(Case::new("b".to_string()))
334 .case(Case::new("c".to_string()));
335
336 let subset = dataset.take(2);
337 assert_eq!(subset.len(), 2);
338 }
339
340 #[test]
341 fn test_string_dataset_json_roundtrip() {
342 let dataset = StringDataset {
343 name: Some("test".to_string()),
344 description: None,
345 cases: vec![StringCase {
346 name: Some("case1".to_string()),
347 input: "hello".to_string(),
348 expected: Some("world".to_string()),
349 tags: vec![],
350 }],
351 };
352
353 let json = dataset.to_json_string().unwrap();
354 let loaded = StringDataset::from_json_str(&json).unwrap();
355
356 assert_eq!(loaded.name, Some("test".to_string()));
357 assert_eq!(loaded.cases.len(), 1);
358 }
359
360 #[test]
361 fn test_dataset_builder() {
362 let dataset: Dataset<String> = DatasetBuilder::new()
363 .name("builder test")
364 .case(Case::new("input".to_string()))
365 .build();
366
367 assert_eq!(dataset.name, Some("builder test".to_string()));
368 assert_eq!(dataset.len(), 1);
369 }
370}