1use super::judge::{compute_aggregate_metrics, compute_by_domain_metrics};
7use super::types::{
8 AggregateMetrics, ChunkJudgment, EvalOutput, EvalRunConfig, JudgmentEntry, QueryResult,
9 RetrievalResultEntry,
10};
11use std::collections::HashMap;
12
13pub fn compute_metrics_from_judgments(
18 retrieval_results: &[RetrievalResultEntry],
19 judgments: &[JudgmentEntry],
20) -> EvalOutput {
21 let mut judgment_map: HashMap<(&str, usize), &JudgmentEntry> = HashMap::new();
23 for j in judgments {
24 judgment_map.insert((&j.query, j.rank), j);
25 }
26
27 let mut per_query = Vec::new();
28
29 for entry in retrieval_results {
30 let mut chunk_judgments = Vec::new();
31
32 for (rank_idx, chunk) in entry.results.iter().enumerate() {
33 let rank = rank_idx + 1;
34
35 let (relevant, reasoning) =
37 if let Some(j) = judgment_map.get(&(entry.query.as_str(), rank)) {
38 (j.relevant, j.reasoning.clone())
39 } else {
40 (false, "no judgment provided".to_string())
42 };
43
44 chunk_judgments.push(ChunkJudgment {
45 rank,
46 score: chunk.score,
47 source: chunk.source.clone(),
48 relevant,
49 reasoning,
50 });
51 }
52
53 let mrr = compute_mrr(&chunk_judgments);
54 let hit_5 = chunk_judgments.iter().take(5).any(|j| j.relevant);
55 let relevant_count = chunk_judgments.iter().filter(|j| j.relevant).count();
56
57 per_query.push(QueryResult {
58 query: entry.query.clone(),
59 domain: entry.domain.clone(),
60 mrr,
61 hit_5,
62 relevant_count,
63 total_results: entry.results.len(),
64 latency_s: entry.latency_s,
65 judgments: chunk_judgments,
66 });
67 }
68
69 let aggregate = compute_aggregate_metrics(&per_query);
70 let by_domain = compute_by_domain_metrics(&per_query);
71 let timestamp = super::judge::chrono_now();
72
73 EvalOutput {
74 timestamp,
75 config: EvalRunConfig {
76 num_queries: retrieval_results.len(),
77 top_k: retrieval_results.first().map(|r| r.results.len()).unwrap_or(10),
78 judge_model: "claude-code".to_string(),
79 cache_hits: 0,
80 api_calls: 0,
81 },
82 aggregate,
83 by_domain,
84 per_query,
85 }
86}
87
88fn compute_mrr(judgments: &[ChunkJudgment]) -> f64 {
89 for j in judgments {
90 if j.relevant {
91 return 1.0 / j.rank as f64;
92 }
93 }
94 0.0
95}
96
97#[allow(clippy::implicit_hasher)]
99pub fn format_metrics_summary(
100 agg: &AggregateMetrics,
101 by_domain: &HashMap<String, AggregateMetrics>,
102) -> String {
103 use std::fmt::Write;
104 let mut s = String::new();
105 s.push_str(&"=".repeat(60));
106 s.push('\n');
107 s.push_str("AGGREGATE RESULTS\n");
108 s.push_str(&"=".repeat(60));
109 s.push('\n');
110 let _ = writeln!(s, " Queries: {}", agg.num_queries);
111 let _ = writeln!(s, " MRR: {:.4}", agg.mrr);
112 let _ = writeln!(s, " NDCG@5: {:.4}", agg.ndcg_5);
113 let _ = writeln!(s, " NDCG@10: {:.4}", agg.ndcg_10);
114 let _ = writeln!(s, " Recall@5: {:.4}", agg.recall_5);
115 let _ = writeln!(s, " Precision@5: {:.4}", agg.precision_5);
116 let _ = writeln!(s, " Hit Rate@5: {:.4}", agg.hit_rate_5);
117 let _ = writeln!(s, " Hit Rate@10: {:.4}", agg.hit_rate_10);
118 let _ = writeln!(s, " MAP: {:.4}", agg.map);
119 let _ = writeln!(s, " Latency: {:.3}s", agg.mean_latency_s);
120 s.push('\n');
121 s.push_str("BY DOMAIN:\n");
122
123 let mut domains: Vec<_> = by_domain.iter().collect();
124 domains.sort_by(|(a, _), (b, _)| a.cmp(b));
125 for (domain, m) in domains {
126 let _ = writeln!(
127 s,
128 " {domain:12} MRR={:.3} NDCG@5={:.3} Hit@5={:.3} (n={})",
129 m.mrr, m.ndcg_5, m.hit_rate_5, m.num_queries
130 );
131 }
132
133 s
134}
135
136#[cfg(test)]
137mod tests {
138 use super::*;
139 use crate::eval::types::RetrievedChunk;
140
141 fn make_retrieval_entry(query: &str, num_results: usize) -> RetrievalResultEntry {
142 RetrievalResultEntry {
143 query: query.to_string(),
144 domain: "test".to_string(),
145 course: "test-course".to_string(),
146 results: (0..num_results)
147 .map(|i| RetrievedChunk {
148 content: format!("chunk {i}"),
149 source: Some(format!("/test/chunk{i}.srt")),
150 score: 1.0 - i as f32 * 0.1,
151 title: None,
152 start_secs: None,
153 end_secs: None,
154 })
155 .collect(),
156 latency_s: 0.5,
157 }
158 }
159
160 #[test]
161 fn test_metrics_basic() {
162 let results = vec![make_retrieval_entry("what is kubernetes?", 5)];
163 let judgments = vec![
164 JudgmentEntry {
165 query: "what is kubernetes?".to_string(),
166 rank: 1,
167 relevant: false,
168 reasoning: "off topic".to_string(),
169 source: None,
170 score: None,
171 },
172 JudgmentEntry {
173 query: "what is kubernetes?".to_string(),
174 rank: 2,
175 relevant: true,
176 reasoning: "discusses k8s".to_string(),
177 source: None,
178 score: None,
179 },
180 ];
181
182 let output = compute_metrics_from_judgments(&results, &judgments);
183 assert_eq!(output.per_query.len(), 1);
184 assert!((output.per_query[0].mrr - 0.5).abs() < 0.001);
185 assert!(output.per_query[0].hit_5);
186 }
187
188 #[test]
189 fn test_metrics_no_relevant() {
190 let results = vec![make_retrieval_entry("obscure query", 3)];
191 let judgments = vec![JudgmentEntry {
192 query: "obscure query".to_string(),
193 rank: 1,
194 relevant: false,
195 reasoning: "not relevant".to_string(),
196 source: None,
197 score: None,
198 }];
199
200 let output = compute_metrics_from_judgments(&results, &judgments);
201 assert!((output.per_query[0].mrr).abs() < 0.001);
202 assert!(!output.per_query[0].hit_5);
203 }
204
205 #[test]
206 fn test_metrics_missing_judgments() {
207 let results = vec![make_retrieval_entry("test query", 5)];
208 let judgments = vec![];
210
211 let output = compute_metrics_from_judgments(&results, &judgments);
212 assert!((output.aggregate.mrr).abs() < 0.001);
213 assert!((output.aggregate.hit_rate_5).abs() < 0.001);
214 }
215}