1use std::collections::HashSet;
16
17use proc_macro2::{Span, TokenStream};
18use quote::quote;
19use syn::parse_quote;
20use syn::spanned::Spanned;
21
22#[cfg(test)]
23mod tests;
24
25#[proc_macro_derive(Wire, attributes(wire))]
28pub fn derive_wire(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
29 let item = syn::parse_macro_input!(item as syn::DeriveInput);
30 print_item(derive_item(&item)).into()
31}
32
33#[proc_macro_attribute]
34pub fn internal_wire(
35 attr: proc_macro::TokenStream, item: proc_macro::TokenStream,
36) -> proc_macro::TokenStream {
37 let item = syn::parse_macro_input!(item as syn::DeriveInput);
38 if !attr.is_empty() {
39 return syn::Error::new(item.span(), "unexpected attribute").into_compile_error().into();
40 }
41 print_item(derive_item(&item)).into()
42}
43
44fn print_item(item: syn::Result<syn::ItemImpl>) -> TokenStream {
45 match item {
46 #[cfg(feature = "_dev")]
47 Ok(_) => TokenStream::new(),
48 #[cfg(not(feature = "_dev"))]
49 Ok(x) => quote!(#x),
50 Err(e) => e.into_compile_error(),
51 }
52}
53
54fn derive_item(item: &syn::DeriveInput) -> syn::Result<syn::ItemImpl> {
55 let mut attrs = Attrs::parse(&item.attrs, AttrsKind::Item)?;
56 let wire = match attrs.crate_.take() {
57 Attr::Absent => syn::Ident::new("wasefire_wire", Span::call_site()).into(),
58 Attr::Present(_, x) => x,
59 };
60 let param = match attrs.param.take() {
61 Attr::Absent => {
62 let first = item.generics.lifetimes().next();
63 first.map_or_else(
64 || syn::Lifetime::new("'wire", Span::call_site()),
65 |x| x.lifetime.clone(),
66 )
67 }
68 Attr::Present(_, x) => x,
69 };
70 let type_param = syn::Lifetime::new(&format!("'_{}", param.ident), param.span());
71 let (_, parameters, _) = item.generics.split_for_impl();
72 let mut generics = item.generics.clone();
73 if generics.lifetimes().all(|x| x.lifetime != param) {
74 generics.params.insert(0, parse_quote!(#param));
75 }
76 if let Attr::Present(_, where_) = attrs.where_.take() {
77 generics.make_where_clause().predicates.extend(where_);
78 }
79 let (generics, _, where_clause) = generics.split_for_impl();
80 let (schema, encode, decode) = match &item.data {
81 syn::Data::Struct(x) => derive_struct(&mut attrs, &wire, ¶m, &item.ident, x)?,
82 syn::Data::Enum(x) => derive_enum(&mut attrs, &wire, &item.ident, x)?,
83 syn::Data::Union(_) => {
84 return Err(syn::Error::new(item.span(), "unions are not supported"));
85 }
86 };
87 if let Attr::Present(span, _) = attrs.range.take() {
88 return Err(syn::Error::new(span, "range is only supported for enums"));
89 }
90 if let Attr::Present(span, _) = attrs.refine.take() {
91 return Err(syn::Error::new(span, "refine is only supported for structs with one field"));
92 }
93 let ident = &item.ident;
94 let statics: Vec<syn::Type> = match attrs.static_.take() {
95 Attr::Absent => Vec::new(),
96 Attr::Present(_, xs) => xs.into_iter().map(|x| parse_quote!(#x)).collect(),
97 };
98 let mut visitor = MakeType { src: ¶m, dst: &type_param, statics: &statics };
99 let mut type_: syn::Type = parse_quote!(#ident #parameters);
100 syn::visit_mut::visit_type_mut(&mut visitor, &mut type_);
101 let schema = match cfg!(feature = "schema") {
102 false => None,
103 true => Some(quote! {
104 fn schema(rules: &mut #wire::internal::Rules) {
105 #(#schema)*
106 }
107 }),
108 };
109 let impl_wire: syn::ItemImpl = parse_quote! {
110 #[automatically_derived]
111 impl #generics #wire::internal::Wire<#param> for #ident #parameters #where_clause {
112 type Type<#type_param> = #type_;
113 #schema
114 fn encode(
115 &self, writer: &mut #wire::internal::Writer<#param>
116 ) -> #wire::internal::Result<()> {
117 #(#encode)*
118 }
119 fn decode(
120 reader: &mut #wire::internal::Reader<#param>
121 ) -> #wire::internal::Result<Self> {
122 #(#decode)*
123 }
124 }
125 };
126 Ok(impl_wire)
127}
128
129fn derive_struct(
130 attrs: &mut Attrs, wire: &syn::Path, param: &syn::Lifetime, name: &syn::Ident,
131 item: &syn::DataStruct,
132) -> syn::Result<(Vec<syn::Stmt>, Vec<syn::Stmt>, Vec<syn::Stmt>)> {
133 let mut types = Types::default();
134 if item.fields.len() == 1
135 && let Attr::Present(_, refine) = attrs.refine.take()
136 {
137 let inner = item.fields.iter().next().unwrap();
138 let ty = &inner.ty;
139 let schema = parse_quote! {
140 rules.alias::<
141 Self::Type<'static>, <#ty as #wire::internal::Wire<#param>>::Type<'static>>();
142 };
143 let encode: syn::Expr = match &inner.ident {
144 None => parse_quote!(self.0.encode(writer)),
145 Some(name) => parse_quote!(self.#name.encode(writer)),
146 };
147 let decode: syn::Expr = parse_quote!(#refine(<#ty>::decode(reader)?));
148 return Ok((
149 vec![schema],
150 vec![syn::Stmt::Expr(encode, None)],
151 vec![syn::Stmt::Expr(decode, None)],
152 ));
153 }
154 let (mut schema, mut encode, decode) =
155 derive_fields(wire, &parse_quote!(#name), &item.fields, &mut types)?;
156 if cfg!(feature = "schema") {
157 types.stmts(&mut schema, wire, "struct_", "fields");
158 }
159 if !encode.is_empty() {
160 let path = parse_quote!(#name);
161 let encode_pat = fields_pat(&path, &item.fields, false);
162 encode.insert(0, parse_quote!(let #encode_pat = self;));
163 }
164 Ok((schema, encode, decode))
165}
166
167fn derive_enum(
168 attrs: &mut Attrs, wire: &syn::Path, name: &syn::Ident, item: &syn::DataEnum,
169) -> syn::Result<(Vec<syn::Stmt>, Vec<syn::Stmt>, Vec<syn::Stmt>)> {
170 let mut schema = Vec::new();
171 let mut types = Types::default();
172 let mut encode = Vec::<syn::Arm>::new();
173 let mut decode = Vec::<syn::Arm>::new();
174 let mut tags = Tags::default();
175 match item.variants.len() {
176 _ if !cfg!(feature = "schema") => (),
177 0 => schema.push(parse_quote!(let variants = #wire::internal::Vec::new();)),
178 n => schema.push(parse_quote!(let mut variants = #wire::internal::Vec::with_capacity(#n);)),
179 }
180 for variant in &item.variants {
181 let mut attrs = Attrs::parse(&variant.attrs, AttrsKind::Variant)?;
182 let tag = match attrs.tag.take() {
183 Attr::Present(span, tag) => tags.use_(span, tag)?,
184 Attr::Absent => tags.next(),
185 };
186 let ident = &variant.ident;
187 let path = parse_quote!(#name::#ident);
188 let (variant_schema, variant_encode, variant_decode) =
189 derive_fields(wire, &path, &variant.fields, &mut types)?;
190 let pat_encode = fields_pat(&path, &variant.fields, true);
191 let ident_schema = format!("{ident}");
192 if cfg!(feature = "schema") {
193 schema.push(parse_quote!({
194 #(#variant_schema)*
195 variants.push((#ident_schema, #tag, fields));
196 }));
197 }
198 encode.push(parse_quote!(#pat_encode => {
199 #wire::internal::encode_tag(#tag, writer)?;
200 #(#variant_encode)*
201 }));
202 decode.push(parse_quote!(#tag => { #(#variant_decode)* }));
203 }
204 if let Attr::Present(span, range) = attrs.range.take()
205 && (tags.used.len() as u32 != range || tags.used.iter().any(|x| range <= *x))
206 {
207 return Err(syn::Error::new(span, "tags don't form a range"));
208 }
209 if cfg!(feature = "schema") {
210 types.stmts(&mut schema, wire, "enum_", "variants");
211 }
212 let encode = parse_quote!(match *self { #(#encode)* });
213 let decode = parse_quote! {
214 let tag = #wire::internal::decode_tag(reader)?;
215 match tag {
216 #(#decode)*
217 _ => Err(#wire::internal::INVALID_TAG),
218 }
219 };
220 Ok((schema, encode, decode))
221}
222
223fn derive_fields<'a>(
224 wire: &syn::Path, name: &syn::Path, fields: &'a syn::Fields, types: &mut Types<'a>,
225) -> syn::Result<(Vec<syn::Stmt>, Vec<syn::Stmt>, Vec<syn::Stmt>)> {
226 let mut schema = Vec::new();
227 let mut encode = Vec::new();
228 let mut decode = Vec::new();
229 match fields.len() {
230 _ if !cfg!(feature = "schema") => (),
231 0 => schema.push(parse_quote!(let fields = #wire::internal::Vec::new();)),
232 n => schema.push(parse_quote!(let mut fields = #wire::internal::Vec::with_capacity(#n);)),
233 }
234 for (i, field) in fields.iter().enumerate() {
235 let _ = Attrs::parse(&field.attrs, AttrsKind::Invalid)?;
236 let name = field_name(i, field);
237 let ty = &field.ty;
238 if cfg!(feature = "schema") {
239 let name_str: syn::Expr = match &field.ident {
240 Some(x) => {
241 let x = format!("{x}");
242 parse_quote!(Some(#x))
243 }
244 None => parse_quote!(None),
245 };
246 schema.push(parse_quote!(fields.push((#name_str, #wire::internal::type_id::<#ty>()));));
247 types.insert(ty);
248 }
249 encode.push(parse_quote!(<#ty as #wire::internal::Wire>::encode(#name, writer)?;));
250 decode.push(parse_quote!(let #name = <#ty as #wire::internal::Wire>::decode(reader)?;));
251 }
252 encode.push(syn::Stmt::Expr(parse_quote!(Ok(())), None));
253 let fields_pat = fields_pat(name, fields, false);
254 decode.push(syn::Stmt::Expr(parse_quote!(Ok(#fields_pat)), None));
255 Ok((schema, encode, decode))
256}
257
258fn fields_pat(name: &syn::Path, fields: &syn::Fields, ref_: bool) -> syn::Pat {
259 let ref_: Option<syn::Token![ref]> = ref_.then_some(parse_quote!(ref));
260 let names = fields.iter().enumerate().map(|(i, field)| field_name(i, field));
261 match fields {
262 syn::Fields::Named(_) => parse_quote!(#name { #(#ref_ #names),* }),
263 syn::Fields::Unnamed(_) => parse_quote!(#name(#(#ref_ #names),*)),
264 syn::Fields::Unit => parse_quote!(#name),
265 }
266}
267
268fn field_name(i: usize, field: &syn::Field) -> syn::Ident {
269 match &field.ident {
270 Some(x) => x.clone(),
271 None => syn::Ident::new(&format!("x{i}"), field.span()),
272 }
273}
274
275struct MakeType<'a> {
276 src: &'a syn::Lifetime,
277 dst: &'a syn::Lifetime,
278 statics: &'a [syn::Type],
279}
280
281impl syn::visit_mut::VisitMut for MakeType<'_> {
282 fn visit_generic_argument_mut(&mut self, x: &mut syn::GenericArgument) {
283 match x {
284 syn::GenericArgument::Lifetime(a) if a == self.src => *a = self.dst.clone(),
285 syn::GenericArgument::Type(t) if !self.statics.contains(t) => {
286 let a = &self.dst;
287 *x = parse_quote!(#t::Type<#a>)
288 }
289 _ => (),
290 }
291 }
292}
293
294#[derive(Default)]
295struct Tags {
296 used: HashSet<u32>,
297 next: u32,
298}
299
300impl Tags {
301 fn use_(&mut self, span: Span, tag: u32) -> syn::Result<u32> {
302 if !self.used.insert(tag) {
303 return Err(syn::Error::new(span, "duplicate tag"));
304 }
305 Ok(self.update_next(tag))
306 }
307
308 fn next(&mut self) -> u32 {
309 while !self.used.insert(self.next) {
310 self.next = self.next.wrapping_add(1);
311 }
312 self.update_next(self.next)
313 }
314
315 fn update_next(&mut self, tag: u32) -> u32 {
316 self.next = tag.wrapping_add(1);
317 tag
318 }
319}
320
321#[derive(Default)]
322struct Types<'a>(Vec<&'a syn::Type>);
323
324impl<'a> Types<'a> {
325 fn insert(&mut self, x: &'a syn::Type) {
326 if !self.0.contains(&x) {
327 self.0.push(x);
328 }
329 }
330
331 fn stmts(&self, schema: &mut Vec<syn::Stmt>, wire: &syn::Path, fun: &str, var: &str) {
332 let types: Vec<syn::Stmt> =
333 self.0.iter().map(|ty| parse_quote!(#wire::internal::schema::<#ty>(rules);)).collect();
334 let fun = syn::Ident::new(fun, Span::call_site());
335 let var = syn::Ident::new(var, Span::call_site());
336 schema.push(parse_quote! {
337 if rules.#fun::<Self::Type<'static>>(#var) {
338 #(#types)*
339 }
340 });
341 }
342}
343
344enum AttrsKind {
345 Item,
346 Variant,
347 Invalid,
348}
349
350#[derive(PartialEq, Eq)]
351enum AttrKind {
352 Crate,
353 Param,
354 Where,
355 Tag,
356 Static,
357 Range,
358}
359
360#[derive(Default)]
361enum Attr<T> {
362 #[default]
363 Absent,
364 Present(Span, T),
365}
366
367impl<T> Attr<T> {
368 fn span(&self) -> Option<Span> {
369 match self {
370 Attr::Absent => None,
371 Attr::Present(x, _) => Some(*x),
372 }
373 }
374
375 fn set(&mut self, span: Span, value: T) -> syn::Result<()> {
376 match self {
377 Attr::Absent => Ok(*self = Attr::Present(span, value)),
378 Attr::Present(other, _) => {
379 let mut error = syn::Error::new(span, "attribute already defined");
380 error.combine(syn::Error::new(*other, "first attribute definition"));
381 Err(error)
382 }
383 }
384 }
385
386 fn take(&mut self) -> Self {
387 std::mem::take(self)
388 }
389}
390
391impl<T> Attr<Vec<T>> {
392 fn push(&mut self, span: Span, value: T) {
393 match self {
394 Attr::Absent => *self = Attr::Present(span, vec![value]),
395 Attr::Present(_, values) => values.push(value),
396 }
397 }
398}
399
400#[derive(Default)]
401struct Attrs {
402 crate_: Attr<syn::Path>,
403 param: Attr<syn::Lifetime>,
404 where_: Attr<Vec<syn::WherePredicate>>,
405 tag: Attr<u32>,
406 static_: Attr<Vec<syn::Ident>>,
407 range: Attr<u32>,
408 refine: Attr<syn::Path>,
409}
410
411impl Attrs {
412 fn parse(attrs: &[syn::Attribute], kind: AttrsKind) -> syn::Result<Self> {
413 let mut result = Attrs::default();
414 for attr in attrs {
415 result.parse_attr(attr)?;
416 }
417 result.check_kind(kind)?;
418 Ok(result)
419 }
420
421 fn parse_attr(&mut self, attr: &syn::Attribute) -> syn::Result<()> {
422 if !attr.path().is_ident("wire") {
423 return Ok(());
424 }
425 attr.parse_nested_meta(|meta| {
426 if meta.path.is_ident("crate") {
427 self.crate_.set(attr.span(), meta.value()?.parse()?)?;
429 }
430 if meta.path.is_ident("param") {
431 self.param.set(attr.span(), meta.value()?.parse()?)?;
433 }
434 if meta.path.is_ident("where") {
435 self.where_.push(attr.span(), meta.value()?.parse()?);
437 }
438 if meta.path.is_ident("tag") {
439 let tag: syn::LitInt = meta.value()?.parse()?;
441 self.tag.set(attr.span(), tag.base10_parse()?)?;
442 }
443 if meta.path.is_ident("static") {
444 self.static_.push(attr.span(), meta.value()?.parse()?);
446 }
447 if meta.path.is_ident("range") {
448 let range: syn::LitInt = meta.value()?.parse()?;
450 self.range.set(attr.span(), range.base10_parse()?)?;
451 }
452 if meta.path.is_ident("refine") {
453 self.refine.set(attr.span(), meta.value()?.parse()?)?;
455 }
456 Ok(())
457 })
458 }
459
460 fn check_kind(&self, kind: AttrsKind) -> syn::Result<()> {
461 let expected: &[AttrKind] = match kind {
462 AttrsKind::Item => &[
463 AttrKind::Crate,
464 AttrKind::Param,
465 AttrKind::Where,
466 AttrKind::Static,
467 AttrKind::Range,
468 ],
469 AttrsKind::Variant => &[AttrKind::Tag],
470 AttrsKind::Invalid => &[],
471 };
472 let check = |name, actual, expected: &[AttrKind]| {
473 if let Some(actual) = actual
474 && !expected.contains(&name)
475 {
476 return Err(syn::Error::new(actual, "unexpected attribute"));
477 }
478 Ok(())
479 };
480 check(AttrKind::Crate, self.crate_.span(), expected)?;
481 check(AttrKind::Param, self.param.span(), expected)?;
482 check(AttrKind::Where, self.where_.span(), expected)?;
483 check(AttrKind::Tag, self.tag.span(), expected)?;
484 check(AttrKind::Static, self.static_.span(), expected)?;
485 check(AttrKind::Range, self.range.span(), expected)?;
486 Ok(())
487 }
488}