Skip to main content

synth_ai_core/data/
rubrics.rs

1//! Rubric and criterion types for evaluation.
2//!
3//! Rubrics define evaluation criteria with weights and descriptions.
4
5use serde::{Deserialize, Serialize};
6use std::collections::HashSet;
7
8/// A single evaluation criterion.
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct Criterion {
11    /// Unique identifier for this criterion.
12    pub id: String,
13    /// Human-readable description of what this criterion evaluates.
14    pub description: String,
15    /// Weight for aggregation (must be positive).
16    #[serde(default = "default_weight")]
17    pub weight: f64,
18    /// Whether this criterion must be satisfied.
19    #[serde(default)]
20    pub required: bool,
21    /// Optional scoring scale (e.g., 0-10, 0-100).
22    #[serde(default)]
23    pub scale_max: Option<f64>,
24    /// Optional examples of good/bad responses.
25    #[serde(default)]
26    pub examples: Vec<CriterionExample>,
27}
28
29fn default_weight() -> f64 {
30    1.0
31}
32
33/// Example for a criterion showing expected scores.
34#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct CriterionExample {
36    /// The input/output being evaluated.
37    pub content: String,
38    /// Expected score for this example.
39    pub expected_score: f64,
40    /// Explanation of why this score is appropriate.
41    #[serde(default)]
42    pub explanation: Option<String>,
43}
44
45impl Criterion {
46    /// Create a new criterion with the given ID and description.
47    pub fn new(id: impl Into<String>, description: impl Into<String>) -> Self {
48        Self {
49            id: id.into(),
50            description: description.into(),
51            weight: 1.0,
52            required: false,
53            scale_max: None,
54            examples: Vec::new(),
55        }
56    }
57
58    /// Set the weight for this criterion.
59    pub fn with_weight(mut self, weight: f64) -> Self {
60        self.weight = weight;
61        self
62    }
63
64    /// Mark this criterion as required.
65    pub fn required(mut self) -> Self {
66        self.required = true;
67        self
68    }
69
70    /// Validate this criterion's configuration.
71    pub fn validate(&self) -> Result<(), String> {
72        if self.id.is_empty() {
73            return Err("Criterion ID cannot be empty".to_string());
74        }
75        if self.weight <= 0.0 {
76            return Err(format!(
77                "Criterion '{}' weight must be positive, got {}",
78                self.id, self.weight
79            ));
80        }
81        if let Some(scale_max) = self.scale_max {
82            if scale_max <= 0.0 {
83                return Err(format!(
84                    "Criterion '{}' scale_max must be positive, got {}",
85                    self.id, scale_max
86                ));
87            }
88        }
89        Ok(())
90    }
91}
92
93/// A rubric containing multiple evaluation criteria.
94#[derive(Debug, Clone, Serialize, Deserialize)]
95pub struct Rubric {
96    /// Version identifier for this rubric.
97    pub version: String,
98    /// High-level goal or purpose of the evaluation.
99    #[serde(default)]
100    pub goal_text: Option<String>,
101    /// List of criteria to evaluate.
102    #[serde(default)]
103    pub criteria: Vec<Criterion>,
104    /// How to aggregate criterion scores.
105    #[serde(default = "default_aggregation")]
106    pub aggregation: String,
107    /// Optional metadata.
108    #[serde(default)]
109    pub metadata: std::collections::HashMap<String, serde_json::Value>,
110}
111
112fn default_aggregation() -> String {
113    "weighted_sum".to_string()
114}
115
116impl Rubric {
117    /// Create a new rubric with the given version.
118    pub fn new(version: impl Into<String>) -> Self {
119        Self {
120            version: version.into(),
121            goal_text: None,
122            criteria: Vec::new(),
123            aggregation: "weighted_sum".to_string(),
124            metadata: std::collections::HashMap::new(),
125        }
126    }
127
128    /// Set the goal text for this rubric.
129    pub fn with_goal(mut self, goal: impl Into<String>) -> Self {
130        self.goal_text = Some(goal.into());
131        self
132    }
133
134    /// Add a criterion to this rubric.
135    pub fn with_criterion(mut self, criterion: Criterion) -> Self {
136        self.criteria.push(criterion);
137        self
138    }
139
140    /// Set the aggregation method.
141    pub fn with_aggregation(mut self, aggregation: impl Into<String>) -> Self {
142        self.aggregation = aggregation.into();
143        self
144    }
145
146    /// Validate this rubric's configuration.
147    pub fn validate(&self) -> Result<(), String> {
148        // Check aggregation method
149        const VALID_AGGREGATIONS: &[&str] = &["sum", "weighted_sum", "mean", "weighted_mean", "custom", "inherit"];
150        if !VALID_AGGREGATIONS.contains(&self.aggregation.as_str()) {
151            return Err(format!(
152                "Invalid aggregation '{}'. Valid options: {:?}",
153                self.aggregation, VALID_AGGREGATIONS
154            ));
155        }
156
157        // Check for duplicate criterion IDs
158        let mut seen = HashSet::new();
159        for criterion in &self.criteria {
160            if !seen.insert(&criterion.id) {
161                return Err(format!("Duplicate criterion ID: {}", criterion.id));
162            }
163            criterion.validate()?;
164        }
165
166        // Check that at least one criterion exists (unless inheriting)
167        if self.criteria.is_empty() && self.aggregation != "inherit" {
168            return Err("Rubric must have at least one criterion".to_string());
169        }
170
171        Ok(())
172    }
173
174    /// Get total weight of all criteria.
175    pub fn total_weight(&self) -> f64 {
176        self.criteria.iter().map(|c| c.weight).sum()
177    }
178
179    /// Get a criterion by ID.
180    pub fn get_criterion(&self, id: &str) -> Option<&Criterion> {
181        self.criteria.iter().find(|c| c.id == id)
182    }
183}
184
185impl Default for Rubric {
186    fn default() -> Self {
187        Self::new("1.0")
188    }
189}
190
191#[cfg(test)]
192mod tests {
193    use super::*;
194
195    #[test]
196    fn test_criterion_creation() {
197        let criterion = Criterion::new("accuracy", "Response is factually correct")
198            .with_weight(2.0)
199            .required();
200
201        assert_eq!(criterion.id, "accuracy");
202        assert_eq!(criterion.weight, 2.0);
203        assert!(criterion.required);
204        assert!(criterion.validate().is_ok());
205    }
206
207    #[test]
208    fn test_criterion_validation() {
209        let invalid = Criterion::new("test", "desc").with_weight(-1.0);
210        assert!(invalid.validate().is_err());
211    }
212
213    #[test]
214    fn test_rubric_creation() {
215        let rubric = Rubric::new("1.0")
216            .with_goal("Evaluate response quality")
217            .with_criterion(Criterion::new("clarity", "Response is clear"))
218            .with_criterion(Criterion::new("accuracy", "Response is accurate"));
219
220        assert_eq!(rubric.criteria.len(), 2);
221        assert!(rubric.validate().is_ok());
222    }
223
224    #[test]
225    fn test_rubric_duplicate_ids() {
226        let rubric = Rubric::new("1.0")
227            .with_criterion(Criterion::new("test", "First"))
228            .with_criterion(Criterion::new("test", "Duplicate"));
229
230        assert!(rubric.validate().is_err());
231    }
232
233    #[test]
234    fn test_rubric_serde() {
235        let rubric = Rubric::new("1.0")
236            .with_criterion(Criterion::new("test", "Test criterion"));
237
238        let json = serde_json::to_string(&rubric).unwrap();
239        let parsed: Rubric = serde_json::from_str(&json).unwrap();
240
241        assert_eq!(parsed.version, rubric.version);
242        assert_eq!(parsed.criteria.len(), 1);
243    }
244}