Skip to main content

the_code_graph_eval/
runner.rs

1use std::path::Path;
2
3use domain::error::{CodeGraphError, Result};
4use domain::model::{Confidence, HybridSearchConfig, ImpactTarget, SearchMode};
5use domain::ports::GraphStore;
6use domain::use_cases::impact::ImpactUseCase;
7use domain::use_cases::index::IndexUseCase;
8use domain::use_cases::query::QueryUseCase;
9use storage::SqliteStore;
10
11use crate::adapters::{EvalFileSystem, EvalParseProvider, NoOpGitProvider};
12use crate::dataset::{ImpactScenario, SearchQuery};
13use crate::report::{CategoryMrr, ImpactSuiteResult, SearchSuiteResult};
14use crate::{metrics, SuiteConfig};
15
16const MRR_TARGET: f64 = 0.30;
17const BLAST_PRECISION_TARGET: f64 = 0.40;
18
19/// Ranked results paired with ground-truth expectations.
20type RankedVsTruth = (Vec<Vec<String>>, Vec<Vec<String>>);
21
22/// Per-category bucket: (ranked lists, truth lists).
23type CategoryBucket = (Vec<Vec<String>>, Vec<Vec<String>>);
24
25pub fn confidence_from_str(s: &str) -> Result<Confidence> {
26    match s.to_lowercase().as_str() {
27        "high" => Ok(Confidence::High),
28        "medium" => Ok(Confidence::Medium),
29        "low" => Ok(Confidence::Low),
30        "structural" => Ok(Confidence::Structural),
31        _ => Err(CodeGraphError::Other(format!("Unknown confidence: {s}"))),
32    }
33}
34
35/// Validate that all expected qualified names exist in the indexed graph.
36pub fn validate_ground_truth(
37    store: &SqliteStore,
38    expected_qnames: &[String],
39    repo_name: &str,
40) -> Result<Vec<String>> {
41    let mut missing = Vec::new();
42    for qname in expected_qnames {
43        if store.get_symbol(qname)?.is_none() {
44            missing.push(format!(
45                "SETUP_ERROR: '{}' not found in indexed graph for repo '{}'",
46                qname, repo_name
47            ));
48        }
49    }
50    Ok(missing)
51}
52
53/// Index a cloned repo into an isolated temp database.
54pub fn index_repo(clone_path: &Path) -> Result<(SqliteStore, tempfile::TempDir)> {
55    let temp_dir =
56        tempfile::tempdir().map_err(|e| CodeGraphError::Other(format!("tempdir: {e}")))?;
57    let db_path = temp_dir.path().join("eval.db");
58    let store = SqliteStore::open(&db_path)?;
59    let fs = EvalFileSystem;
60    let parser = EvalParseProvider::new();
61    let git = NoOpGitProvider;
62    let use_case = IndexUseCase::new(store.clone(), parser, fs, git);
63    use_case.full_index(clone_path)?;
64    Ok((store, temp_dir))
65}
66
67pub fn run_search_queries(
68    store: &SqliteStore,
69    queries: &[SearchQuery],
70    limit: usize,
71    mode: Option<SearchMode>,
72) -> Result<RankedVsTruth> {
73    let query_uc = QueryUseCase::new(store.clone(), store.clone());
74    let config = HybridSearchConfig::default();
75    let mut all_ranked = Vec::new();
76    let mut all_truth = Vec::new();
77    for q in queries {
78        let results = match mode {
79            Some(m) => query_uc.hybrid_search(&q.query, limit, m, &config)?,
80            None => query_uc.search(&q.query, limit)?,
81        };
82        let ranked: Vec<String> = results.iter().map(|r| r.qualified_name.clone()).collect();
83        all_ranked.push(ranked);
84        all_truth.push(q.expected.clone());
85    }
86    Ok((all_ranked, all_truth))
87}
88
89pub fn run_impact_scenarios(
90    store: &SqliteStore,
91    scenarios: &[ImpactScenario],
92) -> Result<RankedVsTruth> {
93    let impact_uc = ImpactUseCase::new(store.clone());
94    let mut all_predicted = Vec::new();
95    let mut all_actual = Vec::new();
96    for s in scenarios {
97        let target = ImpactTarget::Symbol(s.target.clone());
98        let confidence = confidence_from_str(&s.confidence)?;
99        let report = impact_uc.blast_radius(&[target], s.depth, confidence)?;
100        let predicted: Vec<String> = report
101            .affected
102            .iter()
103            .map(|a| a.qualified_name.clone())
104            .collect();
105        all_predicted.push(predicted);
106        all_actual.push(s.expected_affected.clone());
107    }
108    Ok((all_predicted, all_actual))
109}
110
111pub fn aggregate_impact_metrics(
112    all_predicted: &[Vec<String>],
113    all_actual: &[Vec<String>],
114) -> (f64, f64, f64) {
115    if all_predicted.is_empty() {
116        return (0.0, 0.0, 0.0);
117    }
118    let (total_p, total_r) = all_predicted
119        .iter()
120        .zip(all_actual.iter())
121        .map(|(pred, actual)| {
122            (
123                metrics::blast_precision(pred, actual),
124                metrics::blast_recall(pred, actual),
125            )
126        })
127        .fold((0.0, 0.0), |(sp, sr), (p, r)| (sp + p, sr + r));
128    let n = all_predicted.len() as f64;
129    let avg_p = total_p / n;
130    let avg_r = total_r / n;
131    (avg_p, avg_r, metrics::f1(avg_p, avg_r))
132}
133
134/// Run the full search evaluation suite.
135pub fn run_search_suite(config: &SuiteConfig) -> Result<SearchSuiteResult> {
136    let manifest_path = config.suites_dir.join("search").join("manifest.json");
137    let manifest = crate::dataset::parse_manifest(&manifest_path)?;
138    let queries_dir = config.suites_dir.join("search").join("queries");
139
140    let mut all_ranked = Vec::new();
141    let mut all_truth = Vec::new();
142    // Per-category buckets: category -> (ranked_lists, truth_lists)
143    let mut category_buckets: std::collections::HashMap<String, CategoryBucket> =
144        std::collections::HashMap::new();
145    let mut total_queries = 0;
146    let mut setup_errors = Vec::new();
147
148    for repo in &manifest.repos {
149        tracing::info!(repo = %repo.name, "Processing search eval repo");
150        let clone_path = crate::dataset::clone_or_cache(repo, config.no_cache)?;
151        let (store, _temp_dir) = index_repo(&clone_path)?;
152
153        for lang in &repo.languages {
154            let query_file = queries_dir.join(format!("{lang}.json"));
155            if !query_file.exists() {
156                continue;
157            }
158            let queries = crate::dataset::parse_search_queries(&query_file)?;
159            let repo_queries: Vec<_> = queries.iter().filter(|q| q.repo == repo.name).collect();
160
161            // Validate ground truth
162            let all_expected: Vec<String> = repo_queries
163                .iter()
164                .flat_map(|q| q.expected.iter().cloned())
165                .collect();
166            let missing = validate_ground_truth(&store, &all_expected, &repo.name)?;
167            setup_errors.extend(missing);
168
169            // Run queries
170            let filtered: Vec<SearchQuery> = repo_queries.into_iter().cloned().collect();
171            let (ranked, truth) = run_search_queries(&store, &filtered, config.search_limit, None)?;
172            total_queries += ranked.len();
173
174            // Collect per-category data
175            for (q, r, t) in filtered
176                .iter()
177                .zip(ranked.iter())
178                .zip(truth.iter())
179                .map(|((q, r), t)| (q, r, t))
180            {
181                let cat = q
182                    .category
183                    .clone()
184                    .unwrap_or_else(|| "uncategorized".to_string());
185                let bucket = category_buckets.entry(cat).or_default();
186                bucket.0.push(r.clone());
187                bucket.1.push(t.clone());
188            }
189
190            all_ranked.extend(ranked);
191            all_truth.extend(truth);
192        }
193    }
194
195    if !setup_errors.is_empty() {
196        tracing::warn!(
197            "Ground truth validation issues:\n{}",
198            setup_errors.join("\n")
199        );
200    }
201
202    let mrr = metrics::mrr(&all_ranked, &all_truth);
203    let p5 = metrics::precision_at_k(&all_ranked, &all_truth, 5);
204    let p10 = metrics::precision_at_k(&all_ranked, &all_truth, 10);
205
206    // Build sorted per-category breakdown
207    let mut per_category: Vec<CategoryMrr> = category_buckets
208        .into_iter()
209        .map(|(cat, (ranked, truth))| {
210            let cat_mrr = metrics::mrr(&ranked, &truth);
211            CategoryMrr {
212                queries: ranked.len(),
213                category: cat,
214                mrr: cat_mrr,
215            }
216        })
217        .collect();
218    per_category.sort_by(|a, b| a.category.cmp(&b.category));
219
220    Ok(SearchSuiteResult {
221        repos: manifest.repos.len(),
222        queries: total_queries,
223        mrr,
224        precision_at_5: p5,
225        precision_at_10: p10,
226        mrr_target: MRR_TARGET,
227        mrr_passed: mrr >= MRR_TARGET,
228        per_category,
229    })
230}
231
232/// Run the full impact evaluation suite.
233pub fn run_impact_suite(config: &SuiteConfig) -> Result<ImpactSuiteResult> {
234    let manifest_path = config.suites_dir.join("impact").join("manifest.json");
235    let manifest = crate::dataset::parse_manifest(&manifest_path)?;
236    let queries_dir = config.suites_dir.join("impact").join("queries");
237
238    let mut all_predicted = Vec::new();
239    let mut all_actual = Vec::new();
240    let mut total_scenarios = 0;
241    let mut setup_errors = Vec::new();
242
243    for repo in &manifest.repos {
244        tracing::info!(repo = %repo.name, "Processing impact eval repo");
245        let clone_path = crate::dataset::clone_or_cache(repo, config.no_cache)?;
246        let (store, _temp_dir) = index_repo(&clone_path)?;
247
248        for lang in &repo.languages {
249            let query_file = queries_dir.join(format!("{lang}.json"));
250            if !query_file.exists() {
251                continue;
252            }
253            let scenarios = crate::dataset::parse_impact_queries(&query_file)?;
254            let repo_scenarios: Vec<_> = scenarios.iter().filter(|s| s.repo == repo.name).collect();
255
256            // Validate ground truth
257            let all_expected: Vec<String> = repo_scenarios
258                .iter()
259                .flat_map(|s| {
260                    let mut v = s.expected_affected.clone();
261                    v.push(s.target.clone());
262                    v
263                })
264                .collect();
265            let missing = validate_ground_truth(&store, &all_expected, &repo.name)?;
266            setup_errors.extend(missing);
267
268            let filtered: Vec<ImpactScenario> = repo_scenarios.into_iter().cloned().collect();
269            let (predicted, actual) = run_impact_scenarios(&store, &filtered)?;
270            total_scenarios += predicted.len();
271            all_predicted.extend(predicted);
272            all_actual.extend(actual);
273        }
274    }
275
276    if !setup_errors.is_empty() {
277        tracing::warn!(
278            "Ground truth validation issues:\n{}",
279            setup_errors.join("\n")
280        );
281    }
282
283    let (precision, recall, f1) = aggregate_impact_metrics(&all_predicted, &all_actual);
284
285    Ok(ImpactSuiteResult {
286        repos: manifest.repos.len(),
287        scenarios: total_scenarios,
288        precision,
289        recall,
290        f1,
291        precision_target: BLAST_PRECISION_TARGET,
292        precision_passed: precision >= BLAST_PRECISION_TARGET,
293    })
294}
295
296#[cfg(test)]
297mod tests {
298    use super::*;
299
300    #[test]
301    fn confidence_from_string_high() {
302        let c = confidence_from_str("high").unwrap();
303        assert!(matches!(c, Confidence::High));
304    }
305
306    #[test]
307    fn confidence_from_string_medium() {
308        let c = confidence_from_str("medium").unwrap();
309        assert!(matches!(c, Confidence::Medium));
310    }
311
312    #[test]
313    fn confidence_from_string_invalid() {
314        let err = confidence_from_str("unknown");
315        assert!(err.is_err());
316    }
317
318    #[test]
319    fn aggregate_impact_empty() {
320        let (p, r, f) = aggregate_impact_metrics(&[], &[]);
321        assert!((p - 0.0).abs() < f64::EPSILON);
322        assert!((r - 0.0).abs() < f64::EPSILON);
323        assert!((f - 0.0).abs() < f64::EPSILON);
324    }
325
326    #[test]
327    fn aggregate_impact_perfect() {
328        let predicted = vec![vec!["a".into(), "b".into()]];
329        let actual = vec![vec!["a".into(), "b".into()]];
330        let (p, r, f) = aggregate_impact_metrics(&predicted, &actual);
331        assert!((p - 1.0).abs() < f64::EPSILON);
332        assert!((r - 1.0).abs() < f64::EPSILON);
333        assert!((f - 1.0).abs() < f64::EPSILON);
334    }
335
336    #[test]
337    fn validate_ground_truth_empty() {
338        let store = SqliteStore::open_in_memory().unwrap();
339        let missing = validate_ground_truth(&store, &[], "test-repo").unwrap();
340        assert!(missing.is_empty());
341    }
342}