1use schemars::JsonSchema;
6use serde::{Deserialize, Serialize};
7
8use crate::error::{Error, Result};
9
10#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
12#[serde(rename_all = "lowercase")]
13pub enum Comparator {
14 #[serde(alias = ">=")]
16 #[default]
17 Gte,
18 #[serde(alias = ">")]
20 Gt,
21 #[serde(alias = "<=")]
23 Lte,
24 #[serde(alias = "<")]
26 Lt,
27}
28
29impl Comparator {
30 fn satisfied(self, value: f64, threshold: f64) -> bool {
31 match self {
32 Comparator::Gte => value >= threshold,
33 Comparator::Gt => value > threshold,
34 Comparator::Lte => value <= threshold,
35 Comparator::Lt => value < threshold,
36 }
37 }
38
39 fn symbol(self) -> &'static str {
40 match self {
41 Comparator::Gte => ">=",
42 Comparator::Gt => ">",
43 Comparator::Lte => "<=",
44 Comparator::Lt => "<",
45 }
46 }
47}
48
49fn default_true() -> bool {
51 true
52}
53
54#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
56#[serde(tag = "type", rename_all = "lowercase")]
57pub enum Eval {
58 Boolean {
61 criterion: String,
63 #[serde(default = "default_true")]
65 expected: bool,
66 #[serde(default)]
68 name: Option<String>,
69 },
70 Numeric {
73 criterion: String,
75 min: f64,
77 max: f64,
79 threshold: f64,
81 #[serde(default)]
83 comparator: Comparator,
84 #[serde(default)]
86 name: Option<String>,
87 },
88}
89
90impl Eval {
91 #[must_use]
93 pub fn criterion(&self) -> &str {
94 match self {
95 Eval::Boolean { criterion, .. } | Eval::Numeric { criterion, .. } => criterion,
96 }
97 }
98
99 #[must_use]
102 pub fn label(&self) -> &str {
103 match self {
104 Eval::Boolean {
105 name, criterion, ..
106 }
107 | Eval::Numeric {
108 name, criterion, ..
109 } => name.as_deref().unwrap_or(criterion),
110 }
111 }
112
113 pub fn validate(&self) -> Result<()> {
119 if self.criterion().trim().is_empty() {
120 return Err(Error::Invalid("an eval has an empty `criterion`".into()));
121 }
122 if let Eval::Numeric {
123 min,
124 max,
125 threshold,
126 ..
127 } = self
128 {
129 if min >= max {
130 return Err(Error::Invalid(format!(
131 "numeric eval scale is degenerate: min ({min}) must be < max ({max})"
132 )));
133 }
134 if threshold < min || threshold > max {
135 return Err(Error::Invalid(format!(
136 "numeric eval threshold ({threshold}) is outside the scale [{min}, {max}]"
137 )));
138 }
139 }
140 Ok(())
141 }
142
143 pub fn outcome(&self, raw: &JudgeValue, reason: String) -> Result<EvalOutcome> {
152 match (self, raw) {
153 (Eval::Boolean { expected, .. }, JudgeValue::Bool(value)) => Ok(EvalOutcome {
154 label: self.label().to_string(),
155 passed: value == expected,
156 detail: EvalDetail::Boolean {
157 value: *value,
158 expected: *expected,
159 },
160 reason,
161 }),
162 (
163 Eval::Numeric {
164 min,
165 max,
166 threshold,
167 comparator,
168 ..
169 },
170 JudgeValue::Number(value),
171 ) => {
172 let clamped = value.clamp(*min, *max);
173 Ok(EvalOutcome {
174 label: self.label().to_string(),
175 passed: comparator.satisfied(clamped, *threshold),
176 detail: EvalDetail::Numeric {
177 value: clamped,
178 threshold: *threshold,
179 comparator: *comparator,
180 },
181 reason,
182 })
183 }
184 (Eval::Boolean { .. }, JudgeValue::Number(_)) => Err(Error::provider(
185 "judge",
186 "boolean eval received a numeric verdict",
187 )),
188 (Eval::Numeric { .. }, JudgeValue::Bool(_)) => Err(Error::provider(
189 "judge",
190 "numeric eval received a boolean verdict",
191 )),
192 }
193 }
194}
195
196#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
199#[serde(untagged)]
200pub enum JudgeValue {
201 Bool(bool),
202 Number(f64),
203}
204
205#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, JsonSchema)]
210#[serde(tag = "kind", rename_all = "lowercase")]
211pub enum EvalDetail {
212 #[schemars(title = "BooleanDetail")]
213 Boolean { value: bool, expected: bool },
214 #[schemars(title = "NumericDetail")]
215 Numeric {
216 value: f64,
217 threshold: f64,
218 comparator: Comparator,
219 },
220}
221
222impl EvalDetail {
223 #[must_use]
226 pub fn summary(&self) -> String {
227 match self {
228 EvalDetail::Boolean { value, expected } => {
229 format!("{value} (expected {expected})")
230 }
231 EvalDetail::Numeric {
232 value,
233 threshold,
234 comparator,
235 } => format!("{value} {} {threshold}", comparator.symbol()),
236 }
237 }
238}
239
240#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
242pub struct EvalOutcome {
243 pub label: String,
245 pub passed: bool,
247 pub detail: EvalDetail,
249 pub reason: String,
251}
252
253#[cfg(test)]
254mod tests {
255 use super::*;
256
257 #[test]
258 fn numeric_threshold_gte_passes_at_boundary() {
259 let eval = Eval::Numeric {
260 criterion: "polite".into(),
261 min: 0.0,
262 max: 10.0,
263 threshold: 7.0,
264 comparator: Comparator::Gte,
265 name: None,
266 };
267 let outcome = eval.outcome(&JudgeValue::Number(7.0), "ok".into()).unwrap();
268 assert!(outcome.passed);
269 }
270
271 #[test]
272 fn numeric_value_is_clamped_to_scale() {
273 let eval = Eval::Numeric {
274 criterion: "x".into(),
275 min: 0.0,
276 max: 10.0,
277 threshold: 9.0,
278 comparator: Comparator::Gte,
279 name: None,
280 };
281 let outcome = eval
283 .outcome(&JudgeValue::Number(12.0), String::new())
284 .unwrap();
285 assert!(outcome.passed);
286 assert!(matches!(
287 outcome.detail,
288 EvalDetail::Numeric { value, .. } if (value - 10.0).abs() < f64::EPSILON
289 ));
290 }
291
292 #[test]
293 fn boolean_expected_false_inverts() {
294 let eval = Eval::Boolean {
295 criterion: "leaks a secret".into(),
296 expected: false,
297 name: None,
298 };
299 let pass = eval
300 .outcome(&JudgeValue::Bool(false), String::new())
301 .unwrap();
302 assert!(pass.passed);
303 let fail = eval
304 .outcome(&JudgeValue::Bool(true), String::new())
305 .unwrap();
306 assert!(!fail.passed);
307 }
308
309 #[test]
310 fn kind_mismatch_is_provider_error() {
311 let eval = Eval::Boolean {
312 criterion: "x".into(),
313 expected: true,
314 name: None,
315 };
316 assert!(eval
317 .outcome(&JudgeValue::Number(1.0), String::new())
318 .is_err());
319 }
320
321 #[test]
322 fn degenerate_numeric_scale_is_invalid() {
323 let eval = Eval::Numeric {
324 criterion: "x".into(),
325 min: 5.0,
326 max: 5.0,
327 threshold: 5.0,
328 comparator: Comparator::Gte,
329 name: None,
330 };
331 assert!(eval.validate().is_err());
332 }
333
334 #[test]
335 fn comparator_parses_from_symbol() {
336 let c: Comparator = serde_yaml::from_str("\">=\"").unwrap();
337 assert_eq!(c, Comparator::Gte);
338 let c: Comparator = serde_yaml::from_str("lt").unwrap();
339 assert_eq!(c, Comparator::Lt);
340 }
341}