swink_agent_eval/evaluators/
structured.rs1use std::collections::HashSet;
12
13use jsonschema::Validator;
14
15use crate::evaluator::Evaluator;
16use crate::score::Score;
17use crate::types::{EvalCase, EvalMetricResult, Invocation};
18
19type RubricScorer =
20 dyn Fn(&str, &serde_json::Value, Option<&serde_json::Value>) -> f64 + Send + Sync;
21
22#[derive(Clone)]
24pub enum KeyStrategy {
25 Average,
27 All,
29 None,
31 Rubric {
34 scorer: std::sync::Arc<RubricScorer>,
39 },
40}
41
42impl std::fmt::Debug for KeyStrategy {
43 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44 match self {
45 Self::Average => f.debug_tuple("Average").finish(),
46 Self::All => f.debug_tuple("All").finish(),
47 Self::None => f.debug_tuple("None").finish(),
48 Self::Rubric { .. } => f.debug_struct("Rubric").field("scorer", &"<fn>").finish(),
49 }
50 }
51}
52
53pub struct JsonMatchEvaluator {
59 name: &'static str,
60 expected: serde_json::Value,
61 strategy: KeyStrategy,
62 exclude_keys: HashSet<String>,
63}
64
65impl JsonMatchEvaluator {
66 #[must_use]
69 pub fn new(expected: serde_json::Value) -> Self {
70 Self {
71 name: "json_match",
72 expected,
73 strategy: KeyStrategy::Average,
74 exclude_keys: HashSet::new(),
75 }
76 }
77
78 #[must_use]
80 pub const fn with_name(mut self, name: &'static str) -> Self {
81 self.name = name;
82 self
83 }
84
85 #[must_use]
87 pub fn with_strategy(mut self, strategy: KeyStrategy) -> Self {
88 self.strategy = strategy;
89 self
90 }
91
92 #[must_use]
94 pub fn with_exclude_keys<I, S>(mut self, keys: I) -> Self
95 where
96 I: IntoIterator<Item = S>,
97 S: Into<String>,
98 {
99 self.exclude_keys = keys.into_iter().map(Into::into).collect();
100 self
101 }
102
103 fn compare(&self, actual: &serde_json::Value) -> (f64, String) {
104 let expected_obj = if let Some(obj) = self.expected.as_object() {
105 obj
106 } else {
107 let eq = self.expected == *actual;
109 return (
110 if eq { 1.0 } else { 0.0 },
111 if eq {
112 "match".into()
113 } else {
114 "mismatch".into()
115 },
116 );
117 };
118 let actual_obj = actual.as_object();
119
120 let mut per_key: Vec<(String, f64)> = Vec::new();
121 for (key, expected_value) in expected_obj {
122 if self.exclude_keys.contains(key) {
123 continue;
124 }
125 let actual_value = actual_obj.and_then(|obj| obj.get(key));
126 let score = match &self.strategy {
127 KeyStrategy::Average | KeyStrategy::All | KeyStrategy::None => {
128 if actual_value == Some(expected_value) {
129 1.0
130 } else {
131 0.0
132 }
133 }
134 KeyStrategy::Rubric { scorer } => {
135 scorer(key, expected_value, actual_value).clamp(0.0_f64, 1.0_f64)
136 }
137 };
138 per_key.push((key.clone(), score));
139 }
140
141 if per_key.is_empty() {
142 return (1.0, "no comparable keys".into());
143 }
144
145 let score = match &self.strategy {
146 KeyStrategy::Average | KeyStrategy::Rubric { .. } => {
147 let sum: f64 = per_key.iter().map(|(_, s)| *s).sum();
148 #[allow(clippy::cast_precision_loss)]
149 {
150 sum / per_key.len() as f64
151 }
152 }
153 KeyStrategy::All => {
154 if per_key.iter().all(|(_, s)| *s >= 1.0) {
155 1.0
156 } else {
157 0.0
158 }
159 }
160 KeyStrategy::None => {
161 if per_key.iter().all(|(_, s)| *s <= 0.0) {
162 1.0
163 } else {
164 0.0
165 }
166 }
167 };
168
169 let details = per_key
170 .iter()
171 .map(|(k, s)| format!("{k}={s:.2}"))
172 .collect::<Vec<_>>()
173 .join(", ");
174 (score, details)
175 }
176}
177
178impl Evaluator for JsonMatchEvaluator {
179 fn name(&self) -> &'static str {
180 self.name
181 }
182
183 fn evaluate(&self, _case: &EvalCase, invocation: &Invocation) -> Option<EvalMetricResult> {
184 let raw = invocation.final_response.as_ref()?;
185 let parsed: serde_json::Value = match serde_json::from_str(raw) {
186 Ok(value) => value,
187 Err(err) => {
188 return Some(EvalMetricResult {
189 evaluator_name: self.name.to_string(),
190 score: Score::fail(),
191 details: Some(format!("malformed JSON response: {err}")),
192 });
193 }
194 };
195
196 let (value, details) = self.compare(&parsed);
197 Some(EvalMetricResult {
198 evaluator_name: self.name.to_string(),
199 score: Score::new(value, 0.5),
200 details: Some(details),
201 })
202 }
203}
204
205pub struct JsonSchemaEvaluator {
211 name: &'static str,
212 validator: Validator,
213}
214
215impl JsonSchemaEvaluator {
216 pub fn new(schema: &serde_json::Value) -> Result<Self, String> {
219 let validator = jsonschema::validator_for(schema).map_err(|err| err.to_string())?;
220 Ok(Self {
221 name: "json_schema",
222 validator,
223 })
224 }
225
226 #[must_use]
228 pub const fn with_name(mut self, name: &'static str) -> Self {
229 self.name = name;
230 self
231 }
232}
233
234impl Evaluator for JsonSchemaEvaluator {
235 fn name(&self) -> &'static str {
236 self.name
237 }
238
239 fn evaluate(&self, _case: &EvalCase, invocation: &Invocation) -> Option<EvalMetricResult> {
240 let raw = invocation.final_response.as_ref()?;
241 let parsed: serde_json::Value = match serde_json::from_str(raw) {
242 Ok(value) => value,
243 Err(err) => {
244 return Some(EvalMetricResult {
245 evaluator_name: self.name.to_string(),
246 score: Score::fail(),
247 details: Some(format!("malformed JSON response: {err}")),
248 });
249 }
250 };
251
252 let errors: Vec<String> = self
253 .validator
254 .iter_errors(&parsed)
255 .map(|err| err.to_string())
256 .collect();
257
258 if errors.is_empty() {
259 Some(EvalMetricResult {
260 evaluator_name: self.name.to_string(),
261 score: Score::pass(),
262 details: Some("schema valid".into()),
263 })
264 } else {
265 Some(EvalMetricResult {
266 evaluator_name: self.name.to_string(),
267 score: Score::fail(),
268 details: Some(errors.join("; ")),
269 })
270 }
271 }
272}