1use std::collections::HashSet;
2
3use proc_macro::TokenStream;
4use quote::quote;
5use syn::{parse_macro_input, Type, Data, DeriveInput, Field, Fields, LitInt, Result, Ident};
6
7struct FieldOrder {
8 pub order: i32,
9 pub field_name: Ident,
10 pub dtype: Type
11}
12
13#[proc_macro_derive(DeserializeOrdered, attributes(order))]
15pub fn derive_order(input: TokenStream) -> TokenStream {
16 let input = parse_macro_input!(input as DeriveInput);
17 let name = &input.ident;
18
19 let fields = match &input.data {
20 Data::Struct(data_struct) => match &data_struct.fields {
21 Fields::Named(named_field) => &named_field.named,
22 _ => return syn::Error::new_spanned(
23 &input,
24 "DeserializeOrdered only supports structs with named fields",
25 ).to_compile_error().into(),
26 },
27 _ => return syn::Error::new_spanned(
28 &input,
29 "DeserializeOrdered can only be derived for structs",
30 ).to_compile_error().into(),
31 };
32
33 let mut field_orders = vec![];
34
35 for field in fields {
36 let field_name = field.ident.as_ref().unwrap();
37
38 let order = match get_order_from_field(field) {
40 Ok(order) => order,
41 Err(err) => return err.to_compile_error().into(),
42 };
43
44 field_orders.push(FieldOrder {
45 order,
46 field_name: field_name.clone(),
47 dtype: field.ty.clone(),
48 });
49 }
50
51 let total_fields = fields.len();
53 if field_orders.len() != total_fields {
54 return syn::Error::new_spanned(
55 &input,
56 "DeserializeOrdered requires all fields do have #[serde(order = x)]",
57 ).to_compile_error().into();
58 }
59
60 let orders_set = field_orders.iter().map(|fo| fo.order).collect::<HashSet<_>>();
62 if orders_set.len() != total_fields {
63 return syn::Error::new_spanned(
64 &input,
65 "DeserializeOrdered requires all fields to have unique orders",
66 ).to_compile_error().into();
67 }
68
69 field_orders.sort_by_key(|order| order.order);
71 let field_names: Vec<_> = field_orders.iter().map(|fo| fo.field_name.to_owned()).collect();
72 let field_types: Vec<_> = field_orders.iter().map(|fo| fo.dtype.to_owned()).collect();
73 let orders: Vec<_> = field_orders.iter().map(|fo| fo.order.to_owned()).collect();
74
75 let deserialization = quote! {
77 impl<'de> serde::Deserialize<'de> for #name {
78 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
79 where
80 D: serde::Deserializer<'de>,
81 {
82 use serde::de::{SeqAccess, Visitor};
83 use std::fmt;
84
85 struct OrderedVisitor;
86
87 impl<'de> Visitor<'de> for OrderedVisitor {
88 type Value = #name;
89
90 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
91 formatter.write_str("a sequence with ordered fields")
92 }
93
94 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
95 where
96 A: SeqAccess<'de>,
97 {
98 let mut index = 0;
99
100 #(
101 let mut #field_names: Option<#field_types> = None;
102 )*
103
104 while let Ok(element) = seq.next_element::<serde_value::Value>() {
105 if element.is_none() {break;}
106
107 let element = element.unwrap();
108 match index {
109 #(
110 #orders => {
112 let result = match element.deserialize_into::<#field_types>() {
113 Ok(result) => result,
114 Err(err) =>
115 return Err(serde::de::Error::custom(format!("Failed to deserialize key because {:?}", err))),
116 };
117
118 #field_names = Some(result);
119 },
120 )*
121 _ => {}
122 }
123
124 index+=1;
125 }
126
127 #(
128 let #field_names: #field_types = match #field_names {
129 Some(result) => result,
130 None => return Err(serde::de::Error::custom("Order was outside the bounds of the message")),
131 };
132 )*
133
134 Ok(#name {
135 #(#field_names),*
136 })
137 }
138 }
139
140 deserializer.deserialize_seq(OrderedVisitor)
141 }
142 }
143 };
144
145 TokenStream::from(deserialization)
146}
147
148fn get_order_from_field(field: &Field) -> Result<i32> {
150 for attribute in &field.attrs {
151 if attribute.path().is_ident("order") {
152 let order: LitInt = attribute.parse_args()?;
153 return Ok(order.base10_parse::<i32>()?);
154 }
155 }
156
157 Err(syn::Error::new_spanned(
158 field,
159 "No `order` attribute found, which is required for DeserializeOrdered",
160 ))
161}