Skip to main content

synth_ai_core/data/
rubrics.rs

1//! Rubric and criterion types for evaluation.
2//!
3//! Rubrics define evaluation criteria with weights and descriptions.
4
5use serde::{Deserialize, Serialize};
6use std::collections::HashSet;
7
8/// A single evaluation criterion.
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct Criterion {
11    /// Unique identifier for this criterion.
12    pub id: String,
13    /// Human-readable description of what this criterion evaluates.
14    pub description: String,
15    /// Weight for aggregation (must be positive).
16    #[serde(default = "default_weight")]
17    pub weight: f64,
18    /// Whether this criterion must be satisfied.
19    #[serde(default)]
20    pub required: bool,
21    /// Optional scoring scale (e.g., 0-10, 0-100).
22    #[serde(default)]
23    pub scale_max: Option<f64>,
24    /// Optional examples of good/bad responses.
25    #[serde(default)]
26    pub examples: Vec<CriterionExample>,
27}
28
29fn default_weight() -> f64 {
30    1.0
31}
32
33/// Example for a criterion showing expected scores.
34#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct CriterionExample {
36    /// The input/output being evaluated.
37    pub content: String,
38    /// Expected score for this example.
39    pub expected_score: f64,
40    /// Explanation of why this score is appropriate.
41    #[serde(default)]
42    pub explanation: Option<String>,
43}
44
45impl Criterion {
46    /// Create a new criterion with the given ID and description.
47    pub fn new(id: impl Into<String>, description: impl Into<String>) -> Self {
48        Self {
49            id: id.into(),
50            description: description.into(),
51            weight: 1.0,
52            required: false,
53            scale_max: None,
54            examples: Vec::new(),
55        }
56    }
57
58    /// Set the weight for this criterion.
59    pub fn with_weight(mut self, weight: f64) -> Self {
60        self.weight = weight;
61        self
62    }
63
64    /// Mark this criterion as required.
65    pub fn required(mut self) -> Self {
66        self.required = true;
67        self
68    }
69
70    /// Validate this criterion's configuration.
71    pub fn validate(&self) -> Result<(), String> {
72        if self.id.is_empty() {
73            return Err("Criterion ID cannot be empty".to_string());
74        }
75        if self.weight <= 0.0 {
76            return Err(format!(
77                "Criterion '{}' weight must be positive, got {}",
78                self.id, self.weight
79            ));
80        }
81        if let Some(scale_max) = self.scale_max {
82            if scale_max <= 0.0 {
83                return Err(format!(
84                    "Criterion '{}' scale_max must be positive, got {}",
85                    self.id, scale_max
86                ));
87            }
88        }
89        Ok(())
90    }
91}
92
93/// A rubric containing multiple evaluation criteria.
94#[derive(Debug, Clone, Serialize, Deserialize)]
95pub struct Rubric {
96    /// Version identifier for this rubric.
97    pub version: String,
98    /// High-level goal or purpose of the evaluation.
99    #[serde(default)]
100    pub goal_text: Option<String>,
101    /// List of criteria to evaluate.
102    #[serde(default)]
103    pub criteria: Vec<Criterion>,
104    /// How to aggregate criterion scores.
105    #[serde(default = "default_aggregation")]
106    pub aggregation: String,
107    /// Optional metadata.
108    #[serde(default)]
109    pub metadata: std::collections::HashMap<String, serde_json::Value>,
110}
111
112fn default_aggregation() -> String {
113    "weighted_sum".to_string()
114}
115
116impl Rubric {
117    /// Create a new rubric with the given version.
118    pub fn new(version: impl Into<String>) -> Self {
119        Self {
120            version: version.into(),
121            goal_text: None,
122            criteria: Vec::new(),
123            aggregation: "weighted_sum".to_string(),
124            metadata: std::collections::HashMap::new(),
125        }
126    }
127
128    /// Set the goal text for this rubric.
129    pub fn with_goal(mut self, goal: impl Into<String>) -> Self {
130        self.goal_text = Some(goal.into());
131        self
132    }
133
134    /// Add a criterion to this rubric.
135    pub fn with_criterion(mut self, criterion: Criterion) -> Self {
136        self.criteria.push(criterion);
137        self
138    }
139
140    /// Set the aggregation method.
141    pub fn with_aggregation(mut self, aggregation: impl Into<String>) -> Self {
142        self.aggregation = aggregation.into();
143        self
144    }
145
146    /// Validate this rubric's configuration.
147    pub fn validate(&self) -> Result<(), String> {
148        // Check aggregation method
149        const VALID_AGGREGATIONS: &[&str] = &[
150            "sum",
151            "weighted_sum",
152            "mean",
153            "weighted_mean",
154            "custom",
155            "inherit",
156        ];
157        if !VALID_AGGREGATIONS.contains(&self.aggregation.as_str()) {
158            return Err(format!(
159                "Invalid aggregation '{}'. Valid options: {:?}",
160                self.aggregation, VALID_AGGREGATIONS
161            ));
162        }
163
164        // Check for duplicate criterion IDs
165        let mut seen = HashSet::new();
166        for criterion in &self.criteria {
167            if !seen.insert(&criterion.id) {
168                return Err(format!("Duplicate criterion ID: {}", criterion.id));
169            }
170            criterion.validate()?;
171        }
172
173        // Check that at least one criterion exists (unless inheriting)
174        if self.criteria.is_empty() && self.aggregation != "inherit" {
175            return Err("Rubric must have at least one criterion".to_string());
176        }
177
178        Ok(())
179    }
180
181    /// Get total weight of all criteria.
182    pub fn total_weight(&self) -> f64 {
183        self.criteria.iter().map(|c| c.weight).sum()
184    }
185
186    /// Get a criterion by ID.
187    pub fn get_criterion(&self, id: &str) -> Option<&Criterion> {
188        self.criteria.iter().find(|c| c.id == id)
189    }
190}
191
192impl Default for Rubric {
193    fn default() -> Self {
194        Self::new("1.0")
195    }
196}
197
198#[cfg(test)]
199mod tests {
200    use super::*;
201
202    #[test]
203    fn test_criterion_creation() {
204        let criterion = Criterion::new("accuracy", "Response is factually correct")
205            .with_weight(2.0)
206            .required();
207
208        assert_eq!(criterion.id, "accuracy");
209        assert_eq!(criterion.weight, 2.0);
210        assert!(criterion.required);
211        assert!(criterion.validate().is_ok());
212    }
213
214    #[test]
215    fn test_criterion_validation() {
216        let invalid = Criterion::new("test", "desc").with_weight(-1.0);
217        assert!(invalid.validate().is_err());
218    }
219
220    #[test]
221    fn test_rubric_creation() {
222        let rubric = Rubric::new("1.0")
223            .with_goal("Evaluate response quality")
224            .with_criterion(Criterion::new("clarity", "Response is clear"))
225            .with_criterion(Criterion::new("accuracy", "Response is accurate"));
226
227        assert_eq!(rubric.criteria.len(), 2);
228        assert!(rubric.validate().is_ok());
229    }
230
231    #[test]
232    fn test_rubric_duplicate_ids() {
233        let rubric = Rubric::new("1.0")
234            .with_criterion(Criterion::new("test", "First"))
235            .with_criterion(Criterion::new("test", "Duplicate"));
236
237        assert!(rubric.validate().is_err());
238    }
239
240    #[test]
241    fn test_rubric_serde() {
242        let rubric = Rubric::new("1.0").with_criterion(Criterion::new("test", "Test criterion"));
243
244        let json = serde_json::to_string(&rubric).unwrap();
245        let parsed: Rubric = serde_json::from_str(&json).unwrap();
246
247        assert_eq!(parsed.version, rubric.version);
248        assert_eq!(parsed.criteria.len(), 1);
249    }
250}