serde_ordered_derive/
lib.rs1use std::collections::HashSet;
2
3use proc_macro::TokenStream;
4use proc_macro2::Span;
5use quote::quote;
6use syn::{parse_macro_input, Type, Data, DeriveInput, Field, Fields, LitInt, Result, Ident};
7
8struct FieldOrder {
9 pub order: usize,
10 pub field_name: Ident,
11 pub dtype: Type
12}
13
14#[proc_macro_derive(DeserializeOrdered, attributes(order))]
16pub fn derive_order(input: TokenStream) -> TokenStream {
17 let input = parse_macro_input!(input as DeriveInput);
18 let name = &input.ident;
19
20 let fields = match &input.data {
21 Data::Struct(data_struct) => match &data_struct.fields {
22 Fields::Named(named_field) => &named_field.named,
23 _ => return syn::Error::new_spanned(
24 &input,
25 "DeserializeOrdered only supports structs with named fields",
26 ).to_compile_error().into(),
27 },
28 _ => return syn::Error::new_spanned(
29 &input,
30 "DeserializeOrdered can only be derived for structs",
31 ).to_compile_error().into(),
32 };
33
34 let mut field_orders = vec![];
35
36 for field in fields {
37 let field_name = field.ident.as_ref().unwrap();
38
39 let order = match get_order_from_field(field) {
41 Ok(order) => order,
42 Err(err) => return err.to_compile_error().into(),
43 };
44
45 field_orders.push(FieldOrder {
46 order,
47 field_name: field_name.clone(),
48 dtype: field.ty.clone(),
49 });
50 }
51
52 let total_fields = fields.len();
54 if field_orders.len() != total_fields {
55 return syn::Error::new_spanned(
56 &input,
57 "DeserializeOrdered requires all fields do have #[serde(order = x)]",
58 ).to_compile_error().into();
59 }
60
61 let orders_set = field_orders.iter().map(|fo| fo.order).collect::<HashSet<_>>();
63 if orders_set.len() != total_fields {
64 return syn::Error::new_spanned(
65 &input,
66 "DeserializeOrdered requires all fields to have unique orders",
67 ).to_compile_error().into();
68 }
69
70 field_orders.sort_by_key(|order| order.order);
72 let field_names: Vec<_> = field_orders.iter().map(|fo| fo.field_name.to_owned()).collect();
73 let field_types: Vec<_> = field_orders.iter().map(|fo| fo.dtype.to_owned()).collect();
74 let field_orders_only: Vec<_> = field_orders.iter().map(|fo| fo.order).collect();
75
76 let field_enum = Ident::new("__SerdeOrderedField", Span::call_site());
77 let field_enum_variants: Vec<Ident> = field_orders
78 .iter()
79 .enumerate()
80 .map(|(index, _)| Ident::new(&format!("__Field{}", index), Span::call_site()))
81 .collect();
82
83 let deserialization = quote! {
85 impl<'de> serde::Deserialize<'de> for #name {
86 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
87 where
88 D: serde::Deserializer<'de>,
89 {
90 use serde::de::{IgnoredAny, MapAccess, SeqAccess, Unexpected, Visitor};
91 use std::fmt;
92
93 const FIELDS: &'static [&'static str] = &[#(stringify!(#field_names)),*];
94
95 #[allow(non_camel_case_types)]
96 enum #field_enum {
97 #(#field_enum_variants),*,
98 __Ignore,
99 }
100
101 impl<'de> serde::Deserialize<'de> for #field_enum {
102 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
103 where
104 D: serde::Deserializer<'de>,
105 {
106 struct FieldVisitor;
107
108 impl<'de> Visitor<'de> for FieldVisitor {
109 type Value = #field_enum;
110
111 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
112 formatter.write_str("field identifier")
113 }
114
115 fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
116 where
117 E: serde::de::Error,
118 {
119 match value {
120 #(stringify!(#field_names) => Ok(#field_enum::#field_enum_variants),)*
121 _ => Ok(#field_enum::__Ignore),
122 }
123 }
124
125 fn visit_bytes<E>(self, value: &[u8]) -> Result<Self::Value, E>
126 where
127 E: serde::de::Error,
128 {
129 match std::str::from_utf8(value) {
130 Ok(s) => self.visit_str(s),
131 Err(_) => Err(E::invalid_value(Unexpected::Bytes(value), &"field identifier")),
132 }
133 }
134 }
135
136 deserializer.deserialize_identifier(FieldVisitor)
137 }
138 }
139
140 struct OrderedVisitor;
141
142 impl<'de> Visitor<'de> for OrderedVisitor {
143 type Value = #name;
144
145 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
146 formatter.write_str("a struct represented as a sequence or map with ordered fields")
147 }
148
149 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
150 where
151 A: SeqAccess<'de>,
152 {
153 let mut index: usize = 0;
154
155 #(
156 let mut #field_names: Option<#field_types> = None;
157 )*
158
159 loop {
160 let handled = match index {
161 #(
162 #field_orders_only => match seq.next_element::<#field_types>()? {
163 Some(value) => {
164 #field_names = Some(value);
165 true
166 }
167 None => false,
168 },
169 )*
170 _ => match seq.next_element::<IgnoredAny>()? {
171 Some(_) => true,
172 None => false,
173 },
174 };
175
176 if !handled {
177 break;
178 }
179
180 index += 1;
181 }
182
183 #(
184 let #field_names: #field_types = match #field_names {
185 Some(result) => result,
186 None => return Err(serde::de::Error::custom(concat!("Order for ", stringify!(#field_names), " was missing from the sequence"))),
187 };
188 )*
189
190 Ok(#name { #(#field_names),* })
191 }
192
193 fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
194 where
195 A: MapAccess<'de>,
196 {
197 #(
198 let mut #field_names: Option<#field_types> = None;
199 )*
200
201 while let Some(key) = map.next_key::<#field_enum>()? {
202 match key {
203 #(
204 #field_enum::#field_enum_variants => {
205 if #field_names.is_some() {
206 return Err(serde::de::Error::duplicate_field(stringify!(#field_names)));
207 }
208 #field_names = Some(map.next_value()?);
209 },
210 )*
211 #field_enum::__Ignore => {
212 let _: IgnoredAny = map.next_value()?;
213 }
214 }
215 }
216
217 #(
218 let #field_names: #field_types = match #field_names {
219 Some(result) => result,
220 None => return Err(serde::de::Error::missing_field(stringify!(#field_names))),
221 };
222 )*
223
224 Ok(#name { #(#field_names),* })
225 }
226 }
227
228 deserializer.deserialize_struct(stringify!(#name), FIELDS, OrderedVisitor)
229 }
230 }
231 };
232
233 TokenStream::from(deserialization)
234}
235
236fn get_order_from_field(field: &Field) -> Result<usize> {
238 for attribute in &field.attrs {
239 if attribute.path().is_ident("order") {
240 let order: LitInt = attribute.parse_args()?;
241 return Ok(order.base10_parse::<usize>()?);
242 }
243 }
244
245 Err(syn::Error::new_spanned(
246 field,
247 "No `order` attribute found, which is required for DeserializeOrdered",
248 ))
249}