serde_ordered/
lib.rs

1use std::collections::HashSet;
2
3use proc_macro::TokenStream;
4use quote::quote;
5use syn::{parse_macro_input, Type, Data, DeriveInput, Field, Fields, LitInt, Result, Ident};
6
7struct FieldOrder {
8    pub order: i32,
9    pub field_name: Ident,
10    pub dtype: Type
11}
12
13///A procedural macro for deserializing ordered arrays into keyed structs using Serde.
14#[proc_macro_derive(DeserializeOrdered, attributes(order))]
15pub fn derive_order(input: TokenStream) -> TokenStream {
16    let input = parse_macro_input!(input as DeriveInput);
17    let name = &input.ident;
18
19    let fields = match &input.data {
20        Data::Struct(data_struct) => match &data_struct.fields {
21            Fields::Named(named_field) => &named_field.named,
22            _ => return syn::Error::new_spanned(
23                &input,
24                "DeserializeOrdered only supports structs with named fields",
25            ).to_compile_error().into(),
26        },
27        _ => return syn::Error::new_spanned(
28            &input,
29            "DeserializeOrdered can only be derived for structs",
30        ).to_compile_error().into(),
31    };
32
33    let mut field_orders = vec![];
34
35    for field in fields {
36        let field_name = field.ident.as_ref().unwrap();
37
38        // Extract #[serde(order = x)] attribute
39        let order = match get_order_from_field(field) {
40            Ok(order) => order,
41            Err(err) => return err.to_compile_error().into(),
42        };
43
44        field_orders.push(FieldOrder {
45            order,
46            field_name: field_name.clone(),
47            dtype: field.ty.clone(),
48        });
49    }
50
51    // Check if every field has an order and that all orders are unique
52    let total_fields = fields.len();
53    if field_orders.len() != total_fields {
54        return syn::Error::new_spanned(
55            &input,
56            "DeserializeOrdered requires all fields do have #[serde(order = x)]",
57        ).to_compile_error().into();
58    }
59
60    //Check for duplicate orders
61    let orders_set = field_orders.iter().map(|fo| fo.order).collect::<HashSet<_>>();
62    if orders_set.len() != total_fields {
63        return syn::Error::new_spanned(
64            &input,
65            "DeserializeOrdered requires all fields to have unique orders",
66        ).to_compile_error().into();
67    }
68
69    // Sort fields by order index
70    field_orders.sort_by_key(|order| order.order);
71    let field_names: Vec<_> = field_orders.iter().map(|fo| fo.field_name.to_owned()).collect();
72    let field_types: Vec<_> = field_orders.iter().map(|fo| fo.dtype.to_owned()).collect();
73    let orders: Vec<_> = field_orders.iter().map(|fo| fo.order.to_owned()).collect();
74    
75    // Generate deserialization logic (same as before)
76    let deserialization = quote! {
77        impl<'de> serde::Deserialize<'de> for #name {
78            fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
79            where
80                D: serde::Deserializer<'de>,
81            {
82                use serde::de::{SeqAccess, Visitor};
83                use std::fmt;
84
85                struct OrderedVisitor;
86
87                impl<'de> Visitor<'de> for OrderedVisitor {
88                    type Value = #name;
89
90                    fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
91                        formatter.write_str("a sequence with ordered fields")
92                    }
93
94                    fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
95                    where
96                        A: SeqAccess<'de>,
97                    {
98                        let mut index = 0;
99
100                        #(
101                            let mut #field_names: Option<#field_types> = None;
102                        )*
103
104                        while let Ok(element) = seq.next_element::<serde_value::Value>() {
105                            if element.is_none() {break;}
106                
107                            let element = element.unwrap();
108                            match index {
109                                #(
110                                    // #orders => #field_names = element.deserialize_into::<#field_types>().unwrap(),
111                                    #orders => {
112                                        let result = match element.deserialize_into::<#field_types>() {
113                                            Ok(result) => result,
114                                            Err(err) => 
115                                                return Err(serde::de::Error::custom(format!("Failed to deserialize key because {:?}", err))),
116                                        };
117                    
118                                        #field_names = Some(result);
119                                    },
120                                )*
121                                _ => {}
122                            }
123                
124                            index+=1;
125                        }
126
127                        #(
128                            let #field_names: #field_types = match #field_names {
129                                Some(result) => result,
130                                None => return Err(serde::de::Error::custom("Order was outside the bounds of the message")),
131                            };
132                        )*
133
134                        Ok(#name {
135                            #(#field_names),*
136                        })
137                    }
138                }
139
140                deserializer.deserialize_seq(OrderedVisitor)
141            }
142        }
143    };
144
145    TokenStream::from(deserialization)
146}
147
148//Grabs the order from #[order = ...]
149fn get_order_from_field(field: &Field) -> Result<i32> {
150    for attribute in &field.attrs {
151        if attribute.path().is_ident("order") {
152            let order: LitInt = attribute.parse_args()?;
153            return Ok(order.base10_parse::<i32>()?);
154        }
155    }
156
157    Err(syn::Error::new_spanned(
158        field,
159        "No `order` attribute found, which is required for DeserializeOrdered",
160    ))
161}