Skip to main content

the_code_graph_eval/
metrics.rs

1use std::collections::HashSet;
2
3/// Mean Reciprocal Rank across multiple queries.
4/// `ranked_results`: for each query, the ranked list of qualified names returned.
5/// `ground_truth`: for each query, the set of correct qualified names.
6pub fn mrr(ranked_results: &[Vec<String>], ground_truth: &[Vec<String>]) -> f64 {
7    if ranked_results.is_empty() {
8        return 0.0;
9    }
10    let sum: f64 = ranked_results
11        .iter()
12        .zip(ground_truth.iter())
13        .map(|(ranked, truth)| {
14            let truth_set: HashSet<&str> = truth.iter().map(|s| s.as_str()).collect();
15            ranked
16                .iter()
17                .enumerate()
18                .find(|(_, name)| truth_set.contains(name.as_str()))
19                .map(|(i, _)| 1.0 / (i as f64 + 1.0))
20                .unwrap_or(0.0)
21        })
22        .sum();
23    sum / ranked_results.len() as f64
24}
25
26/// Precision at K for a single query.
27fn precision_at_k_single(ranked: &[String], truth: &[String], k: usize) -> f64 {
28    let truth_set: HashSet<&str> = truth.iter().map(|s| s.as_str()).collect();
29    let effective_k = k.min(ranked.len());
30    if effective_k == 0 {
31        return 0.0;
32    }
33    let relevant = ranked[..effective_k]
34        .iter()
35        .filter(|name| truth_set.contains(name.as_str()))
36        .count();
37    relevant as f64 / effective_k as f64
38}
39
40/// Precision at K — average across queries.
41pub fn precision_at_k(
42    ranked_results: &[Vec<String>],
43    ground_truth: &[Vec<String>],
44    k: usize,
45) -> f64 {
46    if ranked_results.is_empty() {
47        return 0.0;
48    }
49    let sum: f64 = ranked_results
50        .iter()
51        .zip(ground_truth.iter())
52        .map(|(ranked, truth)| precision_at_k_single(ranked, truth, k))
53        .sum();
54    sum / ranked_results.len() as f64
55}
56
57/// Blast radius precision: |predicted ∩ actual| / |predicted|
58pub fn blast_precision(predicted: &[String], actual: &[String]) -> f64 {
59    if predicted.is_empty() {
60        return 0.0;
61    }
62    let actual_set: HashSet<&str> = actual.iter().map(|s| s.as_str()).collect();
63    let intersection = predicted
64        .iter()
65        .filter(|p| actual_set.contains(p.as_str()))
66        .count();
67    intersection as f64 / predicted.len() as f64
68}
69
70/// Blast radius recall: |predicted ∩ actual| / |actual|
71pub fn blast_recall(predicted: &[String], actual: &[String]) -> f64 {
72    if actual.is_empty() {
73        return 0.0;
74    }
75    let actual_set: HashSet<&str> = actual.iter().map(|s| s.as_str()).collect();
76    let intersection = predicted
77        .iter()
78        .filter(|p| actual_set.contains(p.as_str()))
79        .count();
80    intersection as f64 / actual.len() as f64
81}
82
83/// Harmonic mean of precision and recall.
84pub fn f1(precision: f64, recall: f64) -> f64 {
85    if precision + recall == 0.0 {
86        return 0.0;
87    }
88    2.0 * precision * recall / (precision + recall)
89}
90
91#[cfg(test)]
92mod tests {
93    use super::*;
94
95    fn s(val: &str) -> String {
96        val.to_string()
97    }
98
99    // ── MRR tests ──────────────────────────────────────────────
100
101    #[test]
102    fn mrr_perfect_ranking() {
103        // All first results are correct → MRR = 1.0
104        let ranked = vec![vec![s("a"), s("b")], vec![s("c"), s("d")]];
105        let truth = vec![vec![s("a")], vec![s("c")]];
106        assert!((mrr(&ranked, &truth) - 1.0).abs() < f64::EPSILON);
107    }
108
109    #[test]
110    fn mrr_second_position() {
111        // First correct at position 2 → reciprocal rank = 0.5
112        let ranked = vec![vec![s("x"), s("a"), s("b")]];
113        let truth = vec![vec![s("a")]];
114        assert!((mrr(&ranked, &truth) - 0.5).abs() < f64::EPSILON);
115    }
116
117    #[test]
118    fn mrr_no_match() {
119        // No correct result found → reciprocal rank = 0.0
120        let ranked = vec![vec![s("x"), s("y")]];
121        let truth = vec![vec![s("a")]];
122        assert!((mrr(&ranked, &truth) - 0.0).abs() < f64::EPSILON);
123    }
124
125    #[test]
126    fn mrr_mixed() {
127        // Query 1: correct at position 1 → 1.0
128        // Query 2: correct at position 3 → 1/3
129        // Average: (1.0 + 1/3) / 2 = 2/3
130        let ranked = vec![vec![s("a"), s("b"), s("c")], vec![s("x"), s("y"), s("a")]];
131        let truth = vec![vec![s("a")], vec![s("a")]];
132        let expected = (1.0 + 1.0 / 3.0) / 2.0;
133        assert!((mrr(&ranked, &truth) - expected).abs() < 1e-10);
134    }
135
136    #[test]
137    fn mrr_empty_queries() {
138        let ranked: Vec<Vec<String>> = vec![];
139        let truth: Vec<Vec<String>> = vec![];
140        assert!((mrr(&ranked, &truth) - 0.0).abs() < f64::EPSILON);
141    }
142
143    // ── Precision@K tests ──────────────────────────────────────
144
145    #[test]
146    fn precision_at_k_all_relevant() {
147        // All top-k are relevant → 1.0
148        let ranked = vec![vec![s("a"), s("b"), s("c")]];
149        let truth = vec![vec![s("a"), s("b"), s("c")]];
150        assert!((precision_at_k(&ranked, &truth, 3) - 1.0).abs() < f64::EPSILON);
151    }
152
153    #[test]
154    fn precision_at_k_none_relevant() {
155        // No top-k relevant → 0.0
156        let ranked = vec![vec![s("x"), s("y"), s("z")]];
157        let truth = vec![vec![s("a"), s("b"), s("c")]];
158        assert!((precision_at_k(&ranked, &truth, 3) - 0.0).abs() < f64::EPSILON);
159    }
160
161    #[test]
162    fn precision_at_k_partial() {
163        // 3 of 5 relevant → 0.6
164        let ranked = vec![vec![s("a"), s("x"), s("b"), s("y"), s("c")]];
165        let truth = vec![vec![s("a"), s("b"), s("c")]];
166        assert!((precision_at_k(&ranked, &truth, 5) - 0.6).abs() < f64::EPSILON);
167    }
168
169    #[test]
170    fn precision_at_k_fewer_results_than_k() {
171        // 3 results, k=5 → use actual count (3). All 3 relevant → 1.0
172        let ranked = vec![vec![s("a"), s("b"), s("c")]];
173        let truth = vec![vec![s("a"), s("b"), s("c"), s("d"), s("e")]];
174        assert!((precision_at_k(&ranked, &truth, 5) - 1.0).abs() < f64::EPSILON);
175    }
176
177    // ── Blast radius precision/recall tests ────────────────────
178
179    #[test]
180    fn blast_precision_perfect() {
181        let predicted = vec![s("a"), s("b")];
182        let actual = vec![s("a"), s("b")];
183        assert!((blast_precision(&predicted, &actual) - 1.0).abs() < f64::EPSILON);
184    }
185
186    #[test]
187    fn blast_precision_empty_predicted() {
188        let predicted: Vec<String> = vec![];
189        let actual = vec![s("a")];
190        assert!((blast_precision(&predicted, &actual) - 0.0).abs() < f64::EPSILON);
191    }
192
193    #[test]
194    fn blast_recall_perfect() {
195        let predicted = vec![s("a"), s("b")];
196        let actual = vec![s("a"), s("b")];
197        assert!((blast_recall(&predicted, &actual) - 1.0).abs() < f64::EPSILON);
198    }
199
200    #[test]
201    fn blast_recall_empty_actual() {
202        let predicted = vec![s("a")];
203        let actual: Vec<String> = vec![];
204        assert!((blast_recall(&predicted, &actual) - 0.0).abs() < f64::EPSILON);
205    }
206
207    // ── F1 tests ───────────────────────────────────────────────
208
209    #[test]
210    fn f1_balanced() {
211        // precision = recall → F1 = precision
212        assert!((f1(0.75, 0.75) - 0.75).abs() < f64::EPSILON);
213    }
214
215    #[test]
216    fn f1_zero_both() {
217        assert!((f1(0.0, 0.0) - 0.0).abs() < f64::EPSILON);
218    }
219
220    #[test]
221    fn f1_typical() {
222        // precision=0.8, recall=0.6 → F1 = 2*0.8*0.6 / (0.8+0.6) ≈ 0.6857142857…
223        let expected = 2.0 * 0.8 * 0.6 / (0.8 + 0.6);
224        assert!((f1(0.8, 0.6) - expected).abs() < 1e-10);
225    }
226}