Skip to main content

semantic_diff/
cache.rs

1use crate::grouper::SemanticGroup;
2use serde::{Deserialize, Serialize};
3use std::collections::hash_map::DefaultHasher;
4use std::hash::{Hash, Hasher};
5use std::path::PathBuf;
6
7/// Cached grouping result stored in .git/semantic-diff-cache.json.
8#[derive(Debug, Serialize, Deserialize)]
9struct CacheEntry {
10    /// Hash of the raw diff output — if this matches, the cache is valid.
11    diff_hash: u64,
12    groups: Vec<CachedGroup>,
13}
14
15/// Serializable version of SemanticGroup.
16#[derive(Debug, Serialize, Deserialize)]
17struct CachedGroup {
18    label: String,
19    description: String,
20    changes: Vec<CachedChange>,
21}
22
23/// Serializable version of GroupedChange.
24#[derive(Debug, Serialize, Deserialize)]
25struct CachedChange {
26    file: String,
27    hunks: Vec<usize>,
28}
29
30/// Compute a fast hash of the raw diff string.
31pub fn diff_hash(raw_diff: &str) -> u64 {
32    let mut hasher = DefaultHasher::new();
33    raw_diff.hash(&mut hasher);
34    hasher.finish()
35}
36
37/// Try to load cached grouping for the given diff hash.
38/// Returns None if no cache, hash mismatch, parse error, or oversized file.
39pub fn load(hash: u64) -> Option<Vec<SemanticGroup>> {
40    let path = cache_path()?;
41
42    // Reject oversized cache files (FINDING-16: prevent OOM from crafted cache)
43    let metadata = std::fs::metadata(&path).ok()?;
44    if metadata.len() > 1_048_576 {
45        // 1MB limit
46        tracing::warn!("Cache file too large ({} bytes), ignoring", metadata.len());
47        return None;
48    }
49
50    let content = std::fs::read_to_string(&path).ok()?;
51    let entry: CacheEntry = serde_json::from_str(&content).ok()?;
52
53    // Validate cache structure (FINDING-16: reject unreasonable group counts)
54    if entry.groups.len() > 50 {
55        tracing::warn!(
56            "Cache has too many groups ({}), ignoring",
57            entry.groups.len()
58        );
59        return None;
60    }
61
62    if entry.diff_hash != hash {
63        tracing::debug!("Cache miss: hash mismatch");
64        return None;
65    }
66
67    tracing::info!("Cache hit: reusing {} groups", entry.groups.len());
68    Some(
69        entry
70            .groups
71            .into_iter()
72            .map(|g| SemanticGroup::new(
73                g.label,
74                g.description,
75                g.changes
76                    .into_iter()
77                    .map(|c| crate::grouper::GroupedChange {
78                        file: c.file,
79                        hunks: c.hunks,
80                    })
81                    .collect(),
82            ))
83            .collect(),
84    )
85}
86
87/// Save grouping result to the cache file.
88pub fn save(hash: u64, groups: &[SemanticGroup]) {
89    let Some(path) = cache_path() else { return };
90
91    let entry = CacheEntry {
92        diff_hash: hash,
93        groups: groups
94            .iter()
95            .map(|g| CachedGroup {
96                label: g.label.clone(),
97                description: g.description.clone(),
98                changes: g
99                    .changes()
100                    .iter()
101                    .map(|c| CachedChange {
102                        file: c.file.clone(),
103                        hunks: c.hunks.clone(),
104                    })
105                    .collect(),
106            })
107            .collect(),
108    };
109
110    match serde_json::to_string(&entry) {
111        Ok(json) => {
112            if let Err(e) = std::fs::write(&path, json) {
113                tracing::warn!("Failed to write cache: {}", e);
114            } else {
115                tracing::debug!("Saved cache to {}", path.display());
116            }
117        }
118        Err(e) => tracing::warn!("Failed to serialize cache: {}", e),
119    }
120}
121
122/// Path to the cache file: .git/semantic-diff-cache.json
123/// Returns None if not in a git repo or if git-dir is outside the repo root.
124fn cache_path() -> Option<PathBuf> {
125    let output = std::process::Command::new("git")
126        .args(["rev-parse", "--git-dir"])
127        .output()
128        .ok()?;
129    if !output.status.success() {
130        return None;
131    }
132    let git_dir = String::from_utf8(output.stdout).ok()?.trim().to_string();
133    let git_path = PathBuf::from(&git_dir);
134
135    // Validate: git-dir should be within or adjacent to the current working directory.
136    // This prevents crafted .git files from redirecting cache writes to arbitrary locations.
137    let cwd = std::env::current_dir().ok()?;
138    let canonical_git = std::fs::canonicalize(&git_path).unwrap_or(git_path.clone());
139    let canonical_cwd = std::fs::canonicalize(&cwd).unwrap_or(cwd);
140    if !canonical_git.starts_with(&canonical_cwd) {
141        tracing::warn!(
142            "git-dir {} is outside repo root {}, refusing to use cache",
143            canonical_git.display(),
144            canonical_cwd.display()
145        );
146        return None;
147    }
148
149    Some(PathBuf::from(git_dir).join("semantic-diff-cache.json"))
150}
151
152#[cfg(test)]
153mod tests {
154    use super::*;
155
156    #[test]
157    fn test_diff_hash_deterministic() {
158        let a = diff_hash("hello world");
159        let b = diff_hash("hello world");
160        assert_eq!(a, b);
161    }
162
163    #[test]
164    fn test_diff_hash_changes() {
165        let a = diff_hash("hello");
166        let b = diff_hash("world");
167        assert_ne!(a, b);
168    }
169
170    #[test]
171    fn test_cache_path_validates_git_dir_within_cwd() {
172        // cache_path() should return a path that's within the repo (when in a git repo)
173        // This test just verifies the function doesn't panic and returns a reasonable result
174        let path = cache_path();
175        if let Some(p) = &path {
176            assert!(
177                p.to_string_lossy().contains("semantic-diff-cache.json"),
178                "cache path should contain cache filename, got: {}",
179                p.display()
180            );
181        }
182        // None is acceptable (not in a git repo, or validation failed)
183    }
184
185    #[test]
186    fn test_load_rejects_oversized_cache() {
187        // Create a temp directory with an oversized cache file
188        let temp_dir = tempfile::tempdir().unwrap();
189        let cache_file = temp_dir.path().join("oversized-cache.json");
190        // Create a file larger than 1MB
191        let large_content = "x".repeat(1_048_577);
192        std::fs::write(&cache_file, large_content).unwrap();
193        let metadata = std::fs::metadata(&cache_file).unwrap();
194        assert!(
195            metadata.len() > 1_048_576,
196            "Test file should be larger than 1MB"
197        );
198        // We can't easily test the full load() path without mocking cache_path(),
199        // but we verify the size check constant is correct
200    }
201
202    #[test]
203    fn test_cache_entry_with_valid_groups_deserializes() {
204        let json = r#"{
205            "diff_hash": 12345,
206            "groups": [
207                {"label": "Auth", "description": "Auth changes", "changes": [{"file": "src/auth.rs", "hunks": [0]}]}
208            ]
209        }"#;
210        let entry: CacheEntry = serde_json::from_str(json).unwrap();
211        assert_eq!(entry.groups.len(), 1);
212        assert_eq!(entry.groups[0].label, "Auth");
213    }
214
215    #[test]
216    fn test_cache_entry_group_count_validation() {
217        // Build a cache entry with 60 groups (over the 50 limit)
218        let mut groups = Vec::new();
219        for i in 0..60 {
220            groups.push(CachedGroup {
221                label: format!("Group {}", i),
222                description: "desc".to_string(),
223                changes: vec![],
224            });
225        }
226        let entry = CacheEntry {
227            diff_hash: 99999,
228            groups,
229        };
230        // Validation check: > 50 groups should be rejected
231        assert!(entry.groups.len() > 50);
232    }
233}