1use sha2::{Digest, Sha256};
2use std::fs;
3use std::path::Path;
4
5use crate::error::Result;
6use crate::ir::SqlcxIR;
7
8pub struct SqlFile {
9 pub path: String,
10 pub content: String,
11}
12
13#[derive(serde::Serialize, serde::Deserialize)]
14struct CacheFile {
15 hash: String,
16 ir: SqlcxIR,
17}
18
19pub fn compute_hash(files: &[SqlFile], parser_name: &str) -> String {
22 let mut sorted: Vec<&SqlFile> = files.iter().collect();
23 sorted.sort_by(|a, b| a.path.cmp(&b.path));
24 let mut hasher = Sha256::new();
25 hasher.update(parser_name.as_bytes());
26 hasher.update(b"\0");
27 for f in &sorted {
28 hasher.update(f.path.as_bytes());
29 hasher.update(b"\0");
30 hasher.update(f.content.as_bytes());
31 hasher.update(b"\0");
32 }
33 format!("{:x}", hasher.finalize())
34}
35
36pub fn write_cache(cache_dir: &Path, ir: &SqlcxIR, hash: &str) -> Result<()> {
38 fs::create_dir_all(cache_dir)?;
39 let data = CacheFile {
40 hash: hash.to_string(),
41 ir: ir.clone(),
42 };
43 let cache_path = cache_dir.join("ir.json");
44 let temp_path = cache_path.with_extension("json.tmp");
45 let json = serde_json::to_string(&data)?;
46 fs::write(&temp_path, &json)?;
47 fs::rename(&temp_path, &cache_path)?;
48 Ok(())
49}
50
51pub fn read_cache(cache_dir: &Path, expected_hash: &str) -> Result<Option<SqlcxIR>> {
53 let cache_path = cache_dir.join("ir.json");
54 if !cache_path.exists() {
55 return Ok(None);
56 }
57 let content = fs::read_to_string(&cache_path)?;
58 let data: CacheFile = match serde_json::from_str(&content) {
59 Ok(d) => d,
60 Err(_) => return Ok(None),
61 };
62 if data.hash != expected_hash {
63 return Ok(None);
64 }
65 Ok(Some(data.ir))
66}
67
68#[cfg(test)]
69mod tests {
70 use super::*;
71 use crate::ir::*;
72
73 fn sample_ir() -> SqlcxIR {
74 SqlcxIR {
75 tables: vec![],
76 queries: vec![],
77 enums: vec![],
78 }
79 }
80
81 #[test]
82 fn compute_hash_deterministic() {
83 let files = vec![
84 SqlFile {
85 path: "a.sql".to_string(),
86 content: "SELECT 1;".to_string(),
87 },
88 SqlFile {
89 path: "b.sql".to_string(),
90 content: "SELECT 2;".to_string(),
91 },
92 ];
93 assert_eq!(
94 compute_hash(&files, "postgres"),
95 compute_hash(&files, "postgres")
96 );
97 }
98
99 #[test]
100 fn compute_hash_order_independent() {
101 let a = vec![
102 SqlFile {
103 path: "b.sql".to_string(),
104 content: "SELECT 2;".to_string(),
105 },
106 SqlFile {
107 path: "a.sql".to_string(),
108 content: "SELECT 1;".to_string(),
109 },
110 ];
111 let b = vec![
112 SqlFile {
113 path: "a.sql".to_string(),
114 content: "SELECT 1;".to_string(),
115 },
116 SqlFile {
117 path: "b.sql".to_string(),
118 content: "SELECT 2;".to_string(),
119 },
120 ];
121 assert_eq!(compute_hash(&a, "postgres"), compute_hash(&b, "postgres"));
122 }
123
124 #[test]
125 fn compute_hash_changes_with_parser_name() {
126 let files = vec![SqlFile {
127 path: "a.sql".to_string(),
128 content: "SELECT 1;".to_string(),
129 }];
130 assert_ne!(
131 compute_hash(&files, "postgres"),
132 compute_hash(&files, "mysql")
133 );
134 }
135
136 #[test]
137 fn cache_round_trip() {
138 let dir = tempfile::tempdir().unwrap();
139 let cache_dir = dir.path().join(".sqlcx");
140 let ir = sample_ir();
141 write_cache(&cache_dir, &ir, "abc123").unwrap();
142 let loaded = read_cache(&cache_dir, "abc123").unwrap();
143 assert!(loaded.is_some());
144 }
145
146 #[test]
147 fn cache_miss_on_hash_mismatch() {
148 let dir = tempfile::tempdir().unwrap();
149 let cache_dir = dir.path().join(".sqlcx");
150 write_cache(&cache_dir, &sample_ir(), "v1").unwrap();
151 assert!(read_cache(&cache_dir, "v2").unwrap().is_none());
152 }
153
154 #[test]
155 fn cache_miss_on_no_file() {
156 let dir = tempfile::tempdir().unwrap();
157 assert!(
158 read_cache(&dir.path().join(".sqlcx"), "any")
159 .unwrap()
160 .is_none()
161 );
162 }
163}