serde_ordered_derive/
lib.rs

1use std::collections::HashSet;
2
3use proc_macro::TokenStream;
4use quote::quote;
5use syn::{parse_macro_input, Type, Data, DeriveInput, Field, Fields, LitInt, Result, Ident};
6
7struct FieldOrder {
8    pub order: i32,
9    pub field_name: Ident,
10    pub dtype: Type
11}
12
13///A procedural macro for deserializing ordered arrays into keyed structs using Serde.
14#[proc_macro_derive(DeserializeOrdered, attributes(order))]
15pub fn derive_order(input: TokenStream) -> TokenStream {
16    let input = parse_macro_input!(input as DeriveInput);
17    let name = &input.ident;
18
19    let fields = match &input.data {
20        Data::Struct(data_struct) => match &data_struct.fields {
21            Fields::Named(named_field) => &named_field.named,
22            _ => return syn::Error::new_spanned(
23                &input,
24                "DeserializeOrdered only supports structs with named fields",
25            ).to_compile_error().into(),
26        },
27        _ => return syn::Error::new_spanned(
28            &input,
29            "DeserializeOrdered can only be derived for structs",
30        ).to_compile_error().into(),
31    };
32
33    let mut field_orders = vec![];
34
35    for field in fields {
36        let field_name = field.ident.as_ref().unwrap();
37
38        // Extract #[serde(order = x)] attribute
39        let order = match get_order_from_field(field) {
40            Ok(order) => order,
41            Err(err) => return err.to_compile_error().into(),
42        };
43
44        field_orders.push(FieldOrder {
45            order,
46            field_name: field_name.clone(),
47            dtype: field.ty.clone(),
48        });
49    }
50
51    // Check if every field has an order and that all orders are unique
52    let total_fields = fields.len();
53    if field_orders.len() != total_fields {
54        return syn::Error::new_spanned(
55            &input,
56            "DeserializeOrdered requires all fields do have #[serde(order = x)]",
57        ).to_compile_error().into();
58    }
59
60    //Check for duplicate orders
61    let orders_set = field_orders.iter().map(|fo| fo.order).collect::<HashSet<_>>();
62    if orders_set.len() != total_fields {
63        return syn::Error::new_spanned(
64            &input,
65            "DeserializeOrdered requires all fields to have unique orders",
66        ).to_compile_error().into();
67    }
68
69    // Sort fields by order index
70    field_orders.sort_by_key(|order| order.order);
71    let field_names: Vec<_> = field_orders.iter().map(|fo| fo.field_name.to_owned()).collect();
72    let field_types: Vec<_> = field_orders.iter().map(|fo| fo.dtype.to_owned()).collect();
73    let orders: Vec<_> = field_orders.iter().map(|fo| fo.order.to_owned()).collect();
74    
75    // Generate deserialization logic (same as before)
76    let deserialization = quote! {
77        impl<'de> serde::Deserialize<'de> for #name {
78            fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
79            where
80                D: serde::Deserializer<'de>,
81            {
82                use serde::de::{IgnoredAny, MapAccess, SeqAccess, Visitor};
83                use std::fmt;
84
85                const FIELDS: &'static [&'static str] = &[#(stringify!(#field_names)),*];
86
87                struct OrderedVisitor;
88
89                impl<'de> Visitor<'de> for OrderedVisitor {
90                    type Value = #name;
91
92                    fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
93                        formatter.write_str("a struct represented as a sequence or map with ordered fields")
94                    }
95
96                    fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
97                    where
98                        A: SeqAccess<'de>,
99                    {
100                        let mut index = 0;
101
102                        #(
103                            let mut #field_names: Option<#field_types> = None;
104                        )*
105
106                        while let Ok(element) = seq.next_element::<::serde_ordered::value::Value>() {
107                            if element.is_none() {break;}
108                
109                            let element = element.unwrap();
110                            match index {
111                                #(
112                                    #orders => {
113                                        let result = match element.deserialize_into::<#field_types>() {
114                                            Ok(result) => result,
115                                            Err(err) => 
116                                                return Err(serde::de::Error::custom(format!("Failed to deserialize key because {:?}", err))),
117                                        };
118                    
119                                        #field_names = Some(result);
120                                    },
121                                )*
122                                _ => {}
123                            }
124                
125                            index+=1;
126                        }
127
128                        #(
129                            let #field_names: #field_types = match #field_names {
130                                Some(result) => result,
131                                None => return Err(serde::de::Error::custom("Order was outside the bounds of the message")),
132                            };
133                        )*
134
135                        Ok(#name {
136                            #(#field_names),*
137                        })
138                    }
139
140                    fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
141                    where
142                        A: MapAccess<'de>,
143                    {
144                        #(
145                            let mut #field_names: Option<#field_types> = None;
146                        )*
147
148                        while let Some(key) = map.next_key::<String>()? {
149                            match key.as_str() {
150                                #(
151                                    stringify!(#field_names) => {
152                                        if #field_names.is_some() {
153                                            return Err(serde::de::Error::duplicate_field(stringify!(#field_names)));
154                                        }
155                                        let value: ::serde_ordered::value::Value = map.next_value()?;
156                                        let result = match value.deserialize_into::<#field_types>() {
157                                            Ok(result) => result,
158                                            Err(err) => 
159                                                return Err(serde::de::Error::custom(format!("Failed to deserialize key because {:?}", err))),
160                                        };
161                                        #field_names = Some(result);
162                                    },
163                                )*
164                                _ => {
165                                    let _: IgnoredAny = map.next_value()?;
166                                }
167                            }
168                        }
169
170                        #(
171                            let #field_names: #field_types = match #field_names {
172                                Some(result) => result,
173                                None => return Err(serde::de::Error::missing_field(stringify!(#field_names))),
174                            };
175                        )*
176
177                        Ok(#name {
178                            #(#field_names),*
179                        })
180                    }
181                }
182
183                deserializer.deserialize_struct(stringify!(#name), FIELDS, OrderedVisitor)
184            }
185        }
186    };
187
188    TokenStream::from(deserialization)
189}
190
191//Grabs the order from #[order = ...]
192fn get_order_from_field(field: &Field) -> Result<i32> {
193    for attribute in &field.attrs {
194        if attribute.path().is_ident("order") {
195            let order: LitInt = attribute.parse_args()?;
196            return Ok(order.base10_parse::<i32>()?);
197        }
198    }
199
200    Err(syn::Error::new_spanned(
201        field,
202        "No `order` attribute found, which is required for DeserializeOrdered",
203    ))
204}