1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{
4 parse_macro_input, Data, DeriveInput, Fields, Lit, Meta, MetaNameValue,
5};
6
7#[proc_macro_derive(SyntaxFmt, attributes(syntax))]
8pub fn derive_syntax_fmt(input: TokenStream) -> TokenStream {
9 let input = parse_macro_input!(input as DeriveInput);
10 let name = &input.ident;
11 let mut generics = input.generics.clone();
12 let (_, ty_generics, _) = input.generics.split_for_impl();
13
14 let (delim, pretty_delim) = parse_delimiters(&input.attrs);
15 let state_bound = parse_state_bound(&input.attrs);
16 let outer_format = parse_outer_format(&input.attrs);
17 let field_types = collect_field_types(&input.data);
18
19 let fmt_body = match &input.data {
20 Data::Struct(data_struct) => generate_struct_fmt(&data_struct.fields),
21 Data::Enum(data_enum) => generate_enum_fmt(name, &data_enum.variants),
22 Data::Union(_) => {
23 return syn::Error::new_spanned(name, "SyntaxFmt cannot be derived for unions")
24 .to_compile_error()
25 .into();
26 }
27 };
28
29 let fmt_body = wrap_with_outer_format(fmt_body, &outer_format);
30
31 let delim_const = delim.map(|d| quote! { const DELIM: &'static str = #d; });
32 let pretty_delim_const = pretty_delim.map(|d| quote! { const PRETTY_DELIM: &'static str = #d; });
33
34 generics.params.push(syn::parse_quote! { __SyntaxFmtState });
35
36 let where_clause = build_where_clause(&mut generics, &field_types, state_bound.as_ref());
37 let (impl_generics_with_state, _, _) = generics.split_for_impl();
38
39 let expanded = quote! {
40 impl #impl_generics_with_state ::syntaxfmt::SyntaxFmt<__SyntaxFmtState> for #name #ty_generics #where_clause {
41 #delim_const
42 #pretty_delim_const
43
44 fn syntax_fmt(&self, ctx: &mut ::syntaxfmt::SyntaxFormatter<__SyntaxFmtState>) -> ::std::fmt::Result {
45 #fmt_body
46 }
47 }
48 };
49
50 TokenStream::from(expanded)
51}
52
53fn build_where_clause(
54 generics: &mut syn::Generics,
55 field_types: &[syn::Type],
56 state_bound: Option<&syn::TraitBound>,
57) -> syn::WhereClause {
58 let mut where_clause = generics.make_where_clause().clone();
59
60 if let Some(bound) = state_bound {
61 where_clause.predicates.push(syn::parse_quote! {
62 __SyntaxFmtState: #bound
63 });
64 }
65
66 for field_ty in field_types {
67 where_clause.predicates.push(syn::parse_quote! {
68 #field_ty: ::syntaxfmt::SyntaxFmt<__SyntaxFmtState>
69 });
70 }
71 where_clause
72}
73
74fn collect_field_types(data: &Data) -> Vec<syn::Type> {
75 let mut types = Vec::new();
76 match data {
77 Data::Struct(data_struct) => collect_struct_field_types(&data_struct.fields, &mut types),
78 Data::Enum(data_enum) => {
79 for variant in &data_enum.variants {
80 collect_struct_field_types(&variant.fields, &mut types);
81 }
82 }
83 Data::Union(_) => {}
84 }
85 types
86}
87
88fn collect_struct_field_types(fields: &Fields, types: &mut Vec<syn::Type>) {
89 for field in fields.iter() {
90 let attrs = parse_field_attrs(&field.attrs);
91 if attrs.skip || is_type_ident(&field.ty, "bool") {
92 continue;
93 }
94
95 let ty = extract_option_inner(&field.ty);
96 types.push(extract_collection_inner(&ty).unwrap_or(ty));
97 }
98}
99
100fn extract_option_inner(ty: &syn::Type) -> syn::Type {
101 if let syn::Type::Path(type_path) = ty {
102 if let Some(segment) = type_path.path.segments.last() {
103 if segment.ident == "Option" {
104 if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
105 if let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() {
106 return inner_ty.clone();
107 }
108 }
109 }
110 }
111 }
112 ty.clone()
113}
114
115fn extract_collection_inner(ty: &syn::Type) -> Option<syn::Type> {
116 match ty {
117 syn::Type::Path(type_path) => {
118 let segment = type_path.path.segments.last()?;
119 if segment.ident == "Vec" {
120 if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
121 if let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() {
122 return Some(inner_ty.clone());
123 }
124 }
125 }
126 None
127 }
128 syn::Type::Reference(type_ref) => {
129 if let syn::Type::Slice(slice) = &*type_ref.elem {
130 return Some((*slice.elem).clone());
131 }
132 None
133 }
134 syn::Type::Array(array) => Some((*array.elem).clone()),
135 _ => None,
136 }
137}
138
139fn generate_collection_iteration(
140 field_expr: &proc_macro2::TokenStream,
141 inner_ty: &syn::Type,
142) -> proc_macro2::TokenStream {
143 quote! {
144 {
145 let delim = if ctx.is_pretty() {
146 <#inner_ty as ::syntaxfmt::SyntaxFmt<__SyntaxFmtState>>::PRETTY_DELIM
147 } else {
148 <#inner_ty as ::syntaxfmt::SyntaxFmt<__SyntaxFmtState>>::DELIM
149 };
150 let fold = |r: ::std::fmt::Result, (i, e): (usize, &#inner_ty)| {
151 r?;
152 if i > 0 {
153 write!(ctx, "{}", delim)?;
154 }
155 if ctx.is_pretty() {
156 ctx.indent()?;
157 }
158 e.syntax_fmt(ctx)?;
159 Ok(())
160 };
161 (#field_expr).iter()
162 .enumerate()
163 .fold(Ok(()), fold)?;
164 }
165 }
166}
167
168fn extract_str_literal(value: &syn::Expr) -> Option<String> {
169 if let syn::Expr::Lit(syn::ExprLit { lit: Lit::Str(s), .. }) = value {
170 Some(s.value())
171 } else {
172 None
173 }
174}
175
176fn parse_pretty_string_attrs(
177 attrs: &[syn::Attribute],
178 normal_name: &str,
179 pretty_name: &str,
180) -> PrettyString {
181 let mut result = PrettyString::default();
182
183 parse_syntax_attrs(attrs, |meta| {
184 if let Meta::NameValue(MetaNameValue { path, value, .. }) = meta {
185 if let Some(s) = extract_str_literal(value) {
186 if path.is_ident(normal_name) {
187 result.normal = Some(s);
188 } else if path.is_ident(pretty_name) {
189 result.pretty = Some(s);
190 }
191 }
192 }
193 });
194
195 result
196}
197
198fn parse_delimiters(attrs: &[syn::Attribute]) -> (Option<String>, Option<String>) {
199 let result = parse_pretty_string_attrs(attrs, "delim", "pretty_delim");
200 (result.normal, result.pretty)
201}
202
203fn parse_outer_format(attrs: &[syn::Attribute]) -> PrettyString {
204 parse_pretty_string_attrs(attrs, "format", "pretty_format")
205}
206
207fn parse_state_bound(attrs: &[syn::Attribute]) -> Option<syn::TraitBound> {
208 let mut state_bound = None;
209
210 parse_syntax_attrs(attrs, |meta| {
211 if let Meta::NameValue(MetaNameValue { path, value, .. }) = meta {
212 if path.is_ident("state_bound") {
213 if let Some(s) = extract_str_literal(value) {
214 if let Ok(bound) = syn::parse_str::<syn::TraitBound>(&s) {
215 state_bound = Some(bound);
216 }
217 }
218 }
219 }
220 });
221
222 state_bound
223}
224
225fn parse_field_attrs(attrs: &[syn::Attribute]) -> FieldAttrs {
226 let mut field_attrs = FieldAttrs::default();
227
228 parse_syntax_attrs(attrs, |meta| match meta {
229 Meta::NameValue(MetaNameValue { path, value, .. }) => {
230 if path.is_ident("content") {
231 field_attrs.content = Some(value.clone());
232 } else if let Some(s) = extract_str_literal(value) {
233 if path.is_ident("format") {
234 field_attrs.format.normal = Some(s);
235 } else if path.is_ident("pretty_format") {
236 field_attrs.format.pretty = Some(s);
237 } else if path.is_ident("empty_suffix") {
238 field_attrs.empty_suffix = Some(s);
239 }
240 }
241 }
242 Meta::Path(path) => {
243 if path.is_ident("skip") {
244 field_attrs.skip = true;
245 } else if path.is_ident("indent_region") {
246 field_attrs.indent_region = true;
247 } else if path.is_ident("indent") {
248 field_attrs.indent = true;
249 }
250 }
251 _ => {}
252 });
253
254 field_attrs
255}
256
257fn parse_syntax_attrs(attrs: &[syn::Attribute], mut f: impl FnMut(&Meta)) {
258 for attr in attrs {
259 if attr.path().is_ident("syntax") {
260 if let Ok(meta_list) = attr.meta.require_list() {
261 if let Ok(nested_list) = meta_list.parse_args_with(
262 syn::punctuated::Punctuated::<Meta, syn::Token![,]>::parse_terminated
263 ) {
264 for nested in &nested_list {
265 f(nested);
266 }
267 }
268 }
269 }
270 }
271}
272
273#[derive(Default)]
274struct PrettyString {
275 normal: Option<String>,
276 pretty: Option<String>,
277}
278
279impl PrettyString {
280 fn get_pair(&self) -> (String, String) {
281 let normal = self.normal.as_deref().unwrap_or("");
282 let pretty = self.pretty.as_deref().or(self.normal.as_deref()).unwrap_or("");
283 (normal.to_string(), pretty.to_string())
284 }
285}
286
287#[derive(Default)]
288struct FieldAttrs {
289 format: PrettyString,
290 content: Option<syn::Expr>,
291 empty_suffix: Option<String>,
292 indent_region: bool,
293 indent: bool,
294 skip: bool,
295}
296
297fn split_format_string(format_str: &str) -> (&str, &str, bool) {
298 if let Some(pos) = format_str.find("{content}") {
299 (&format_str[..pos], &format_str[pos + 9..], true)
300 } else {
301 (format_str, "", false)
302 }
303}
304
305fn generate_default_content(
306 field_expr: &proc_macro2::TokenStream,
307 content_expr: Option<&syn::Expr>,
308 field_ty: Option<&syn::Type>,
309) -> proc_macro2::TokenStream {
310 if let Some(content_fn) = content_expr {
311 return quote! { (#content_fn)(&#field_expr, ctx)?; };
312 }
313
314 if let Some(ty) = field_ty {
315 if let Some(inner_ty) = extract_collection_inner(ty) {
316 return generate_collection_iteration(field_expr, &inner_ty);
317 }
318 }
319
320 quote! { #field_expr.syntax_fmt(ctx)?; }
321}
322
323fn expand_format_string(
324 format_str: &str,
325 field_expr: &proc_macro2::TokenStream,
326 content_expr: Option<&syn::Expr>,
327 field_ty: Option<&syn::Type>,
328) -> proc_macro2::TokenStream {
329 let (before, after, has_placeholder) = split_format_string(format_str);
330 let mut statements = Vec::new();
331
332 if !before.is_empty() {
333 statements.push(quote! { write!(ctx, #before)?; });
334 }
335
336 if has_placeholder {
337 statements.push(generate_default_content(field_expr, content_expr, field_ty));
338 }
339
340 if !after.is_empty() {
341 statements.push(quote! { write!(ctx, #after)?; });
342 }
343
344 quote! { #(#statements)* }
345}
346
347fn pretty_conditional(
348 normal: proc_macro2::TokenStream,
349 pretty: proc_macro2::TokenStream,
350) -> proc_macro2::TokenStream {
351 quote! {
352 if ctx.is_pretty() {
353 #pretty
354 } else {
355 #normal
356 }
357 }
358}
359
360fn wrap_with_outer_format(
361 fmt_body: proc_macro2::TokenStream,
362 outer_format: &PrettyString,
363) -> proc_macro2::TokenStream {
364 if outer_format.normal.is_none() && outer_format.pretty.is_none() {
365 return fmt_body;
366 }
367
368 let (normal_fmt, pretty_fmt) = outer_format.get_pair();
369
370 let wrap_body = |format_str: &str| -> proc_macro2::TokenStream {
371 let (before, after, has_placeholder) = split_format_string(format_str);
372
373 if !has_placeholder {
374 return quote! {
375 write!(ctx, #format_str)?;
376 #fmt_body
377 };
378 }
379
380 if before.is_empty() && after.is_empty() {
381 return fmt_body.clone();
382 }
383
384 if after.is_empty() {
385 return quote! {
386 write!(ctx, #before)?;
387 #fmt_body
388 };
389 }
390
391 quote! {
392 write!(ctx, #before)?;
393 (|| -> ::std::fmt::Result { #fmt_body })()?;
394 write!(ctx, #after)?;
395 Ok(())
396 }
397 };
398
399 if normal_fmt == pretty_fmt {
400 wrap_body(&normal_fmt)
401 } else {
402 pretty_conditional(wrap_body(&normal_fmt), wrap_body(&pretty_fmt))
403 }
404}
405
406fn generate_format_output(
407 field_expr: &proc_macro2::TokenStream,
408 format: &PrettyString,
409 content_expr: Option<&syn::Expr>,
410 field_ty: Option<&syn::Type>,
411) -> proc_macro2::TokenStream {
412 if format.normal.is_none() && format.pretty.is_none() {
414 return generate_default_content(field_expr, content_expr, field_ty);
415 }
416
417 let (normal_fmt, pretty_fmt) = format.get_pair();
418
419 if format.normal.is_none() {
421 let default_content = generate_default_content(field_expr, content_expr, field_ty);
422 let pretty_write = expand_format_string(&pretty_fmt, field_expr, content_expr, field_ty);
423 return quote! {
424 if ctx.is_pretty() {
425 #pretty_write
426 } else {
427 #default_content
428 }
429 };
430 }
431
432 let normal_write = expand_format_string(&normal_fmt, field_expr, content_expr, field_ty);
434
435 if normal_fmt == pretty_fmt {
436 normal_write
437 } else {
438 let pretty_write = expand_format_string(&pretty_fmt, field_expr, content_expr, field_ty);
439 pretty_conditional(normal_write, pretty_write)
440 }
441}
442
443fn generate_struct_fmt(fields: &Fields) -> proc_macro2::TokenStream {
444 match fields {
445 Fields::Named(fields_named) => generate_named_fields_fmt(&fields_named.named),
446 Fields::Unnamed(fields_unnamed) if fields_unnamed.unnamed.len() == 1 => {
447 let field = fields_unnamed.unnamed.first().unwrap();
448 let attrs = parse_field_attrs(&field.attrs);
449 let format_output = generate_format_output(
450 "e! { self.0 },
451 &attrs.format,
452 attrs.content.as_ref(),
453 Some(&field.ty),
454 );
455 quote! {
456 #format_output
457 Ok(())
458 }
459 }
460 Fields::Unnamed(_) | Fields::Unit => quote! { Ok(()) },
461 }
462}
463
464fn generate_named_fields_fmt(
465 fields: &syn::punctuated::Punctuated<syn::Field, syn::Token![,]>,
466) -> proc_macro2::TokenStream {
467 let mut statements = Vec::new();
468
469 for field in fields {
470 let field_name = field.ident.as_ref().unwrap();
471 let attrs = parse_field_attrs(&field.attrs);
472
473 if attrs.skip {
474 continue;
475 }
476
477 if is_type_ident(&field.ty, "bool") {
478 let format_output = generate_format_output(
479 "e! { &true },
480 &attrs.format,
481 attrs.content.as_ref(),
482 None,
483 );
484 statements.push(quote! {
485 if self.#field_name {
486 #format_output
487 }
488 });
489 } else if is_type_ident(&field.ty, "Option") {
490 let field_expr = quote! { #field_name };
491 let inner_ty = extract_option_inner(&field.ty);
492 let format_output = generate_format_output(
493 &field_expr,
494 &attrs.format,
495 attrs.content.as_ref(),
496 Some(&inner_ty),
497 );
498 statements.push(quote! {
499 if let Some(#field_name) = &self.#field_name {
500 #format_output
501 }
502 });
503 } else {
504 let field_expr = quote! { self.#field_name };
505 let mut field_statements = Vec::new();
506
507 if attrs.indent {
508 field_statements.push(quote! {
509 if ctx.is_pretty() {
510 ctx.indent()?;
511 }
512 });
513 }
514
515 if attrs.indent_region {
516 field_statements.push(quote! {
517 if ctx.is_pretty() {
518 ctx.inc_indent();
519 }
520 });
521 }
522
523 let format_output = generate_format_output(
524 &field_expr,
525 &attrs.format,
526 attrs.content.as_ref(),
527 Some(&field.ty),
528 );
529
530 field_statements.push(format_output);
531
532 if attrs.indent_region {
533 field_statements.push(quote! {
534 if ctx.is_pretty() {
535 ctx.dec_indent();
536 }
537 });
538 }
539
540 if let Some(empty_suffix) = &attrs.empty_suffix {
541 statements.push(quote! {
542 if self.#field_name.is_empty() {
543 write!(ctx, #empty_suffix)?;
544 } else {
545 #(#field_statements)*
546 }
547 });
548 } else {
549 statements.extend(field_statements);
550 }
551 }
552 }
553
554 statements.push(quote! { Ok(()) });
555 quote! { #(#statements)* }
556}
557
558fn generate_enum_fmt(
559 name: &syn::Ident,
560 variants: &syn::punctuated::Punctuated<syn::Variant, syn::Token![,]>,
561) -> proc_macro2::TokenStream {
562 let match_arms: Vec<_> = variants.iter().map(|variant| {
563 let variant_name = &variant.ident;
564 let attrs = parse_field_attrs(&variant.attrs);
565
566 match &variant.fields {
567 Fields::Named(_) => {
568 quote! {
569 #name::#variant_name { .. } => todo!("Named enum variants not yet supported")
570 }
571 }
572 Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {
573 let field = fields.unnamed.first().unwrap();
574 let format_output = generate_format_output(
575 "e! { inner },
576 &attrs.format,
577 attrs.content.as_ref(),
578 Some(&field.ty),
579 );
580 quote! {
581 #name::#variant_name(inner) => { #format_output Ok(()) }
582 }
583 }
584 Fields::Unnamed(_) => {
585 quote! {
586 #name::#variant_name(..) => todo!("Multi-field tuple variants not yet supported")
587 }
588 }
589 Fields::Unit => {
590 if attrs.format.normal.is_some() || attrs.format.pretty.is_some() {
591 let format_output = generate_format_output(
592 "e! { "" },
593 &attrs.format,
594 attrs.content.as_ref(),
595 None,
596 );
597 quote! { #name::#variant_name => { #format_output Ok(()) } }
598 } else {
599 let lower_name = variant_name.to_string().to_lowercase();
600 quote! { #name::#variant_name => write!(ctx, #lower_name) }
601 }
602 }
603 }
604 }).collect();
605
606 quote! {
607 match self {
608 #(#match_arms,)*
609 }
610 }
611}
612
613fn is_type_ident(ty: &syn::Type, ident_name: &str) -> bool {
614 if let syn::Type::Path(type_path) = ty {
615 if let Some(segment) = type_path.path.segments.last() {
616 return segment.ident == ident_name;
617 }
618 }
619 false
620}