1use crate::{
2 ChunkInput, FusionStrategy, QueryProfile, Result, RuntimeConfig, SearchRequest, SqlRite,
3 SqlRiteError,
4};
5use serde::{Deserialize, Serialize};
6use std::cmp::min;
7use std::collections::{HashMap, HashSet};
8
9fn default_k_values() -> Vec<usize> {
10 vec![1, 3, 5, 10]
11}
12
13fn default_alpha() -> f32 {
14 0.65
15}
16
17fn default_candidate_limit() -> usize {
18 1000
19}
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct EvalDataset {
23 pub corpus: Vec<ChunkInput>,
24 pub queries: Vec<EvalQuery>,
25 #[serde(default = "default_k_values")]
26 pub k_values: Vec<usize>,
27}
28
29#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct EvalQuery {
31 pub id: String,
32 pub query_text: Option<String>,
33 pub query_embedding: Option<Vec<f32>>,
34 pub relevant_chunk_ids: Vec<String>,
35 #[serde(default)]
36 pub metadata_filters: HashMap<String, String>,
37 pub doc_id: Option<String>,
38 #[serde(default = "default_alpha")]
39 pub alpha: f32,
40 #[serde(default = "default_candidate_limit")]
41 pub candidate_limit: usize,
42 pub top_k: Option<usize>,
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct EvalMetricsAtK {
47 pub recall: f32,
48 pub precision: f32,
49 pub mrr: f32,
50 pub ndcg: f32,
51 pub hit_rate: f32,
52}
53
54#[derive(Debug, Clone, Serialize, Deserialize)]
55pub struct QueryEvalResult {
56 pub query_id: String,
57 pub retrieved_chunk_ids: Vec<String>,
58 pub relevant_chunk_ids: Vec<String>,
59 pub metrics_at_k: HashMap<usize, EvalMetricsAtK>,
60}
61
62#[derive(Debug, Clone, Serialize, Deserialize)]
63pub struct EvalSummary {
64 pub corpus_size: usize,
65 pub query_count: usize,
66 pub k_values: Vec<usize>,
67}
68
69#[derive(Debug, Clone, Serialize, Deserialize)]
70pub struct EvalReport {
71 pub summary: EvalSummary,
72 pub aggregate_metrics_at_k: HashMap<usize, EvalMetricsAtK>,
73 pub per_query: Vec<QueryEvalResult>,
74}
75
76pub fn evaluate_dataset(dataset: EvalDataset, runtime_config: RuntimeConfig) -> Result<EvalReport> {
77 let k_values = normalized_k_values(&dataset.k_values)?;
78 validate_dataset(&dataset, &k_values)?;
79
80 let max_k = *k_values.last().expect("k_values cannot be empty");
81 let db = SqlRite::open_in_memory_with_config(runtime_config)?;
82 db.ingest_chunks(&dataset.corpus)?;
83
84 let mut per_query = Vec::with_capacity(dataset.queries.len());
85 let mut aggregate: HashMap<usize, EvalMetricAccumulator> = HashMap::new();
86 for &k in &k_values {
87 aggregate.insert(k, EvalMetricAccumulator::default());
88 }
89
90 for query in &dataset.queries {
91 let top_k = query.top_k.unwrap_or(max_k).max(max_k);
92 let request = SearchRequest {
93 query_text: query.query_text.clone(),
94 query_embedding: query.query_embedding.clone(),
95 top_k,
96 alpha: query.alpha,
97 candidate_limit: query.candidate_limit.max(top_k),
98 include_payloads: true,
99 query_profile: QueryProfile::Balanced,
100 metadata_filters: query.metadata_filters.clone(),
101 doc_id: query.doc_id.clone(),
102 fusion_strategy: FusionStrategy::Weighted,
103 };
104 let search_results = db.search(request)?;
105 let ranked_ids: Vec<String> = search_results.into_iter().map(|r| r.chunk_id).collect();
106
107 let relevant_set: HashSet<&str> = query
108 .relevant_chunk_ids
109 .iter()
110 .map(String::as_str)
111 .collect();
112 let mut metrics_at_k = HashMap::new();
113
114 for &k in &k_values {
115 let metrics = compute_metrics_at_k(&ranked_ids, &relevant_set, k);
116 metrics_at_k.insert(k, metrics.clone());
117
118 if let Some(acc) = aggregate.get_mut(&k) {
119 acc.add(&metrics);
120 }
121 }
122
123 per_query.push(QueryEvalResult {
124 query_id: query.id.clone(),
125 retrieved_chunk_ids: ranked_ids,
126 relevant_chunk_ids: query.relevant_chunk_ids.clone(),
127 metrics_at_k,
128 });
129 }
130
131 let mut aggregate_metrics_at_k = HashMap::new();
132 for &k in &k_values {
133 let metrics = aggregate
134 .remove(&k)
135 .expect("aggregate key exists")
136 .mean(dataset.queries.len());
137 aggregate_metrics_at_k.insert(k, metrics);
138 }
139
140 Ok(EvalReport {
141 summary: EvalSummary {
142 corpus_size: dataset.corpus.len(),
143 query_count: dataset.queries.len(),
144 k_values,
145 },
146 aggregate_metrics_at_k,
147 per_query,
148 })
149}
150
151#[derive(Debug, Default, Clone)]
152struct EvalMetricAccumulator {
153 recall: f32,
154 precision: f32,
155 mrr: f32,
156 ndcg: f32,
157 hit_rate: f32,
158}
159
160impl EvalMetricAccumulator {
161 fn add(&mut self, metrics: &EvalMetricsAtK) {
162 self.recall += metrics.recall;
163 self.precision += metrics.precision;
164 self.mrr += metrics.mrr;
165 self.ndcg += metrics.ndcg;
166 self.hit_rate += metrics.hit_rate;
167 }
168
169 fn mean(self, count: usize) -> EvalMetricsAtK {
170 let denom = count as f32;
171 EvalMetricsAtK {
172 recall: self.recall / denom,
173 precision: self.precision / denom,
174 mrr: self.mrr / denom,
175 ndcg: self.ndcg / denom,
176 hit_rate: self.hit_rate / denom,
177 }
178 }
179}
180
181fn compute_metrics_at_k(
182 ranked_ids: &[String],
183 relevant_ids: &HashSet<&str>,
184 k: usize,
185) -> EvalMetricsAtK {
186 let relevant_count = relevant_ids.len();
187 if relevant_count == 0 || k == 0 {
188 return EvalMetricsAtK {
189 recall: 0.0,
190 precision: 0.0,
191 mrr: 0.0,
192 ndcg: 0.0,
193 hit_rate: 0.0,
194 };
195 }
196
197 let cutoff = min(k, ranked_ids.len());
198 let hits = ranked_ids
199 .iter()
200 .take(cutoff)
201 .filter(|id| relevant_ids.contains(id.as_str()))
202 .count();
203 let recall = hits as f32 / relevant_count as f32;
204 let precision = hits as f32 / k as f32;
205 let hit_rate = if hits > 0 { 1.0 } else { 0.0 };
206
207 let mut reciprocal_rank = 0.0;
208 for (idx, id) in ranked_ids.iter().take(cutoff).enumerate() {
209 if relevant_ids.contains(id.as_str()) {
210 reciprocal_rank = 1.0 / (idx as f32 + 1.0);
211 break;
212 }
213 }
214
215 let mut dcg = 0.0;
216 for (idx, id) in ranked_ids.iter().take(cutoff).enumerate() {
217 if relevant_ids.contains(id.as_str()) {
218 dcg += 1.0 / ((idx as f32 + 2.0).log2());
219 }
220 }
221
222 let ideal_hits = min(relevant_count, k);
223 let mut idcg = 0.0;
224 for idx in 0..ideal_hits {
225 idcg += 1.0 / ((idx as f32 + 2.0).log2());
226 }
227 let ndcg = if idcg > 0.0 { dcg / idcg } else { 0.0 };
228
229 EvalMetricsAtK {
230 recall,
231 precision,
232 mrr: reciprocal_rank,
233 ndcg,
234 hit_rate,
235 }
236}
237
238fn normalized_k_values(values: &[usize]) -> Result<Vec<usize>> {
239 let mut unique: Vec<usize> = values.iter().copied().filter(|v| *v > 0).collect();
240 unique.sort_unstable();
241 unique.dedup();
242 if unique.is_empty() {
243 return Err(SqlRiteError::InvalidEvaluationDataset(
244 "k_values must contain at least one positive integer".to_string(),
245 ));
246 }
247 Ok(unique)
248}
249
250fn validate_dataset(dataset: &EvalDataset, k_values: &[usize]) -> Result<()> {
251 if dataset.corpus.is_empty() {
252 return Err(SqlRiteError::InvalidEvaluationDataset(
253 "corpus cannot be empty".to_string(),
254 ));
255 }
256 if dataset.queries.is_empty() {
257 return Err(SqlRiteError::InvalidEvaluationDataset(
258 "queries cannot be empty".to_string(),
259 ));
260 }
261 if k_values.is_empty() {
262 return Err(SqlRiteError::InvalidEvaluationDataset(
263 "k_values cannot be empty".to_string(),
264 ));
265 }
266
267 for query in &dataset.queries {
268 if query.query_text.is_none() && query.query_embedding.is_none() {
269 return Err(SqlRiteError::InvalidEvaluationDataset(format!(
270 "query `{}` must contain query_text, query_embedding, or both",
271 query.id
272 )));
273 }
274 if query.relevant_chunk_ids.is_empty() {
275 return Err(SqlRiteError::InvalidEvaluationDataset(format!(
276 "query `{}` has no relevant_chunk_ids",
277 query.id
278 )));
279 }
280 if query.candidate_limit == 0 {
281 return Err(SqlRiteError::InvalidEvaluationDataset(format!(
282 "query `{}` candidate_limit must be >= 1",
283 query.id
284 )));
285 }
286 if !(0.0..=1.0).contains(&query.alpha) {
287 return Err(SqlRiteError::InvalidEvaluationDataset(format!(
288 "query `{}` alpha must be between 0.0 and 1.0",
289 query.id
290 )));
291 }
292 }
293
294 Ok(())
295}
296
297#[cfg(test)]
298mod tests {
299 use super::*;
300 use serde_json::json;
301
302 fn sample_dataset() -> EvalDataset {
303 EvalDataset {
304 corpus: vec![
305 ChunkInput {
306 id: "c1".to_string(),
307 doc_id: "d1".to_string(),
308 content: "Rust for retrieval".to_string(),
309 embedding: vec![1.0, 0.0, 0.0],
310 metadata: json!({"tenant": "acme"}),
311 source: None,
312 },
313 ChunkInput {
314 id: "c2".to_string(),
315 doc_id: "d2".to_string(),
316 content: "Postgres transactions".to_string(),
317 embedding: vec![0.0, 1.0, 0.0],
318 metadata: json!({"tenant": "acme"}),
319 source: None,
320 },
321 ChunkInput {
322 id: "c3".to_string(),
323 doc_id: "d3".to_string(),
324 content: "SQLite local memory".to_string(),
325 embedding: vec![0.8, 0.2, 0.0],
326 metadata: json!({"tenant": "acme"}),
327 source: None,
328 },
329 ],
330 queries: vec![
331 EvalQuery {
332 id: "q1".to_string(),
333 query_text: Some("rust retrieval".to_string()),
334 query_embedding: Some(vec![0.95, 0.05, 0.0]),
335 relevant_chunk_ids: vec!["c1".to_string()],
336 metadata_filters: HashMap::new(),
337 doc_id: None,
338 alpha: 0.6,
339 candidate_limit: 10,
340 top_k: Some(3),
341 },
342 EvalQuery {
343 id: "q2".to_string(),
344 query_text: Some("sqlite memory".to_string()),
345 query_embedding: Some(vec![0.75, 0.25, 0.0]),
346 relevant_chunk_ids: vec!["c3".to_string()],
347 metadata_filters: HashMap::new(),
348 doc_id: None,
349 alpha: 0.5,
350 candidate_limit: 10,
351 top_k: Some(3),
352 },
353 ],
354 k_values: vec![1, 3],
355 }
356 }
357
358 #[test]
359 fn compute_metrics_is_correct_for_simple_case() {
360 let ranked = vec!["a".to_string(), "b".to_string(), "c".to_string()];
361 let relevant: HashSet<&str> = HashSet::from(["b"]);
362 let m = compute_metrics_at_k(&ranked, &relevant, 3);
363 assert!((m.recall - 1.0).abs() < 1e-6);
364 assert!((m.precision - (1.0 / 3.0)).abs() < 1e-6);
365 assert!((m.mrr - 0.5).abs() < 1e-6);
366 assert!(m.ndcg > 0.6 && m.ndcg < 0.7);
367 assert!((m.hit_rate - 1.0).abs() < 1e-6);
368 }
369
370 #[test]
371 fn evaluation_report_has_aggregate_metrics() -> Result<()> {
372 let report = evaluate_dataset(sample_dataset(), RuntimeConfig::default())?;
373 assert_eq!(report.summary.corpus_size, 3);
374 assert_eq!(report.summary.query_count, 2);
375 assert_eq!(report.per_query.len(), 2);
376 assert!(report.aggregate_metrics_at_k.contains_key(&1));
377 assert!(report.aggregate_metrics_at_k.contains_key(&3));
378 Ok(())
379 }
380}