serde_ordered_derive/
lib.rs

1use std::collections::HashSet;
2
3use proc_macro::TokenStream;
4use proc_macro2::Span;
5use quote::quote;
6use syn::{parse_macro_input, Type, Data, DeriveInput, Field, Fields, LitInt, Result, Ident};
7
8struct FieldOrder {
9    pub order: usize,
10    pub field_name: Ident,
11    pub dtype: Type
12}
13
14///A procedural macro for deserializing ordered arrays into keyed structs using Serde.
15#[proc_macro_derive(DeserializeOrdered, attributes(order))]
16pub fn derive_order(input: TokenStream) -> TokenStream {
17    let input = parse_macro_input!(input as DeriveInput);
18    let name = &input.ident;
19
20    let fields = match &input.data {
21        Data::Struct(data_struct) => match &data_struct.fields {
22            Fields::Named(named_field) => &named_field.named,
23            _ => return syn::Error::new_spanned(
24                &input,
25                "DeserializeOrdered only supports structs with named fields",
26            ).to_compile_error().into(),
27        },
28        _ => return syn::Error::new_spanned(
29            &input,
30            "DeserializeOrdered can only be derived for structs",
31        ).to_compile_error().into(),
32    };
33
34    let mut field_orders = vec![];
35
36    for field in fields {
37        let field_name = field.ident.as_ref().unwrap();
38
39        // Extract #[order(x)] attribute
40        let order = match get_order_from_field(field) {
41            Ok(order) => order,
42            Err(err) => return err.to_compile_error().into(),
43        };
44
45        field_orders.push(FieldOrder {
46            order,
47            field_name: field_name.clone(),
48            dtype: field.ty.clone(),
49        });
50    }
51
52    // Check if every field has an order and that all orders are unique
53    let total_fields = fields.len();
54    if field_orders.len() != total_fields {
55        return syn::Error::new_spanned(
56            &input,
57            "DeserializeOrdered requires all fields do have #[serde(order = x)]",
58        ).to_compile_error().into();
59    }
60
61    //Check for duplicate orders
62    let orders_set = field_orders.iter().map(|fo| fo.order).collect::<HashSet<_>>();
63    if orders_set.len() != total_fields {
64        return syn::Error::new_spanned(
65            &input,
66            "DeserializeOrdered requires all fields to have unique orders",
67        ).to_compile_error().into();
68    }
69
70    // Sort fields by order index
71    field_orders.sort_by_key(|order| order.order);
72    let field_names: Vec<_> = field_orders.iter().map(|fo| fo.field_name.to_owned()).collect();
73    let field_types: Vec<_> = field_orders.iter().map(|fo| fo.dtype.to_owned()).collect();
74    let field_orders_only: Vec<_> = field_orders.iter().map(|fo| fo.order).collect();
75
76    let field_enum = Ident::new("__SerdeOrderedField", Span::call_site());
77    let field_enum_variants: Vec<Ident> = field_orders
78        .iter()
79        .enumerate()
80        .map(|(index, _)| Ident::new(&format!("__Field{}", index), Span::call_site()))
81        .collect();
82    
83    // Generate deserialization logic
84    let deserialization = quote! {
85        impl<'de> serde::Deserialize<'de> for #name {
86            fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
87            where
88                D: serde::Deserializer<'de>,
89            {
90                use serde::de::{IgnoredAny, MapAccess, SeqAccess, Unexpected, Visitor};
91                use std::fmt;
92
93                const FIELDS: &'static [&'static str] = &[#(stringify!(#field_names)),*];
94
95                #[allow(non_camel_case_types)]
96                enum #field_enum {
97                    #(#field_enum_variants),*,
98                    __Ignore,
99                }
100
101                impl<'de> serde::Deserialize<'de> for #field_enum {
102                    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
103                    where
104                        D: serde::Deserializer<'de>,
105                    {
106                        struct FieldVisitor;
107
108                        impl<'de> Visitor<'de> for FieldVisitor {
109                            type Value = #field_enum;
110
111                            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
112                                formatter.write_str("field identifier")
113                            }
114
115                            fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
116                            where
117                                E: serde::de::Error,
118                            {
119                                match value {
120                                    #(stringify!(#field_names) => Ok(#field_enum::#field_enum_variants),)*
121                                    _ => Ok(#field_enum::__Ignore),
122                                }
123                            }
124
125                            fn visit_bytes<E>(self, value: &[u8]) -> Result<Self::Value, E>
126                            where
127                                E: serde::de::Error,
128                            {
129                                match std::str::from_utf8(value) {
130                                    Ok(s) => self.visit_str(s),
131                                    Err(_) => Err(E::invalid_value(Unexpected::Bytes(value), &"field identifier")),
132                                }
133                            }
134                        }
135
136                        deserializer.deserialize_identifier(FieldVisitor)
137                    }
138                }
139
140                struct OrderedVisitor;
141
142                impl<'de> Visitor<'de> for OrderedVisitor {
143                    type Value = #name;
144
145                    fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
146                        formatter.write_str("a struct represented as a sequence or map with ordered fields")
147                    }
148
149                    fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
150                    where
151                        A: SeqAccess<'de>,
152                    {
153                        let mut index: usize = 0;
154
155                        #(
156                            let mut #field_names: Option<#field_types> = None;
157                        )*
158
159                        loop {
160                            let handled = match index {
161                                #(
162                                    #field_orders_only => match seq.next_element::<#field_types>()? {
163                                        Some(value) => {
164                                            #field_names = Some(value);
165                                            true
166                                        }
167                                        None => false,
168                                    },
169                                )*
170                                _ => match seq.next_element::<IgnoredAny>()? {
171                                    Some(_) => true,
172                                    None => false,
173                                },
174                            };
175
176                            if !handled {
177                                break;
178                            }
179
180                            index += 1;
181                        }
182
183                        #(
184                            let #field_names: #field_types = match #field_names {
185                                Some(result) => result,
186                                None => return Err(serde::de::Error::custom(concat!("Order for ", stringify!(#field_names), " was missing from the sequence"))),
187                            };
188                        )*
189
190                        Ok(#name { #(#field_names),* })
191                    }
192
193                    fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
194                    where
195                        A: MapAccess<'de>,
196                    {
197                        #(
198                            let mut #field_names: Option<#field_types> = None;
199                        )*
200
201                        while let Some(key) = map.next_key::<#field_enum>()? {
202                            match key {
203                                #(
204                                    #field_enum::#field_enum_variants => {
205                                        if #field_names.is_some() {
206                                            return Err(serde::de::Error::duplicate_field(stringify!(#field_names)));
207                                        }
208                                        #field_names = Some(map.next_value()?);
209                                    },
210                                )*
211                                #field_enum::__Ignore => {
212                                    let _: IgnoredAny = map.next_value()?;
213                                }
214                            }
215                        }
216
217                        #(
218                            let #field_names: #field_types = match #field_names {
219                                Some(result) => result,
220                                None => return Err(serde::de::Error::missing_field(stringify!(#field_names))),
221                            };
222                        )*
223
224                        Ok(#name { #(#field_names),* })
225                    }
226                }
227
228                deserializer.deserialize_struct(stringify!(#name), FIELDS, OrderedVisitor)
229            }
230        }
231    };
232
233    TokenStream::from(deserialization)
234}
235
236//Grabs the order from #[order = ...]
237fn get_order_from_field(field: &Field) -> Result<usize> {
238    for attribute in &field.attrs {
239        if attribute.path().is_ident("order") {
240            let order: LitInt = attribute.parse_args()?;
241            return Ok(order.base10_parse::<usize>()?);
242        }
243    }
244
245    Err(syn::Error::new_spanned(
246        field,
247        "No `order` attribute found, which is required for DeserializeOrdered",
248    ))
249}