Skip to main content

typewriter_engine/
scan.rs

1//! Source scanning for `#[derive(TypeWriter)]` items.
2
3use anyhow::{Context, Result};
4use std::path::{Path, PathBuf};
5use syn::{Data, DeriveInput, Item, ItemEnum, ItemStruct, ItemUnion};
6use walkdir::WalkDir;
7
8use crate::{parser, TypeSpec};
9
10pub fn scan_project(project_root: &Path) -> Result<Vec<TypeSpec>> {
11    let mut specs = Vec::new();
12    for file in discover_rust_files(project_root) {
13        specs.extend(scan_file(&file)?);
14    }
15    Ok(specs)
16}
17
18pub fn scan_file(path: &Path) -> Result<Vec<TypeSpec>> {
19    let content = std::fs::read_to_string(path)
20        .with_context(|| format!("failed to read source file {}", path.display()))?;
21    let parsed = syn::parse_file(&content)
22        .with_context(|| format!("failed to parse Rust source {}", path.display()))?;
23
24    let mut specs = Vec::new();
25    collect_items(&parsed.items, path, &mut specs)?;
26    Ok(specs)
27}
28
29pub fn discover_rust_files(project_root: &Path) -> Vec<PathBuf> {
30    WalkDir::new(project_root)
31        .into_iter()
32        .filter_entry(|entry| {
33            let name = entry.file_name().to_string_lossy();
34            !(name == ".git" || name == "target")
35        })
36        .filter_map(|entry| entry.ok())
37        .filter(|entry| {
38            entry.file_type().is_file()
39                && entry
40                    .path()
41                    .extension()
42                    .map(|ext| ext == "rs")
43                    .unwrap_or(false)
44        })
45        .map(|entry| entry.into_path())
46        .collect()
47}
48
49fn collect_items(items: &[Item], source_path: &Path, specs: &mut Vec<TypeSpec>) -> Result<()> {
50    for item in items {
51        match item {
52            Item::Struct(item_struct) => {
53                maybe_collect_from_derive_input(
54                    item_struct_to_derive(item_struct),
55                    source_path,
56                    specs,
57                )?;
58            }
59            Item::Enum(item_enum) => {
60                maybe_collect_from_derive_input(
61                    item_enum_to_derive(item_enum),
62                    source_path,
63                    specs,
64                )?;
65            }
66            Item::Union(item_union) => {
67                maybe_collect_from_derive_input(
68                    item_union_to_derive(item_union),
69                    source_path,
70                    specs,
71                )?;
72            }
73            Item::Mod(item_mod) => {
74                if let Some((_, inline_items)) = &item_mod.content {
75                    collect_items(inline_items, source_path, specs)?;
76                }
77            }
78            _ => {}
79        }
80    }
81
82    Ok(())
83}
84
85fn maybe_collect_from_derive_input(
86    input: DeriveInput,
87    source_path: &Path,
88    specs: &mut Vec<TypeSpec>,
89) -> Result<()> {
90    if !parser::has_typewriter_derive(&input.attrs) {
91        return Ok(());
92    }
93
94    let type_def = parser::parse_type_def(&input)
95        .map_err(|err| anyhow::anyhow!("{} ({})", err, source_path.display()))?;
96    let targets = parser::parse_sync_to_attr(&input)
97        .map_err(|err| anyhow::anyhow!("{} ({})", err, source_path.display()))?;
98    let zod_schema = parser::parse_tw_zod_attr(&input)
99        .map_err(|err| anyhow::anyhow!("{} ({})", err, source_path.display()))?;
100
101    if targets.is_empty() {
102        return Err(anyhow::anyhow!(
103            "typewriter: #[sync_to(...)] attribute is required. Example: #[sync_to(typescript, python)] ({})",
104            source_path.display()
105        ));
106    }
107
108    specs.push(TypeSpec {
109        type_def,
110        targets,
111        source_path: source_path.to_path_buf(),
112        zod_schema,
113    });
114
115    Ok(())
116}
117
118fn item_struct_to_derive(item: &ItemStruct) -> DeriveInput {
119    DeriveInput {
120        attrs: item.attrs.clone(),
121        vis: item.vis.clone(),
122        ident: item.ident.clone(),
123        generics: item.generics.clone(),
124        data: Data::Struct(syn::DataStruct {
125            struct_token: item.struct_token,
126            fields: item.fields.clone(),
127            semi_token: item.semi_token,
128        }),
129    }
130}
131
132fn item_enum_to_derive(item: &ItemEnum) -> DeriveInput {
133    DeriveInput {
134        attrs: item.attrs.clone(),
135        vis: item.vis.clone(),
136        ident: item.ident.clone(),
137        generics: item.generics.clone(),
138        data: Data::Enum(syn::DataEnum {
139            enum_token: item.enum_token,
140            brace_token: item.brace_token,
141            variants: item.variants.clone(),
142        }),
143    }
144}
145
146fn item_union_to_derive(item: &ItemUnion) -> DeriveInput {
147    DeriveInput {
148        attrs: item.attrs.clone(),
149        vis: item.vis.clone(),
150        ident: item.ident.clone(),
151        generics: item.generics.clone(),
152        data: Data::Union(syn::DataUnion {
153            union_token: item.union_token,
154            fields: item.fields.clone(),
155        }),
156    }
157}
158
159#[cfg(test)]
160mod tests {
161    use super::*;
162
163    #[test]
164    fn scans_typewriter_items_from_file() {
165        let temp = tempfile::tempdir().unwrap();
166        let file = temp.path().join("mod.rs");
167        std::fs::write(
168            &file,
169            r#"
170            #[derive(TypeWriter)]
171            #[sync_to(typescript, python)]
172            struct User {
173                id: String,
174            }
175            "#,
176        )
177        .unwrap();
178
179        let specs = scan_file(&file).unwrap();
180        assert_eq!(specs.len(), 1);
181        assert_eq!(specs[0].type_def.name(), "User");
182        assert_eq!(specs[0].zod_schema, None);
183    }
184
185    #[test]
186    fn scans_type_level_zod_override() {
187        let temp = tempfile::tempdir().unwrap();
188        let file = temp.path().join("mod.rs");
189        std::fs::write(
190            &file,
191            r#"
192            #[derive(TypeWriter)]
193            #[sync_to(typescript)]
194            #[tw(zod = false)]
195            struct Address {
196                id: String,
197            }
198            "#,
199        )
200        .unwrap();
201
202        let specs = scan_file(&file).unwrap();
203        assert_eq!(specs.len(), 1);
204        assert_eq!(specs[0].type_def.name(), "Address");
205        assert_eq!(specs[0].zod_schema, Some(false));
206    }
207}