Skip to main content

serdes_ai_evals/
dataset.rs

1//! Dataset management for evaluation cases.
2
3use crate::case::Case;
4use crate::error::{EvalError, EvalResult};
5use serde::{Deserialize, Serialize};
6use std::path::Path;
7
8/// A collection of test cases.
9#[derive(Debug, Clone)]
10pub struct Dataset<Inputs, Output = (), Metadata = ()> {
11    /// Dataset name.
12    pub name: Option<String>,
13    /// Description.
14    pub description: Option<String>,
15    /// Test cases.
16    pub cases: Vec<Case<Inputs, Output, Metadata>>,
17}
18
19impl<Inputs, Output, Metadata> Dataset<Inputs, Output, Metadata> {
20    /// Create a new empty dataset.
21    pub fn new() -> Self {
22        Self {
23            name: None,
24            description: None,
25            cases: Vec::new(),
26        }
27    }
28
29    /// Set the dataset name.
30    pub fn with_name(mut self, name: impl Into<String>) -> Self {
31        self.name = Some(name.into());
32        self
33    }
34
35    /// Set the description.
36    pub fn with_description(mut self, desc: impl Into<String>) -> Self {
37        self.description = Some(desc.into());
38        self
39    }
40
41    /// Add a case.
42    pub fn case(mut self, case: Case<Inputs, Output, Metadata>) -> Self {
43        self.cases.push(case);
44        self
45    }
46
47    /// Add multiple cases.
48    pub fn cases(
49        mut self,
50        cases: impl IntoIterator<Item = Case<Inputs, Output, Metadata>>,
51    ) -> Self {
52        self.cases.extend(cases);
53        self
54    }
55
56    /// Get the number of cases.
57    pub fn len(&self) -> usize {
58        self.cases.len()
59    }
60
61    /// Check if empty.
62    pub fn is_empty(&self) -> bool {
63        self.cases.is_empty()
64    }
65
66    /// Filter cases by tag.
67    pub fn filter_by_tag(&self, tag: &str) -> Vec<&Case<Inputs, Output, Metadata>> {
68        self.cases.iter().filter(|c| c.has_tag(tag)).collect()
69    }
70
71    /// Filter cases by predicate.
72    pub fn filter<F>(&self, predicate: F) -> Vec<&Case<Inputs, Output, Metadata>>
73    where
74        F: Fn(&Case<Inputs, Output, Metadata>) -> bool,
75    {
76        self.cases.iter().filter(|c| predicate(c)).collect()
77    }
78
79    /// Get a subset of cases.
80    pub fn subset(&self, indices: &[usize]) -> Self
81    where
82        Inputs: Clone,
83        Output: Clone,
84        Metadata: Clone,
85    {
86        let cases = indices
87            .iter()
88            .filter_map(|&i| self.cases.get(i).cloned())
89            .collect();
90        Dataset {
91            name: self.name.clone(),
92            description: self.description.clone(),
93            cases,
94        }
95    }
96
97    /// Take first N cases.
98    pub fn take(&self, n: usize) -> Self
99    where
100        Inputs: Clone,
101        Output: Clone,
102        Metadata: Clone,
103    {
104        Dataset {
105            name: self.name.clone(),
106            description: self.description.clone(),
107            cases: self.cases.iter().take(n).cloned().collect(),
108        }
109    }
110
111    /// Shuffle cases (deterministically with seed).
112    pub fn shuffle(&self, seed: u64) -> Self
113    where
114        Inputs: Clone,
115        Output: Clone,
116        Metadata: Clone,
117    {
118        use std::collections::hash_map::DefaultHasher;
119        use std::hash::{Hash, Hasher};
120
121        let mut cases = self.cases.clone();
122        let n = cases.len();
123
124        for i in 0..n {
125            let mut hasher = DefaultHasher::new();
126            seed.hash(&mut hasher);
127            i.hash(&mut hasher);
128            let j = (hasher.finish() as usize) % n;
129            cases.swap(i, j);
130        }
131
132        Dataset {
133            name: self.name.clone(),
134            description: self.description.clone(),
135            cases,
136        }
137    }
138}
139
140impl<Inputs, Output, Metadata> Default for Dataset<Inputs, Output, Metadata> {
141    fn default() -> Self {
142        Self::new()
143    }
144}
145
146/// String-based dataset for easy serialization.
147#[derive(Debug, Clone, Serialize, Deserialize)]
148pub struct StringDataset {
149    /// Dataset name.
150    #[serde(default)]
151    pub name: Option<String>,
152    /// Description.
153    #[serde(default)]
154    pub description: Option<String>,
155    /// Test cases as (input, expected_output) pairs.
156    pub cases: Vec<StringCase>,
157}
158
159/// String-based case for serialization.
160#[derive(Debug, Clone, Serialize, Deserialize)]
161pub struct StringCase {
162    /// Case name.
163    #[serde(default)]
164    pub name: Option<String>,
165    /// Input string.
166    pub input: String,
167    /// Expected output.
168    #[serde(default)]
169    pub expected: Option<String>,
170    /// Tags.
171    #[serde(default)]
172    pub tags: Vec<String>,
173}
174
175impl StringDataset {
176    /// Load from JSON file.
177    pub fn from_json(path: impl AsRef<Path>) -> EvalResult<Self> {
178        let content = std::fs::read_to_string(path.as_ref())?;
179        Self::from_json_str(&content)
180    }
181
182    /// Load from JSON string.
183    pub fn from_json_str(content: &str) -> EvalResult<Self> {
184        serde_json::from_str(content).map_err(|e| EvalError::Serialization(e.to_string()))
185    }
186
187    /// Load from YAML file.
188    pub fn from_yaml(path: impl AsRef<Path>) -> EvalResult<Self> {
189        let content = std::fs::read_to_string(path.as_ref())?;
190        Self::from_yaml_str(&content)
191    }
192
193    /// Load from YAML string.
194    pub fn from_yaml_str(content: &str) -> EvalResult<Self> {
195        serde_yaml::from_str(content).map_err(|e| EvalError::Yaml(e.to_string()))
196    }
197
198    /// Save to JSON file.
199    pub fn to_json(&self, path: impl AsRef<Path>) -> EvalResult<()> {
200        let content = self.to_json_string()?;
201        std::fs::write(path.as_ref(), content)?;
202        Ok(())
203    }
204
205    /// Serialize to JSON string.
206    pub fn to_json_string(&self) -> EvalResult<String> {
207        serde_json::to_string_pretty(self).map_err(|e| EvalError::Serialization(e.to_string()))
208    }
209
210    /// Save to YAML file.
211    pub fn to_yaml(&self, path: impl AsRef<Path>) -> EvalResult<()> {
212        let content = self.to_yaml_string()?;
213        std::fs::write(path.as_ref(), content)?;
214        Ok(())
215    }
216
217    /// Serialize to YAML string.
218    pub fn to_yaml_string(&self) -> EvalResult<String> {
219        serde_yaml::to_string(self).map_err(|e| EvalError::Yaml(e.to_string()))
220    }
221
222    /// Convert to generic Dataset.
223    pub fn to_dataset(&self) -> Dataset<String, String> {
224        let cases = self
225            .cases
226            .iter()
227            .map(|c| {
228                Case::new(c.input.clone())
229                    .with_name(c.name.clone().unwrap_or_default())
230                    .with_tags(c.tags.clone())
231                    .with_expected_output(c.expected.clone().unwrap_or_default())
232            })
233            .collect();
234
235        Dataset {
236            name: self.name.clone(),
237            description: self.description.clone(),
238            cases,
239        }
240    }
241}
242
243/// Builder for creating datasets programmatically.
244#[derive(Debug)]
245pub struct DatasetBuilder<Inputs, Output = (), Metadata = ()> {
246    name: Option<String>,
247    description: Option<String>,
248    cases: Vec<Case<Inputs, Output, Metadata>>,
249}
250
251impl<Inputs, Output, Metadata> DatasetBuilder<Inputs, Output, Metadata> {
252    /// Create a new builder.
253    pub fn new() -> Self {
254        Self {
255            name: None,
256            description: None,
257            cases: Vec::new(),
258        }
259    }
260
261    /// Set the name.
262    pub fn name(mut self, name: impl Into<String>) -> Self {
263        self.name = Some(name.into());
264        self
265    }
266
267    /// Set the description.
268    pub fn description(mut self, desc: impl Into<String>) -> Self {
269        self.description = Some(desc.into());
270        self
271    }
272
273    /// Add a case.
274    pub fn case(mut self, case: Case<Inputs, Output, Metadata>) -> Self {
275        self.cases.push(case);
276        self
277    }
278
279    /// Build the dataset.
280    pub fn build(self) -> Dataset<Inputs, Output, Metadata> {
281        Dataset {
282            name: self.name,
283            description: self.description,
284            cases: self.cases,
285        }
286    }
287}
288
289impl<Inputs, Output, Metadata> Default for DatasetBuilder<Inputs, Output, Metadata> {
290    fn default() -> Self {
291        Self::new()
292    }
293}
294
295#[cfg(test)]
296mod tests {
297    use super::*;
298
299    #[test]
300    fn test_dataset_new() {
301        let dataset: Dataset<String> = Dataset::new()
302            .with_name("test")
303            .with_description("Test dataset");
304
305        assert_eq!(dataset.name, Some("test".to_string()));
306        assert!(dataset.is_empty());
307    }
308
309    #[test]
310    fn test_dataset_add_cases() {
311        let dataset: Dataset<String> = Dataset::new()
312            .case(Case::new("input1".to_string()))
313            .case(Case::new("input2".to_string()));
314
315        assert_eq!(dataset.len(), 2);
316    }
317
318    #[test]
319    fn test_dataset_filter_by_tag() {
320        let dataset: Dataset<String> = Dataset::new()
321            .case(Case::new("a".to_string()).with_tag("unit"))
322            .case(Case::new("b".to_string()).with_tag("integration"))
323            .case(Case::new("c".to_string()).with_tag("unit"));
324
325        let unit_tests = dataset.filter_by_tag("unit");
326        assert_eq!(unit_tests.len(), 2);
327    }
328
329    #[test]
330    fn test_dataset_take() {
331        let dataset: Dataset<String> = Dataset::new()
332            .case(Case::new("a".to_string()))
333            .case(Case::new("b".to_string()))
334            .case(Case::new("c".to_string()));
335
336        let subset = dataset.take(2);
337        assert_eq!(subset.len(), 2);
338    }
339
340    #[test]
341    fn test_string_dataset_json_roundtrip() {
342        let dataset = StringDataset {
343            name: Some("test".to_string()),
344            description: None,
345            cases: vec![StringCase {
346                name: Some("case1".to_string()),
347                input: "hello".to_string(),
348                expected: Some("world".to_string()),
349                tags: vec![],
350            }],
351        };
352
353        let json = dataset.to_json_string().unwrap();
354        let loaded = StringDataset::from_json_str(&json).unwrap();
355
356        assert_eq!(loaded.name, Some("test".to_string()));
357        assert_eq!(loaded.cases.len(), 1);
358    }
359
360    #[test]
361    fn test_dataset_builder() {
362        let dataset: Dataset<String> = DatasetBuilder::new()
363            .name("builder test")
364            .case(Case::new("input".to_string()))
365            .build();
366
367        assert_eq!(dataset.name, Some("builder test".to_string()));
368        assert_eq!(dataset.len(), 1);
369    }
370}