1use std::collections::BTreeSet;
18
19use proc_macro::TokenStream;
20use proc_macro2::TokenStream as TokenStream2;
21use quote::ToTokens;
22use quote::quote;
23use serde_derive_internals::Ctxt;
24use serde_derive_internals::Derive;
25use serde_derive_internals::ast;
26use serde_derive_internals::attr;
27use syn::DeriveInput;
28use syn::GenericArgument;
29use syn::LitStr;
30use syn::Member;
31use syn::PathArguments;
32use syn::ReturnType;
33use syn::Type;
34use syn::TypeParamBound;
35use syn::parse_macro_input;
36use syn::parse_quote;
37
38#[proc_macro_derive(SerdeShape, attributes(serde))]
40pub fn derive_serde_shape(input: TokenStream) -> TokenStream {
41 let input = parse_macro_input!(input as DeriveInput);
42
43 match expand_serde_shape(&input) {
44 Ok(tokens) => tokens.into(),
45 Err(err) => err.to_compile_error().into(),
46 }
47}
48
49fn expand_serde_shape(input: &DeriveInput) -> syn::Result<TokenStream2> {
50 let cx = Ctxt::new();
51 let Some(container) = ast::Container::from_ast(&cx, input, Derive::Deserialize) else {
52 cx.check()?;
53 return Err(syn::Error::new_spanned(
54 input,
55 "serde-shape could not parse this item",
56 ));
57 };
58 cx.check()?;
59
60 let ident = &input.ident;
61 let mut generics = input.generics.clone();
62 add_shape_bounds(&mut generics, &container);
63 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
64 let body = shape_body(&container);
65
66 Ok(quote! {
67 impl #impl_generics ::serde_shape::SerdeShape for #ident #ty_generics #where_clause {
68 fn shape_in(context: &mut ::serde_shape::ShapeContext) -> ::serde_shape::ShapeRef {
69 #body
70 }
71 }
72 })
73}
74
75fn add_shape_bounds(generics: &mut syn::Generics, container: &ast::Container<'_>) {
76 if container.attrs.type_from().is_some()
77 || container.attrs.type_try_from().is_some()
78 || container.attrs.remote().is_some()
79 {
80 return;
81 }
82
83 let type_params: BTreeSet<_> = generics
84 .type_params()
85 .map(|param| param.ident.to_string())
86 .collect();
87 let mut field_bound_types = Vec::new();
88
89 match &container.data {
90 ast::Data::Struct(_, fields) => {
91 collect_field_bound_types(fields, &type_params, &mut field_bound_types);
92 }
93 ast::Data::Enum(variants) => {
94 for variant in variants {
95 if variant.attrs.skip_deserializing() || variant.attrs.deserialize_with().is_some()
96 {
97 continue;
98 }
99 collect_field_bound_types(&variant.fields, &type_params, &mut field_bound_types);
100 }
101 }
102 }
103
104 for ty in field_bound_types {
105 generics
106 .make_where_clause()
107 .predicates
108 .push(parse_quote!(#ty: ::serde_shape::SerdeShape));
109 }
110}
111
112fn collect_field_bound_types(
113 fields: &[ast::Field<'_>],
114 type_params: &BTreeSet<String>,
115 field_bound_types: &mut Vec<Type>,
116) {
117 for field in fields {
118 if field.attrs.skip_deserializing() || field.attrs.deserialize_with().is_some() {
119 continue;
120 }
121
122 let mut used_type_params = BTreeSet::new();
123 collect_type_params(field.ty, type_params, &mut used_type_params);
124 if !used_type_params.is_empty() {
125 field_bound_types.push((*field.ty).clone());
126 }
127 }
128}
129
130fn collect_type_params(
131 ty: &Type,
132 type_params: &BTreeSet<String>,
133 used_type_params: &mut BTreeSet<String>,
134) {
135 match ty {
136 Type::Array(ty) => collect_type_params(&ty.elem, type_params, used_type_params),
137 Type::BareFn(ty) => {
138 for input in &ty.inputs {
139 collect_type_params(&input.ty, type_params, used_type_params);
140 }
141 collect_return_type_params(&ty.output, type_params, used_type_params);
142 }
143 Type::Group(ty) => collect_type_params(&ty.elem, type_params, used_type_params),
144 Type::ImplTrait(ty) => collect_type_param_bounds(&ty.bounds, type_params, used_type_params),
145 Type::Paren(ty) => collect_type_params(&ty.elem, type_params, used_type_params),
146 Type::Path(ty) => {
147 if let Some(qself) = &ty.qself {
148 collect_type_params(&qself.ty, type_params, used_type_params);
149 }
150 for segment in &ty.path.segments {
151 let ident = segment.ident.to_string();
152 if type_params.contains(&ident) {
153 used_type_params.insert(ident);
154 }
155 collect_path_arguments(&segment.arguments, type_params, used_type_params);
156 }
157 }
158 Type::Ptr(ty) => collect_type_params(&ty.elem, type_params, used_type_params),
159 Type::Reference(ty) => collect_type_params(&ty.elem, type_params, used_type_params),
160 Type::Slice(ty) => collect_type_params(&ty.elem, type_params, used_type_params),
161 Type::TraitObject(ty) => {
162 collect_type_param_bounds(&ty.bounds, type_params, used_type_params);
163 }
164 Type::Tuple(ty) => {
165 for elem in &ty.elems {
166 collect_type_params(elem, type_params, used_type_params);
167 }
168 }
169 Type::Infer(_) | Type::Macro(_) | Type::Never(_) | Type::Verbatim(_) => {}
170 _ => {}
171 }
172}
173
174fn collect_path_arguments(
175 arguments: &PathArguments,
176 type_params: &BTreeSet<String>,
177 used_type_params: &mut BTreeSet<String>,
178) {
179 match arguments {
180 PathArguments::None => {}
181 PathArguments::AngleBracketed(arguments) => {
182 for argument in &arguments.args {
183 match argument {
184 GenericArgument::Type(ty) => {
185 collect_type_params(ty, type_params, used_type_params);
186 }
187 GenericArgument::AssocType(assoc) => {
188 collect_type_params(&assoc.ty, type_params, used_type_params);
189 }
190 GenericArgument::Constraint(constraint) => {
191 collect_type_param_bounds(
192 &constraint.bounds,
193 type_params,
194 used_type_params,
195 );
196 }
197 GenericArgument::Lifetime(_)
198 | GenericArgument::Const(_)
199 | GenericArgument::AssocConst(_) => {}
200 _ => {}
201 }
202 }
203 }
204 PathArguments::Parenthesized(arguments) => {
205 for input in &arguments.inputs {
206 collect_type_params(input, type_params, used_type_params);
207 }
208 collect_return_type_params(&arguments.output, type_params, used_type_params);
209 }
210 }
211}
212
213fn collect_type_param_bounds(
214 bounds: &syn::punctuated::Punctuated<TypeParamBound, syn::Token![+]>,
215 type_params: &BTreeSet<String>,
216 used_type_params: &mut BTreeSet<String>,
217) {
218 for bound in bounds {
219 if let TypeParamBound::Trait(bound) = bound {
220 for segment in &bound.path.segments {
221 collect_path_arguments(&segment.arguments, type_params, used_type_params);
222 }
223 }
224 }
225}
226
227fn collect_return_type_params(
228 return_type: &ReturnType,
229 type_params: &BTreeSet<String>,
230 used_type_params: &mut BTreeSet<String>,
231) {
232 if let ReturnType::Type(_, ty) = return_type {
233 collect_type_params(ty, type_params, used_type_params);
234 }
235}
236
237fn shape_body(container: &ast::Container<'_>) -> TokenStream2 {
238 let serde_name = lit(container.attrs.name().deserialize_name());
239 let kind = definition_kind(container);
240
241 quote! {
242 context.define_named_type(
243 ::serde_shape::TypeName {
244 rust_name: ::std::any::type_name::<Self>(),
245 serde_name: #serde_name,
246 },
247 |context| {
248 #kind
249 },
250 )
251 }
252}
253
254fn definition_kind(container: &ast::Container<'_>) -> TokenStream2 {
255 if let Some(ty) = container.attrs.type_from() {
256 return opaque_definition("FromType", ty);
257 }
258 if let Some(ty) = container.attrs.type_try_from() {
259 return opaque_definition("TryFromType", ty);
260 }
261 if let Some(path) = container.attrs.remote() {
262 return opaque_definition("Remote", path);
263 }
264
265 let attributes = container_attributes(&container.attrs);
266 match &container.data {
267 ast::Data::Struct(style, fields) => {
268 let style = fields_style(*style);
269 let fields = fields.iter().map(field_shape);
270 quote! {
271 ::serde_shape::DefinitionKind::Struct(::serde_shape::StructShape {
272 style: #style,
273 fields: ::std::vec![#(#fields),*],
274 attributes: #attributes,
275 })
276 }
277 }
278 ast::Data::Enum(variants) => {
279 let repr = tagging(container.attrs.tag());
280 let variants = variants.iter().map(variant_shape);
281 quote! {
282 ::serde_shape::DefinitionKind::Enum(::serde_shape::EnumShape {
283 repr: #repr,
284 variants: ::std::vec![#(#variants),*],
285 attributes: #attributes,
286 })
287 }
288 }
289 }
290}
291
292fn opaque_definition<T>(reason: &str, detail: T) -> TokenStream2
293where
294 T: ToTokens,
295{
296 let reason = opaque_reason(reason);
297 let detail = lit(detail.to_token_stream().to_string());
298
299 quote! {
300 ::serde_shape::DefinitionKind::Opaque(::serde_shape::OpaqueShape {
301 type_name: ::std::any::type_name::<Self>(),
302 reason: #reason,
303 detail: ::std::option::Option::Some(#detail),
304 })
305 }
306}
307
308fn container_attributes(attrs: &attr::Container) -> TokenStream2 {
309 let tagging = tagging(attrs.tag());
310 let deny_unknown_fields = attrs.deny_unknown_fields();
311 let default = default_shape(attrs.default());
312 let has_flatten = attrs.has_flatten();
313 let transparent = attrs.transparent();
314 let expecting = option_lit(attrs.expecting());
315 let non_exhaustive = attrs.non_exhaustive();
316
317 quote! {
318 ::serde_shape::ContainerAttributes {
319 tagging: #tagging,
320 deny_unknown_fields: #deny_unknown_fields,
321 default: #default,
322 has_flatten: #has_flatten,
323 transparent: #transparent,
324 expecting: #expecting,
325 non_exhaustive: #non_exhaustive,
326 }
327 }
328}
329
330fn variant_shape(variant: &ast::Variant<'_>) -> TokenStream2 {
331 let rust_name = lit(variant.ident.to_string());
332 let deserialize_name = lit(variant.attrs.name().deserialize_name());
333 let deserialize_aliases = aliases(variant.attrs.aliases());
334 let style = fields_style(variant.style);
335 let skip_deserializing = variant.attrs.skip_deserializing();
336 let custom_deserializer = variant.attrs.deserialize_with().is_some();
337 let other = variant.attrs.other();
338 let untagged = variant.attrs.untagged();
339 let fields: Vec<_> = if skip_deserializing || custom_deserializer {
340 Vec::new()
341 } else {
342 variant.fields.iter().map(field_shape).collect()
343 };
344
345 quote! {
346 ::serde_shape::VariantShape {
347 rust_name: #rust_name,
348 deserialize_name: #deserialize_name,
349 deserialize_aliases: #deserialize_aliases,
350 style: #style,
351 fields: ::std::vec![#(#fields),*],
352 skip_deserializing: #skip_deserializing,
353 custom_deserializer: #custom_deserializer,
354 other: #other,
355 untagged: #untagged,
356 }
357 }
358}
359
360fn field_shape(field: &ast::Field<'_>) -> TokenStream2 {
361 let member = field_member(&field.member);
362 let deserialize_name = lit(field.attrs.name().deserialize_name());
363 let deserialize_aliases = aliases(field.attrs.aliases());
364 let skip_deserializing = field.attrs.skip_deserializing();
365 let custom_deserializer = field.attrs.deserialize_with().is_some();
366 let default = default_shape(field.attrs.default());
367 let flatten = field.attrs.flatten();
368 let transparent = field.attrs.transparent();
369 let ty = field.ty;
370 let shape = if skip_deserializing || custom_deserializer {
371 quote!(::std::option::Option::None)
372 } else {
373 quote!(::std::option::Option::Some(<#ty as ::serde_shape::SerdeShape>::shape_in(context)))
374 };
375
376 quote! {
377 ::serde_shape::FieldShape {
378 member: #member,
379 deserialize_name: #deserialize_name,
380 deserialize_aliases: #deserialize_aliases,
381 shape: #shape,
382 default: #default,
383 flatten: #flatten,
384 skip_deserializing: #skip_deserializing,
385 custom_deserializer: #custom_deserializer,
386 transparent: #transparent,
387 }
388 }
389}
390
391fn field_member(member: &Member) -> TokenStream2 {
392 match member {
393 Member::Named(ident) => {
394 let ident = lit(ident.to_string());
395 quote!(::serde_shape::FieldMember::Named(#ident))
396 }
397 Member::Unnamed(index) => {
398 let index = index.index as usize;
399 quote!(::serde_shape::FieldMember::Unnamed(#index))
400 }
401 }
402}
403
404fn fields_style(style: ast::Style) -> TokenStream2 {
405 match style {
406 ast::Style::Struct => quote!(::serde_shape::FieldsStyle::Struct),
407 ast::Style::Tuple => quote!(::serde_shape::FieldsStyle::Tuple),
408 ast::Style::Newtype => quote!(::serde_shape::FieldsStyle::Newtype),
409 ast::Style::Unit => quote!(::serde_shape::FieldsStyle::Unit),
410 }
411}
412
413fn tagging(tag: &attr::TagType) -> TokenStream2 {
414 match tag {
415 attr::TagType::External => quote!(::serde_shape::Tagging::External),
416 attr::TagType::Internal { tag } => {
417 let tag = lit(tag);
418 quote!(::serde_shape::Tagging::Internal { tag: #tag })
419 }
420 attr::TagType::Adjacent { tag, content } => {
421 let tag = lit(tag);
422 let content = lit(content);
423 quote!(::serde_shape::Tagging::Adjacent {
424 tag: #tag,
425 content: #content,
426 })
427 }
428 attr::TagType::None => quote!(::serde_shape::Tagging::Untagged),
429 }
430}
431
432fn default_shape(default: &attr::Default) -> TokenStream2 {
433 match default {
434 attr::Default::None => quote!(::serde_shape::DefaultShape::None),
435 attr::Default::Default => quote!(::serde_shape::DefaultShape::Default),
436 attr::Default::Path(path) => {
437 let path = lit(path.to_token_stream().to_string());
438 quote!(::serde_shape::DefaultShape::Path(#path))
439 }
440 }
441}
442
443fn opaque_reason(reason: &str) -> TokenStream2 {
444 match reason {
445 "FromType" => quote!(::serde_shape::OpaqueReason::FromType),
446 "TryFromType" => quote!(::serde_shape::OpaqueReason::TryFromType),
447 "Remote" => quote!(::serde_shape::OpaqueReason::Remote),
448 _ => quote!(::serde_shape::OpaqueReason::Unsupported),
449 }
450}
451
452fn aliases(aliases: &std::collections::BTreeSet<String>) -> TokenStream2 {
453 let aliases = aliases.iter().map(lit);
454 quote!(::std::vec![#(#aliases),*])
455}
456
457fn option_lit(value: Option<&str>) -> TokenStream2 {
458 match value {
459 Some(value) => {
460 let value = lit(value);
461 quote!(::std::option::Option::Some(#value))
462 }
463 None => quote!(::std::option::Option::None),
464 }
465}
466
467fn lit(value: impl AsRef<str>) -> LitStr {
468 LitStr::new(value.as_ref(), proc_macro2::Span::call_site())
469}