Skip to main content

serde_shape_derive/
lib.rs

1// Copyright 2026 FastLabs Developers
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Derive macro for `serde-shape`.
16
17use std::collections::BTreeSet;
18
19use proc_macro::TokenStream;
20use proc_macro2::TokenStream as TokenStream2;
21use quote::ToTokens;
22use quote::quote;
23use serde_derive_internals::Ctxt;
24use serde_derive_internals::Derive;
25use serde_derive_internals::ast;
26use serde_derive_internals::attr;
27use syn::DeriveInput;
28use syn::GenericArgument;
29use syn::LitStr;
30use syn::Member;
31use syn::PathArguments;
32use syn::ReturnType;
33use syn::Type;
34use syn::TypeParamBound;
35use syn::parse_macro_input;
36use syn::parse_quote;
37
38/// Derive `serde_shape::SerdeShape` from Serde derive metadata.
39#[proc_macro_derive(SerdeShape, attributes(serde))]
40pub fn derive_serde_shape(input: TokenStream) -> TokenStream {
41    let input = parse_macro_input!(input as DeriveInput);
42
43    match expand_serde_shape(&input) {
44        Ok(tokens) => tokens.into(),
45        Err(err) => err.to_compile_error().into(),
46    }
47}
48
49fn expand_serde_shape(input: &DeriveInput) -> syn::Result<TokenStream2> {
50    let cx = Ctxt::new();
51    let Some(container) = ast::Container::from_ast(&cx, input, Derive::Deserialize) else {
52        cx.check()?;
53        return Err(syn::Error::new_spanned(
54            input,
55            "serde-shape could not parse this item",
56        ));
57    };
58    cx.check()?;
59
60    let ident = &input.ident;
61    let mut generics = input.generics.clone();
62    add_shape_bounds(&mut generics, &container);
63    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
64    let body = shape_body(&container);
65
66    Ok(quote! {
67        impl #impl_generics ::serde_shape::SerdeShape for #ident #ty_generics #where_clause {
68            fn shape_in(context: &mut ::serde_shape::ShapeContext) -> ::serde_shape::ShapeRef {
69                #body
70            }
71        }
72    })
73}
74
75fn add_shape_bounds(generics: &mut syn::Generics, container: &ast::Container<'_>) {
76    if container.attrs.type_from().is_some()
77        || container.attrs.type_try_from().is_some()
78        || container.attrs.remote().is_some()
79    {
80        return;
81    }
82
83    let type_params: BTreeSet<_> = generics
84        .type_params()
85        .map(|param| param.ident.to_string())
86        .collect();
87    let mut field_bound_types = Vec::new();
88
89    match &container.data {
90        ast::Data::Struct(_, fields) => {
91            collect_field_bound_types(fields, &type_params, &mut field_bound_types);
92        }
93        ast::Data::Enum(variants) => {
94            for variant in variants {
95                if variant.attrs.skip_deserializing() || variant.attrs.deserialize_with().is_some()
96                {
97                    continue;
98                }
99                collect_field_bound_types(&variant.fields, &type_params, &mut field_bound_types);
100            }
101        }
102    }
103
104    for ty in field_bound_types {
105        generics
106            .make_where_clause()
107            .predicates
108            .push(parse_quote!(#ty: ::serde_shape::SerdeShape));
109    }
110}
111
112fn collect_field_bound_types(
113    fields: &[ast::Field<'_>],
114    type_params: &BTreeSet<String>,
115    field_bound_types: &mut Vec<Type>,
116) {
117    for field in fields {
118        if field.attrs.skip_deserializing() || field.attrs.deserialize_with().is_some() {
119            continue;
120        }
121
122        let mut used_type_params = BTreeSet::new();
123        collect_type_params(field.ty, type_params, &mut used_type_params);
124        if !used_type_params.is_empty() {
125            field_bound_types.push((*field.ty).clone());
126        }
127    }
128}
129
130fn collect_type_params(
131    ty: &Type,
132    type_params: &BTreeSet<String>,
133    used_type_params: &mut BTreeSet<String>,
134) {
135    match ty {
136        Type::Array(ty) => collect_type_params(&ty.elem, type_params, used_type_params),
137        Type::BareFn(ty) => {
138            for input in &ty.inputs {
139                collect_type_params(&input.ty, type_params, used_type_params);
140            }
141            collect_return_type_params(&ty.output, type_params, used_type_params);
142        }
143        Type::Group(ty) => collect_type_params(&ty.elem, type_params, used_type_params),
144        Type::ImplTrait(ty) => collect_type_param_bounds(&ty.bounds, type_params, used_type_params),
145        Type::Paren(ty) => collect_type_params(&ty.elem, type_params, used_type_params),
146        Type::Path(ty) => {
147            if let Some(qself) = &ty.qself {
148                collect_type_params(&qself.ty, type_params, used_type_params);
149            }
150            for segment in &ty.path.segments {
151                let ident = segment.ident.to_string();
152                if type_params.contains(&ident) {
153                    used_type_params.insert(ident);
154                }
155                collect_path_arguments(&segment.arguments, type_params, used_type_params);
156            }
157        }
158        Type::Ptr(ty) => collect_type_params(&ty.elem, type_params, used_type_params),
159        Type::Reference(ty) => collect_type_params(&ty.elem, type_params, used_type_params),
160        Type::Slice(ty) => collect_type_params(&ty.elem, type_params, used_type_params),
161        Type::TraitObject(ty) => {
162            collect_type_param_bounds(&ty.bounds, type_params, used_type_params);
163        }
164        Type::Tuple(ty) => {
165            for elem in &ty.elems {
166                collect_type_params(elem, type_params, used_type_params);
167            }
168        }
169        Type::Infer(_) | Type::Macro(_) | Type::Never(_) | Type::Verbatim(_) => {}
170        _ => {}
171    }
172}
173
174fn collect_path_arguments(
175    arguments: &PathArguments,
176    type_params: &BTreeSet<String>,
177    used_type_params: &mut BTreeSet<String>,
178) {
179    match arguments {
180        PathArguments::None => {}
181        PathArguments::AngleBracketed(arguments) => {
182            for argument in &arguments.args {
183                match argument {
184                    GenericArgument::Type(ty) => {
185                        collect_type_params(ty, type_params, used_type_params);
186                    }
187                    GenericArgument::AssocType(assoc) => {
188                        collect_type_params(&assoc.ty, type_params, used_type_params);
189                    }
190                    GenericArgument::Constraint(constraint) => {
191                        collect_type_param_bounds(
192                            &constraint.bounds,
193                            type_params,
194                            used_type_params,
195                        );
196                    }
197                    GenericArgument::Lifetime(_)
198                    | GenericArgument::Const(_)
199                    | GenericArgument::AssocConst(_) => {}
200                    _ => {}
201                }
202            }
203        }
204        PathArguments::Parenthesized(arguments) => {
205            for input in &arguments.inputs {
206                collect_type_params(input, type_params, used_type_params);
207            }
208            collect_return_type_params(&arguments.output, type_params, used_type_params);
209        }
210    }
211}
212
213fn collect_type_param_bounds(
214    bounds: &syn::punctuated::Punctuated<TypeParamBound, syn::Token![+]>,
215    type_params: &BTreeSet<String>,
216    used_type_params: &mut BTreeSet<String>,
217) {
218    for bound in bounds {
219        if let TypeParamBound::Trait(bound) = bound {
220            for segment in &bound.path.segments {
221                collect_path_arguments(&segment.arguments, type_params, used_type_params);
222            }
223        }
224    }
225}
226
227fn collect_return_type_params(
228    return_type: &ReturnType,
229    type_params: &BTreeSet<String>,
230    used_type_params: &mut BTreeSet<String>,
231) {
232    if let ReturnType::Type(_, ty) = return_type {
233        collect_type_params(ty, type_params, used_type_params);
234    }
235}
236
237fn shape_body(container: &ast::Container<'_>) -> TokenStream2 {
238    let serde_name = lit(container.attrs.name().deserialize_name());
239    let kind = definition_kind(container);
240
241    quote! {
242        context.define_named_type(
243            ::serde_shape::TypeName {
244                rust_name: ::std::any::type_name::<Self>(),
245                serde_name: #serde_name,
246            },
247            |context| {
248                #kind
249            },
250        )
251    }
252}
253
254fn definition_kind(container: &ast::Container<'_>) -> TokenStream2 {
255    if let Some(ty) = container.attrs.type_from() {
256        return opaque_definition("FromType", ty);
257    }
258    if let Some(ty) = container.attrs.type_try_from() {
259        return opaque_definition("TryFromType", ty);
260    }
261    if let Some(path) = container.attrs.remote() {
262        return opaque_definition("Remote", path);
263    }
264
265    let attributes = container_attributes(&container.attrs);
266    match &container.data {
267        ast::Data::Struct(style, fields) => {
268            let style = fields_style(*style);
269            let fields = fields.iter().map(field_shape);
270            quote! {
271                ::serde_shape::DefinitionKind::Struct(::serde_shape::StructShape {
272                    style: #style,
273                    fields: ::std::vec![#(#fields),*],
274                    attributes: #attributes,
275                })
276            }
277        }
278        ast::Data::Enum(variants) => {
279            let repr = tagging(container.attrs.tag());
280            let variants = variants.iter().map(variant_shape);
281            quote! {
282                ::serde_shape::DefinitionKind::Enum(::serde_shape::EnumShape {
283                    repr: #repr,
284                    variants: ::std::vec![#(#variants),*],
285                    attributes: #attributes,
286                })
287            }
288        }
289    }
290}
291
292fn opaque_definition<T>(reason: &str, detail: T) -> TokenStream2
293where
294    T: ToTokens,
295{
296    let reason = opaque_reason(reason);
297    let detail = lit(detail.to_token_stream().to_string());
298
299    quote! {
300        ::serde_shape::DefinitionKind::Opaque(::serde_shape::OpaqueShape {
301            type_name: ::std::any::type_name::<Self>(),
302            reason: #reason,
303            detail: ::std::option::Option::Some(#detail),
304        })
305    }
306}
307
308fn container_attributes(attrs: &attr::Container) -> TokenStream2 {
309    let tagging = tagging(attrs.tag());
310    let deny_unknown_fields = attrs.deny_unknown_fields();
311    let default = default_shape(attrs.default());
312    let has_flatten = attrs.has_flatten();
313    let transparent = attrs.transparent();
314    let expecting = option_lit(attrs.expecting());
315    let non_exhaustive = attrs.non_exhaustive();
316
317    quote! {
318        ::serde_shape::ContainerAttributes {
319            tagging: #tagging,
320            deny_unknown_fields: #deny_unknown_fields,
321            default: #default,
322            has_flatten: #has_flatten,
323            transparent: #transparent,
324            expecting: #expecting,
325            non_exhaustive: #non_exhaustive,
326        }
327    }
328}
329
330fn variant_shape(variant: &ast::Variant<'_>) -> TokenStream2 {
331    let rust_name = lit(variant.ident.to_string());
332    let deserialize_name = lit(variant.attrs.name().deserialize_name());
333    let deserialize_aliases = aliases(variant.attrs.aliases());
334    let style = fields_style(variant.style);
335    let skip_deserializing = variant.attrs.skip_deserializing();
336    let custom_deserializer = variant.attrs.deserialize_with().is_some();
337    let other = variant.attrs.other();
338    let untagged = variant.attrs.untagged();
339    let fields: Vec<_> = if skip_deserializing || custom_deserializer {
340        Vec::new()
341    } else {
342        variant.fields.iter().map(field_shape).collect()
343    };
344
345    quote! {
346        ::serde_shape::VariantShape {
347            rust_name: #rust_name,
348            deserialize_name: #deserialize_name,
349            deserialize_aliases: #deserialize_aliases,
350            style: #style,
351            fields: ::std::vec![#(#fields),*],
352            skip_deserializing: #skip_deserializing,
353            custom_deserializer: #custom_deserializer,
354            other: #other,
355            untagged: #untagged,
356        }
357    }
358}
359
360fn field_shape(field: &ast::Field<'_>) -> TokenStream2 {
361    let member = field_member(&field.member);
362    let deserialize_name = lit(field.attrs.name().deserialize_name());
363    let deserialize_aliases = aliases(field.attrs.aliases());
364    let skip_deserializing = field.attrs.skip_deserializing();
365    let custom_deserializer = field.attrs.deserialize_with().is_some();
366    let default = default_shape(field.attrs.default());
367    let flatten = field.attrs.flatten();
368    let transparent = field.attrs.transparent();
369    let ty = field.ty;
370    let shape = if skip_deserializing || custom_deserializer {
371        quote!(::std::option::Option::None)
372    } else {
373        quote!(::std::option::Option::Some(<#ty as ::serde_shape::SerdeShape>::shape_in(context)))
374    };
375
376    quote! {
377        ::serde_shape::FieldShape {
378            member: #member,
379            deserialize_name: #deserialize_name,
380            deserialize_aliases: #deserialize_aliases,
381            shape: #shape,
382            default: #default,
383            flatten: #flatten,
384            skip_deserializing: #skip_deserializing,
385            custom_deserializer: #custom_deserializer,
386            transparent: #transparent,
387        }
388    }
389}
390
391fn field_member(member: &Member) -> TokenStream2 {
392    match member {
393        Member::Named(ident) => {
394            let ident = lit(ident.to_string());
395            quote!(::serde_shape::FieldMember::Named(#ident))
396        }
397        Member::Unnamed(index) => {
398            let index = index.index as usize;
399            quote!(::serde_shape::FieldMember::Unnamed(#index))
400        }
401    }
402}
403
404fn fields_style(style: ast::Style) -> TokenStream2 {
405    match style {
406        ast::Style::Struct => quote!(::serde_shape::FieldsStyle::Struct),
407        ast::Style::Tuple => quote!(::serde_shape::FieldsStyle::Tuple),
408        ast::Style::Newtype => quote!(::serde_shape::FieldsStyle::Newtype),
409        ast::Style::Unit => quote!(::serde_shape::FieldsStyle::Unit),
410    }
411}
412
413fn tagging(tag: &attr::TagType) -> TokenStream2 {
414    match tag {
415        attr::TagType::External => quote!(::serde_shape::Tagging::External),
416        attr::TagType::Internal { tag } => {
417            let tag = lit(tag);
418            quote!(::serde_shape::Tagging::Internal { tag: #tag })
419        }
420        attr::TagType::Adjacent { tag, content } => {
421            let tag = lit(tag);
422            let content = lit(content);
423            quote!(::serde_shape::Tagging::Adjacent {
424                tag: #tag,
425                content: #content,
426            })
427        }
428        attr::TagType::None => quote!(::serde_shape::Tagging::Untagged),
429    }
430}
431
432fn default_shape(default: &attr::Default) -> TokenStream2 {
433    match default {
434        attr::Default::None => quote!(::serde_shape::DefaultShape::None),
435        attr::Default::Default => quote!(::serde_shape::DefaultShape::Default),
436        attr::Default::Path(path) => {
437            let path = lit(path.to_token_stream().to_string());
438            quote!(::serde_shape::DefaultShape::Path(#path))
439        }
440    }
441}
442
443fn opaque_reason(reason: &str) -> TokenStream2 {
444    match reason {
445        "FromType" => quote!(::serde_shape::OpaqueReason::FromType),
446        "TryFromType" => quote!(::serde_shape::OpaqueReason::TryFromType),
447        "Remote" => quote!(::serde_shape::OpaqueReason::Remote),
448        _ => quote!(::serde_shape::OpaqueReason::Unsupported),
449    }
450}
451
452fn aliases(aliases: &std::collections::BTreeSet<String>) -> TokenStream2 {
453    let aliases = aliases.iter().map(lit);
454    quote!(::std::vec![#(#aliases),*])
455}
456
457fn option_lit(value: Option<&str>) -> TokenStream2 {
458    match value {
459        Some(value) => {
460            let value = lit(value);
461            quote!(::std::option::Option::Some(#value))
462        }
463        None => quote!(::std::option::Option::None),
464    }
465}
466
467fn lit(value: impl AsRef<str>) -> LitStr {
468    LitStr::new(value.as_ref(), proc_macro2::Span::call_site())
469}