syntaxfmt_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{
4    parse_macro_input, Data, DeriveInput, Fields, Lit, Meta, MetaNameValue,
5};
6
7#[proc_macro_derive(SyntaxFmt, attributes(syntax))]
8pub fn derive_syntax_fmt(input: TokenStream) -> TokenStream {
9    let input = parse_macro_input!(input as DeriveInput);
10    let name = &input.ident;
11    let mut generics = input.generics.clone();
12    let (_, ty_generics, _) = input.generics.split_for_impl();
13
14    let (delim, pretty_delim) = parse_delimiters(&input.attrs);
15    let state_bound = parse_state_bound(&input.attrs);
16    let outer_format = parse_outer_format(&input.attrs);
17    let field_types = collect_field_types(&input.data);
18
19    let fmt_body = match &input.data {
20        Data::Struct(data_struct) => generate_struct_fmt(&data_struct.fields),
21        Data::Enum(data_enum) => generate_enum_fmt(name, &data_enum.variants),
22        Data::Union(_) => {
23            return syn::Error::new_spanned(name, "SyntaxFmt cannot be derived for unions")
24                .to_compile_error()
25                .into();
26        }
27    };
28
29    let fmt_body = wrap_with_outer_format(fmt_body, &outer_format);
30
31    let delim_const = delim.map(|d| quote! { const DELIM: &'static str = #d; });
32    let pretty_delim_const = pretty_delim.map(|d| quote! { const PRETTY_DELIM: &'static str = #d; });
33
34    generics.params.push(syn::parse_quote! { __SyntaxFmtState });
35
36    let where_clause = build_where_clause(&mut generics, &field_types, state_bound.as_ref());
37    let (impl_generics_with_state, _, _) = generics.split_for_impl();
38
39    let expanded = quote! {
40        impl #impl_generics_with_state ::syntaxfmt::SyntaxFmt<__SyntaxFmtState> for #name #ty_generics #where_clause {
41            #delim_const
42            #pretty_delim_const
43
44            fn syntax_fmt(&self, ctx: &mut ::syntaxfmt::SyntaxFormatter<__SyntaxFmtState>) -> ::std::fmt::Result {
45                #fmt_body
46            }
47        }
48    };
49
50    TokenStream::from(expanded)
51}
52
53fn build_where_clause(
54    generics: &mut syn::Generics,
55    field_types: &[syn::Type],
56    state_bound: Option<&syn::TraitBound>,
57) -> syn::WhereClause {
58    let mut where_clause = generics.make_where_clause().clone();
59
60    if let Some(bound) = state_bound {
61        where_clause.predicates.push(syn::parse_quote! {
62            __SyntaxFmtState: #bound
63        });
64    }
65
66    for field_ty in field_types {
67        where_clause.predicates.push(syn::parse_quote! {
68            #field_ty: ::syntaxfmt::SyntaxFmt<__SyntaxFmtState>
69        });
70    }
71    where_clause
72}
73
74fn collect_field_types(data: &Data) -> Vec<syn::Type> {
75    let mut types = Vec::new();
76    match data {
77        Data::Struct(data_struct) => collect_struct_field_types(&data_struct.fields, &mut types),
78        Data::Enum(data_enum) => {
79            for variant in &data_enum.variants {
80                collect_struct_field_types(&variant.fields, &mut types);
81            }
82        }
83        Data::Union(_) => {}
84    }
85    types
86}
87
88fn collect_struct_field_types(fields: &Fields, types: &mut Vec<syn::Type>) {
89    for field in fields.iter() {
90        let attrs = parse_field_attrs(&field.attrs);
91        if attrs.skip || is_type_ident(&field.ty, "bool") {
92            continue;
93        }
94
95        let ty = extract_option_inner(&field.ty);
96        types.push(extract_collection_inner(&ty).unwrap_or(ty));
97    }
98}
99
100fn extract_option_inner(ty: &syn::Type) -> syn::Type {
101    if let syn::Type::Path(type_path) = ty {
102        if let Some(segment) = type_path.path.segments.last() {
103            if segment.ident == "Option" {
104                if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
105                    if let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() {
106                        return inner_ty.clone();
107                    }
108                }
109            }
110        }
111    }
112    ty.clone()
113}
114
115fn extract_collection_inner(ty: &syn::Type) -> Option<syn::Type> {
116    match ty {
117        syn::Type::Path(type_path) => {
118            let segment = type_path.path.segments.last()?;
119            if segment.ident == "Vec" {
120                if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
121                    if let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() {
122                        return Some(inner_ty.clone());
123                    }
124                }
125            }
126            None
127        }
128        syn::Type::Reference(type_ref) => {
129            if let syn::Type::Slice(slice) = &*type_ref.elem {
130                return Some((*slice.elem).clone());
131            }
132            None
133        }
134        syn::Type::Array(array) => Some((*array.elem).clone()),
135        _ => None,
136    }
137}
138
139fn generate_collection_iteration(
140    field_expr: &proc_macro2::TokenStream,
141    inner_ty: &syn::Type,
142) -> proc_macro2::TokenStream {
143    quote! {
144        {
145            let delim = if ctx.is_pretty() {
146                <#inner_ty as ::syntaxfmt::SyntaxFmt<__SyntaxFmtState>>::PRETTY_DELIM
147            } else {
148                <#inner_ty as ::syntaxfmt::SyntaxFmt<__SyntaxFmtState>>::DELIM
149            };
150            let fold = |r: ::std::fmt::Result, (i, e): (usize, &#inner_ty)| {
151                r?;
152                if i > 0 {
153                    write!(ctx, "{}", delim)?;
154                }
155                if ctx.is_pretty() {
156                    ctx.indent()?;
157                }
158                e.syntax_fmt(ctx)?;
159                Ok(())
160            };
161            (#field_expr).iter()
162                .enumerate()
163                .fold(Ok(()), fold)?;
164        }
165    }
166}
167
168fn extract_str_literal(value: &syn::Expr) -> Option<String> {
169    if let syn::Expr::Lit(syn::ExprLit { lit: Lit::Str(s), .. }) = value {
170        Some(s.value())
171    } else {
172        None
173    }
174}
175
176fn parse_pretty_string_attrs(
177    attrs: &[syn::Attribute],
178    normal_name: &str,
179    pretty_name: &str,
180) -> PrettyString {
181    let mut result = PrettyString::default();
182
183    parse_syntax_attrs(attrs, |meta| {
184        if let Meta::NameValue(MetaNameValue { path, value, .. }) = meta {
185            if let Some(s) = extract_str_literal(value) {
186                if path.is_ident(normal_name) {
187                    result.normal = Some(s);
188                } else if path.is_ident(pretty_name) {
189                    result.pretty = Some(s);
190                }
191            }
192        }
193    });
194
195    result
196}
197
198fn parse_delimiters(attrs: &[syn::Attribute]) -> (Option<String>, Option<String>) {
199    let result = parse_pretty_string_attrs(attrs, "delim", "pretty_delim");
200    (result.normal, result.pretty)
201}
202
203fn parse_outer_format(attrs: &[syn::Attribute]) -> PrettyString {
204    parse_pretty_string_attrs(attrs, "format", "pretty_format")
205}
206
207fn parse_state_bound(attrs: &[syn::Attribute]) -> Option<syn::TraitBound> {
208    let mut state_bound = None;
209
210    parse_syntax_attrs(attrs, |meta| {
211        if let Meta::NameValue(MetaNameValue { path, value, .. }) = meta {
212            if path.is_ident("state_bound") {
213                if let Some(s) = extract_str_literal(value) {
214                    if let Ok(bound) = syn::parse_str::<syn::TraitBound>(&s) {
215                        state_bound = Some(bound);
216                    }
217                }
218            }
219        }
220    });
221
222    state_bound
223}
224
225fn parse_field_attrs(attrs: &[syn::Attribute]) -> FieldAttrs {
226    let mut field_attrs = FieldAttrs::default();
227
228    parse_syntax_attrs(attrs, |meta| match meta {
229        Meta::NameValue(MetaNameValue { path, value, .. }) => {
230            if path.is_ident("content") {
231                field_attrs.content = Some(value.clone());
232            } else if let Some(s) = extract_str_literal(value) {
233                if path.is_ident("format") {
234                    field_attrs.format.normal = Some(s);
235                } else if path.is_ident("pretty_format") {
236                    field_attrs.format.pretty = Some(s);
237                } else if path.is_ident("empty_suffix") {
238                    field_attrs.empty_suffix = Some(s);
239                }
240            }
241        }
242        Meta::Path(path) => {
243            if path.is_ident("skip") {
244                field_attrs.skip = true;
245            } else if path.is_ident("indent_region") {
246                field_attrs.indent_region = true;
247            } else if path.is_ident("indent") {
248                field_attrs.indent = true;
249            }
250        }
251        _ => {}
252    });
253
254    field_attrs
255}
256
257fn parse_syntax_attrs(attrs: &[syn::Attribute], mut f: impl FnMut(&Meta)) {
258    for attr in attrs {
259        if attr.path().is_ident("syntax") {
260            if let Ok(meta_list) = attr.meta.require_list() {
261                if let Ok(nested_list) = meta_list.parse_args_with(
262                    syn::punctuated::Punctuated::<Meta, syn::Token![,]>::parse_terminated
263                ) {
264                    for nested in &nested_list {
265                        f(nested);
266                    }
267                }
268            }
269        }
270    }
271}
272
273#[derive(Default)]
274struct PrettyString {
275    normal: Option<String>,
276    pretty: Option<String>,
277}
278
279impl PrettyString {
280    fn get_pair(&self) -> (String, String) {
281        let normal = self.normal.as_deref().unwrap_or("");
282        let pretty = self.pretty.as_deref().or(self.normal.as_deref()).unwrap_or("");
283        (normal.to_string(), pretty.to_string())
284    }
285}
286
287#[derive(Default)]
288struct FieldAttrs {
289    format: PrettyString,
290    content: Option<syn::Expr>,
291    empty_suffix: Option<String>,
292    indent_region: bool,
293    indent: bool,
294    skip: bool,
295}
296
297fn split_format_string(format_str: &str) -> (&str, &str, bool) {
298    if let Some(pos) = format_str.find("{content}") {
299        (&format_str[..pos], &format_str[pos + 9..], true)
300    } else {
301        (format_str, "", false)
302    }
303}
304
305fn generate_default_content(
306    field_expr: &proc_macro2::TokenStream,
307    content_expr: Option<&syn::Expr>,
308    field_ty: Option<&syn::Type>,
309) -> proc_macro2::TokenStream {
310    if let Some(content_fn) = content_expr {
311        return quote! { (#content_fn)(&#field_expr, ctx)?; };
312    }
313
314    if let Some(ty) = field_ty {
315        if let Some(inner_ty) = extract_collection_inner(ty) {
316            return generate_collection_iteration(field_expr, &inner_ty);
317        }
318    }
319
320    quote! { #field_expr.syntax_fmt(ctx)?; }
321}
322
323fn expand_format_string(
324    format_str: &str,
325    field_expr: &proc_macro2::TokenStream,
326    content_expr: Option<&syn::Expr>,
327    field_ty: Option<&syn::Type>,
328) -> proc_macro2::TokenStream {
329    let (before, after, has_placeholder) = split_format_string(format_str);
330    let mut statements = Vec::new();
331
332    if !before.is_empty() {
333        statements.push(quote! { write!(ctx, #before)?; });
334    }
335
336    if has_placeholder {
337        statements.push(generate_default_content(field_expr, content_expr, field_ty));
338    }
339
340    if !after.is_empty() {
341        statements.push(quote! { write!(ctx, #after)?; });
342    }
343
344    quote! { #(#statements)* }
345}
346
347fn pretty_conditional(
348    normal: proc_macro2::TokenStream,
349    pretty: proc_macro2::TokenStream,
350) -> proc_macro2::TokenStream {
351    quote! {
352        if ctx.is_pretty() {
353            #pretty
354        } else {
355            #normal
356        }
357    }
358}
359
360fn wrap_with_outer_format(
361    fmt_body: proc_macro2::TokenStream,
362    outer_format: &PrettyString,
363) -> proc_macro2::TokenStream {
364    if outer_format.normal.is_none() && outer_format.pretty.is_none() {
365        return fmt_body;
366    }
367
368    let (normal_fmt, pretty_fmt) = outer_format.get_pair();
369
370    let wrap_body = |format_str: &str| -> proc_macro2::TokenStream {
371        let (before, after, has_placeholder) = split_format_string(format_str);
372
373        if !has_placeholder {
374            return quote! {
375                write!(ctx, #format_str)?;
376                #fmt_body
377            };
378        }
379
380        if before.is_empty() && after.is_empty() {
381            return fmt_body.clone();
382        }
383
384        if after.is_empty() {
385            return quote! {
386                write!(ctx, #before)?;
387                #fmt_body
388            };
389        }
390
391        quote! {
392            write!(ctx, #before)?;
393            (|| -> ::std::fmt::Result { #fmt_body })()?;
394            write!(ctx, #after)?;
395            Ok(())
396        }
397    };
398
399    if normal_fmt == pretty_fmt {
400        wrap_body(&normal_fmt)
401    } else {
402        pretty_conditional(wrap_body(&normal_fmt), wrap_body(&pretty_fmt))
403    }
404}
405
406fn generate_format_output(
407    field_expr: &proc_macro2::TokenStream,
408    format: &PrettyString,
409    content_expr: Option<&syn::Expr>,
410    field_ty: Option<&syn::Type>,
411) -> proc_macro2::TokenStream {
412    // No format specified - use default
413    if format.normal.is_none() && format.pretty.is_none() {
414        return generate_default_content(field_expr, content_expr, field_ty);
415    }
416
417    let (normal_fmt, pretty_fmt) = format.get_pair();
418
419    // Only pretty_format specified
420    if format.normal.is_none() {
421        let default_content = generate_default_content(field_expr, content_expr, field_ty);
422        let pretty_write = expand_format_string(&pretty_fmt, field_expr, content_expr, field_ty);
423        return quote! {
424            if ctx.is_pretty() {
425                #pretty_write
426            } else {
427                #default_content
428            }
429        };
430    }
431
432    // Normal format (with optional different pretty format)
433    let normal_write = expand_format_string(&normal_fmt, field_expr, content_expr, field_ty);
434
435    if normal_fmt == pretty_fmt {
436        normal_write
437    } else {
438        let pretty_write = expand_format_string(&pretty_fmt, field_expr, content_expr, field_ty);
439        pretty_conditional(normal_write, pretty_write)
440    }
441}
442
443fn generate_struct_fmt(fields: &Fields) -> proc_macro2::TokenStream {
444    match fields {
445        Fields::Named(fields_named) => generate_named_fields_fmt(&fields_named.named),
446        Fields::Unnamed(fields_unnamed) if fields_unnamed.unnamed.len() == 1 => {
447            let field = fields_unnamed.unnamed.first().unwrap();
448            let attrs = parse_field_attrs(&field.attrs);
449            let format_output = generate_format_output(
450                &quote! { self.0 },
451                &attrs.format,
452                attrs.content.as_ref(),
453                Some(&field.ty),
454            );
455            quote! {
456                #format_output
457                Ok(())
458            }
459        }
460        Fields::Unnamed(_) | Fields::Unit => quote! { Ok(()) },
461    }
462}
463
464fn generate_named_fields_fmt(
465    fields: &syn::punctuated::Punctuated<syn::Field, syn::Token![,]>,
466) -> proc_macro2::TokenStream {
467    let mut statements = Vec::new();
468
469    for field in fields {
470        let field_name = field.ident.as_ref().unwrap();
471        let attrs = parse_field_attrs(&field.attrs);
472
473        if attrs.skip {
474            continue;
475        }
476
477        if is_type_ident(&field.ty, "bool") {
478            let format_output = generate_format_output(
479                &quote! { &true },
480                &attrs.format,
481                attrs.content.as_ref(),
482                None,
483            );
484            statements.push(quote! {
485                if self.#field_name {
486                    #format_output
487                }
488            });
489        } else if is_type_ident(&field.ty, "Option") {
490            let field_expr = quote! { #field_name };
491            let inner_ty = extract_option_inner(&field.ty);
492            let format_output = generate_format_output(
493                &field_expr,
494                &attrs.format,
495                attrs.content.as_ref(),
496                Some(&inner_ty),
497            );
498            statements.push(quote! {
499                if let Some(#field_name) = &self.#field_name {
500                    #format_output
501                }
502            });
503        } else {
504            let field_expr = quote! { self.#field_name };
505            let mut field_statements = Vec::new();
506
507            if attrs.indent {
508                field_statements.push(quote! {
509                    if ctx.is_pretty() {
510                        ctx.indent()?;
511                    }
512                });
513            }
514
515            if attrs.indent_region {
516                field_statements.push(quote! {
517                    if ctx.is_pretty() {
518                        ctx.inc_indent();
519                    }
520                });
521            }
522
523            let format_output = generate_format_output(
524                &field_expr,
525                &attrs.format,
526                attrs.content.as_ref(),
527                Some(&field.ty),
528            );
529
530            field_statements.push(format_output);
531
532            if attrs.indent_region {
533                field_statements.push(quote! {
534                    if ctx.is_pretty() {
535                        ctx.dec_indent();
536                    }
537                });
538            }
539
540            if let Some(empty_suffix) = &attrs.empty_suffix {
541                statements.push(quote! {
542                    if self.#field_name.is_empty() {
543                        write!(ctx, #empty_suffix)?;
544                    } else {
545                        #(#field_statements)*
546                    }
547                });
548            } else {
549                statements.extend(field_statements);
550            }
551        }
552    }
553
554    statements.push(quote! { Ok(()) });
555    quote! { #(#statements)* }
556}
557
558fn generate_enum_fmt(
559    name: &syn::Ident,
560    variants: &syn::punctuated::Punctuated<syn::Variant, syn::Token![,]>,
561) -> proc_macro2::TokenStream {
562    let match_arms: Vec<_> = variants.iter().map(|variant| {
563        let variant_name = &variant.ident;
564        let attrs = parse_field_attrs(&variant.attrs);
565
566        match &variant.fields {
567            Fields::Named(_) => {
568                quote! {
569                    #name::#variant_name { .. } => todo!("Named enum variants not yet supported")
570                }
571            }
572            Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {
573                let field = fields.unnamed.first().unwrap();
574                let format_output = generate_format_output(
575                    &quote! { inner },
576                    &attrs.format,
577                    attrs.content.as_ref(),
578                    Some(&field.ty),
579                );
580                quote! {
581                    #name::#variant_name(inner) => { #format_output Ok(()) }
582                }
583            }
584            Fields::Unnamed(_) => {
585                quote! {
586                    #name::#variant_name(..) => todo!("Multi-field tuple variants not yet supported")
587                }
588            }
589            Fields::Unit => {
590                if attrs.format.normal.is_some() || attrs.format.pretty.is_some() {
591                    let format_output = generate_format_output(
592                        &quote! { "" },
593                        &attrs.format,
594                        attrs.content.as_ref(),
595                        None,
596                    );
597                    quote! { #name::#variant_name => { #format_output Ok(()) } }
598                } else {
599                    let lower_name = variant_name.to_string().to_lowercase();
600                    quote! { #name::#variant_name => write!(ctx, #lower_name) }
601                }
602            }
603        }
604    }).collect();
605
606    quote! {
607        match self {
608            #(#match_arms,)*
609        }
610    }
611}
612
613fn is_type_ident(ty: &syn::Type, ident_name: &str) -> bool {
614    if let syn::Type::Path(type_path) = ty {
615        if let Some(segment) = type_path.path.segments.last() {
616            return segment.ident == ident_name;
617        }
618    }
619    false
620}