sourcegen_cli/
generate.rs

1use crate::error::{Location, SourcegenError, SourcegenErrorKind};
2use crate::mods::ModResolver;
3use crate::{GeneratorsMap, SourceGenerator};
4use anyhow::Context;
5use proc_macro2::{LineColumn, TokenStream};
6use std::collections::{BTreeMap, HashMap};
7use std::path::Path;
8use syn::spanned::Spanned;
9use syn::{Attribute, AttributeArgs, File, Item, LitStr, Meta, NestedMeta};
10
11static ITEM_COMMENT: &str =
12    "// Generated. All manual edits to the block annotated with #[sourcegen...] will be discarded.";
13static FILE_COMMENT: &str = "// Generated. All manual edits below this line will be discarded.";
14
15#[derive(Debug, PartialOrd, Ord, PartialEq, Eq)]
16struct Region {
17    from: usize,
18    to: usize,
19    indent: usize,
20}
21
22/// Replace a single file with the generated content
23pub fn process_single_file(path: &Path, tokens: TokenStream) -> Result<(), SourcegenError> {
24    let formatter = crate::rustfmt::Formatter::new(path.parent().unwrap())?;
25
26    let source = if path.exists() {
27        std::fs::read_to_string(path)
28            .with_context(|| SourcegenErrorKind::ProcessFile(path.display().to_string()))?
29    } else {
30        String::new()
31    };
32    let replacement = Replacement {
33        comment: FILE_COMMENT,
34        is_cr_lf: is_cr_lf(&source),
35        tokens: &tokens,
36    };
37    let output = formatter.format(path, replacement)?;
38    if source != output {
39        std::fs::write(path, output)
40            .with_context(|| SourcegenErrorKind::ProcessFile(path.display().to_string()))?;
41    }
42    Ok(())
43}
44
45pub fn process_source_file(
46    path: &Path,
47    generators: &HashMap<&str, &dyn SourceGenerator>,
48    mod_resolver: &ModResolver,
49) -> Result<(), SourcegenError> {
50    let source = std::fs::read_to_string(path)
51        .with_context(|| SourcegenErrorKind::ProcessFile(path.display().to_string()))?;
52    let mut file = syn::parse_file(&source)
53        .with_context(|| SourcegenErrorKind::ProcessFile(path.display().to_string()))?;
54
55    let output = if let Some(invoke) = detect_file_invocation(path, &mut file, generators)? {
56        if !invoke.is_file {
57            // Remove all attributes in front of the `#![sourcegen]` attribute
58            file.attrs.drain(0..invoke.sourcegen_attr_index + 1);
59        }
60
61        // Handle full file generation
62        let context_location = invoke.context_location;
63        let result = invoke
64            .generator
65            .generate_file(invoke.args, &file)
66            .with_context(|| SourcegenErrorKind::GeneratorError(context_location))?;
67        if let Some(expansion) = result {
68            let from_loc = if invoke.is_file {
69                crate::region::item_end_span(&file.items[0]).end()
70            } else {
71                invoke.sourcegen_attr.bracket_token.span.end()
72            };
73            let from = line_column_to_offset(&source, from_loc)?;
74            let from = from + skip_whitespaces(&source[from..]);
75            let region = Region {
76                from,
77                to: source.len(),
78                indent: 0,
79            };
80
81            // Replace the whole file
82            let mut replacements = BTreeMap::new();
83            replacements.insert(region, expansion);
84            render_expansions(path, &source, &replacements, FILE_COMMENT)?
85        } else {
86            // Nothing to replace
87            return Ok(());
88        }
89    } else {
90        let mut replacements = BTreeMap::new();
91        handle_content(
92            path,
93            &source,
94            &mut file.items,
95            &generators,
96            &mut replacements,
97            &mod_resolver,
98        )?;
99        render_expansions(path, &source, &replacements, ITEM_COMMENT)?
100    };
101
102    if source != output {
103        std::fs::write(path, output)
104            .with_context(|| SourcegenErrorKind::ProcessFile(path.display().to_string()))?;
105    }
106    Ok(())
107}
108
109/// Render given list of replacements into the source file. `basefile` is used to determine base
110/// directory to run `rustfmt` in (so it can use local overrides for formatting rules).
111///
112/// `comment` is the warning comment that will be added in front of each generated block.
113fn render_expansions(
114    basefile: &Path,
115    source: &str,
116    expansions: &BTreeMap<Region, TokenStream>,
117    comment: &str,
118) -> Result<String, SourcegenError> {
119    let mut output = String::with_capacity(source.len());
120    let formatter = crate::rustfmt::Formatter::new(basefile.parent().unwrap())?;
121
122    let mut offset = 0;
123    let is_cr_lf = is_cr_lf(source);
124    for (region, tokens) in expansions {
125        output += &source[offset..region.from];
126        offset = region.to;
127        let indent = format!("{:indent$}", "", indent = region.indent);
128        if !tokens.is_empty() {
129            let replacement = Replacement {
130                comment,
131                is_cr_lf,
132                tokens,
133            };
134            let formatted = formatter.format(basefile, replacement)?;
135            let mut first = true;
136            for line in formatted.lines() {
137                // We don't want newline on the last line (the captured region does not include the
138                // one) and also we don't want an indent on the first line (we splice after it).
139                if first {
140                    first = false
141                } else {
142                    if is_cr_lf {
143                        output.push('\r');
144                    }
145                    output.push('\n');
146                    output += &indent;
147                }
148                output += line;
149            }
150        }
151    }
152    // Insert newline at the end of the file!
153    if offset == source.len() {
154        if is_cr_lf {
155            output.push('\r');
156        }
157        output.push('\n');
158    }
159    output += &source[offset..];
160    Ok(output)
161}
162
163fn handle_content(
164    path: &Path,
165    source: &str,
166    items: &mut [Item],
167    generators: &GeneratorsMap,
168    replacements: &mut BTreeMap<Region, TokenStream>,
169    mod_resolver: &ModResolver,
170) -> Result<(), SourcegenError> {
171    let mut item_idx = 0;
172    while item_idx < items.len() {
173        item_idx += 1;
174        let (head, tail) = items.split_at_mut(item_idx);
175        let item = head.last_mut().unwrap();
176
177        let mut empty_attrs = Vec::new();
178        let attrs = crate::region::item_attributes(item).unwrap_or(&mut empty_attrs);
179        if let Some(invoke) = detect_invocation(path, attrs, generators)? {
180            // Remove all attributes in front of the `#[sourcegen]` attribute
181            attrs.drain(0..invoke.sourcegen_attr_index + 1);
182            let context_location = invoke.context_location;
183            let result = crate::region::invoke_generator(item, invoke.args, invoke.generator)
184                .with_context(|| SourcegenErrorKind::GeneratorError(context_location))?;
185            if let Some(expansion) = result {
186                let indent = invoke.sourcegen_attr.span().start().column;
187                let from_loc = invoke.sourcegen_attr.bracket_token.span.end();
188                let from = line_column_to_offset(source, from_loc)?;
189                let from = from + skip_whitespaces(&source[from..]);
190
191                // Find the first item that is not marked as "generated"
192                let skip_count = (0..tail.len())
193                    .find(|pos| {
194                        !is_generated(
195                            crate::region::item_attributes(&mut tail[*pos])
196                                .unwrap_or(&mut empty_attrs),
197                        )
198                    })
199                    .unwrap_or(tail.len());
200                let to_span = if skip_count == 0 {
201                    crate::region::item_end_span(item)
202                } else {
203                    // Skip consecutive items marked via `#[sourcegen::generated]`
204                    item_idx += skip_count;
205                    crate::region::item_end_span(&tail[skip_count - 1])
206                };
207                let to = line_column_to_offset(source, to_span.end())?;
208
209                let region = Region { from, to, indent };
210                replacements.insert(region, expansion);
211                continue;
212            }
213        }
214
215        if let Item::Mod(item) = item {
216            let nested_mod_resolved = mod_resolver.push_module(&item.ident.to_string());
217            if item.content.is_some() {
218                let items = &mut item.content.as_mut().unwrap().1;
219                handle_content(
220                    path,
221                    source,
222                    items,
223                    generators,
224                    replacements,
225                    &nested_mod_resolved,
226                )?;
227            } else {
228                let mod_file = mod_resolver.resolve_module_file(item)?;
229                process_source_file(&mod_file, generators, &nested_mod_resolved)?;
230            }
231        }
232    }
233    Ok(())
234}
235
236fn is_generated(attrs: &[Attribute]) -> bool {
237    let sourcegen_attr = attrs.iter().find(|attr| {
238        attr.path
239            .segments
240            .first()
241            .map_or(false, |segment| segment.ident == "sourcegen")
242    });
243    if let Some(sourcegen) = sourcegen_attr {
244        sourcegen
245            .path
246            .segments
247            .iter()
248            .skip(1)
249            .next()
250            .map_or(false, |segment| segment.ident == "generated")
251    } else {
252        false
253    }
254}
255
256fn detect_file_invocation<'a>(
257    path: &Path,
258    file: &mut File,
259    generators: &'a GeneratorsMap,
260) -> Result<Option<GeneratorInfo<'a>>, SourcegenError> {
261    if let Some(mut invoke) = detect_invocation(path, &mut file.attrs, generators)? {
262        // This flag should only be set when we are processing a special workaround
263        invoke.is_file = false;
264        return Ok(Some(invoke));
265    }
266
267    if let Some(item) = file.items.iter_mut().next() {
268        // Special case: if first item in the file has `sourcegen::sourcegen` attribute with `file` set
269        // to `true`, we treat it as file sourcegen.
270        let mut empty_attrs = Vec::new();
271        let attrs = crate::region::item_attributes(item).unwrap_or(&mut empty_attrs);
272        if let Some(invoke) = detect_invocation(path, &mut attrs.clone(), generators)? {
273            if invoke.is_file {
274                return Ok(Some(invoke));
275            }
276        }
277    }
278    Ok(None)
279}
280
281/// Collect parameters from `#[sourcegen]` attribute.
282fn detect_invocation<'a>(
283    path: &Path,
284    attrs: &[Attribute],
285    generators: &'a GeneratorsMap,
286) -> Result<Option<GeneratorInfo<'a>>, SourcegenError> {
287    let sourcegen_attr = attrs.iter().position(|attr| {
288        attr.path
289            .segments
290            .first()
291            .map_or(false, |segment| segment.ident == "sourcegen")
292    });
293    if let Some(attr_pos) = sourcegen_attr {
294        let invoke = detect_generator(path, attrs, attr_pos, generators)?;
295        Ok(Some(invoke))
296    } else {
297        Ok(None)
298    }
299}
300
301/// Map from the line number and column back to the offset.
302fn line_column_to_offset(text: &str, lc: LineColumn) -> Result<usize, SourcegenError> {
303    let mut line = lc.line as usize;
304
305    assert_ne!(line, 0, "line number must be 1-indexed");
306
307    let mut offset = 0;
308    for (idx, ch) in text.char_indices() {
309        offset = idx;
310        if line == 1 {
311            break;
312        }
313        if ch == '\n' {
314            line -= 1;
315        }
316    }
317    offset += lc.column;
318    Ok(offset.min(text.len()))
319}
320
321fn skip_whitespaces(text: &str) -> usize {
322    let end = text.trim_start().as_ptr() as usize;
323    let start = text.as_ptr() as usize;
324    end - start
325}
326
327struct GeneratorInfo<'a> {
328    /// Source generator to run
329    generator: &'a dyn SourceGenerator,
330    args: AttributeArgs,
331    /// `#[sourcegen]` attribute itself
332    sourcegen_attr: Attribute,
333    /// Index of `#[sourcegen]` attribute
334    sourcegen_attr_index: usize,
335    /// Location for error reporting
336    context_location: Location,
337    /// If this invocation should regenerate the whole block up to the end.
338    /// (this is used as a workaround for attributes not allowed on modules)
339    is_file: bool,
340}
341
342fn detect_generator<'a>(
343    path: &Path,
344    attrs: &[Attribute],
345    sourcegen_attr_index: usize,
346    generators: &'a GeneratorsMap,
347) -> Result<GeneratorInfo<'a>, SourcegenError> {
348    let sourcegen_attr = attrs[sourcegen_attr_index].clone();
349
350    let loc = Location::from_path_span(path, sourcegen_attr.span());
351    let meta = sourcegen_attr
352        .parse_meta()
353        .with_context(|| SourcegenErrorKind::GeneratorError(loc.clone()))?;
354
355    let meta_span = meta.span();
356    if let Meta::List(list) = meta {
357        let mut name: Option<&LitStr> = None;
358        let mut is_file = false;
359        for item in &list.nested {
360            match item {
361                NestedMeta::Meta(Meta::NameValue(nv)) if nv.path.is_ident("generator") => {
362                    if let syn::Lit::Str(ref value) = nv.lit {
363                        if name.is_some() {
364                            let loc = Location::from_path_span(path, item.span());
365                            return Err(SourcegenErrorKind::MultipleGeneratorAttributes(loc).into());
366                        }
367                        name = Some(value);
368                    } else {
369                        let loc = Location::from_path_span(path, item.span());
370                        return Err(SourcegenErrorKind::GeneratorAttributeMustBeString(loc).into());
371                    }
372                }
373                NestedMeta::Meta(Meta::NameValue(nv)) if nv.path.is_ident("file") => {
374                    if let syn::Lit::Bool(ref value) = nv.lit {
375                        is_file = value.value;
376                    }
377                }
378                _ => {}
379            }
380        }
381        if let Some(name) = name {
382            let name_span = name.span();
383            let name = name.value();
384            let args = list.nested.into_iter().collect::<Vec<_>>();
385            let context_location = Location::from_path_span(path, meta_span);
386            let generator = *generators.get(name.as_str()).ok_or_else(|| {
387                SourcegenErrorKind::GeneratorNotFound(
388                    Location::from_path_span(path, name_span),
389                    name,
390                )
391            })?;
392            return Ok(GeneratorInfo {
393                generator,
394                args,
395                sourcegen_attr_index,
396                sourcegen_attr,
397                context_location,
398                is_file,
399            });
400        }
401    }
402
403    let loc = Location::from_path_span(path, meta_span);
404    Err(SourcegenErrorKind::MissingGeneratorAttribute(loc).into())
405}
406
407/// Look at the first newline and decide if we should use `\r\n` (Windows).
408fn is_cr_lf(source: &str) -> bool {
409    if let Some(pos) = source.find('\n') {
410        source[..pos].ends_with('\r')
411    } else {
412        false
413    }
414}
415
416/// Struct used to generate replacement code directly into stdin of `rustfmt`.
417struct Replacement<'a> {
418    comment: &'a str,
419    is_cr_lf: bool,
420    tokens: &'a TokenStream,
421}
422
423impl std::fmt::Display for Replacement<'_> {
424    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
425        use std::fmt::Write;
426
427        f.write_str(self.comment)?;
428        if self.is_cr_lf {
429            f.write_char('\r')?;
430        }
431        f.write_char('\n')?;
432
433        #[cfg(feature = "disable_normalize_doc_attributes")]
434        write!(f, "{}", self.tokens)?;
435
436        #[cfg(not(feature = "disable_normalize_doc_attributes"))]
437        crate::normalize::write_tokens_normalized(f, self.tokens.clone())?;
438
439        Ok(())
440    }
441}