Skip to main content

trueno_rag/eval/
metrics.rs

1//! Compute IR metrics from pre-judged results (no API calls needed)
2//!
3//! Reads judgments JSONL (produced by Claude Code or external judge)
4//! and retrieval results JSONL, computes standard IR metrics.
5
6use super::judge::{compute_aggregate_metrics, compute_by_domain_metrics};
7use super::types::{
8    AggregateMetrics, ChunkJudgment, EvalOutput, EvalRunConfig, JudgmentEntry, QueryResult,
9    RetrievalResultEntry,
10};
11use std::collections::HashMap;
12
13/// Compute metrics from pre-judged results
14///
15/// Takes retrieval results (with chunks) and judgment entries (with verdicts),
16/// correlates them by (query, rank), and computes standard IR metrics.
17pub fn compute_metrics_from_judgments(
18    retrieval_results: &[RetrievalResultEntry],
19    judgments: &[JudgmentEntry],
20) -> EvalOutput {
21    // Index judgments by (query, rank) for fast lookup
22    let mut judgment_map: HashMap<(&str, usize), &JudgmentEntry> = HashMap::new();
23    for j in judgments {
24        judgment_map.insert((&j.query, j.rank), j);
25    }
26
27    let mut per_query = Vec::new();
28
29    for entry in retrieval_results {
30        let mut chunk_judgments = Vec::new();
31
32        for (rank_idx, chunk) in entry.results.iter().enumerate() {
33            let rank = rank_idx + 1;
34
35            // Look up the judgment for this (query, rank) pair
36            let (relevant, reasoning) =
37                if let Some(j) = judgment_map.get(&(entry.query.as_str(), rank)) {
38                    (j.relevant, j.reasoning.clone())
39                } else {
40                    // No judgment = not relevant (unjudged chunks count against)
41                    (false, "no judgment provided".to_string())
42                };
43
44            chunk_judgments.push(ChunkJudgment {
45                rank,
46                score: chunk.score,
47                source: chunk.source.clone(),
48                relevant,
49                reasoning,
50            });
51        }
52
53        let mrr = compute_mrr(&chunk_judgments);
54        let hit_5 = chunk_judgments.iter().take(5).any(|j| j.relevant);
55        let relevant_count = chunk_judgments.iter().filter(|j| j.relevant).count();
56
57        per_query.push(QueryResult {
58            query: entry.query.clone(),
59            domain: entry.domain.clone(),
60            mrr,
61            hit_5,
62            relevant_count,
63            total_results: entry.results.len(),
64            latency_s: entry.latency_s,
65            judgments: chunk_judgments,
66        });
67    }
68
69    let aggregate = compute_aggregate_metrics(&per_query);
70    let by_domain = compute_by_domain_metrics(&per_query);
71    let timestamp = super::judge::chrono_now();
72
73    EvalOutput {
74        timestamp,
75        config: EvalRunConfig {
76            num_queries: retrieval_results.len(),
77            top_k: retrieval_results.first().map(|r| r.results.len()).unwrap_or(10),
78            judge_model: "claude-code".to_string(),
79            cache_hits: 0,
80            api_calls: 0,
81        },
82        aggregate,
83        by_domain,
84        per_query,
85    }
86}
87
88fn compute_mrr(judgments: &[ChunkJudgment]) -> f64 {
89    for j in judgments {
90        if j.relevant {
91            return 1.0 / j.rank as f64;
92        }
93    }
94    0.0
95}
96
97/// Format metrics summary to stdout
98#[allow(clippy::implicit_hasher)]
99pub fn format_metrics_summary(
100    agg: &AggregateMetrics,
101    by_domain: &HashMap<String, AggregateMetrics>,
102) -> String {
103    use std::fmt::Write;
104    let mut s = String::new();
105    s.push_str(&"=".repeat(60));
106    s.push('\n');
107    s.push_str("AGGREGATE RESULTS\n");
108    s.push_str(&"=".repeat(60));
109    s.push('\n');
110    let _ = writeln!(s, "  Queries:       {}", agg.num_queries);
111    let _ = writeln!(s, "  MRR:           {:.4}", agg.mrr);
112    let _ = writeln!(s, "  NDCG@5:        {:.4}", agg.ndcg_5);
113    let _ = writeln!(s, "  NDCG@10:       {:.4}", agg.ndcg_10);
114    let _ = writeln!(s, "  Recall@5:      {:.4}", agg.recall_5);
115    let _ = writeln!(s, "  Precision@5:   {:.4}", agg.precision_5);
116    let _ = writeln!(s, "  Hit Rate@5:    {:.4}", agg.hit_rate_5);
117    let _ = writeln!(s, "  Hit Rate@10:   {:.4}", agg.hit_rate_10);
118    let _ = writeln!(s, "  MAP:           {:.4}", agg.map);
119    let _ = writeln!(s, "  Latency:       {:.3}s", agg.mean_latency_s);
120    s.push('\n');
121    s.push_str("BY DOMAIN:\n");
122
123    let mut domains: Vec<_> = by_domain.iter().collect();
124    domains.sort_by(|(a, _), (b, _)| a.cmp(b));
125    for (domain, m) in domains {
126        let _ = writeln!(
127            s,
128            "  {domain:12}  MRR={:.3}  NDCG@5={:.3}  Hit@5={:.3}  (n={})",
129            m.mrr, m.ndcg_5, m.hit_rate_5, m.num_queries
130        );
131    }
132
133    s
134}
135
136#[cfg(test)]
137mod tests {
138    use super::*;
139    use crate::eval::types::RetrievedChunk;
140
141    fn make_retrieval_entry(query: &str, num_results: usize) -> RetrievalResultEntry {
142        RetrievalResultEntry {
143            query: query.to_string(),
144            domain: "test".to_string(),
145            course: "test-course".to_string(),
146            results: (0..num_results)
147                .map(|i| RetrievedChunk {
148                    content: format!("chunk {i}"),
149                    source: Some(format!("/test/chunk{i}.srt")),
150                    score: 1.0 - i as f32 * 0.1,
151                    title: None,
152                    start_secs: None,
153                    end_secs: None,
154                })
155                .collect(),
156            latency_s: 0.5,
157        }
158    }
159
160    #[test]
161    fn test_metrics_basic() {
162        let results = vec![make_retrieval_entry("what is kubernetes?", 5)];
163        let judgments = vec![
164            JudgmentEntry {
165                query: "what is kubernetes?".to_string(),
166                rank: 1,
167                relevant: false,
168                reasoning: "off topic".to_string(),
169                source: None,
170                score: None,
171            },
172            JudgmentEntry {
173                query: "what is kubernetes?".to_string(),
174                rank: 2,
175                relevant: true,
176                reasoning: "discusses k8s".to_string(),
177                source: None,
178                score: None,
179            },
180        ];
181
182        let output = compute_metrics_from_judgments(&results, &judgments);
183        assert_eq!(output.per_query.len(), 1);
184        assert!((output.per_query[0].mrr - 0.5).abs() < 0.001);
185        assert!(output.per_query[0].hit_5);
186    }
187
188    #[test]
189    fn test_metrics_no_relevant() {
190        let results = vec![make_retrieval_entry("obscure query", 3)];
191        let judgments = vec![JudgmentEntry {
192            query: "obscure query".to_string(),
193            rank: 1,
194            relevant: false,
195            reasoning: "not relevant".to_string(),
196            source: None,
197            score: None,
198        }];
199
200        let output = compute_metrics_from_judgments(&results, &judgments);
201        assert!((output.per_query[0].mrr).abs() < 0.001);
202        assert!(!output.per_query[0].hit_5);
203    }
204
205    #[test]
206    fn test_metrics_missing_judgments() {
207        let results = vec![make_retrieval_entry("test query", 5)];
208        // No judgments at all — everything defaults to not relevant
209        let judgments = vec![];
210
211        let output = compute_metrics_from_judgments(&results, &judgments);
212        assert!((output.aggregate.mrr).abs() < 0.001);
213        assert!((output.aggregate.hit_rate_5).abs() < 0.001);
214    }
215}