serde_ordered_derive/
lib.rs1use std::collections::HashSet;
2
3use proc_macro::TokenStream;
4use quote::quote;
5use syn::{parse_macro_input, Type, Data, DeriveInput, Field, Fields, LitInt, Result, Ident};
6
7struct FieldOrder {
8 pub order: i32,
9 pub field_name: Ident,
10 pub dtype: Type
11}
12
13#[proc_macro_derive(DeserializeOrdered, attributes(order))]
15pub fn derive_order(input: TokenStream) -> TokenStream {
16 let input = parse_macro_input!(input as DeriveInput);
17 let name = &input.ident;
18
19 let fields = match &input.data {
20 Data::Struct(data_struct) => match &data_struct.fields {
21 Fields::Named(named_field) => &named_field.named,
22 _ => return syn::Error::new_spanned(
23 &input,
24 "DeserializeOrdered only supports structs with named fields",
25 ).to_compile_error().into(),
26 },
27 _ => return syn::Error::new_spanned(
28 &input,
29 "DeserializeOrdered can only be derived for structs",
30 ).to_compile_error().into(),
31 };
32
33 let mut field_orders = vec![];
34
35 for field in fields {
36 let field_name = field.ident.as_ref().unwrap();
37
38 let order = match get_order_from_field(field) {
40 Ok(order) => order,
41 Err(err) => return err.to_compile_error().into(),
42 };
43
44 field_orders.push(FieldOrder {
45 order,
46 field_name: field_name.clone(),
47 dtype: field.ty.clone(),
48 });
49 }
50
51 let total_fields = fields.len();
53 if field_orders.len() != total_fields {
54 return syn::Error::new_spanned(
55 &input,
56 "DeserializeOrdered requires all fields do have #[serde(order = x)]",
57 ).to_compile_error().into();
58 }
59
60 let orders_set = field_orders.iter().map(|fo| fo.order).collect::<HashSet<_>>();
62 if orders_set.len() != total_fields {
63 return syn::Error::new_spanned(
64 &input,
65 "DeserializeOrdered requires all fields to have unique orders",
66 ).to_compile_error().into();
67 }
68
69 field_orders.sort_by_key(|order| order.order);
71 let field_names: Vec<_> = field_orders.iter().map(|fo| fo.field_name.to_owned()).collect();
72 let field_types: Vec<_> = field_orders.iter().map(|fo| fo.dtype.to_owned()).collect();
73 let orders: Vec<_> = field_orders.iter().map(|fo| fo.order.to_owned()).collect();
74
75 let deserialization = quote! {
77 impl<'de> serde::Deserialize<'de> for #name {
78 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
79 where
80 D: serde::Deserializer<'de>,
81 {
82 use serde::de::{IgnoredAny, MapAccess, SeqAccess, Visitor};
83 use std::fmt;
84
85 const FIELDS: &'static [&'static str] = &[#(stringify!(#field_names)),*];
86
87 struct OrderedVisitor;
88
89 impl<'de> Visitor<'de> for OrderedVisitor {
90 type Value = #name;
91
92 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
93 formatter.write_str("a struct represented as a sequence or map with ordered fields")
94 }
95
96 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
97 where
98 A: SeqAccess<'de>,
99 {
100 let mut index = 0;
101
102 #(
103 let mut #field_names: Option<#field_types> = None;
104 )*
105
106 while let Ok(element) = seq.next_element::<::serde_ordered::value::Value>() {
107 if element.is_none() {break;}
108
109 let element = element.unwrap();
110 match index {
111 #(
112 #orders => {
113 let result = match element.deserialize_into::<#field_types>() {
114 Ok(result) => result,
115 Err(err) =>
116 return Err(serde::de::Error::custom(format!("Failed to deserialize key because {:?}", err))),
117 };
118
119 #field_names = Some(result);
120 },
121 )*
122 _ => {}
123 }
124
125 index+=1;
126 }
127
128 #(
129 let #field_names: #field_types = match #field_names {
130 Some(result) => result,
131 None => return Err(serde::de::Error::custom("Order was outside the bounds of the message")),
132 };
133 )*
134
135 Ok(#name {
136 #(#field_names),*
137 })
138 }
139
140 fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
141 where
142 A: MapAccess<'de>,
143 {
144 #(
145 let mut #field_names: Option<#field_types> = None;
146 )*
147
148 while let Some(key) = map.next_key::<String>()? {
149 match key.as_str() {
150 #(
151 stringify!(#field_names) => {
152 if #field_names.is_some() {
153 return Err(serde::de::Error::duplicate_field(stringify!(#field_names)));
154 }
155 let value: ::serde_ordered::value::Value = map.next_value()?;
156 let result = match value.deserialize_into::<#field_types>() {
157 Ok(result) => result,
158 Err(err) =>
159 return Err(serde::de::Error::custom(format!("Failed to deserialize key because {:?}", err))),
160 };
161 #field_names = Some(result);
162 },
163 )*
164 _ => {
165 let _: IgnoredAny = map.next_value()?;
166 }
167 }
168 }
169
170 #(
171 let #field_names: #field_types = match #field_names {
172 Some(result) => result,
173 None => return Err(serde::de::Error::missing_field(stringify!(#field_names))),
174 };
175 )*
176
177 Ok(#name {
178 #(#field_names),*
179 })
180 }
181 }
182
183 deserializer.deserialize_struct(stringify!(#name), FIELDS, OrderedVisitor)
184 }
185 }
186 };
187
188 TokenStream::from(deserialization)
189}
190
191fn get_order_from_field(field: &Field) -> Result<i32> {
193 for attribute in &field.attrs {
194 if attribute.path().is_ident("order") {
195 let order: LitInt = attribute.parse_args()?;
196 return Ok(order.base10_parse::<i32>()?);
197 }
198 }
199
200 Err(syn::Error::new_spanned(
201 field,
202 "No `order` attribute found, which is required for DeserializeOrdered",
203 ))
204}