synth_ai_core/data/
rubrics.rs1use serde::{Deserialize, Serialize};
6use std::collections::HashSet;
7
8#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct Criterion {
11 pub id: String,
13 pub description: String,
15 #[serde(default = "default_weight")]
17 pub weight: f64,
18 #[serde(default)]
20 pub required: bool,
21 #[serde(default)]
23 pub scale_max: Option<f64>,
24 #[serde(default)]
26 pub examples: Vec<CriterionExample>,
27}
28
29fn default_weight() -> f64 {
30 1.0
31}
32
33#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct CriterionExample {
36 pub content: String,
38 pub expected_score: f64,
40 #[serde(default)]
42 pub explanation: Option<String>,
43}
44
45impl Criterion {
46 pub fn new(id: impl Into<String>, description: impl Into<String>) -> Self {
48 Self {
49 id: id.into(),
50 description: description.into(),
51 weight: 1.0,
52 required: false,
53 scale_max: None,
54 examples: Vec::new(),
55 }
56 }
57
58 pub fn with_weight(mut self, weight: f64) -> Self {
60 self.weight = weight;
61 self
62 }
63
64 pub fn required(mut self) -> Self {
66 self.required = true;
67 self
68 }
69
70 pub fn validate(&self) -> Result<(), String> {
72 if self.id.is_empty() {
73 return Err("Criterion ID cannot be empty".to_string());
74 }
75 if self.weight <= 0.0 {
76 return Err(format!(
77 "Criterion '{}' weight must be positive, got {}",
78 self.id, self.weight
79 ));
80 }
81 if let Some(scale_max) = self.scale_max {
82 if scale_max <= 0.0 {
83 return Err(format!(
84 "Criterion '{}' scale_max must be positive, got {}",
85 self.id, scale_max
86 ));
87 }
88 }
89 Ok(())
90 }
91}
92
93#[derive(Debug, Clone, Serialize, Deserialize)]
95pub struct Rubric {
96 pub version: String,
98 #[serde(default)]
100 pub goal_text: Option<String>,
101 #[serde(default)]
103 pub criteria: Vec<Criterion>,
104 #[serde(default = "default_aggregation")]
106 pub aggregation: String,
107 #[serde(default)]
109 pub metadata: std::collections::HashMap<String, serde_json::Value>,
110}
111
112fn default_aggregation() -> String {
113 "weighted_sum".to_string()
114}
115
116impl Rubric {
117 pub fn new(version: impl Into<String>) -> Self {
119 Self {
120 version: version.into(),
121 goal_text: None,
122 criteria: Vec::new(),
123 aggregation: "weighted_sum".to_string(),
124 metadata: std::collections::HashMap::new(),
125 }
126 }
127
128 pub fn with_goal(mut self, goal: impl Into<String>) -> Self {
130 self.goal_text = Some(goal.into());
131 self
132 }
133
134 pub fn with_criterion(mut self, criterion: Criterion) -> Self {
136 self.criteria.push(criterion);
137 self
138 }
139
140 pub fn with_aggregation(mut self, aggregation: impl Into<String>) -> Self {
142 self.aggregation = aggregation.into();
143 self
144 }
145
146 pub fn validate(&self) -> Result<(), String> {
148 const VALID_AGGREGATIONS: &[&str] = &[
150 "sum",
151 "weighted_sum",
152 "mean",
153 "weighted_mean",
154 "custom",
155 "inherit",
156 ];
157 if !VALID_AGGREGATIONS.contains(&self.aggregation.as_str()) {
158 return Err(format!(
159 "Invalid aggregation '{}'. Valid options: {:?}",
160 self.aggregation, VALID_AGGREGATIONS
161 ));
162 }
163
164 let mut seen = HashSet::new();
166 for criterion in &self.criteria {
167 if !seen.insert(&criterion.id) {
168 return Err(format!("Duplicate criterion ID: {}", criterion.id));
169 }
170 criterion.validate()?;
171 }
172
173 if self.criteria.is_empty() && self.aggregation != "inherit" {
175 return Err("Rubric must have at least one criterion".to_string());
176 }
177
178 Ok(())
179 }
180
181 pub fn total_weight(&self) -> f64 {
183 self.criteria.iter().map(|c| c.weight).sum()
184 }
185
186 pub fn get_criterion(&self, id: &str) -> Option<&Criterion> {
188 self.criteria.iter().find(|c| c.id == id)
189 }
190}
191
192impl Default for Rubric {
193 fn default() -> Self {
194 Self::new("1.0")
195 }
196}
197
198#[cfg(test)]
199mod tests {
200 use super::*;
201
202 #[test]
203 fn test_criterion_creation() {
204 let criterion = Criterion::new("accuracy", "Response is factually correct")
205 .with_weight(2.0)
206 .required();
207
208 assert_eq!(criterion.id, "accuracy");
209 assert_eq!(criterion.weight, 2.0);
210 assert!(criterion.required);
211 assert!(criterion.validate().is_ok());
212 }
213
214 #[test]
215 fn test_criterion_validation() {
216 let invalid = Criterion::new("test", "desc").with_weight(-1.0);
217 assert!(invalid.validate().is_err());
218 }
219
220 #[test]
221 fn test_rubric_creation() {
222 let rubric = Rubric::new("1.0")
223 .with_goal("Evaluate response quality")
224 .with_criterion(Criterion::new("clarity", "Response is clear"))
225 .with_criterion(Criterion::new("accuracy", "Response is accurate"));
226
227 assert_eq!(rubric.criteria.len(), 2);
228 assert!(rubric.validate().is_ok());
229 }
230
231 #[test]
232 fn test_rubric_duplicate_ids() {
233 let rubric = Rubric::new("1.0")
234 .with_criterion(Criterion::new("test", "First"))
235 .with_criterion(Criterion::new("test", "Duplicate"));
236
237 assert!(rubric.validate().is_err());
238 }
239
240 #[test]
241 fn test_rubric_serde() {
242 let rubric = Rubric::new("1.0").with_criterion(Criterion::new("test", "Test criterion"));
243
244 let json = serde_json::to_string(&rubric).unwrap();
245 let parsed: Rubric = serde_json::from_str(&json).unwrap();
246
247 assert_eq!(parsed.version, rubric.version);
248 assert_eq!(parsed.criteria.len(), 1);
249 }
250}