1use core::fmt;
2
3use attrs::*;
4use proc_macro2::{Span, TokenStream};
5use quote::{ToTokens, quote};
6use syn::{
7    Token,
8    parse::{Parse, ParseStream},
9    parse_quote,
10    punctuated::Punctuated,
11    spanned::Spanned,
12    token,
13};
14
15#[proc_macro]
16pub fn assert_abi(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
17    syn::parse_macro_input!(item with do_assert_abi).into()
18}
19
20fn do_assert_abi(input: ParseStream) -> syn::Result<TokenStream> {
21    let mut krate: syn::Path = parse_quote!(seasick);
22    if input.peek(Token![crate]) && (input.peek2(Token![=]) || input.peek2(token::Paren)) {
23        input.parse::<Token![crate]>()?;
24        with::peq(on::maybe_str(&mut krate))(input)?;
25        input.parse::<Token![;]>()?;
26    }
27    let asserts = input.parse_terminated(AssertAbi::parse, Token![;])?;
28
29    let mut checks = TokenStream::new();
30    for AssertAbi {
31        this:
32            ref this @ Cast {
33                path: _,
34                as_token: _,
35                bare_fn:
36                    syn::TypeBareFn {
37                        lifetimes: _,
38                        ref unsafety,
39                        ref abi,
40                        fn_token: _,
41                        paren_token: _,
42                        ref inputs,
43                        ref variadic,
44                        ref output,
45                    },
46            },
47        mut other,
48    } in asserts
49    {
50        check_abi(abi.as_ref(), other.bare_fn.abi.as_ref())?;
51        if let (None, Some(it)) = (unsafety, &other.bare_fn.unsafety) {
52            return Err(syn::Error::new(
53                it.span,
54                "Safe fn may not be replaced by an unsafe fn",
55            ));
56        }
57        if let (Some(var), _) | (_, Some(var)) = (variadic, &other.bare_fn.variadic) {
58            return Err(syn::Error::new(
59                var.span(),
60                "Variadic functions are not supported",
61            ));
62        }
63        if inputs.len() != other.bare_fn.inputs.len() {
64            return Err(syn::Error::new(
65                inputs.span(),
66                format_args!(
67                    "Mismatched function arity ({} vs {})",
68                    inputs.len(),
69                    other.bare_fn.inputs.len()
70                ),
71            ));
72        }
73        for (ix, param) in other.bare_fn.inputs.iter_mut().enumerate() {
74            if matches!(param.ty, syn::Type::Infer(_)) {
75                let ty = &inputs[ix].ty;
76                param.ty = parse_quote!(#ty)
77            }
78        }
79        for (ix, (this, other)) in inputs.iter().zip(&other.bare_fn.inputs).enumerate() {
80            let this = &this.ty;
81            let other = &other.ty;
82            let bad_size = format!("Mismatched size in parameter {ix}");
83            let bad_align = format!("Mismatched align in parameter {ix}");
84            checks.extend(quote! {
85                let this = Layout::new::<#this>();
86                let other = Layout::new::<#other>();
87                if this.size() != other.size() {
88                    panic!(#bad_size)
89                }
90                if this.align() != other.align() {
91                    panic!(#bad_align)
92                }
93            });
94        }
95        if let syn::ReturnType::Type(_, it) = &mut other.bare_fn.output {
96            if let syn::Type::Infer(_) = &mut **it {
97                other.bare_fn.output = parse_quote!(#output)
98            }
99        }
100        let this_ret = return_type(output);
101        let other_ret = return_type(&other.bare_fn.output);
102
103        checks.extend(quote! {
104            let this = Layout::new::<#this_ret>();
105            let other = Layout::new::<#other_ret>();
106            if this.size() != other.size() {
107                panic!("Mismatched size in return type")
108            }
109            if this.align() != other.align() {
110                panic!("Mismatched align in return type")
111            }
112            let _ = #this;
113            let _ = #other;
114        });
115    }
116    Ok(quote! {
117        const _: () = {
118            use #krate::{
119                __private::{
120                    core::{
121                        alloc::Layout,
122                        panic,
123                    }
124                }
125            };
126
127            #checks
128        };
129    })
130}
131
132fn return_type(it: &syn::ReturnType) -> syn::Type {
133    match it {
134        syn::ReturnType::Default => syn::Type::Tuple(syn::TypeTuple {
135            paren_token: token::Paren::default(),
136            elems: Punctuated::new(),
137        }),
138        syn::ReturnType::Type(_, it) => parse_quote!(#it),
139    }
140}
141
142fn check_abi(left: Option<&syn::Abi>, right: Option<&syn::Abi>) -> syn::Result<()> {
143    let left = abi_string(left);
144    let right = abi_string(right);
145    if left == right {
146        return Ok(());
147    }
148
149    match (&*left, &*right) {
150        ("C-unwind", "C") => Ok(()),
151        _ => Err({
152            let mut e = syn::Error::new(left.span(), "Mismatched ABI");
153            e.combine(syn::Error::new(right.span(), "Mismatched ABI"));
154            e
155        }),
156    }
157}
158
159fn abi_string(abi: Option<&syn::Abi>) -> String {
160    match abi {
161        Some(it) => match &it.name {
162            Some(s) => s.value(),
163            None => String::from("C"),
164        },
165        None => String::from("Rust"),
166    }
167}
168
169struct AssertAbi {
170    this: Cast,
171    other: Cast,
172}
173impl Parse for AssertAbi {
174    fn parse(input: ParseStream) -> syn::Result<Self> {
175        Ok(Self {
176            this: input.parse()?,
177            other: {
178                input.parse::<Token![==]>()?;
179                input.parse()?
180            },
181        })
182    }
183}
184
185struct Cast {
186    path: syn::Path,
187    as_token: Token![as],
188    bare_fn: syn::TypeBareFn,
189}
190
191impl Parse for Cast {
192    fn parse(input: ParseStream) -> syn::Result<Self> {
193        Ok(Self {
194            path: input.parse()?,
195            as_token: input.parse()?,
196            bare_fn: input.parse()?,
197        })
198    }
199}
200
201impl ToTokens for Cast {
202    fn to_tokens(&self, tokens: &mut TokenStream) {
203        let Self {
204            path,
205            as_token,
206            bare_fn,
207        } = self;
208        path.to_tokens(tokens);
209        as_token.to_tokens(tokens);
210        bare_fn.to_tokens(tokens);
211    }
212}
213
214#[proc_macro_derive(TransmuteFrom, attributes(transmute))]
215pub fn transmute_from(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
216    expand_transmute_from(syn::parse_macro_input!(item as _))
217        .unwrap_or_else(syn::Error::into_compile_error)
218        .into()
219}
220
221fn expand_transmute_from(input: syn::DeriveInput) -> syn::Result<TokenStream> {
222    let syn::DeriveInput {
223        ref attrs,
224        vis: _,
225        ident: local,
226        generics,
227        data,
228    } = input;
229    let syn::DataStruct {
230        struct_token: _,
231        fields,
232        semi_token: _,
233    } = as_struct(data)?;
234
235    let ContainerArgs {
236        from: remote,
237        krate,
238        strict,
239    } = ContainerArgs::parse(local.span(), attrs)?;
240
241    let mut checks = TokenStream::new();
242    let mut remote_members = vec![];
243
244    for StructMember {
245        ref attrs,
246        member: local_member,
247        ty: local_ty,
248    } in struct_members(fields)
249    {
250        let (remote_member, remote_ty) = match FieldArgs::parse_attrs(attrs)? {
251            FieldArgs::Skip => continue,
252            FieldArgs::Field { member, ty } => (member.unwrap_or(parse_quote!(#local_member)), ty),
253        };
254
255        let preamble = format!(
256            "`{local}.{}` and (from) `{}.{}`: mismatched",
257            Fmt(&local_member),
258            Fmt(&remote),
259            Fmt(&remote_member)
260        );
261        let bad_size = format!("{preamble} size");
262        let bad_align = format!("{preamble} alignment");
263        let bad_offset = format!("{preamble} offset");
264
265        checks.extend(quote! {
266            let local = layout_of_field(|it: &#local|&it.#local_member);
267            let remote = layout_of_field(|it: &#remote| &it.#remote_member);
268
269            if local.size() != remote.size() {
270                panic!(#bad_size)
271            }
272            if local.align() != remote.align() {
273                panic!(#bad_align)
274            }
275            if offset_of!(#local, #local_member) != offset_of!(#remote, #remote_member) {
276                panic!(#bad_offset)
277            }
278        });
279
280        let remote_ty = match strict {
281            true => Some(remote_ty.unwrap_or(local_ty)),
282            false => remote_ty,
283        };
284
285        if let Some(remote_ty) = remote_ty {
286            if strict && matches!(remote_ty, syn::Type::Infer(_)) {
287                return Err(syn::Error::new(
288                    remote_ty.span(),
289                    "Inferred types may not be used with `strict`",
290                ));
291            }
292            checks.extend(quote! {
293                layout_of_field::<#remote, #remote_ty>(|it| &it.#remote_member);
294            });
295        }
296
297        remote_members.push(remote_member);
298    }
299
300    checks.extend(quote! {
301        fn exhaustive(#remote {
302            #(#remote_members: _,)*
303        }: #remote) {}
304    });
305
306    let preamble = format!("`{local}` and (from) `{}`: mismatched", Fmt(&remote));
307    let bad_size = format!("{preamble} size");
308    let bad_align = format!("{preamble} alignment");
309
310    checks.extend(quote! {
311        let local = Layout::new::<#local>();
312        let remote = Layout::new::<#remote>();
313
314        if local.size() != remote.size() {
315            panic!(#bad_size)
316        }
317        if local.align() != remote.align() {
318            panic!(#bad_align)
319        }
320    });
321
322    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
323
324    Ok(quote! {
325        const _: () = {
326            use #krate::{
327                TransmuteFrom,
328                __private::{
329                    layout_of_field,
330                    core::{
331                        alloc::Layout,
332                        panic,
333                        mem::{offset_of, transmute},
334                    }
335                }
336            };
337
338            #checks
339
340            unsafe impl #impl_generics TransmuteFrom<#remote> for #local #ty_generics #where_clause {
341                unsafe fn transmute_from(remote: #remote) -> Self {
342                    unsafe { transmute::<#remote, #local>(remote) }
343                }
344            }
345            unsafe impl #impl_generics TransmuteFrom<#local #ty_generics> for #remote #where_clause {
346                unsafe fn transmute_from(local: #local) -> Self {
347                    unsafe { transmute::<#local, #remote>(local) }
348                }
349            }
350        };
351    })
352}
353
354struct ContainerArgs {
355    from: syn::Path,
356    krate: syn::Path,
357    strict: bool,
358}
359
360impl ContainerArgs {
361    fn parse(span: Span, attrs: &[syn::Attribute]) -> syn::Result<Self> {
362        let mut from = None;
363        let mut krate = parse_quote!(seasick);
364        let mut strict = false;
365
366        Attrs::new()
367            .once("from", with::peq(set::maybe_str(&mut from)))
368            .once("crate", with::peq(on::maybe_str(&mut krate)))
369            .once("strict", flag::or_peq(&mut strict))
370            .parse_attrs("transmute", attrs)?;
371
372        match from {
373            Some(from) => Ok(Self {
374                from,
375                krate,
376                strict,
377            }),
378            None => Err(syn::Error::new(span, "Requires `#[transmute(from = ...)]`")),
379        }
380    }
381}
382
383enum FieldArgs {
384    Skip,
385    Field {
386        member: Option<syn::Member>,
387        ty: Option<syn::Type>,
388    },
389}
390
391impl FieldArgs {
392    fn parse_attrs(attrs: &[syn::Attribute]) -> syn::Result<Self> {
393        let attributes = &*attrs
394            .iter()
395            .filter(|it| it.meta.path().is_ident("transmute"))
396            .collect::<Vec<_>>();
397        match attributes {
398            [] => Ok(Self::Field {
399                member: None,
400                ty: None,
401            }),
402            [one] => {
403                let this = one.parse_args();
404                match this {
405                    Ok(t) => Ok(t),
406                    Err(mut e) => Err({
407                        let message = "Expected `#[transmute(skip)]`, `#[transmute($ident: $ty)` or `#[transmute($ty)`";
408                        e.combine(syn::Error::new(one.span(), message));
409                        e
410                    }),
411                }
412            }
413            [_, two, ..] => Err(syn::Error::new(
414                two.span(),
415                "Only one `#[transmute(..)]` attribute is permitted",
416            )),
417        }
418    }
419}
420
421impl Parse for FieldArgs {
422    fn parse(input: ParseStream) -> syn::Result<Self> {
423        syn::custom_keyword!(skip);
424        if (input.peek(syn::Ident) || input.peek(syn::LitInt))
425            && input.peek2(Token![:])
426            && !input.peek2(Token![::])
427        {
428            Ok(Self::Field {
429                member: Some(input.parse()?),
430                ty: Some({
431                    input.parse::<Token![:]>()?;
432                    input.parse()?
433                }),
434            })
435        } else if input.peek(skip) && input.peek2(syn::parse::End) {
436            input.parse::<skip>()?;
437            Ok(Self::Skip)
438        } else {
439            Ok(Self::Field {
440                member: None,
441                ty: Some(input.parse()?),
442            })
443        }
444    }
445}
446
447fn as_struct(data: syn::Data) -> syn::Result<syn::DataStruct> {
448    match data {
449        syn::Data::Struct(it) => Ok(it),
450        syn::Data::Enum(syn::DataEnum { enum_token, .. }) => Err(syn::Error::new(
451            enum_token.span,
452            "Only `struct` is supported",
453        )),
454        syn::Data::Union(syn::DataUnion { union_token, .. }) => Err(syn::Error::new(
455            union_token.span,
456            "Only `struct` is supported",
457        )),
458    }
459}
460
461fn struct_members(fields: syn::Fields) -> impl Iterator<Item = StructMember> {
462    fields
463        .into_iter()
464        .enumerate()
465        .map(|(ix, field)| StructMember {
466            attrs: field.attrs,
467            member: match field.ident {
468                Some(it) => it.into(),
469                None => ix.into(),
470            },
471            ty: field.ty,
472        })
473}
474
475struct StructMember {
476    attrs: Vec<syn::Attribute>,
477    member: syn::Member,
478    ty: syn::Type,
479}
480
481struct Fmt<T>(T);
482
483impl fmt::Display for Fmt<&syn::Member> {
484    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
485        match self.0 {
486            syn::Member::Named(ident) => ident.fmt(f),
487            syn::Member::Unnamed(syn::Index { index, .. }) => index.fmt(f),
488        }
489    }
490}
491
492impl fmt::Display for Fmt<&syn::Path> {
493    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
494        let syn::Path {
495            leading_colon,
496            segments,
497        } = self.0;
498        if leading_colon.is_some() {
499            f.write_str("::")?
500        }
501        let mut first = true;
502        for syn::PathSegment {
503            ident,
504            arguments: _,
505        } in segments
506        {
507            match first {
508                true => first = false,
509                false => f.write_str("::")?,
510            }
511            f.write_fmt(format_args!("{}", ident))?
512        }
513        Ok(())
514    }
515}