Skip to main content

sqlx_gen/
writer.rs

1use std::io::Write;
2use std::path::Path;
3
4use crate::error::Result;
5
6use crate::codegen::GeneratedFile;
7
8const COMMENT: &str = "// Auto-generated by sqlx-gen. Do not edit.";
9const INNER_ATTR: &str = "#![allow(unused_attributes)]";
10
11/// Write `content` to `path` atomically: stream to a sibling temp file then rename.
12/// Avoids leaving partially-written files on Ctrl-C or disk-full errors.
13pub(crate) fn write_atomic(path: &Path, content: &[u8]) -> Result<()> {
14    let parent = path.parent().ok_or_else(|| {
15        crate::error::Error::Config(format!(
16            "Cannot determine parent directory of {}",
17            path.display()
18        ))
19    })?;
20    let mut tmp = tempfile::NamedTempFile::new_in(parent)?;
21    tmp.write_all(content)?;
22    tmp.flush()?;
23    tmp.persist(path).map_err(|e| e.error)?;
24    Ok(())
25}
26
27pub fn write_files(
28    files: &[GeneratedFile],
29    output_dir: &Path,
30    single_file: bool,
31    dry_run: bool,
32) -> Result<()> {
33    for f in files {
34        validate_safe_filename(&f.filename)?;
35    }
36
37    if dry_run {
38        for f in files {
39            println!("{}", build_file_content(f));
40            println!();
41        }
42        return Ok(());
43    }
44
45    std::fs::create_dir_all(output_dir)?;
46
47    if single_file {
48        write_single_file(files, output_dir)?;
49    } else {
50        write_multi_files(files, output_dir)?;
51    }
52
53    Ok(())
54}
55
56/// Reject filenames that could escape the output directory (`..`, path
57/// separators, absolute paths) or that aren't `.rs` files. Defends against
58/// malicious DB metadata in the rare case introspected table names flow into
59/// the file name.
60fn validate_safe_filename(filename: &str) -> Result<()> {
61    let p = Path::new(filename);
62    if filename.is_empty()
63        || p.components().count() != 1
64        || p.is_absolute()
65        || filename.contains("..")
66        || filename.contains('/')
67        || filename.contains('\\')
68        || !filename.ends_with(".rs")
69    {
70        return Err(crate::error::Error::Config(format!(
71            "Refusing to write generated file with unsafe name: {:?}",
72            filename
73        )));
74    }
75    Ok(())
76}
77
78fn build_file_content(f: &GeneratedFile) -> String {
79    let mut content = String::new();
80    content.push_str(COMMENT);
81    content.push('\n');
82    if let Some(origin) = &f.origin {
83        content.push_str(&format!("// {}\n", origin));
84    }
85    content.push('\n');
86    content.push_str(INNER_ATTR);
87    content.push_str("\n\n");
88    content.push_str(&f.code);
89    content
90}
91
92fn write_single_file(files: &[GeneratedFile], output_dir: &Path) -> Result<()> {
93    let mut content = format!("{}\n\n{}\n\n", COMMENT, INNER_ATTR);
94
95    for f in files {
96        if let Some(origin) = &f.origin {
97            content.push_str(&format!("// --- {} ---\n\n", origin));
98        }
99        content.push_str(&f.code);
100        content.push('\n');
101    }
102
103    let path = output_dir.join("models.rs");
104    write_atomic(&path, content.as_bytes())?;
105    log::info!("Wrote {}", path.display());
106
107    Ok(())
108}
109
110fn write_multi_files(files: &[GeneratedFile], output_dir: &Path) -> Result<()> {
111    let mut mod_entries = Vec::new();
112
113    for f in files {
114        let content = build_file_content(f);
115        let path = output_dir.join(&f.filename);
116        write_atomic(&path, content.as_bytes())?;
117        log::info!("Wrote {}", path.display());
118
119        let mod_name = f.filename.strip_suffix(".rs").unwrap_or(&f.filename);
120        mod_entries.push(mod_name.to_string());
121    }
122
123    // Generate mod.rs
124    let mut mod_content = format!("{}\n\n{}\n\n", COMMENT, INNER_ATTR);
125    for m in &mod_entries {
126        mod_content.push_str(&format!("pub mod {};\n", m));
127    }
128
129    let mod_path = output_dir.join("mod.rs");
130    write_atomic(&mod_path, mod_content.as_bytes())?;
131    log::info!("Wrote {}", mod_path.display());
132
133    Ok(())
134}
135
136#[cfg(test)]
137mod tests {
138    use super::*;
139    use crate::codegen::GeneratedFile;
140
141    fn make_file(filename: &str, code: &str, origin: Option<&str>) -> GeneratedFile {
142        GeneratedFile {
143            filename: filename.to_string(),
144            origin: origin.map(|s| s.to_string()),
145            code: code.to_string(),
146        }
147    }
148
149    // ========== build_file_content ==========
150
151    #[test]
152    fn test_build_content_with_origin() {
153        let f = make_file(
154            "users.rs",
155            "pub struct Users {}",
156            Some("Table: public.users"),
157        );
158        let content = build_file_content(&f);
159        assert!(content.contains(COMMENT));
160        assert!(content.contains(INNER_ATTR));
161        assert!(content.contains("// Table: public.users"));
162        assert!(content.contains("pub struct Users {}"));
163    }
164
165    #[test]
166    fn test_build_content_without_origin() {
167        let f = make_file("types.rs", "pub enum Status {}", None);
168        let content = build_file_content(&f);
169        assert!(content.contains(COMMENT));
170        assert!(content.contains(INNER_ATTR));
171        assert!(!content.contains("// Table:"));
172        assert!(content.contains("pub enum Status {}"));
173    }
174
175    #[test]
176    fn test_build_content_header_value() {
177        let f = make_file("x.rs", "", None);
178        let content = build_file_content(&f);
179        assert!(content.starts_with("// Auto-generated by sqlx-gen. Do not edit."));
180    }
181
182    #[test]
183    fn test_build_content_preserves_code() {
184        let code = "use chrono::NaiveDateTime;\n\npub struct Foo {\n    pub x: i32,\n}";
185        let f = make_file("foo.rs", code, None);
186        let content = build_file_content(&f);
187        assert!(content.contains(code));
188    }
189
190    #[test]
191    fn test_build_content_origin_format() {
192        let f = make_file("x.rs", "code", Some("Table: public.users"));
193        let content = build_file_content(&f);
194        assert!(content.contains("// Table: public.users\n"));
195    }
196
197    #[test]
198    fn test_build_content_empty_code() {
199        let f = make_file("x.rs", "", Some("Table: public.x"));
200        let content = build_file_content(&f);
201        assert!(content.contains(COMMENT));
202        assert!(content.contains(INNER_ATTR));
203        assert!(content.contains("// Table: public.x"));
204    }
205
206    // ========== write_files dry_run ==========
207
208    #[test]
209    fn test_dry_run_returns_ok() {
210        let files = vec![make_file("users.rs", "code", Some("origin"))];
211        let dir = tempfile::tempdir().unwrap();
212        let result = write_files(&files, dir.path(), false, true);
213        assert!(result.is_ok());
214    }
215
216    #[test]
217    fn test_dry_run_no_files_created() {
218        let files = vec![make_file("users.rs", "code", Some("origin"))];
219        let dir = tempfile::tempdir().unwrap();
220        let sub = dir.path().join("output");
221        let _ = write_files(&files, &sub, false, true);
222        // Output dir should NOT be created in dry_run mode
223        assert!(!sub.exists());
224    }
225
226    #[test]
227    fn test_dry_run_empty_files() {
228        let dir = tempfile::tempdir().unwrap();
229        let result = write_files(&[], dir.path(), false, true);
230        assert!(result.is_ok());
231    }
232
233    // ========== write_multi_files ==========
234
235    #[test]
236    fn test_multi_creates_files_and_mod() {
237        let files = vec![
238            make_file(
239                "users.rs",
240                "pub struct Users {}",
241                Some("Table: public.users"),
242            ),
243            make_file(
244                "posts.rs",
245                "pub struct Posts {}",
246                Some("Table: public.posts"),
247            ),
248        ];
249        let dir = tempfile::tempdir().unwrap();
250        write_files(&files, dir.path(), false, false).unwrap();
251
252        assert!(dir.path().join("users.rs").exists());
253        assert!(dir.path().join("posts.rs").exists());
254        assert!(dir.path().join("mod.rs").exists());
255    }
256
257    #[test]
258    fn test_multi_mod_rs_content() {
259        let files = vec![
260            make_file("users.rs", "code", Some("origin")),
261            make_file("types.rs", "code", None),
262        ];
263        let dir = tempfile::tempdir().unwrap();
264        write_files(&files, dir.path(), false, false).unwrap();
265
266        let mod_content = std::fs::read_to_string(dir.path().join("mod.rs")).unwrap();
267        assert!(mod_content.contains("pub mod users;"));
268        assert!(mod_content.contains("pub mod types;"));
269    }
270
271    #[test]
272    fn test_multi_file_has_header() {
273        let files = vec![make_file("users.rs", "code", Some("Table: public.users"))];
274        let dir = tempfile::tempdir().unwrap();
275        write_files(&files, dir.path(), false, false).unwrap();
276
277        let content = std::fs::read_to_string(dir.path().join("users.rs")).unwrap();
278        assert!(content.starts_with(COMMENT));
279    }
280
281    #[test]
282    fn test_multi_file_has_origin() {
283        let files = vec![make_file("users.rs", "code", Some("Table: public.users"))];
284        let dir = tempfile::tempdir().unwrap();
285        write_files(&files, dir.path(), false, false).unwrap();
286
287        let content = std::fs::read_to_string(dir.path().join("users.rs")).unwrap();
288        assert!(content.contains("// Table: public.users"));
289    }
290
291    #[test]
292    fn test_multi_creates_output_dir() {
293        let dir = tempfile::tempdir().unwrap();
294        let sub = dir.path().join("nested").join("output");
295        let files = vec![make_file("users.rs", "code", Some("origin"))];
296        write_files(&files, &sub, false, false).unwrap();
297        assert!(sub.join("users.rs").exists());
298    }
299
300    #[test]
301    fn test_multi_file_no_origin() {
302        let files = vec![make_file("types.rs", "code", None)];
303        let dir = tempfile::tempdir().unwrap();
304        write_files(&files, dir.path(), false, false).unwrap();
305
306        let content = std::fs::read_to_string(dir.path().join("types.rs")).unwrap();
307        assert!(content.contains(COMMENT));
308        assert!(content.contains(INNER_ATTR));
309        // No origin line
310        assert!(!content.contains("// Table:"));
311    }
312
313    // ========== write_single_file ==========
314
315    #[test]
316    fn test_single_creates_models_rs() {
317        let files = vec![make_file(
318            "users.rs",
319            "pub struct Users {}",
320            Some("Table: public.users"),
321        )];
322        let dir = tempfile::tempdir().unwrap();
323        write_files(&files, dir.path(), true, false).unwrap();
324        assert!(dir.path().join("models.rs").exists());
325    }
326
327    #[test]
328    fn test_single_starts_with_header() {
329        let files = vec![make_file("users.rs", "code", Some("origin"))];
330        let dir = tempfile::tempdir().unwrap();
331        write_files(&files, dir.path(), true, false).unwrap();
332
333        let content = std::fs::read_to_string(dir.path().join("models.rs")).unwrap();
334        assert!(content.starts_with(COMMENT));
335    }
336
337    #[test]
338    fn test_single_has_section_separator() {
339        let files = vec![make_file("users.rs", "code", Some("Table: public.users"))];
340        let dir = tempfile::tempdir().unwrap();
341        write_files(&files, dir.path(), true, false).unwrap();
342
343        let content = std::fs::read_to_string(dir.path().join("models.rs")).unwrap();
344        assert!(content.contains("// --- Table: public.users ---"));
345    }
346
347    #[test]
348    fn test_single_concatenates_all_code() {
349        let files = vec![
350            make_file("users.rs", "struct Users;", Some("Table: public.users")),
351            make_file("posts.rs", "struct Posts;", Some("Table: public.posts")),
352        ];
353        let dir = tempfile::tempdir().unwrap();
354        write_files(&files, dir.path(), true, false).unwrap();
355
356        let content = std::fs::read_to_string(dir.path().join("models.rs")).unwrap();
357        assert!(content.contains("struct Users;"));
358        assert!(content.contains("struct Posts;"));
359    }
360
361    #[test]
362    fn test_single_no_origin_no_separator() {
363        let files = vec![make_file("types.rs", "code", None)];
364        let dir = tempfile::tempdir().unwrap();
365        write_files(&files, dir.path(), true, false).unwrap();
366
367        let content = std::fs::read_to_string(dir.path().join("models.rs")).unwrap();
368        assert!(!content.contains("// ---"));
369    }
370
371    // ========== write_atomic ==========
372
373    #[test]
374    fn test_atomic_creates_file_with_content() {
375        let dir = tempfile::tempdir().unwrap();
376        let path = dir.path().join("out.rs");
377        write_atomic(&path, b"hello").unwrap();
378        assert_eq!(std::fs::read_to_string(&path).unwrap(), "hello");
379    }
380
381    #[test]
382    fn test_atomic_overwrites_existing_file() {
383        let dir = tempfile::tempdir().unwrap();
384        let path = dir.path().join("out.rs");
385        std::fs::write(&path, "old").unwrap();
386        write_atomic(&path, b"new").unwrap();
387        assert_eq!(std::fs::read_to_string(&path).unwrap(), "new");
388    }
389
390    #[test]
391    fn test_atomic_leaves_no_temp_artifacts_on_success() {
392        let dir = tempfile::tempdir().unwrap();
393        let path = dir.path().join("out.rs");
394        write_atomic(&path, b"x").unwrap();
395        let entries: Vec<_> = std::fs::read_dir(dir.path())
396            .unwrap()
397            .map(|e| e.unwrap().file_name())
398            .collect();
399        assert_eq!(entries.len(), 1);
400        assert_eq!(entries[0].to_string_lossy(), "out.rs");
401    }
402
403    // ========== validate_safe_filename ==========
404
405    #[test]
406    fn test_rejects_dot_dot_in_filename() {
407        let files = vec![make_file("../escape.rs", "code", None)];
408        let dir = tempfile::tempdir().unwrap();
409        assert!(write_files(&files, dir.path(), false, false).is_err());
410    }
411
412    #[test]
413    fn test_rejects_absolute_path_filename() {
414        let files = vec![make_file("/etc/passwd", "code", None)];
415        let dir = tempfile::tempdir().unwrap();
416        assert!(write_files(&files, dir.path(), false, false).is_err());
417    }
418
419    #[test]
420    fn test_rejects_path_separator_in_filename() {
421        let files = vec![make_file("sub/dir/file.rs", "code", None)];
422        let dir = tempfile::tempdir().unwrap();
423        assert!(write_files(&files, dir.path(), false, false).is_err());
424    }
425
426    #[test]
427    fn test_rejects_non_rs_extension() {
428        let files = vec![make_file("evil.sh", "code", None)];
429        let dir = tempfile::tempdir().unwrap();
430        assert!(write_files(&files, dir.path(), false, false).is_err());
431    }
432
433    #[test]
434    fn test_rejects_empty_filename() {
435        let files = vec![make_file("", "code", None)];
436        let dir = tempfile::tempdir().unwrap();
437        assert!(write_files(&files, dir.path(), false, false).is_err());
438    }
439
440    #[test]
441    fn test_accepts_normal_rs_filename() {
442        let files = vec![make_file("users.rs", "code", None)];
443        let dir = tempfile::tempdir().unwrap();
444        assert!(write_files(&files, dir.path(), false, false).is_ok());
445    }
446}