Skip to main content

trueno_rag/eval/
judge.rs

1//! LLM-as-judge for content-based relevance scoring
2
3use super::client::AnthropicClient;
4use super::types::{
5    AggregateMetrics, ChunkJudgment, EvalOutput, EvalRunConfig, JudgeCache, JudgeVerdict,
6    QueryResult, RetrievalResultEntry,
7};
8use std::collections::HashMap;
9
10const JUDGE_SYSTEM: &str = "You judge relevance for information retrieval evaluation.
11Given a QUERY and DOCUMENT (video transcript chunk), decide if the document
12is RELEVANT — contains information that helps answer the query, even partially.
13RELEVANT: discusses the specific topic with substantive content.
14NOT RELEVANT: merely mentions a keyword, covers a different topic, or is navigational.
15Respond ONLY with JSON: {\"relevant\": true, \"reasoning\": \"brief explanation\"} or {\"relevant\": false, \"reasoning\": \"brief explanation\"}";
16
17/// LLM-based relevance judge
18pub struct RelevanceJudge {
19    client: AnthropicClient,
20    model: String,
21    cache: JudgeCache,
22}
23
24impl RelevanceJudge {
25    /// Create a new judge
26    pub fn new(client: AnthropicClient, model: &str, cache: JudgeCache) -> Self {
27        Self { client, model: model.to_string(), cache }
28    }
29
30    /// Judge whether a chunk is relevant to a query
31    pub async fn judge(&mut self, query: &str, content: &str) -> Result<JudgeVerdict, String> {
32        // Check cache first
33        if let Some(cached) = self.cache.get(query, content) {
34            return Ok(cached.clone());
35        }
36
37        let user_msg = format!("QUERY: {query}\nDOCUMENT:\n---\n{content}\n---");
38
39        let result = self.client.complete(&self.model, Some(JUDGE_SYSTEM), &user_msg, 200).await?;
40
41        let verdict = parse_verdict(&result.text)?;
42
43        // Cache the result
44        self.cache.insert(query, content, verdict.clone(), &self.model);
45
46        Ok(verdict)
47    }
48
49    /// Get the current cache (for saving)
50    pub fn cache(&self) -> &JudgeCache {
51        &self.cache
52    }
53
54    /// Run full evaluation: judge all retrieval results and compute metrics
55    pub async fn evaluate(
56        &mut self,
57        results: &[RetrievalResultEntry],
58        top_k: usize,
59    ) -> Result<EvalOutput, String> {
60        let total = results.len();
61        let mut per_query = Vec::new();
62        let mut cache_hits = 0usize;
63        let mut api_calls = 0usize;
64        let _cache_size_before = self.cache.entries.len();
65
66        for (i, entry) in results.iter().enumerate() {
67            eprint!("[{}/{}] {}...", i + 1, total, &entry.query[..entry.query.len().min(60)]);
68
69            let mut judgments = Vec::new();
70            let chunks_to_judge = entry.results.len().min(top_k);
71
72            for (rank, chunk) in entry.results.iter().take(chunks_to_judge).enumerate() {
73                let was_cached = self.cache.get(&entry.query, &chunk.content).is_some();
74
75                let verdict = self.judge(&entry.query, &chunk.content).await?;
76
77                if was_cached {
78                    cache_hits += 1;
79                } else {
80                    api_calls += 1;
81                }
82
83                judgments.push(ChunkJudgment {
84                    rank: rank + 1,
85                    score: chunk.score,
86                    source: chunk.source.clone(),
87                    relevant: verdict.relevant,
88                    reasoning: verdict.reasoning,
89                });
90            }
91
92            let relevant_count = judgments.iter().filter(|j| j.relevant).count();
93            let mrr = compute_mrr(&judgments);
94            let hit_5 = judgments.iter().take(5).any(|j| j.relevant);
95
96            let status = if hit_5 { "HIT" } else { "MISS" };
97            eprintln!(" [{status}] rel={relevant_count}/{chunks_to_judge} MRR={mrr:.2}");
98
99            per_query.push(QueryResult {
100                query: entry.query.clone(),
101                domain: entry.domain.clone(),
102                mrr,
103                hit_5,
104                relevant_count,
105                total_results: entry.results.len(),
106                latency_s: entry.latency_s,
107                judgments,
108            });
109        }
110
111        // Compute aggregates
112        let aggregate = compute_aggregate_metrics(&per_query);
113        let by_domain = compute_by_domain_metrics(&per_query);
114
115        let timestamp = chrono_now();
116
117        eprintln!("\n{}", format_summary(&aggregate, &by_domain));
118        eprintln!(
119            "Cache: {} hits, {} new calls ({} total cached)",
120            cache_hits,
121            api_calls,
122            self.cache.entries.len()
123        );
124
125        Ok(EvalOutput {
126            timestamp,
127            config: EvalRunConfig {
128                num_queries: total,
129                top_k,
130                judge_model: self.model.clone(),
131                cache_hits,
132                api_calls,
133            },
134            aggregate,
135            by_domain,
136            per_query,
137        })
138    }
139}
140
141fn parse_verdict(text: &str) -> Result<JudgeVerdict, String> {
142    // Try to extract JSON from the response
143    let trimmed = text.trim();
144
145    // Try direct parse first
146    if let Ok(v) = serde_json::from_str::<JudgeVerdict>(trimmed) {
147        return Ok(v);
148    }
149
150    // Try to find JSON in the response (model sometimes wraps in markdown)
151    if let Some(start) = trimmed.find('{') {
152        if let Some(end) = trimmed.rfind('}') {
153            let json_str = &trimmed[start..=end];
154            if let Ok(v) = serde_json::from_str::<JudgeVerdict>(json_str) {
155                return Ok(v);
156            }
157        }
158    }
159
160    // Fallback: check for keywords
161    let lower = trimmed.to_lowercase();
162    if lower.contains("not relevant") || lower.contains("\"relevant\": false") {
163        return Ok(JudgeVerdict { relevant: false, reasoning: trimmed.to_string() });
164    }
165    if lower.contains("relevant") || lower.contains("\"relevant\": true") {
166        return Ok(JudgeVerdict { relevant: true, reasoning: trimmed.to_string() });
167    }
168
169    Err(format!("Could not parse judge response: {trimmed}"))
170}
171
172fn compute_mrr(judgments: &[ChunkJudgment]) -> f64 {
173    for j in judgments {
174        if j.relevant {
175            return 1.0 / j.rank as f64;
176        }
177    }
178    0.0
179}
180
181fn compute_ndcg(judgments: &[ChunkJudgment], k: usize) -> f64 {
182    let dcg: f64 = judgments
183        .iter()
184        .take(k)
185        .filter(|j| j.relevant)
186        .map(|j| 1.0 / (j.rank as f64 + 1.0).log2())
187        .sum();
188
189    let relevant_count = judgments.iter().take(k).filter(|j| j.relevant).count();
190    let idcg: f64 = (0..relevant_count.min(k)).map(|r| 1.0 / (r as f64 + 2.0).log2()).sum();
191
192    if idcg == 0.0 {
193        0.0
194    } else {
195        dcg / idcg
196    }
197}
198
199fn compute_average_precision(judgments: &[ChunkJudgment]) -> f64 {
200    let mut sum = 0.0;
201    let mut rel_count: usize = 0;
202
203    for (i, j) in judgments.iter().enumerate() {
204        if j.relevant {
205            rel_count += 1;
206            sum += rel_count as f64 / (i + 1) as f64;
207        }
208    }
209
210    let total_relevant = judgments.iter().filter(|j| j.relevant).count();
211    if total_relevant == 0 {
212        0.0
213    } else {
214        sum / total_relevant as f64
215    }
216}
217
218/// Compute aggregate metrics across all queries (public for metrics module)
219pub fn compute_aggregate_metrics(queries: &[QueryResult]) -> AggregateMetrics {
220    if queries.is_empty() {
221        return AggregateMetrics::default();
222    }
223    let n = queries.len() as f64;
224
225    let mrr: f64 = queries.iter().map(|q| q.mrr).sum::<f64>() / n;
226    let hit_5: f64 = queries.iter().filter(|q| q.hit_5).count() as f64 / n;
227
228    let hit_10: f64 =
229        queries.iter().filter(|q| q.judgments.iter().take(10).any(|j| j.relevant)).count() as f64
230            / n;
231
232    let ndcg_5: f64 = queries.iter().map(|q| compute_ndcg(&q.judgments, 5)).sum::<f64>() / n;
233
234    let ndcg_10: f64 = queries.iter().map(|q| compute_ndcg(&q.judgments, 10)).sum::<f64>() / n;
235
236    let recall_5: f64 = queries
237        .iter()
238        .map(|q| {
239            let rel_in_5 = q.judgments.iter().take(5).filter(|j| j.relevant).count();
240            let total_rel = q.judgments.iter().filter(|j| j.relevant).count().max(1);
241            rel_in_5 as f64 / total_rel as f64
242        })
243        .sum::<f64>()
244        / n;
245
246    let precision_5: f64 = queries
247        .iter()
248        .map(|q| {
249            let k = q.judgments.len().min(5);
250            if k == 0 {
251                return 0.0;
252            }
253            q.judgments.iter().take(5).filter(|j| j.relevant).count() as f64 / k as f64
254        })
255        .sum::<f64>()
256        / n;
257
258    let map: f64 = queries.iter().map(|q| compute_average_precision(&q.judgments)).sum::<f64>() / n;
259
260    let mean_latency: f64 = queries.iter().map(|q| q.latency_s).sum::<f64>() / n;
261
262    AggregateMetrics {
263        num_queries: queries.len(),
264        mrr: round4(mrr),
265        ndcg_5: round4(ndcg_5),
266        ndcg_10: round4(ndcg_10),
267        recall_5: round4(recall_5),
268        precision_5: round4(precision_5),
269        hit_rate_5: round4(hit_5),
270        hit_rate_10: round4(hit_10),
271        map: round4(map),
272        mean_latency_s: round4(mean_latency),
273    }
274}
275
276/// Compute per-domain metrics (public for metrics module)
277pub fn compute_by_domain_metrics(queries: &[QueryResult]) -> HashMap<String, AggregateMetrics> {
278    let mut by_domain: HashMap<String, Vec<&QueryResult>> = HashMap::new();
279    for q in queries {
280        by_domain.entry(q.domain.clone()).or_default().push(q);
281    }
282
283    by_domain
284        .into_iter()
285        .map(|(domain, qs)| {
286            let owned: Vec<QueryResult> = qs.into_iter().cloned().collect();
287            (domain, compute_aggregate_metrics(&owned))
288        })
289        .collect()
290}
291
292fn format_summary(agg: &AggregateMetrics, by_domain: &HashMap<String, AggregateMetrics>) -> String {
293    use std::fmt::Write;
294    let mut s = String::new();
295    s.push_str(&"=".repeat(60));
296    s.push('\n');
297    s.push_str("AGGREGATE RESULTS\n");
298    s.push_str(&"=".repeat(60));
299    s.push('\n');
300    let _ = writeln!(s, "  Queries:       {}", agg.num_queries);
301    let _ = writeln!(s, "  MRR:           {:.4}", agg.mrr);
302    let _ = writeln!(s, "  NDCG@5:        {:.4}", agg.ndcg_5);
303    let _ = writeln!(s, "  NDCG@10:       {:.4}", agg.ndcg_10);
304    let _ = writeln!(s, "  Recall@5:      {:.4}", agg.recall_5);
305    let _ = writeln!(s, "  Precision@5:   {:.4}", agg.precision_5);
306    let _ = writeln!(s, "  Hit Rate@5:    {:.4}", agg.hit_rate_5);
307    let _ = writeln!(s, "  Hit Rate@10:   {:.4}", agg.hit_rate_10);
308    let _ = writeln!(s, "  MAP:           {:.4}", agg.map);
309    let _ = writeln!(s, "  Latency:       {:.3}s", agg.mean_latency_s);
310    s.push('\n');
311    s.push_str("BY DOMAIN:\n");
312
313    let mut domains: Vec<_> = by_domain.iter().collect();
314    domains.sort_by(|(a, _), (b, _)| a.cmp(b));
315    for (domain, m) in domains {
316        let _ = writeln!(
317            s,
318            "  {domain:12}  MRR={:.3}  NDCG@5={:.3}  Hit@5={:.3}  (n={})",
319            m.mrr, m.ndcg_5, m.hit_rate_5, m.num_queries
320        );
321    }
322
323    s
324}
325
326fn round4(v: f64) -> f64 {
327    (v * 10000.0).round() / 10000.0
328}
329
330/// Simple UTC timestamp without chrono crate (public for metrics module)
331pub fn chrono_now() -> String {
332    // Simple UTC timestamp without chrono crate
333    let dur =
334        std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap_or_default();
335    let secs = dur.as_secs();
336    // Basic ISO 8601 from epoch
337    let days = secs / 86400;
338    let remaining = secs % 86400;
339    let hours = remaining / 3600;
340    let minutes = (remaining % 3600) / 60;
341    let seconds = remaining % 60;
342
343    // Days since 1970-01-01
344    let (year, month, day) = days_to_ymd(days);
345    format!("{year:04}-{month:02}-{day:02}T{hours:02}:{minutes:02}:{seconds:02}Z")
346}
347
348fn days_to_ymd(mut days: u64) -> (u64, u64, u64) {
349    // Simple Gregorian calendar conversion
350    let mut year = 1970;
351    loop {
352        let days_in_year = if is_leap(year) { 366 } else { 365 };
353        if days < days_in_year {
354            break;
355        }
356        days -= days_in_year;
357        year += 1;
358    }
359    let month_days: &[u64] = if is_leap(year) {
360        &[31, 29, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31]
361    } else {
362        &[31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31]
363    };
364    let mut month = 0;
365    for (i, &md) in month_days.iter().enumerate() {
366        if days < md {
367            month = i as u64 + 1;
368            break;
369        }
370        days -= md;
371    }
372    if month == 0 {
373        month = 12;
374    }
375    (year, month, days + 1)
376}
377
378fn is_leap(year: u64) -> bool {
379    (year % 4 == 0 && year % 100 != 0) || year % 400 == 0
380}
381
382/// Compare two eval outputs and print deltas
383pub fn compare_results(baseline: &EvalOutput, candidate: &EvalOutput) -> String {
384    use std::fmt::Write;
385    let b = &baseline.aggregate;
386    let c = &candidate.aggregate;
387
388    let mut s = String::new();
389    s.push_str(&"=".repeat(60));
390    s.push('\n');
391    s.push_str("COMPARISON: baseline \u{2192} candidate\n");
392    s.push_str(&"=".repeat(60));
393    s.push('\n');
394
395    let metrics = [
396        ("MRR", b.mrr, c.mrr),
397        ("NDCG@5", b.ndcg_5, c.ndcg_5),
398        ("NDCG@10", b.ndcg_10, c.ndcg_10),
399        ("Recall@5", b.recall_5, c.recall_5),
400        ("Precision@5", b.precision_5, c.precision_5),
401        ("Hit Rate@5", b.hit_rate_5, c.hit_rate_5),
402        ("Hit Rate@10", b.hit_rate_10, c.hit_rate_10),
403        ("MAP", b.map, c.map),
404    ];
405
406    for (name, base, cand) in metrics {
407        let delta = cand - base;
408        let arrow = if delta > 0.001 {
409            "^"
410        } else if delta < -0.001 {
411            "v"
412        } else {
413            "="
414        };
415        let _ = writeln!(s, "  {name:14}  {base:.4} \u{2192} {cand:.4}  ({delta:+.4}) {arrow}");
416    }
417
418    s
419}
420
421/// Check if results meet minimum thresholds (regression gate)
422pub fn check_gate(output: &EvalOutput, min_mrr: f64, min_hit5: f64) -> Result<(), String> {
423    let a = &output.aggregate;
424    let mut failures = Vec::new();
425
426    if a.mrr < min_mrr {
427        failures.push(format!("MRR {:.4} < {min_mrr:.4}", a.mrr));
428    }
429    if a.hit_rate_5 < min_hit5 {
430        failures.push(format!("Hit@5 {:.4} < {min_hit5:.4}", a.hit_rate_5));
431    }
432
433    if failures.is_empty() {
434        Ok(())
435    } else {
436        Err(format!("Regression gate FAILED: {}", failures.join(", ")))
437    }
438}
439
440#[cfg(test)]
441mod tests {
442    use super::*;
443
444    #[test]
445    fn test_parse_verdict_json() {
446        let v = parse_verdict(r#"{"relevant": true, "reasoning": "discusses topic"}"#).unwrap();
447        assert!(v.relevant);
448        assert_eq!(v.reasoning, "discusses topic");
449    }
450
451    #[test]
452    fn test_parse_verdict_wrapped() {
453        let v = parse_verdict(
454            r#"Here is my judgment:
455{"relevant": false, "reasoning": "off topic"}
456"#,
457        )
458        .unwrap();
459        assert!(!v.relevant);
460    }
461
462    #[test]
463    fn test_parse_verdict_markdown() {
464        let v = parse_verdict(
465            r#"```json
466{"relevant": true, "reasoning": "discusses AWS Lambda"}
467```"#,
468        )
469        .unwrap();
470        assert!(v.relevant);
471    }
472
473    #[test]
474    fn test_compute_mrr_first() {
475        let judgments = vec![
476            ChunkJudgment {
477                rank: 1,
478                score: 0.9,
479                source: None,
480                relevant: true,
481                reasoning: String::new(),
482            },
483            ChunkJudgment {
484                rank: 2,
485                score: 0.8,
486                source: None,
487                relevant: false,
488                reasoning: String::new(),
489            },
490        ];
491        assert!((compute_mrr(&judgments) - 1.0).abs() < 0.001);
492    }
493
494    #[test]
495    fn test_compute_mrr_third() {
496        let judgments = vec![
497            ChunkJudgment {
498                rank: 1,
499                score: 0.9,
500                source: None,
501                relevant: false,
502                reasoning: String::new(),
503            },
504            ChunkJudgment {
505                rank: 2,
506                score: 0.8,
507                source: None,
508                relevant: false,
509                reasoning: String::new(),
510            },
511            ChunkJudgment {
512                rank: 3,
513                score: 0.7,
514                source: None,
515                relevant: true,
516                reasoning: String::new(),
517            },
518        ];
519        assert!((compute_mrr(&judgments) - 1.0 / 3.0).abs() < 0.001);
520    }
521
522    #[test]
523    fn test_compute_mrr_none() {
524        let judgments = vec![ChunkJudgment {
525            rank: 1,
526            score: 0.9,
527            source: None,
528            relevant: false,
529            reasoning: String::new(),
530        }];
531        assert!((compute_mrr(&judgments)).abs() < 0.001);
532    }
533
534    #[test]
535    fn test_check_gate_pass() {
536        let output = EvalOutput {
537            timestamp: String::new(),
538            config: EvalRunConfig {
539                num_queries: 10,
540                top_k: 10,
541                judge_model: String::new(),
542                cache_hits: 0,
543                api_calls: 10,
544            },
545            aggregate: AggregateMetrics {
546                num_queries: 10,
547                mrr: 0.6,
548                hit_rate_5: 0.8,
549                ..Default::default()
550            },
551            by_domain: HashMap::new(),
552            per_query: Vec::new(),
553        };
554        assert!(check_gate(&output, 0.5, 0.7).is_ok());
555    }
556
557    #[test]
558    fn test_check_gate_fail() {
559        let output = EvalOutput {
560            timestamp: String::new(),
561            config: EvalRunConfig {
562                num_queries: 10,
563                top_k: 10,
564                judge_model: String::new(),
565                cache_hits: 0,
566                api_calls: 10,
567            },
568            aggregate: AggregateMetrics {
569                num_queries: 10,
570                mrr: 0.3,
571                hit_rate_5: 0.4,
572                ..Default::default()
573            },
574            by_domain: HashMap::new(),
575            per_query: Vec::new(),
576        };
577        assert!(check_gate(&output, 0.5, 0.7).is_err());
578    }
579
580    #[test]
581    fn test_days_to_ymd() {
582        // 2024-01-01 is day 19723
583        let (y, m, d) = days_to_ymd(19723);
584        assert_eq!((y, m, d), (2024, 1, 1));
585    }
586}