1use super::client::AnthropicClient;
4use super::types::{
5 AggregateMetrics, ChunkJudgment, EvalOutput, EvalRunConfig, JudgeCache, JudgeVerdict,
6 QueryResult, RetrievalResultEntry,
7};
8use std::collections::HashMap;
9
10const JUDGE_SYSTEM: &str = "You judge relevance for information retrieval evaluation.
11Given a QUERY and DOCUMENT (video transcript chunk), decide if the document
12is RELEVANT — contains information that helps answer the query, even partially.
13RELEVANT: discusses the specific topic with substantive content.
14NOT RELEVANT: merely mentions a keyword, covers a different topic, or is navigational.
15Respond ONLY with JSON: {\"relevant\": true, \"reasoning\": \"brief explanation\"} or {\"relevant\": false, \"reasoning\": \"brief explanation\"}";
16
17pub struct RelevanceJudge {
19 client: AnthropicClient,
20 model: String,
21 cache: JudgeCache,
22}
23
24impl RelevanceJudge {
25 pub fn new(client: AnthropicClient, model: &str, cache: JudgeCache) -> Self {
27 Self { client, model: model.to_string(), cache }
28 }
29
30 pub async fn judge(&mut self, query: &str, content: &str) -> Result<JudgeVerdict, String> {
32 if let Some(cached) = self.cache.get(query, content) {
34 return Ok(cached.clone());
35 }
36
37 let user_msg = format!("QUERY: {query}\nDOCUMENT:\n---\n{content}\n---");
38
39 let result = self.client.complete(&self.model, Some(JUDGE_SYSTEM), &user_msg, 200).await?;
40
41 let verdict = parse_verdict(&result.text)?;
42
43 self.cache.insert(query, content, verdict.clone(), &self.model);
45
46 Ok(verdict)
47 }
48
49 pub fn cache(&self) -> &JudgeCache {
51 &self.cache
52 }
53
54 pub async fn evaluate(
56 &mut self,
57 results: &[RetrievalResultEntry],
58 top_k: usize,
59 ) -> Result<EvalOutput, String> {
60 let total = results.len();
61 let mut per_query = Vec::new();
62 let mut cache_hits = 0usize;
63 let mut api_calls = 0usize;
64 let _cache_size_before = self.cache.entries.len();
65
66 for (i, entry) in results.iter().enumerate() {
67 eprint!("[{}/{}] {}...", i + 1, total, &entry.query[..entry.query.len().min(60)]);
68
69 let mut judgments = Vec::new();
70 let chunks_to_judge = entry.results.len().min(top_k);
71
72 for (rank, chunk) in entry.results.iter().take(chunks_to_judge).enumerate() {
73 let was_cached = self.cache.get(&entry.query, &chunk.content).is_some();
74
75 let verdict = self.judge(&entry.query, &chunk.content).await?;
76
77 if was_cached {
78 cache_hits += 1;
79 } else {
80 api_calls += 1;
81 }
82
83 judgments.push(ChunkJudgment {
84 rank: rank + 1,
85 score: chunk.score,
86 source: chunk.source.clone(),
87 relevant: verdict.relevant,
88 reasoning: verdict.reasoning,
89 });
90 }
91
92 let relevant_count = judgments.iter().filter(|j| j.relevant).count();
93 let mrr = compute_mrr(&judgments);
94 let hit_5 = judgments.iter().take(5).any(|j| j.relevant);
95
96 let status = if hit_5 { "HIT" } else { "MISS" };
97 eprintln!(" [{status}] rel={relevant_count}/{chunks_to_judge} MRR={mrr:.2}");
98
99 per_query.push(QueryResult {
100 query: entry.query.clone(),
101 domain: entry.domain.clone(),
102 mrr,
103 hit_5,
104 relevant_count,
105 total_results: entry.results.len(),
106 latency_s: entry.latency_s,
107 judgments,
108 });
109 }
110
111 let aggregate = compute_aggregate_metrics(&per_query);
113 let by_domain = compute_by_domain_metrics(&per_query);
114
115 let timestamp = chrono_now();
116
117 eprintln!("\n{}", format_summary(&aggregate, &by_domain));
118 eprintln!(
119 "Cache: {} hits, {} new calls ({} total cached)",
120 cache_hits,
121 api_calls,
122 self.cache.entries.len()
123 );
124
125 Ok(EvalOutput {
126 timestamp,
127 config: EvalRunConfig {
128 num_queries: total,
129 top_k,
130 judge_model: self.model.clone(),
131 cache_hits,
132 api_calls,
133 },
134 aggregate,
135 by_domain,
136 per_query,
137 })
138 }
139}
140
141fn parse_verdict(text: &str) -> Result<JudgeVerdict, String> {
142 let trimmed = text.trim();
144
145 if let Ok(v) = serde_json::from_str::<JudgeVerdict>(trimmed) {
147 return Ok(v);
148 }
149
150 if let Some(start) = trimmed.find('{') {
152 if let Some(end) = trimmed.rfind('}') {
153 let json_str = &trimmed[start..=end];
154 if let Ok(v) = serde_json::from_str::<JudgeVerdict>(json_str) {
155 return Ok(v);
156 }
157 }
158 }
159
160 let lower = trimmed.to_lowercase();
162 if lower.contains("not relevant") || lower.contains("\"relevant\": false") {
163 return Ok(JudgeVerdict { relevant: false, reasoning: trimmed.to_string() });
164 }
165 if lower.contains("relevant") || lower.contains("\"relevant\": true") {
166 return Ok(JudgeVerdict { relevant: true, reasoning: trimmed.to_string() });
167 }
168
169 Err(format!("Could not parse judge response: {trimmed}"))
170}
171
172fn compute_mrr(judgments: &[ChunkJudgment]) -> f64 {
173 for j in judgments {
174 if j.relevant {
175 return 1.0 / j.rank as f64;
176 }
177 }
178 0.0
179}
180
181fn compute_ndcg(judgments: &[ChunkJudgment], k: usize) -> f64 {
182 let dcg: f64 = judgments
183 .iter()
184 .take(k)
185 .filter(|j| j.relevant)
186 .map(|j| 1.0 / (j.rank as f64 + 1.0).log2())
187 .sum();
188
189 let relevant_count = judgments.iter().take(k).filter(|j| j.relevant).count();
190 let idcg: f64 = (0..relevant_count.min(k)).map(|r| 1.0 / (r as f64 + 2.0).log2()).sum();
191
192 if idcg == 0.0 {
193 0.0
194 } else {
195 dcg / idcg
196 }
197}
198
199fn compute_average_precision(judgments: &[ChunkJudgment]) -> f64 {
200 let mut sum = 0.0;
201 let mut rel_count: usize = 0;
202
203 for (i, j) in judgments.iter().enumerate() {
204 if j.relevant {
205 rel_count += 1;
206 sum += rel_count as f64 / (i + 1) as f64;
207 }
208 }
209
210 let total_relevant = judgments.iter().filter(|j| j.relevant).count();
211 if total_relevant == 0 {
212 0.0
213 } else {
214 sum / total_relevant as f64
215 }
216}
217
218pub fn compute_aggregate_metrics(queries: &[QueryResult]) -> AggregateMetrics {
220 if queries.is_empty() {
221 return AggregateMetrics::default();
222 }
223 let n = queries.len() as f64;
224
225 let mrr: f64 = queries.iter().map(|q| q.mrr).sum::<f64>() / n;
226 let hit_5: f64 = queries.iter().filter(|q| q.hit_5).count() as f64 / n;
227
228 let hit_10: f64 =
229 queries.iter().filter(|q| q.judgments.iter().take(10).any(|j| j.relevant)).count() as f64
230 / n;
231
232 let ndcg_5: f64 = queries.iter().map(|q| compute_ndcg(&q.judgments, 5)).sum::<f64>() / n;
233
234 let ndcg_10: f64 = queries.iter().map(|q| compute_ndcg(&q.judgments, 10)).sum::<f64>() / n;
235
236 let recall_5: f64 = queries
237 .iter()
238 .map(|q| {
239 let rel_in_5 = q.judgments.iter().take(5).filter(|j| j.relevant).count();
240 let total_rel = q.judgments.iter().filter(|j| j.relevant).count().max(1);
241 rel_in_5 as f64 / total_rel as f64
242 })
243 .sum::<f64>()
244 / n;
245
246 let precision_5: f64 = queries
247 .iter()
248 .map(|q| {
249 let k = q.judgments.len().min(5);
250 if k == 0 {
251 return 0.0;
252 }
253 q.judgments.iter().take(5).filter(|j| j.relevant).count() as f64 / k as f64
254 })
255 .sum::<f64>()
256 / n;
257
258 let map: f64 = queries.iter().map(|q| compute_average_precision(&q.judgments)).sum::<f64>() / n;
259
260 let mean_latency: f64 = queries.iter().map(|q| q.latency_s).sum::<f64>() / n;
261
262 AggregateMetrics {
263 num_queries: queries.len(),
264 mrr: round4(mrr),
265 ndcg_5: round4(ndcg_5),
266 ndcg_10: round4(ndcg_10),
267 recall_5: round4(recall_5),
268 precision_5: round4(precision_5),
269 hit_rate_5: round4(hit_5),
270 hit_rate_10: round4(hit_10),
271 map: round4(map),
272 mean_latency_s: round4(mean_latency),
273 }
274}
275
276pub fn compute_by_domain_metrics(queries: &[QueryResult]) -> HashMap<String, AggregateMetrics> {
278 let mut by_domain: HashMap<String, Vec<&QueryResult>> = HashMap::new();
279 for q in queries {
280 by_domain.entry(q.domain.clone()).or_default().push(q);
281 }
282
283 by_domain
284 .into_iter()
285 .map(|(domain, qs)| {
286 let owned: Vec<QueryResult> = qs.into_iter().cloned().collect();
287 (domain, compute_aggregate_metrics(&owned))
288 })
289 .collect()
290}
291
292fn format_summary(agg: &AggregateMetrics, by_domain: &HashMap<String, AggregateMetrics>) -> String {
293 use std::fmt::Write;
294 let mut s = String::new();
295 s.push_str(&"=".repeat(60));
296 s.push('\n');
297 s.push_str("AGGREGATE RESULTS\n");
298 s.push_str(&"=".repeat(60));
299 s.push('\n');
300 let _ = writeln!(s, " Queries: {}", agg.num_queries);
301 let _ = writeln!(s, " MRR: {:.4}", agg.mrr);
302 let _ = writeln!(s, " NDCG@5: {:.4}", agg.ndcg_5);
303 let _ = writeln!(s, " NDCG@10: {:.4}", agg.ndcg_10);
304 let _ = writeln!(s, " Recall@5: {:.4}", agg.recall_5);
305 let _ = writeln!(s, " Precision@5: {:.4}", agg.precision_5);
306 let _ = writeln!(s, " Hit Rate@5: {:.4}", agg.hit_rate_5);
307 let _ = writeln!(s, " Hit Rate@10: {:.4}", agg.hit_rate_10);
308 let _ = writeln!(s, " MAP: {:.4}", agg.map);
309 let _ = writeln!(s, " Latency: {:.3}s", agg.mean_latency_s);
310 s.push('\n');
311 s.push_str("BY DOMAIN:\n");
312
313 let mut domains: Vec<_> = by_domain.iter().collect();
314 domains.sort_by(|(a, _), (b, _)| a.cmp(b));
315 for (domain, m) in domains {
316 let _ = writeln!(
317 s,
318 " {domain:12} MRR={:.3} NDCG@5={:.3} Hit@5={:.3} (n={})",
319 m.mrr, m.ndcg_5, m.hit_rate_5, m.num_queries
320 );
321 }
322
323 s
324}
325
326fn round4(v: f64) -> f64 {
327 (v * 10000.0).round() / 10000.0
328}
329
330pub fn chrono_now() -> String {
332 let dur =
334 std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap_or_default();
335 let secs = dur.as_secs();
336 let days = secs / 86400;
338 let remaining = secs % 86400;
339 let hours = remaining / 3600;
340 let minutes = (remaining % 3600) / 60;
341 let seconds = remaining % 60;
342
343 let (year, month, day) = days_to_ymd(days);
345 format!("{year:04}-{month:02}-{day:02}T{hours:02}:{minutes:02}:{seconds:02}Z")
346}
347
348fn days_to_ymd(mut days: u64) -> (u64, u64, u64) {
349 let mut year = 1970;
351 loop {
352 let days_in_year = if is_leap(year) { 366 } else { 365 };
353 if days < days_in_year {
354 break;
355 }
356 days -= days_in_year;
357 year += 1;
358 }
359 let month_days: &[u64] = if is_leap(year) {
360 &[31, 29, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31]
361 } else {
362 &[31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31]
363 };
364 let mut month = 0;
365 for (i, &md) in month_days.iter().enumerate() {
366 if days < md {
367 month = i as u64 + 1;
368 break;
369 }
370 days -= md;
371 }
372 if month == 0 {
373 month = 12;
374 }
375 (year, month, days + 1)
376}
377
378fn is_leap(year: u64) -> bool {
379 (year % 4 == 0 && year % 100 != 0) || year % 400 == 0
380}
381
382pub fn compare_results(baseline: &EvalOutput, candidate: &EvalOutput) -> String {
384 use std::fmt::Write;
385 let b = &baseline.aggregate;
386 let c = &candidate.aggregate;
387
388 let mut s = String::new();
389 s.push_str(&"=".repeat(60));
390 s.push('\n');
391 s.push_str("COMPARISON: baseline \u{2192} candidate\n");
392 s.push_str(&"=".repeat(60));
393 s.push('\n');
394
395 let metrics = [
396 ("MRR", b.mrr, c.mrr),
397 ("NDCG@5", b.ndcg_5, c.ndcg_5),
398 ("NDCG@10", b.ndcg_10, c.ndcg_10),
399 ("Recall@5", b.recall_5, c.recall_5),
400 ("Precision@5", b.precision_5, c.precision_5),
401 ("Hit Rate@5", b.hit_rate_5, c.hit_rate_5),
402 ("Hit Rate@10", b.hit_rate_10, c.hit_rate_10),
403 ("MAP", b.map, c.map),
404 ];
405
406 for (name, base, cand) in metrics {
407 let delta = cand - base;
408 let arrow = if delta > 0.001 {
409 "^"
410 } else if delta < -0.001 {
411 "v"
412 } else {
413 "="
414 };
415 let _ = writeln!(s, " {name:14} {base:.4} \u{2192} {cand:.4} ({delta:+.4}) {arrow}");
416 }
417
418 s
419}
420
421pub fn check_gate(output: &EvalOutput, min_mrr: f64, min_hit5: f64) -> Result<(), String> {
423 let a = &output.aggregate;
424 let mut failures = Vec::new();
425
426 if a.mrr < min_mrr {
427 failures.push(format!("MRR {:.4} < {min_mrr:.4}", a.mrr));
428 }
429 if a.hit_rate_5 < min_hit5 {
430 failures.push(format!("Hit@5 {:.4} < {min_hit5:.4}", a.hit_rate_5));
431 }
432
433 if failures.is_empty() {
434 Ok(())
435 } else {
436 Err(format!("Regression gate FAILED: {}", failures.join(", ")))
437 }
438}
439
440#[cfg(test)]
441mod tests {
442 use super::*;
443
444 #[test]
445 fn test_parse_verdict_json() {
446 let v = parse_verdict(r#"{"relevant": true, "reasoning": "discusses topic"}"#).unwrap();
447 assert!(v.relevant);
448 assert_eq!(v.reasoning, "discusses topic");
449 }
450
451 #[test]
452 fn test_parse_verdict_wrapped() {
453 let v = parse_verdict(
454 r#"Here is my judgment:
455{"relevant": false, "reasoning": "off topic"}
456"#,
457 )
458 .unwrap();
459 assert!(!v.relevant);
460 }
461
462 #[test]
463 fn test_parse_verdict_markdown() {
464 let v = parse_verdict(
465 r#"```json
466{"relevant": true, "reasoning": "discusses AWS Lambda"}
467```"#,
468 )
469 .unwrap();
470 assert!(v.relevant);
471 }
472
473 #[test]
474 fn test_compute_mrr_first() {
475 let judgments = vec![
476 ChunkJudgment {
477 rank: 1,
478 score: 0.9,
479 source: None,
480 relevant: true,
481 reasoning: String::new(),
482 },
483 ChunkJudgment {
484 rank: 2,
485 score: 0.8,
486 source: None,
487 relevant: false,
488 reasoning: String::new(),
489 },
490 ];
491 assert!((compute_mrr(&judgments) - 1.0).abs() < 0.001);
492 }
493
494 #[test]
495 fn test_compute_mrr_third() {
496 let judgments = vec![
497 ChunkJudgment {
498 rank: 1,
499 score: 0.9,
500 source: None,
501 relevant: false,
502 reasoning: String::new(),
503 },
504 ChunkJudgment {
505 rank: 2,
506 score: 0.8,
507 source: None,
508 relevant: false,
509 reasoning: String::new(),
510 },
511 ChunkJudgment {
512 rank: 3,
513 score: 0.7,
514 source: None,
515 relevant: true,
516 reasoning: String::new(),
517 },
518 ];
519 assert!((compute_mrr(&judgments) - 1.0 / 3.0).abs() < 0.001);
520 }
521
522 #[test]
523 fn test_compute_mrr_none() {
524 let judgments = vec![ChunkJudgment {
525 rank: 1,
526 score: 0.9,
527 source: None,
528 relevant: false,
529 reasoning: String::new(),
530 }];
531 assert!((compute_mrr(&judgments)).abs() < 0.001);
532 }
533
534 #[test]
535 fn test_check_gate_pass() {
536 let output = EvalOutput {
537 timestamp: String::new(),
538 config: EvalRunConfig {
539 num_queries: 10,
540 top_k: 10,
541 judge_model: String::new(),
542 cache_hits: 0,
543 api_calls: 10,
544 },
545 aggregate: AggregateMetrics {
546 num_queries: 10,
547 mrr: 0.6,
548 hit_rate_5: 0.8,
549 ..Default::default()
550 },
551 by_domain: HashMap::new(),
552 per_query: Vec::new(),
553 };
554 assert!(check_gate(&output, 0.5, 0.7).is_ok());
555 }
556
557 #[test]
558 fn test_check_gate_fail() {
559 let output = EvalOutput {
560 timestamp: String::new(),
561 config: EvalRunConfig {
562 num_queries: 10,
563 top_k: 10,
564 judge_model: String::new(),
565 cache_hits: 0,
566 api_calls: 10,
567 },
568 aggregate: AggregateMetrics {
569 num_queries: 10,
570 mrr: 0.3,
571 hit_rate_5: 0.4,
572 ..Default::default()
573 },
574 by_domain: HashMap::new(),
575 per_query: Vec::new(),
576 };
577 assert!(check_gate(&output, 0.5, 0.7).is_err());
578 }
579
580 #[test]
581 fn test_days_to_ymd() {
582 let (y, m, d) = days_to_ymd(19723);
584 assert_eq!((y, m, d), (2024, 1, 1));
585 }
586}