Skip to main content

sqlx_gen/
writer.rs

1use std::path::Path;
2
3use anyhow::Result;
4
5use crate::codegen::GeneratedFile;
6
7const HEADER: &str = "// Auto-generated by sqlx-gen. Do not edit.";
8
9pub fn write_files(
10    files: &[GeneratedFile],
11    output_dir: &Path,
12    single_file: bool,
13    dry_run: bool,
14) -> Result<()> {
15    if dry_run {
16        for f in files {
17            println!("{}", build_file_content(f));
18            println!();
19        }
20        return Ok(());
21    }
22
23    std::fs::create_dir_all(output_dir)?;
24
25    if single_file {
26        write_single_file(files, output_dir)?;
27    } else {
28        write_multi_files(files, output_dir)?;
29    }
30
31    Ok(())
32}
33
34fn build_file_content(f: &GeneratedFile) -> String {
35    let mut content = String::new();
36    content.push_str(HEADER);
37    content.push('\n');
38    if let Some(origin) = &f.origin {
39        content.push_str(&format!("// {}\n", origin));
40    }
41    content.push('\n');
42    content.push_str(&f.code);
43    content
44}
45
46fn write_single_file(files: &[GeneratedFile], output_dir: &Path) -> Result<()> {
47    let mut content = String::from(HEADER);
48    content.push_str("\n\n");
49
50    for f in files {
51        if let Some(origin) = &f.origin {
52            content.push_str(&format!("// --- {} ---\n\n", origin));
53        }
54        content.push_str(&f.code);
55        content.push('\n');
56    }
57
58    let path = output_dir.join("models.rs");
59    std::fs::write(&path, &content)?;
60    eprintln!("Wrote {}", path.display());
61
62    Ok(())
63}
64
65fn write_multi_files(files: &[GeneratedFile], output_dir: &Path) -> Result<()> {
66    let mut mod_entries = Vec::new();
67
68    for f in files {
69        let content = build_file_content(f);
70        let path = output_dir.join(&f.filename);
71        std::fs::write(&path, &content)?;
72        eprintln!("Wrote {}", path.display());
73
74        let mod_name = f.filename.strip_suffix(".rs").unwrap_or(&f.filename);
75        mod_entries.push(mod_name.to_string());
76    }
77
78    // Generate mod.rs
79    let mut mod_content = String::from(HEADER);
80    mod_content.push_str("\n\n");
81    for m in &mod_entries {
82        mod_content.push_str(&format!("pub mod {};\n", m));
83    }
84
85    let mod_path = output_dir.join("mod.rs");
86    std::fs::write(&mod_path, &mod_content)?;
87    eprintln!("Wrote {}", mod_path.display());
88
89    Ok(())
90}
91
92#[cfg(test)]
93mod tests {
94    use super::*;
95    use crate::codegen::GeneratedFile;
96
97    fn make_file(filename: &str, code: &str, origin: Option<&str>) -> GeneratedFile {
98        GeneratedFile {
99            filename: filename.to_string(),
100            origin: origin.map(|s| s.to_string()),
101            code: code.to_string(),
102        }
103    }
104
105    // ========== build_file_content ==========
106
107    #[test]
108    fn test_build_content_with_origin() {
109        let f = make_file("users.rs", "pub struct Users {}", Some("Table: public.users"));
110        let content = build_file_content(&f);
111        assert!(content.contains(HEADER));
112        assert!(content.contains("// Table: public.users"));
113        assert!(content.contains("pub struct Users {}"));
114    }
115
116    #[test]
117    fn test_build_content_without_origin() {
118        let f = make_file("types.rs", "pub enum Status {}", None);
119        let content = build_file_content(&f);
120        assert!(content.contains(HEADER));
121        assert!(!content.contains("// Table:"));
122        assert!(content.contains("pub enum Status {}"));
123    }
124
125    #[test]
126    fn test_build_content_header_value() {
127        let f = make_file("x.rs", "", None);
128        let content = build_file_content(&f);
129        assert!(content.starts_with("// Auto-generated by sqlx-gen. Do not edit."));
130    }
131
132    #[test]
133    fn test_build_content_preserves_code() {
134        let code = "use chrono::NaiveDateTime;\n\npub struct Foo {\n    pub x: i32,\n}";
135        let f = make_file("foo.rs", code, None);
136        let content = build_file_content(&f);
137        assert!(content.contains(code));
138    }
139
140    #[test]
141    fn test_build_content_origin_format() {
142        let f = make_file("x.rs", "code", Some("Table: public.users"));
143        let content = build_file_content(&f);
144        assert!(content.contains("// Table: public.users\n"));
145    }
146
147    #[test]
148    fn test_build_content_empty_code() {
149        let f = make_file("x.rs", "", Some("Table: public.x"));
150        let content = build_file_content(&f);
151        assert!(content.contains(HEADER));
152        assert!(content.contains("// Table: public.x"));
153    }
154
155    // ========== write_files dry_run ==========
156
157    #[test]
158    fn test_dry_run_returns_ok() {
159        let files = vec![make_file("users.rs", "code", Some("origin"))];
160        let dir = tempfile::tempdir().unwrap();
161        let result = write_files(&files, dir.path(), false, true);
162        assert!(result.is_ok());
163    }
164
165    #[test]
166    fn test_dry_run_no_files_created() {
167        let files = vec![make_file("users.rs", "code", Some("origin"))];
168        let dir = tempfile::tempdir().unwrap();
169        let sub = dir.path().join("output");
170        let _ = write_files(&files, &sub, false, true);
171        // Output dir should NOT be created in dry_run mode
172        assert!(!sub.exists());
173    }
174
175    #[test]
176    fn test_dry_run_empty_files() {
177        let dir = tempfile::tempdir().unwrap();
178        let result = write_files(&[], dir.path(), false, true);
179        assert!(result.is_ok());
180    }
181
182    // ========== write_multi_files ==========
183
184    #[test]
185    fn test_multi_creates_files_and_mod() {
186        let files = vec![
187            make_file("users.rs", "pub struct Users {}", Some("Table: public.users")),
188            make_file("posts.rs", "pub struct Posts {}", Some("Table: public.posts")),
189        ];
190        let dir = tempfile::tempdir().unwrap();
191        write_files(&files, dir.path(), false, false).unwrap();
192
193        assert!(dir.path().join("users.rs").exists());
194        assert!(dir.path().join("posts.rs").exists());
195        assert!(dir.path().join("mod.rs").exists());
196    }
197
198    #[test]
199    fn test_multi_mod_rs_content() {
200        let files = vec![
201            make_file("users.rs", "code", Some("origin")),
202            make_file("types.rs", "code", None),
203        ];
204        let dir = tempfile::tempdir().unwrap();
205        write_files(&files, dir.path(), false, false).unwrap();
206
207        let mod_content = std::fs::read_to_string(dir.path().join("mod.rs")).unwrap();
208        assert!(mod_content.contains("pub mod users;"));
209        assert!(mod_content.contains("pub mod types;"));
210    }
211
212    #[test]
213    fn test_multi_file_has_header() {
214        let files = vec![make_file("users.rs", "code", Some("Table: public.users"))];
215        let dir = tempfile::tempdir().unwrap();
216        write_files(&files, dir.path(), false, false).unwrap();
217
218        let content = std::fs::read_to_string(dir.path().join("users.rs")).unwrap();
219        assert!(content.starts_with(HEADER));
220    }
221
222    #[test]
223    fn test_multi_file_has_origin() {
224        let files = vec![make_file("users.rs", "code", Some("Table: public.users"))];
225        let dir = tempfile::tempdir().unwrap();
226        write_files(&files, dir.path(), false, false).unwrap();
227
228        let content = std::fs::read_to_string(dir.path().join("users.rs")).unwrap();
229        assert!(content.contains("// Table: public.users"));
230    }
231
232    #[test]
233    fn test_multi_creates_output_dir() {
234        let dir = tempfile::tempdir().unwrap();
235        let sub = dir.path().join("nested").join("output");
236        let files = vec![make_file("users.rs", "code", Some("origin"))];
237        write_files(&files, &sub, false, false).unwrap();
238        assert!(sub.join("users.rs").exists());
239    }
240
241    #[test]
242    fn test_multi_file_no_origin() {
243        let files = vec![make_file("types.rs", "code", None)];
244        let dir = tempfile::tempdir().unwrap();
245        write_files(&files, dir.path(), false, false).unwrap();
246
247        let content = std::fs::read_to_string(dir.path().join("types.rs")).unwrap();
248        assert!(content.contains(HEADER));
249        // No origin line
250        assert!(!content.contains("// Table:"));
251    }
252
253    // ========== write_single_file ==========
254
255    #[test]
256    fn test_single_creates_models_rs() {
257        let files = vec![make_file("users.rs", "pub struct Users {}", Some("Table: public.users"))];
258        let dir = tempfile::tempdir().unwrap();
259        write_files(&files, dir.path(), true, false).unwrap();
260        assert!(dir.path().join("models.rs").exists());
261    }
262
263    #[test]
264    fn test_single_starts_with_header() {
265        let files = vec![make_file("users.rs", "code", Some("origin"))];
266        let dir = tempfile::tempdir().unwrap();
267        write_files(&files, dir.path(), true, false).unwrap();
268
269        let content = std::fs::read_to_string(dir.path().join("models.rs")).unwrap();
270        assert!(content.starts_with(HEADER));
271    }
272
273    #[test]
274    fn test_single_has_section_separator() {
275        let files = vec![make_file("users.rs", "code", Some("Table: public.users"))];
276        let dir = tempfile::tempdir().unwrap();
277        write_files(&files, dir.path(), true, false).unwrap();
278
279        let content = std::fs::read_to_string(dir.path().join("models.rs")).unwrap();
280        assert!(content.contains("// --- Table: public.users ---"));
281    }
282
283    #[test]
284    fn test_single_concatenates_all_code() {
285        let files = vec![
286            make_file("users.rs", "struct Users;", Some("Table: public.users")),
287            make_file("posts.rs", "struct Posts;", Some("Table: public.posts")),
288        ];
289        let dir = tempfile::tempdir().unwrap();
290        write_files(&files, dir.path(), true, false).unwrap();
291
292        let content = std::fs::read_to_string(dir.path().join("models.rs")).unwrap();
293        assert!(content.contains("struct Users;"));
294        assert!(content.contains("struct Posts;"));
295    }
296
297    #[test]
298    fn test_single_no_origin_no_separator() {
299        let files = vec![make_file("types.rs", "code", None)];
300        let dir = tempfile::tempdir().unwrap();
301        write_files(&files, dir.path(), true, false).unwrap();
302
303        let content = std::fs::read_to_string(dir.path().join("models.rs")).unwrap();
304        assert!(!content.contains("// ---"));
305    }
306}