1use std::path::Path;
2
3use anyhow::Result;
4
5use crate::codegen::GeneratedFile;
6
7const HEADER: &str = "// Auto-generated by sqlx-gen. Do not edit.";
8
9pub fn write_files(
10 files: &[GeneratedFile],
11 output_dir: &Path,
12 single_file: bool,
13 dry_run: bool,
14) -> Result<()> {
15 if dry_run {
16 for f in files {
17 println!("{}", build_file_content(f));
18 println!();
19 }
20 return Ok(());
21 }
22
23 std::fs::create_dir_all(output_dir)?;
24
25 if single_file {
26 write_single_file(files, output_dir)?;
27 } else {
28 write_multi_files(files, output_dir)?;
29 }
30
31 Ok(())
32}
33
34fn build_file_content(f: &GeneratedFile) -> String {
35 let mut content = String::new();
36 content.push_str(HEADER);
37 content.push('\n');
38 if let Some(origin) = &f.origin {
39 content.push_str(&format!("// {}\n", origin));
40 }
41 content.push('\n');
42 content.push_str(&f.code);
43 content
44}
45
46fn write_single_file(files: &[GeneratedFile], output_dir: &Path) -> Result<()> {
47 let mut content = String::from(HEADER);
48 content.push_str("\n\n");
49
50 for f in files {
51 if let Some(origin) = &f.origin {
52 content.push_str(&format!("// --- {} ---\n\n", origin));
53 }
54 content.push_str(&f.code);
55 content.push('\n');
56 }
57
58 let path = output_dir.join("models.rs");
59 std::fs::write(&path, &content)?;
60 eprintln!("Wrote {}", path.display());
61
62 Ok(())
63}
64
65fn write_multi_files(files: &[GeneratedFile], output_dir: &Path) -> Result<()> {
66 let mut mod_entries = Vec::new();
67
68 for f in files {
69 let content = build_file_content(f);
70 let path = output_dir.join(&f.filename);
71 std::fs::write(&path, &content)?;
72 eprintln!("Wrote {}", path.display());
73
74 let mod_name = f.filename.strip_suffix(".rs").unwrap_or(&f.filename);
75 mod_entries.push(mod_name.to_string());
76 }
77
78 let mut mod_content = String::from(HEADER);
80 mod_content.push_str("\n\n");
81 for m in &mod_entries {
82 mod_content.push_str(&format!("pub mod {};\n", m));
83 }
84
85 let mod_path = output_dir.join("mod.rs");
86 std::fs::write(&mod_path, &mod_content)?;
87 eprintln!("Wrote {}", mod_path.display());
88
89 Ok(())
90}
91
92#[cfg(test)]
93mod tests {
94 use super::*;
95 use crate::codegen::GeneratedFile;
96
97 fn make_file(filename: &str, code: &str, origin: Option<&str>) -> GeneratedFile {
98 GeneratedFile {
99 filename: filename.to_string(),
100 origin: origin.map(|s| s.to_string()),
101 code: code.to_string(),
102 }
103 }
104
105 #[test]
108 fn test_build_content_with_origin() {
109 let f = make_file("users.rs", "pub struct Users {}", Some("Table: public.users"));
110 let content = build_file_content(&f);
111 assert!(content.contains(HEADER));
112 assert!(content.contains("// Table: public.users"));
113 assert!(content.contains("pub struct Users {}"));
114 }
115
116 #[test]
117 fn test_build_content_without_origin() {
118 let f = make_file("types.rs", "pub enum Status {}", None);
119 let content = build_file_content(&f);
120 assert!(content.contains(HEADER));
121 assert!(!content.contains("// Table:"));
122 assert!(content.contains("pub enum Status {}"));
123 }
124
125 #[test]
126 fn test_build_content_header_value() {
127 let f = make_file("x.rs", "", None);
128 let content = build_file_content(&f);
129 assert!(content.starts_with("// Auto-generated by sqlx-gen. Do not edit."));
130 }
131
132 #[test]
133 fn test_build_content_preserves_code() {
134 let code = "use chrono::NaiveDateTime;\n\npub struct Foo {\n pub x: i32,\n}";
135 let f = make_file("foo.rs", code, None);
136 let content = build_file_content(&f);
137 assert!(content.contains(code));
138 }
139
140 #[test]
141 fn test_build_content_origin_format() {
142 let f = make_file("x.rs", "code", Some("Table: public.users"));
143 let content = build_file_content(&f);
144 assert!(content.contains("// Table: public.users\n"));
145 }
146
147 #[test]
148 fn test_build_content_empty_code() {
149 let f = make_file("x.rs", "", Some("Table: public.x"));
150 let content = build_file_content(&f);
151 assert!(content.contains(HEADER));
152 assert!(content.contains("// Table: public.x"));
153 }
154
155 #[test]
158 fn test_dry_run_returns_ok() {
159 let files = vec![make_file("users.rs", "code", Some("origin"))];
160 let dir = tempfile::tempdir().unwrap();
161 let result = write_files(&files, dir.path(), false, true);
162 assert!(result.is_ok());
163 }
164
165 #[test]
166 fn test_dry_run_no_files_created() {
167 let files = vec![make_file("users.rs", "code", Some("origin"))];
168 let dir = tempfile::tempdir().unwrap();
169 let sub = dir.path().join("output");
170 let _ = write_files(&files, &sub, false, true);
171 assert!(!sub.exists());
173 }
174
175 #[test]
176 fn test_dry_run_empty_files() {
177 let dir = tempfile::tempdir().unwrap();
178 let result = write_files(&[], dir.path(), false, true);
179 assert!(result.is_ok());
180 }
181
182 #[test]
185 fn test_multi_creates_files_and_mod() {
186 let files = vec![
187 make_file("users.rs", "pub struct Users {}", Some("Table: public.users")),
188 make_file("posts.rs", "pub struct Posts {}", Some("Table: public.posts")),
189 ];
190 let dir = tempfile::tempdir().unwrap();
191 write_files(&files, dir.path(), false, false).unwrap();
192
193 assert!(dir.path().join("users.rs").exists());
194 assert!(dir.path().join("posts.rs").exists());
195 assert!(dir.path().join("mod.rs").exists());
196 }
197
198 #[test]
199 fn test_multi_mod_rs_content() {
200 let files = vec![
201 make_file("users.rs", "code", Some("origin")),
202 make_file("types.rs", "code", None),
203 ];
204 let dir = tempfile::tempdir().unwrap();
205 write_files(&files, dir.path(), false, false).unwrap();
206
207 let mod_content = std::fs::read_to_string(dir.path().join("mod.rs")).unwrap();
208 assert!(mod_content.contains("pub mod users;"));
209 assert!(mod_content.contains("pub mod types;"));
210 }
211
212 #[test]
213 fn test_multi_file_has_header() {
214 let files = vec![make_file("users.rs", "code", Some("Table: public.users"))];
215 let dir = tempfile::tempdir().unwrap();
216 write_files(&files, dir.path(), false, false).unwrap();
217
218 let content = std::fs::read_to_string(dir.path().join("users.rs")).unwrap();
219 assert!(content.starts_with(HEADER));
220 }
221
222 #[test]
223 fn test_multi_file_has_origin() {
224 let files = vec![make_file("users.rs", "code", Some("Table: public.users"))];
225 let dir = tempfile::tempdir().unwrap();
226 write_files(&files, dir.path(), false, false).unwrap();
227
228 let content = std::fs::read_to_string(dir.path().join("users.rs")).unwrap();
229 assert!(content.contains("// Table: public.users"));
230 }
231
232 #[test]
233 fn test_multi_creates_output_dir() {
234 let dir = tempfile::tempdir().unwrap();
235 let sub = dir.path().join("nested").join("output");
236 let files = vec![make_file("users.rs", "code", Some("origin"))];
237 write_files(&files, &sub, false, false).unwrap();
238 assert!(sub.join("users.rs").exists());
239 }
240
241 #[test]
242 fn test_multi_file_no_origin() {
243 let files = vec![make_file("types.rs", "code", None)];
244 let dir = tempfile::tempdir().unwrap();
245 write_files(&files, dir.path(), false, false).unwrap();
246
247 let content = std::fs::read_to_string(dir.path().join("types.rs")).unwrap();
248 assert!(content.contains(HEADER));
249 assert!(!content.contains("// Table:"));
251 }
252
253 #[test]
256 fn test_single_creates_models_rs() {
257 let files = vec![make_file("users.rs", "pub struct Users {}", Some("Table: public.users"))];
258 let dir = tempfile::tempdir().unwrap();
259 write_files(&files, dir.path(), true, false).unwrap();
260 assert!(dir.path().join("models.rs").exists());
261 }
262
263 #[test]
264 fn test_single_starts_with_header() {
265 let files = vec![make_file("users.rs", "code", Some("origin"))];
266 let dir = tempfile::tempdir().unwrap();
267 write_files(&files, dir.path(), true, false).unwrap();
268
269 let content = std::fs::read_to_string(dir.path().join("models.rs")).unwrap();
270 assert!(content.starts_with(HEADER));
271 }
272
273 #[test]
274 fn test_single_has_section_separator() {
275 let files = vec![make_file("users.rs", "code", Some("Table: public.users"))];
276 let dir = tempfile::tempdir().unwrap();
277 write_files(&files, dir.path(), true, false).unwrap();
278
279 let content = std::fs::read_to_string(dir.path().join("models.rs")).unwrap();
280 assert!(content.contains("// --- Table: public.users ---"));
281 }
282
283 #[test]
284 fn test_single_concatenates_all_code() {
285 let files = vec![
286 make_file("users.rs", "struct Users;", Some("Table: public.users")),
287 make_file("posts.rs", "struct Posts;", Some("Table: public.posts")),
288 ];
289 let dir = tempfile::tempdir().unwrap();
290 write_files(&files, dir.path(), true, false).unwrap();
291
292 let content = std::fs::read_to_string(dir.path().join("models.rs")).unwrap();
293 assert!(content.contains("struct Users;"));
294 assert!(content.contains("struct Posts;"));
295 }
296
297 #[test]
298 fn test_single_no_origin_no_separator() {
299 let files = vec![make_file("types.rs", "code", None)];
300 let dir = tempfile::tempdir().unwrap();
301 write_files(&files, dir.path(), true, false).unwrap();
302
303 let content = std::fs::read_to_string(dir.path().join("models.rs")).unwrap();
304 assert!(!content.contains("// ---"));
305 }
306}