synth_ai_core/data/
rubrics.rs1use serde::{Deserialize, Serialize};
6use std::collections::HashSet;
7
8#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct Criterion {
11 pub id: String,
13 pub description: String,
15 #[serde(default = "default_weight")]
17 pub weight: f64,
18 #[serde(default)]
20 pub required: bool,
21 #[serde(default)]
23 pub scale_max: Option<f64>,
24 #[serde(default)]
26 pub examples: Vec<CriterionExample>,
27}
28
29fn default_weight() -> f64 {
30 1.0
31}
32
33#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct CriterionExample {
36 pub content: String,
38 pub expected_score: f64,
40 #[serde(default)]
42 pub explanation: Option<String>,
43}
44
45impl Criterion {
46 pub fn new(id: impl Into<String>, description: impl Into<String>) -> Self {
48 Self {
49 id: id.into(),
50 description: description.into(),
51 weight: 1.0,
52 required: false,
53 scale_max: None,
54 examples: Vec::new(),
55 }
56 }
57
58 pub fn with_weight(mut self, weight: f64) -> Self {
60 self.weight = weight;
61 self
62 }
63
64 pub fn required(mut self) -> Self {
66 self.required = true;
67 self
68 }
69
70 pub fn validate(&self) -> Result<(), String> {
72 if self.id.is_empty() {
73 return Err("Criterion ID cannot be empty".to_string());
74 }
75 if self.weight <= 0.0 {
76 return Err(format!(
77 "Criterion '{}' weight must be positive, got {}",
78 self.id, self.weight
79 ));
80 }
81 if let Some(scale_max) = self.scale_max {
82 if scale_max <= 0.0 {
83 return Err(format!(
84 "Criterion '{}' scale_max must be positive, got {}",
85 self.id, scale_max
86 ));
87 }
88 }
89 Ok(())
90 }
91}
92
93#[derive(Debug, Clone, Serialize, Deserialize)]
95pub struct Rubric {
96 pub version: String,
98 #[serde(default)]
100 pub goal_text: Option<String>,
101 #[serde(default)]
103 pub criteria: Vec<Criterion>,
104 #[serde(default = "default_aggregation")]
106 pub aggregation: String,
107 #[serde(default)]
109 pub metadata: std::collections::HashMap<String, serde_json::Value>,
110}
111
112fn default_aggregation() -> String {
113 "weighted_sum".to_string()
114}
115
116impl Rubric {
117 pub fn new(version: impl Into<String>) -> Self {
119 Self {
120 version: version.into(),
121 goal_text: None,
122 criteria: Vec::new(),
123 aggregation: "weighted_sum".to_string(),
124 metadata: std::collections::HashMap::new(),
125 }
126 }
127
128 pub fn with_goal(mut self, goal: impl Into<String>) -> Self {
130 self.goal_text = Some(goal.into());
131 self
132 }
133
134 pub fn with_criterion(mut self, criterion: Criterion) -> Self {
136 self.criteria.push(criterion);
137 self
138 }
139
140 pub fn with_aggregation(mut self, aggregation: impl Into<String>) -> Self {
142 self.aggregation = aggregation.into();
143 self
144 }
145
146 pub fn validate(&self) -> Result<(), String> {
148 const VALID_AGGREGATIONS: &[&str] = &["sum", "weighted_sum", "mean", "weighted_mean", "custom", "inherit"];
150 if !VALID_AGGREGATIONS.contains(&self.aggregation.as_str()) {
151 return Err(format!(
152 "Invalid aggregation '{}'. Valid options: {:?}",
153 self.aggregation, VALID_AGGREGATIONS
154 ));
155 }
156
157 let mut seen = HashSet::new();
159 for criterion in &self.criteria {
160 if !seen.insert(&criterion.id) {
161 return Err(format!("Duplicate criterion ID: {}", criterion.id));
162 }
163 criterion.validate()?;
164 }
165
166 if self.criteria.is_empty() && self.aggregation != "inherit" {
168 return Err("Rubric must have at least one criterion".to_string());
169 }
170
171 Ok(())
172 }
173
174 pub fn total_weight(&self) -> f64 {
176 self.criteria.iter().map(|c| c.weight).sum()
177 }
178
179 pub fn get_criterion(&self, id: &str) -> Option<&Criterion> {
181 self.criteria.iter().find(|c| c.id == id)
182 }
183}
184
185impl Default for Rubric {
186 fn default() -> Self {
187 Self::new("1.0")
188 }
189}
190
191#[cfg(test)]
192mod tests {
193 use super::*;
194
195 #[test]
196 fn test_criterion_creation() {
197 let criterion = Criterion::new("accuracy", "Response is factually correct")
198 .with_weight(2.0)
199 .required();
200
201 assert_eq!(criterion.id, "accuracy");
202 assert_eq!(criterion.weight, 2.0);
203 assert!(criterion.required);
204 assert!(criterion.validate().is_ok());
205 }
206
207 #[test]
208 fn test_criterion_validation() {
209 let invalid = Criterion::new("test", "desc").with_weight(-1.0);
210 assert!(invalid.validate().is_err());
211 }
212
213 #[test]
214 fn test_rubric_creation() {
215 let rubric = Rubric::new("1.0")
216 .with_goal("Evaluate response quality")
217 .with_criterion(Criterion::new("clarity", "Response is clear"))
218 .with_criterion(Criterion::new("accuracy", "Response is accurate"));
219
220 assert_eq!(rubric.criteria.len(), 2);
221 assert!(rubric.validate().is_ok());
222 }
223
224 #[test]
225 fn test_rubric_duplicate_ids() {
226 let rubric = Rubric::new("1.0")
227 .with_criterion(Criterion::new("test", "First"))
228 .with_criterion(Criterion::new("test", "Duplicate"));
229
230 assert!(rubric.validate().is_err());
231 }
232
233 #[test]
234 fn test_rubric_serde() {
235 let rubric = Rubric::new("1.0")
236 .with_criterion(Criterion::new("test", "Test criterion"));
237
238 let json = serde_json::to_string(&rubric).unwrap();
239 let parsed: Rubric = serde_json::from_str(&json).unwrap();
240
241 assert_eq!(parsed.version, rubric.version);
242 assert_eq!(parsed.criteria.len(), 1);
243 }
244}