Skip to main content

the_code_graph_eval/
dataset.rs

1use serde::Deserialize;
2use std::path::{Path, PathBuf};
3
4use domain::error::{CodeGraphError, Result};
5
6// ---------------------------------------------------------------------------
7// Types
8// ---------------------------------------------------------------------------
9
10#[derive(Debug, Deserialize)]
11pub struct SuiteManifest {
12    pub suite: SuiteInfo,
13    pub repos: Vec<ManifestRepo>,
14}
15
16#[derive(Debug, Deserialize)]
17pub struct SuiteInfo {
18    pub name: String,
19    pub description: String,
20}
21
22#[derive(Debug, Clone, Deserialize)]
23pub struct ManifestRepo {
24    pub name: String,
25    pub url: String,
26    pub revision: String,
27    pub languages: Vec<String>,
28}
29
30#[derive(Debug, Deserialize)]
31pub struct SearchQueryFile {
32    pub queries: Vec<SearchQuery>,
33}
34
35#[derive(Debug, Clone, Deserialize)]
36pub struct SearchQuery {
37    pub repo: String,
38    pub query: String,
39    pub expected: Vec<String>,
40    #[serde(default)]
41    pub category: Option<String>,
42}
43
44#[derive(Debug, Deserialize)]
45pub struct ImpactQueryFile {
46    pub scenarios: Vec<ImpactScenario>,
47}
48
49#[derive(Debug, Clone, Deserialize)]
50pub struct ImpactScenario {
51    pub repo: String,
52    pub description: String,
53    pub target: String,
54    pub depth: usize,
55    pub confidence: String,
56    pub expected_affected: Vec<String>,
57}
58
59// ---------------------------------------------------------------------------
60// Cache management
61// ---------------------------------------------------------------------------
62
63pub fn eval_cache_dir() -> Result<PathBuf> {
64    if let Ok(xdg) = std::env::var("XDG_CACHE_HOME") {
65        return Ok(PathBuf::from(xdg).join("code-graph-eval"));
66    }
67    let home = std::env::var("HOME").map_err(|_| CodeGraphError::Other("HOME not set".into()))?;
68    Ok(PathBuf::from(home).join(".cache").join("code-graph-eval"))
69}
70
71pub fn repo_cache_path(repo: &ManifestRepo) -> Result<PathBuf> {
72    Ok(eval_cache_dir()?.join(&repo.name).join(&repo.revision))
73}
74
75pub fn validate_cache(repo: &ManifestRepo) -> Result<bool> {
76    let path = repo_cache_path(repo)?;
77    if !path.exists() {
78        return Ok(false);
79    }
80    let marker = path.join(".revision");
81    if !marker.exists() {
82        return Ok(false);
83    }
84    let stored = std::fs::read_to_string(&marker)
85        .map_err(|e| CodeGraphError::Other(format!("read .revision: {e}")))?;
86    Ok(stored.trim() == repo.revision)
87}
88
89pub fn clone_or_cache(repo: &ManifestRepo, no_cache: bool) -> Result<PathBuf> {
90    let cache_path = repo_cache_path(repo)?;
91
92    if no_cache {
93        if cache_path.exists() {
94            std::fs::remove_dir_all(&cache_path)
95                .map_err(|e| CodeGraphError::Other(format!("remove cache: {e}")))?;
96        }
97    } else if validate_cache(repo)? {
98        tracing::info!(repo = %repo.name, "Using cached clone");
99        return Ok(cache_path);
100    }
101
102    // Validate inputs to prevent git argument injection
103    if repo.revision.starts_with('-') {
104        return Err(CodeGraphError::Other(format!(
105            "invalid revision: '{}' (must not start with '-')",
106            repo.revision
107        )));
108    }
109    if !repo.url.starts_with("https://") && !repo.url.starts_with("http://") {
110        return Err(CodeGraphError::Other(format!(
111            "invalid repo URL: '{}' (must be an HTTP(S) URL)",
112            repo.url
113        )));
114    }
115
116    tracing::info!(repo = %repo.name, revision = %repo.revision, "Cloning");
117    if cache_path.exists() {
118        std::fs::remove_dir_all(&cache_path)
119            .map_err(|e| CodeGraphError::Other(format!("remove stale cache: {e}")))?;
120    }
121    std::fs::create_dir_all(&cache_path)
122        .map_err(|e| CodeGraphError::Other(format!("mkdir: {e}")))?;
123
124    let output = std::process::Command::new("git")
125        .args([
126            "clone",
127            "--depth",
128            "1",
129            "--branch",
130            &repo.revision,
131            &repo.url,
132        ])
133        .arg(&cache_path)
134        .output()
135        .map_err(|e| CodeGraphError::Other(format!("git clone failed: {e}")))?;
136
137    if !output.status.success() {
138        let stderr = String::from_utf8_lossy(&output.stderr);
139        return Err(CodeGraphError::Other(format!("git clone failed: {stderr}")));
140    }
141
142    std::fs::write(cache_path.join(".revision"), &repo.revision)
143        .map_err(|e| CodeGraphError::Other(format!("write .revision: {e}")))?;
144
145    Ok(cache_path)
146}
147
148pub fn clear_cache(repo: &ManifestRepo) -> Result<()> {
149    let path = repo_cache_path(repo)?;
150    if path.exists() {
151        std::fs::remove_dir_all(&path)
152            .map_err(|e| CodeGraphError::Other(format!("clear cache: {e}")))?;
153    }
154    Ok(())
155}
156
157// ---------------------------------------------------------------------------
158// Manifest / query parsing
159// ---------------------------------------------------------------------------
160
161pub fn parse_manifest(path: &Path) -> Result<SuiteManifest> {
162    let content = std::fs::read_to_string(path)
163        .map_err(|e| CodeGraphError::Other(format!("Failed to read manifest: {e}")))?;
164    serde_json::from_str(&content)
165        .map_err(|e| CodeGraphError::Other(format!("Invalid manifest JSON: {e}")))
166}
167
168pub fn parse_search_queries(path: &Path) -> Result<Vec<SearchQuery>> {
169    let content = std::fs::read_to_string(path)
170        .map_err(|e| CodeGraphError::Other(format!("Failed to read queries: {e}")))?;
171    let file: SearchQueryFile = serde_json::from_str(&content)
172        .map_err(|e| CodeGraphError::Other(format!("Invalid query JSON: {e}")))?;
173    Ok(file.queries)
174}
175
176pub fn parse_impact_queries(path: &Path) -> Result<Vec<ImpactScenario>> {
177    let content = std::fs::read_to_string(path)
178        .map_err(|e| CodeGraphError::Other(format!("Failed to read scenarios: {e}")))?;
179    let file: ImpactQueryFile = serde_json::from_str(&content)
180        .map_err(|e| CodeGraphError::Other(format!("Invalid scenario JSON: {e}")))?;
181    Ok(file.scenarios)
182}
183
184// ---------------------------------------------------------------------------
185// Tests
186// ---------------------------------------------------------------------------
187
188#[cfg(test)]
189mod tests {
190    use super::*;
191    use std::io::Write;
192    use std::sync::Mutex;
193
194    /// Tests that mutate `XDG_CACHE_HOME` must hold this lock to avoid races.
195    static ENV_LOCK: Mutex<()> = Mutex::new(());
196
197    fn test_repo() -> ManifestRepo {
198        ManifestRepo {
199            name: "test-repo".into(),
200            url: "https://github.com/example/test-repo.git".into(),
201            revision: "abc123".into(),
202            languages: vec!["rust".into()],
203        }
204    }
205
206    const MANIFEST_JSON: &str = r#"{
207        "suite": {
208            "name": "search-v1",
209            "description": "Search evaluation suite"
210        },
211        "repos": [
212            {
213                "name": "sample-repo",
214                "url": "https://github.com/example/sample.git",
215                "revision": "v1.0.0",
216                "languages": ["rust", "python"]
217            }
218        ]
219    }"#;
220
221    const SEARCH_QUERIES_JSON: &str = r#"{
222        "queries": [
223            {
224                "repo": "sample-repo",
225                "query": "find all error handlers",
226                "expected": ["src/error.rs", "src/handler.rs"]
227            }
228        ]
229    }"#;
230
231    const IMPACT_QUERIES_JSON: &str = r#"{
232        "scenarios": [
233            {
234                "repo": "sample-repo",
235                "description": "Change error type",
236                "target": "src/error.rs::AppError",
237                "depth": 3,
238                "confidence": "high",
239                "expected_affected": ["src/handler.rs", "src/main.rs"]
240            }
241        ]
242    }"#;
243
244    // -- SearchQuery category field ----------------------------------------
245
246    #[test]
247    fn search_query_category_deserialization() {
248        let json =
249            r#"{"repo": "test", "query": "foo", "expected": ["a::b"], "category": "semantic"}"#;
250        let q: SearchQuery = serde_json::from_str(json).unwrap();
251        assert_eq!(q.category.unwrap(), "semantic");
252    }
253
254    #[test]
255    fn search_query_category_optional() {
256        let json = r#"{"repo": "test", "query": "foo", "expected": ["a::b"]}"#;
257        let q: SearchQuery = serde_json::from_str(json).unwrap();
258        assert!(q.category.is_none());
259    }
260
261    // -- Manifest parsing ---------------------------------------------------
262
263    #[test]
264    fn parse_search_manifest() {
265        let dir = tempfile::tempdir().unwrap();
266        let path = dir.path().join("manifest.json");
267        std::fs::write(&path, MANIFEST_JSON).unwrap();
268
269        let manifest = parse_manifest(&path).unwrap();
270        assert_eq!(manifest.suite.name, "search-v1");
271        assert_eq!(manifest.suite.description, "Search evaluation suite");
272        assert_eq!(manifest.repos.len(), 1);
273        assert_eq!(manifest.repos[0].name, "sample-repo");
274        assert_eq!(manifest.repos[0].revision, "v1.0.0");
275        assert_eq!(manifest.repos[0].languages, vec!["rust", "python"]);
276    }
277
278    #[test]
279    fn parse_search_manifest_invalid_json() {
280        let dir = tempfile::tempdir().unwrap();
281        let path = dir.path().join("bad.json");
282        std::fs::write(&path, "{ not valid json }").unwrap();
283
284        let err = parse_manifest(&path).unwrap_err();
285        let msg = format!("{err}");
286        assert!(
287            msg.contains("Invalid manifest JSON"),
288            "expected clear error, got: {msg}"
289        );
290    }
291
292    // -- Query parsing ------------------------------------------------------
293
294    #[test]
295    fn parse_search_queries() {
296        let dir = tempfile::tempdir().unwrap();
297        let path = dir.path().join("search.json");
298        std::fs::write(&path, SEARCH_QUERIES_JSON).unwrap();
299
300        let queries = super::parse_search_queries(&path).unwrap();
301        assert_eq!(queries.len(), 1);
302        assert_eq!(queries[0].repo, "sample-repo");
303        assert_eq!(queries[0].query, "find all error handlers");
304        assert_eq!(queries[0].expected, vec!["src/error.rs", "src/handler.rs"]);
305    }
306
307    #[test]
308    fn parse_impact_queries() {
309        let dir = tempfile::tempdir().unwrap();
310        let path = dir.path().join("impact.json");
311        std::fs::write(&path, IMPACT_QUERIES_JSON).unwrap();
312
313        let scenarios = super::parse_impact_queries(&path).unwrap();
314        assert_eq!(scenarios.len(), 1);
315        assert_eq!(scenarios[0].repo, "sample-repo");
316        assert_eq!(scenarios[0].description, "Change error type");
317        assert_eq!(scenarios[0].target, "src/error.rs::AppError");
318        assert_eq!(scenarios[0].depth, 3);
319        assert_eq!(scenarios[0].confidence, "high");
320        assert_eq!(
321            scenarios[0].expected_affected,
322            vec!["src/handler.rs", "src/main.rs"]
323        );
324    }
325
326    // -- Cache dir resolution -----------------------------------------------
327
328    #[test]
329    fn cache_dir_resolution() {
330        let repo = test_repo();
331        let path = repo_cache_path(&repo).unwrap();
332        // Must end with <name>/<revision>
333        assert!(
334            path.ends_with("test-repo/abc123"),
335            "unexpected cache path: {path:?}"
336        );
337    }
338
339    #[test]
340    fn cache_dir_respects_xdg() {
341        let _guard = ENV_LOCK.lock().unwrap();
342        let dir = tempfile::tempdir().unwrap();
343        let xdg_path = dir.path().to_str().unwrap().to_string();
344
345        unsafe { std::env::set_var("XDG_CACHE_HOME", &xdg_path) };
346        let result = eval_cache_dir().unwrap();
347        unsafe { std::env::remove_var("XDG_CACHE_HOME") };
348
349        assert_eq!(
350            result,
351            PathBuf::from(&xdg_path).join("code-graph-eval"),
352            "XDG_CACHE_HOME should be respected"
353        );
354    }
355
356    // -- Cache validation ---------------------------------------------------
357
358    #[test]
359    fn validate_cache_missing_dir() {
360        let _guard = ENV_LOCK.lock().unwrap();
361        let dir = tempfile::tempdir().unwrap();
362        let fake_home = dir.path().to_str().unwrap().to_string();
363
364        unsafe { std::env::set_var("XDG_CACHE_HOME", &fake_home) };
365        let repo = test_repo();
366        let valid = validate_cache(&repo).unwrap();
367        unsafe { std::env::remove_var("XDG_CACHE_HOME") };
368
369        assert!(!valid, "cache should be invalid when directory is missing");
370    }
371
372    #[test]
373    fn validate_cache_wrong_revision() {
374        let _guard = ENV_LOCK.lock().unwrap();
375        let dir = tempfile::tempdir().unwrap();
376        let cache_root = dir.path().to_str().unwrap().to_string();
377        let repo = test_repo();
378
379        let cache_dir = dir
380            .path()
381            .join("code-graph-eval")
382            .join(&repo.name)
383            .join(&repo.revision);
384        std::fs::create_dir_all(&cache_dir).unwrap();
385        let mut f = std::fs::File::create(cache_dir.join(".revision")).unwrap();
386        f.write_all(b"wrong-revision").unwrap();
387
388        unsafe { std::env::set_var("XDG_CACHE_HOME", &cache_root) };
389        let valid = validate_cache(&repo).unwrap();
390        unsafe { std::env::remove_var("XDG_CACHE_HOME") };
391
392        assert!(
393            !valid,
394            "cache should be invalid when revision doesn't match"
395        );
396    }
397
398    #[test]
399    fn validate_cache_valid() {
400        let _guard = ENV_LOCK.lock().unwrap();
401        let dir = tempfile::tempdir().unwrap();
402        let cache_root = dir.path().to_str().unwrap().to_string();
403        let repo = test_repo();
404
405        let cache_dir = dir
406            .path()
407            .join("code-graph-eval")
408            .join(&repo.name)
409            .join(&repo.revision);
410        std::fs::create_dir_all(&cache_dir).unwrap();
411        std::fs::write(cache_dir.join(".revision"), &repo.revision).unwrap();
412
413        unsafe { std::env::set_var("XDG_CACHE_HOME", &cache_root) };
414        let valid = validate_cache(&repo).unwrap();
415        unsafe { std::env::remove_var("XDG_CACHE_HOME") };
416
417        assert!(
418            valid,
419            "cache should be valid when dir exists and revision matches"
420        );
421    }
422
423    // -- Clear cache --------------------------------------------------------
424
425    #[test]
426    fn clear_cache_removes_dir() {
427        let _guard = ENV_LOCK.lock().unwrap();
428        let dir = tempfile::tempdir().unwrap();
429        let cache_root = dir.path().to_str().unwrap().to_string();
430        let repo = test_repo();
431
432        let cache_dir = dir
433            .path()
434            .join("code-graph-eval")
435            .join(&repo.name)
436            .join(&repo.revision);
437        std::fs::create_dir_all(&cache_dir).unwrap();
438        std::fs::write(cache_dir.join(".revision"), &repo.revision).unwrap();
439        assert!(cache_dir.exists(), "setup: cache dir should exist");
440
441        unsafe { std::env::set_var("XDG_CACHE_HOME", &cache_root) };
442        clear_cache(&repo).unwrap();
443        unsafe { std::env::remove_var("XDG_CACHE_HOME") };
444
445        assert!(
446            !cache_dir.exists(),
447            "cache dir should be removed after clear"
448        );
449    }
450}