1use crate::grouper::SemanticGroup;
2use serde::{Deserialize, Serialize};
3use std::collections::hash_map::DefaultHasher;
4use std::hash::{Hash, Hasher};
5use std::path::PathBuf;
6
7#[derive(Debug, Serialize, Deserialize)]
9struct CacheEntry {
10 diff_hash: u64,
12 groups: Vec<CachedGroup>,
13}
14
15#[derive(Debug, Serialize, Deserialize)]
17struct CachedGroup {
18 label: String,
19 description: String,
20 changes: Vec<CachedChange>,
21}
22
23#[derive(Debug, Serialize, Deserialize)]
25struct CachedChange {
26 file: String,
27 hunks: Vec<usize>,
28}
29
30pub fn diff_hash(raw_diff: &str) -> u64 {
32 let mut hasher = DefaultHasher::new();
33 raw_diff.hash(&mut hasher);
34 hasher.finish()
35}
36
37pub fn load(hash: u64) -> Option<Vec<SemanticGroup>> {
40 let path = cache_path()?;
41
42 let metadata = std::fs::metadata(&path).ok()?;
44 if metadata.len() > 1_048_576 {
45 tracing::warn!("Cache file too large ({} bytes), ignoring", metadata.len());
47 return None;
48 }
49
50 let content = std::fs::read_to_string(&path).ok()?;
51 let entry: CacheEntry = serde_json::from_str(&content).ok()?;
52
53 if entry.groups.len() > 50 {
55 tracing::warn!(
56 "Cache has too many groups ({}), ignoring",
57 entry.groups.len()
58 );
59 return None;
60 }
61
62 if entry.diff_hash != hash {
63 tracing::debug!("Cache miss: hash mismatch");
64 return None;
65 }
66
67 tracing::info!("Cache hit: reusing {} groups", entry.groups.len());
68 Some(
69 entry
70 .groups
71 .into_iter()
72 .map(|g| SemanticGroup::new(
73 g.label,
74 g.description,
75 g.changes
76 .into_iter()
77 .map(|c| crate::grouper::GroupedChange {
78 file: c.file,
79 hunks: c.hunks,
80 })
81 .collect(),
82 ))
83 .collect(),
84 )
85}
86
87pub fn save(hash: u64, groups: &[SemanticGroup]) {
89 let Some(path) = cache_path() else { return };
90
91 let entry = CacheEntry {
92 diff_hash: hash,
93 groups: groups
94 .iter()
95 .map(|g| CachedGroup {
96 label: g.label.clone(),
97 description: g.description.clone(),
98 changes: g
99 .changes()
100 .iter()
101 .map(|c| CachedChange {
102 file: c.file.clone(),
103 hunks: c.hunks.clone(),
104 })
105 .collect(),
106 })
107 .collect(),
108 };
109
110 match serde_json::to_string(&entry) {
111 Ok(json) => {
112 if let Err(e) = std::fs::write(&path, json) {
113 tracing::warn!("Failed to write cache: {}", e);
114 } else {
115 tracing::debug!("Saved cache to {}", path.display());
116 }
117 }
118 Err(e) => tracing::warn!("Failed to serialize cache: {}", e),
119 }
120}
121
122fn cache_path() -> Option<PathBuf> {
125 let output = std::process::Command::new("git")
126 .args(["rev-parse", "--git-dir"])
127 .output()
128 .ok()?;
129 if !output.status.success() {
130 return None;
131 }
132 let git_dir = String::from_utf8(output.stdout).ok()?.trim().to_string();
133 let git_path = PathBuf::from(&git_dir);
134
135 let cwd = std::env::current_dir().ok()?;
138 let canonical_git = std::fs::canonicalize(&git_path).unwrap_or(git_path.clone());
139 let canonical_cwd = std::fs::canonicalize(&cwd).unwrap_or(cwd);
140 if !canonical_git.starts_with(&canonical_cwd) {
141 tracing::warn!(
142 "git-dir {} is outside repo root {}, refusing to use cache",
143 canonical_git.display(),
144 canonical_cwd.display()
145 );
146 return None;
147 }
148
149 Some(PathBuf::from(git_dir).join("semantic-diff-cache.json"))
150}
151
152#[cfg(test)]
153mod tests {
154 use super::*;
155
156 #[test]
157 fn test_diff_hash_deterministic() {
158 let a = diff_hash("hello world");
159 let b = diff_hash("hello world");
160 assert_eq!(a, b);
161 }
162
163 #[test]
164 fn test_diff_hash_changes() {
165 let a = diff_hash("hello");
166 let b = diff_hash("world");
167 assert_ne!(a, b);
168 }
169
170 #[test]
171 fn test_cache_path_validates_git_dir_within_cwd() {
172 let path = cache_path();
175 if let Some(p) = &path {
176 assert!(
177 p.to_string_lossy().contains("semantic-diff-cache.json"),
178 "cache path should contain cache filename, got: {}",
179 p.display()
180 );
181 }
182 }
184
185 #[test]
186 fn test_load_rejects_oversized_cache() {
187 let temp_dir = tempfile::tempdir().unwrap();
189 let cache_file = temp_dir.path().join("oversized-cache.json");
190 let large_content = "x".repeat(1_048_577);
192 std::fs::write(&cache_file, large_content).unwrap();
193 let metadata = std::fs::metadata(&cache_file).unwrap();
194 assert!(
195 metadata.len() > 1_048_576,
196 "Test file should be larger than 1MB"
197 );
198 }
201
202 #[test]
203 fn test_cache_entry_with_valid_groups_deserializes() {
204 let json = r#"{
205 "diff_hash": 12345,
206 "groups": [
207 {"label": "Auth", "description": "Auth changes", "changes": [{"file": "src/auth.rs", "hunks": [0]}]}
208 ]
209 }"#;
210 let entry: CacheEntry = serde_json::from_str(json).unwrap();
211 assert_eq!(entry.groups.len(), 1);
212 assert_eq!(entry.groups[0].label, "Auth");
213 }
214
215 #[test]
216 fn test_cache_entry_group_count_validation() {
217 let mut groups = Vec::new();
219 for i in 0..60 {
220 groups.push(CachedGroup {
221 label: format!("Group {}", i),
222 description: "desc".to_string(),
223 changes: vec![],
224 });
225 }
226 let entry = CacheEntry {
227 diff_hash: 99999,
228 groups,
229 };
230 assert!(entry.groups.len() > 50);
232 }
233}