Skip to main content

sqlrite/
eval.rs

1use crate::{
2    ChunkInput, FusionStrategy, QueryProfile, Result, RuntimeConfig, SearchRequest, SqlRite,
3    SqlRiteError,
4};
5use serde::{Deserialize, Serialize};
6use std::cmp::min;
7use std::collections::{HashMap, HashSet};
8
9fn default_k_values() -> Vec<usize> {
10    vec![1, 3, 5, 10]
11}
12
13fn default_alpha() -> f32 {
14    0.65
15}
16
17fn default_candidate_limit() -> usize {
18    1000
19}
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct EvalDataset {
23    pub corpus: Vec<ChunkInput>,
24    pub queries: Vec<EvalQuery>,
25    #[serde(default = "default_k_values")]
26    pub k_values: Vec<usize>,
27}
28
29#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct EvalQuery {
31    pub id: String,
32    pub query_text: Option<String>,
33    pub query_embedding: Option<Vec<f32>>,
34    pub relevant_chunk_ids: Vec<String>,
35    #[serde(default)]
36    pub metadata_filters: HashMap<String, String>,
37    pub doc_id: Option<String>,
38    #[serde(default = "default_alpha")]
39    pub alpha: f32,
40    #[serde(default = "default_candidate_limit")]
41    pub candidate_limit: usize,
42    pub top_k: Option<usize>,
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct EvalMetricsAtK {
47    pub recall: f32,
48    pub precision: f32,
49    pub mrr: f32,
50    pub ndcg: f32,
51    pub hit_rate: f32,
52}
53
54#[derive(Debug, Clone, Serialize, Deserialize)]
55pub struct QueryEvalResult {
56    pub query_id: String,
57    pub retrieved_chunk_ids: Vec<String>,
58    pub relevant_chunk_ids: Vec<String>,
59    pub metrics_at_k: HashMap<usize, EvalMetricsAtK>,
60}
61
62#[derive(Debug, Clone, Serialize, Deserialize)]
63pub struct EvalSummary {
64    pub corpus_size: usize,
65    pub query_count: usize,
66    pub k_values: Vec<usize>,
67}
68
69#[derive(Debug, Clone, Serialize, Deserialize)]
70pub struct EvalReport {
71    pub summary: EvalSummary,
72    pub aggregate_metrics_at_k: HashMap<usize, EvalMetricsAtK>,
73    pub per_query: Vec<QueryEvalResult>,
74}
75
76pub fn evaluate_dataset(dataset: EvalDataset, runtime_config: RuntimeConfig) -> Result<EvalReport> {
77    let k_values = normalized_k_values(&dataset.k_values)?;
78    validate_dataset(&dataset, &k_values)?;
79
80    let max_k = *k_values.last().expect("k_values cannot be empty");
81    let db = SqlRite::open_in_memory_with_config(runtime_config)?;
82    db.ingest_chunks(&dataset.corpus)?;
83
84    let mut per_query = Vec::with_capacity(dataset.queries.len());
85    let mut aggregate: HashMap<usize, EvalMetricAccumulator> = HashMap::new();
86    for &k in &k_values {
87        aggregate.insert(k, EvalMetricAccumulator::default());
88    }
89
90    for query in &dataset.queries {
91        let top_k = query.top_k.unwrap_or(max_k).max(max_k);
92        let request = SearchRequest {
93            query_text: query.query_text.clone(),
94            query_embedding: query.query_embedding.clone(),
95            top_k,
96            alpha: query.alpha,
97            candidate_limit: query.candidate_limit.max(top_k),
98            include_payloads: true,
99            query_profile: QueryProfile::Balanced,
100            metadata_filters: query.metadata_filters.clone(),
101            doc_id: query.doc_id.clone(),
102            fusion_strategy: FusionStrategy::Weighted,
103        };
104        let search_results = db.search(request)?;
105        let ranked_ids: Vec<String> = search_results.into_iter().map(|r| r.chunk_id).collect();
106
107        let relevant_set: HashSet<&str> = query
108            .relevant_chunk_ids
109            .iter()
110            .map(String::as_str)
111            .collect();
112        let mut metrics_at_k = HashMap::new();
113
114        for &k in &k_values {
115            let metrics = compute_metrics_at_k(&ranked_ids, &relevant_set, k);
116            metrics_at_k.insert(k, metrics.clone());
117
118            if let Some(acc) = aggregate.get_mut(&k) {
119                acc.add(&metrics);
120            }
121        }
122
123        per_query.push(QueryEvalResult {
124            query_id: query.id.clone(),
125            retrieved_chunk_ids: ranked_ids,
126            relevant_chunk_ids: query.relevant_chunk_ids.clone(),
127            metrics_at_k,
128        });
129    }
130
131    let mut aggregate_metrics_at_k = HashMap::new();
132    for &k in &k_values {
133        let metrics = aggregate
134            .remove(&k)
135            .expect("aggregate key exists")
136            .mean(dataset.queries.len());
137        aggregate_metrics_at_k.insert(k, metrics);
138    }
139
140    Ok(EvalReport {
141        summary: EvalSummary {
142            corpus_size: dataset.corpus.len(),
143            query_count: dataset.queries.len(),
144            k_values,
145        },
146        aggregate_metrics_at_k,
147        per_query,
148    })
149}
150
151#[derive(Debug, Default, Clone)]
152struct EvalMetricAccumulator {
153    recall: f32,
154    precision: f32,
155    mrr: f32,
156    ndcg: f32,
157    hit_rate: f32,
158}
159
160impl EvalMetricAccumulator {
161    fn add(&mut self, metrics: &EvalMetricsAtK) {
162        self.recall += metrics.recall;
163        self.precision += metrics.precision;
164        self.mrr += metrics.mrr;
165        self.ndcg += metrics.ndcg;
166        self.hit_rate += metrics.hit_rate;
167    }
168
169    fn mean(self, count: usize) -> EvalMetricsAtK {
170        let denom = count as f32;
171        EvalMetricsAtK {
172            recall: self.recall / denom,
173            precision: self.precision / denom,
174            mrr: self.mrr / denom,
175            ndcg: self.ndcg / denom,
176            hit_rate: self.hit_rate / denom,
177        }
178    }
179}
180
181fn compute_metrics_at_k(
182    ranked_ids: &[String],
183    relevant_ids: &HashSet<&str>,
184    k: usize,
185) -> EvalMetricsAtK {
186    let relevant_count = relevant_ids.len();
187    if relevant_count == 0 || k == 0 {
188        return EvalMetricsAtK {
189            recall: 0.0,
190            precision: 0.0,
191            mrr: 0.0,
192            ndcg: 0.0,
193            hit_rate: 0.0,
194        };
195    }
196
197    let cutoff = min(k, ranked_ids.len());
198    let hits = ranked_ids
199        .iter()
200        .take(cutoff)
201        .filter(|id| relevant_ids.contains(id.as_str()))
202        .count();
203    let recall = hits as f32 / relevant_count as f32;
204    let precision = hits as f32 / k as f32;
205    let hit_rate = if hits > 0 { 1.0 } else { 0.0 };
206
207    let mut reciprocal_rank = 0.0;
208    for (idx, id) in ranked_ids.iter().take(cutoff).enumerate() {
209        if relevant_ids.contains(id.as_str()) {
210            reciprocal_rank = 1.0 / (idx as f32 + 1.0);
211            break;
212        }
213    }
214
215    let mut dcg = 0.0;
216    for (idx, id) in ranked_ids.iter().take(cutoff).enumerate() {
217        if relevant_ids.contains(id.as_str()) {
218            dcg += 1.0 / ((idx as f32 + 2.0).log2());
219        }
220    }
221
222    let ideal_hits = min(relevant_count, k);
223    let mut idcg = 0.0;
224    for idx in 0..ideal_hits {
225        idcg += 1.0 / ((idx as f32 + 2.0).log2());
226    }
227    let ndcg = if idcg > 0.0 { dcg / idcg } else { 0.0 };
228
229    EvalMetricsAtK {
230        recall,
231        precision,
232        mrr: reciprocal_rank,
233        ndcg,
234        hit_rate,
235    }
236}
237
238fn normalized_k_values(values: &[usize]) -> Result<Vec<usize>> {
239    let mut unique: Vec<usize> = values.iter().copied().filter(|v| *v > 0).collect();
240    unique.sort_unstable();
241    unique.dedup();
242    if unique.is_empty() {
243        return Err(SqlRiteError::InvalidEvaluationDataset(
244            "k_values must contain at least one positive integer".to_string(),
245        ));
246    }
247    Ok(unique)
248}
249
250fn validate_dataset(dataset: &EvalDataset, k_values: &[usize]) -> Result<()> {
251    if dataset.corpus.is_empty() {
252        return Err(SqlRiteError::InvalidEvaluationDataset(
253            "corpus cannot be empty".to_string(),
254        ));
255    }
256    if dataset.queries.is_empty() {
257        return Err(SqlRiteError::InvalidEvaluationDataset(
258            "queries cannot be empty".to_string(),
259        ));
260    }
261    if k_values.is_empty() {
262        return Err(SqlRiteError::InvalidEvaluationDataset(
263            "k_values cannot be empty".to_string(),
264        ));
265    }
266
267    for query in &dataset.queries {
268        if query.query_text.is_none() && query.query_embedding.is_none() {
269            return Err(SqlRiteError::InvalidEvaluationDataset(format!(
270                "query `{}` must contain query_text, query_embedding, or both",
271                query.id
272            )));
273        }
274        if query.relevant_chunk_ids.is_empty() {
275            return Err(SqlRiteError::InvalidEvaluationDataset(format!(
276                "query `{}` has no relevant_chunk_ids",
277                query.id
278            )));
279        }
280        if query.candidate_limit == 0 {
281            return Err(SqlRiteError::InvalidEvaluationDataset(format!(
282                "query `{}` candidate_limit must be >= 1",
283                query.id
284            )));
285        }
286        if !(0.0..=1.0).contains(&query.alpha) {
287            return Err(SqlRiteError::InvalidEvaluationDataset(format!(
288                "query `{}` alpha must be between 0.0 and 1.0",
289                query.id
290            )));
291        }
292    }
293
294    Ok(())
295}
296
297#[cfg(test)]
298mod tests {
299    use super::*;
300    use serde_json::json;
301
302    fn sample_dataset() -> EvalDataset {
303        EvalDataset {
304            corpus: vec![
305                ChunkInput {
306                    id: "c1".to_string(),
307                    doc_id: "d1".to_string(),
308                    content: "Rust for retrieval".to_string(),
309                    embedding: vec![1.0, 0.0, 0.0],
310                    metadata: json!({"tenant": "acme"}),
311                    source: None,
312                },
313                ChunkInput {
314                    id: "c2".to_string(),
315                    doc_id: "d2".to_string(),
316                    content: "Postgres transactions".to_string(),
317                    embedding: vec![0.0, 1.0, 0.0],
318                    metadata: json!({"tenant": "acme"}),
319                    source: None,
320                },
321                ChunkInput {
322                    id: "c3".to_string(),
323                    doc_id: "d3".to_string(),
324                    content: "SQLite local memory".to_string(),
325                    embedding: vec![0.8, 0.2, 0.0],
326                    metadata: json!({"tenant": "acme"}),
327                    source: None,
328                },
329            ],
330            queries: vec![
331                EvalQuery {
332                    id: "q1".to_string(),
333                    query_text: Some("rust retrieval".to_string()),
334                    query_embedding: Some(vec![0.95, 0.05, 0.0]),
335                    relevant_chunk_ids: vec!["c1".to_string()],
336                    metadata_filters: HashMap::new(),
337                    doc_id: None,
338                    alpha: 0.6,
339                    candidate_limit: 10,
340                    top_k: Some(3),
341                },
342                EvalQuery {
343                    id: "q2".to_string(),
344                    query_text: Some("sqlite memory".to_string()),
345                    query_embedding: Some(vec![0.75, 0.25, 0.0]),
346                    relevant_chunk_ids: vec!["c3".to_string()],
347                    metadata_filters: HashMap::new(),
348                    doc_id: None,
349                    alpha: 0.5,
350                    candidate_limit: 10,
351                    top_k: Some(3),
352                },
353            ],
354            k_values: vec![1, 3],
355        }
356    }
357
358    #[test]
359    fn compute_metrics_is_correct_for_simple_case() {
360        let ranked = vec!["a".to_string(), "b".to_string(), "c".to_string()];
361        let relevant: HashSet<&str> = HashSet::from(["b"]);
362        let m = compute_metrics_at_k(&ranked, &relevant, 3);
363        assert!((m.recall - 1.0).abs() < 1e-6);
364        assert!((m.precision - (1.0 / 3.0)).abs() < 1e-6);
365        assert!((m.mrr - 0.5).abs() < 1e-6);
366        assert!(m.ndcg > 0.6 && m.ndcg < 0.7);
367        assert!((m.hit_rate - 1.0).abs() < 1e-6);
368    }
369
370    #[test]
371    fn evaluation_report_has_aggregate_metrics() -> Result<()> {
372        let report = evaluate_dataset(sample_dataset(), RuntimeConfig::default())?;
373        assert_eq!(report.summary.corpus_size, 3);
374        assert_eq!(report.summary.query_count, 2);
375        assert_eq!(report.per_query.len(), 2);
376        assert!(report.aggregate_metrics_at_k.contains_key(&1));
377        assert!(report.aggregate_metrics_at_k.contains_key(&3));
378        Ok(())
379    }
380}