Skip to main content

simple_dst_derive/
lib.rs

1use proc_macro2::{Span, TokenStream};
2use quote::quote;
3use syn::{
4    Attribute, Data, DataEnum, DataStruct, DataUnion, DeriveInput, Error, GenericParam, Generics,
5    Ident, Member, Path, Token, TraitBound, Type, TypeParamBound, Visibility, parse_macro_input,
6    parse_quote,
7};
8
9fn require_repr_c(attrs: &[Attribute]) -> syn::Result<()> {
10    let mut found = false;
11    for attr in attrs {
12        if !attr.path().is_ident("repr") {
13            continue;
14        }
15
16        attr.parse_nested_meta(|meta| {
17            if meta.path.is_ident("C") {
18                Ok(())
19            } else {
20                Err(meta.error("only #[repr(C)] is supported"))
21            }
22        })?;
23        if found {
24            return Err(syn::Error::new_spanned(attr, "only one #[repr(C)] allowed"));
25        }
26        found = true;
27    }
28    if !found {
29        return Err(syn::Error::new(
30            Span::call_site(),
31            "type must be #[repr(C)]",
32        ));
33    }
34    Ok(())
35}
36
37fn get_fields(
38    data: &Data,
39) -> syn::Result<(
40    impl Iterator<Item = Member> + Clone,
41    impl Iterator<Item = &Type> + Clone,
42    usize,
43)> {
44    Ok(match data {
45        Data::Struct(DataStruct { fields, .. }) => {
46            (fields.members(), fields.iter().map(|f| &f.ty), fields.len())
47        }
48        Data::Enum(DataEnum { enum_token, .. }) => {
49            return Err(Error::new_spanned(enum_token, "only structs are supported"));
50        }
51        Data::Union(DataUnion { union_token, .. }) => {
52            return Err(Error::new_spanned(
53                union_token,
54                "only structs are supported",
55            ));
56        }
57    })
58}
59
60struct DstAttrs {
61    simple_dst_path: Path,
62    new_unchecked_vis: Visibility,
63}
64
65fn get_dst_attrs(attrs: &[Attribute]) -> syn::Result<DstAttrs> {
66    let mut simple_dst_path: Option<Path> = None;
67    let mut new_unchecked_vis: Option<Visibility> = None;
68    for attr in attrs {
69        if !attr.path().is_ident("dst") {
70            continue;
71        }
72
73        attr.parse_nested_meta(|meta| {
74            if meta.path.is_ident("simple_dst_path") {
75                if simple_dst_path.is_some() {
76                    return Err(meta.error("only one #[dst(simple_dst_path = ...)] is allowed"));
77                }
78                simple_dst_path = Some({
79                    meta.input.parse::<Token![=]>()?;
80                    meta.input.parse()?
81                });
82            } else if meta.path.is_ident("new_unchecked_vis") {
83                if new_unchecked_vis.is_some() {
84                    return Err(meta.error("only one #[dst(new_unchecked_vis = ...)] is allowed"));
85                }
86                new_unchecked_vis = Some({
87                    meta.input.parse::<Token![=]>()?;
88                    meta.input.parse()?
89                });
90            } else {
91                return Err(meta.error("unrecognised #[dst(...)] argument"));
92            }
93            Ok(())
94        })?;
95    }
96
97    let dst_attrs = DstAttrs {
98        simple_dst_path: simple_dst_path.unwrap_or_else(|| parse_quote! { ::simple_dst }),
99        new_unchecked_vis: new_unchecked_vis.unwrap_or(Visibility::Inherited),
100    };
101    Ok(dst_attrs)
102}
103
104fn has_unsized_bound<'a>(bounds: impl Iterator<Item = &'a TypeParamBound>) -> bool {
105    for bound in bounds {
106        if let TypeParamBound::Trait(TraitBound {
107            modifier: syn::TraitBoundModifier::Maybe(_),
108            lifetimes: None,
109            path,
110            ..
111        }) = bound
112            && path.is_ident("Sized")
113        {
114            return true;
115        }
116    }
117    false
118}
119
120fn add_dst_trait_bounds(mut generics: Generics, simple_dst_path: &Path) -> Generics {
121    for param in &mut generics.params {
122        if let GenericParam::Type(type_param) = param
123            && has_unsized_bound(type_param.bounds.iter())
124        {
125            type_param
126                .bounds
127                .push(parse_quote! { #simple_dst_path::Dst });
128            type_param
129                .bounds
130                .push(parse_quote! { #simple_dst_path::CloneToUninit });
131        }
132    }
133    generics
134}
135
136/// Derive macro for the `Dst` trait.
137///
138/// This derive also creates a `new_unchecked` function, which takes each of the fields
139/// in the struct as arguments, with the last field (the DST) being taken as a reference.
140/// This `new_unchecked` function is marked as `unsafe` as it doesn't check any of the
141/// type's interior invariants. The visibility of this generated function can be modified
142/// with the `#[dst(new_unchecked_vis = ...)]` attribute.
143///
144/// The path to the `simple_dst` crate can be modified with the
145/// `#[dst(simple_dst_path = ...)]` attribute.
146///
147/// If there are any generic types with a `?Sized` trait bound, those are assumed to be
148/// the type of the DST, so the `Dst` and `CloneToUninit` trait bounds will be added.
149#[proc_macro_derive(Dst, attributes(dst))]
150pub fn derive_dst(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
151    let input = parse_macro_input!(input as DeriveInput);
152    derive_dst_impl(input)
153        .unwrap_or_else(syn::Error::into_compile_error)
154        .into()
155}
156
157fn derive_dst_impl(input: DeriveInput) -> syn::Result<TokenStream> {
158    require_repr_c(&input.attrs)?;
159
160    let name = input.ident;
161
162    let DstAttrs {
163        simple_dst_path,
164        new_unchecked_vis,
165    } = get_dst_attrs(&input.attrs)?;
166
167    let generics = add_dst_trait_bounds(input.generics, &simple_dst_path);
168    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
169
170    let (members, tys, n_fields) = get_fields(&input.data)?;
171    if n_fields == 0 {
172        return Err(Error::new_spanned(
173            name,
174            "type must have at least one field",
175        ));
176    }
177
178    let idxs: Vec<_> = (0..n_fields).collect();
179    let last_idx = n_fields - 1;
180    let first_idxs: Vec<_> = (0..n_fields - 1).collect();
181
182    let first_members: Vec<_> = members.clone().take(n_fields - 1).collect();
183    let last_member = members.clone().nth(last_idx);
184
185    let first_tys: Vec<_> = tys.clone().take(n_fields - 1).collect();
186    let last_ty = tys.clone().nth(last_idx);
187
188    Ok(quote! {
189        #[automatically_derived]
190        unsafe impl #impl_generics #simple_dst_path::Dst for #name #ty_generics #where_clause {
191            fn len(&self) -> usize {
192                #simple_dst_path::Dst::len(&self.#last_member)
193            }
194
195            fn layout(len: usize) -> ::core::result::Result<::core::alloc::Layout, ::core::alloc::LayoutError> {
196                let (layout, _) = Self::__dst_impl_layout_offsets(len)?;
197                ::core::result::Result::Ok(layout)
198            }
199
200            fn retype(ptr: ::core::ptr::NonNull<u8>, len: usize) -> ::core::ptr::NonNull<Self> {
201                // FUTURE: switch to ptr::from_raw_parts_mut() when it has stabilised.
202                // SAFETY: the pointer value doesn't change when using `slice_from_raw_parts_mut`,
203                // so the invariants of `NonNull` are upheld
204                unsafe {
205                    #[allow(
206                        clippy::cast_ptr_alignment,
207                        reason = "the responsibility to provide a pointer with the correct alignment is on the caller"
208                    )]
209                    ::core::ptr::NonNull::new_unchecked(::core::ptr::slice_from_raw_parts_mut(ptr.as_ptr(), len) as *mut Self)
210                }
211            }
212        }
213
214        #[automatically_derived]
215        impl #impl_generics #name #ty_generics #where_clause {
216            #[doc(hidden)]
217            #[inline]
218            fn __dst_impl_layout_offsets(len: usize) -> ::core::result::Result<(::core::alloc::Layout, [usize; #n_fields]), ::core::alloc::LayoutError> {
219                let layouts = [#(::core::alloc::Layout::new::<#first_tys>()),*, <#last_ty as #simple_dst_path::Dst>::layout(len)?];
220                let mut offsets = [0; #n_fields];
221                let layout = ::core::alloc::Layout::from_size_align(0, 1)?;
222                #(
223                    let (layout, offset) = layout.extend(layouts[#idxs])?;
224                    offsets[#idxs] = offset;
225                )*
226                ::core::result::Result::Ok((layout.pad_to_align(), offsets))
227            }
228
229            #new_unchecked_vis unsafe fn new_unchecked<A: #simple_dst_path::AllocDst<Self>>(
230                #( #first_members: #first_tys, )*
231                #last_member: &#last_ty
232            ) -> ::core::result::Result<A, ::core::alloc::LayoutError> {
233                let (layout, offsets) = Self::__dst_impl_layout_offsets(#last_member.len())?;
234                Ok(unsafe {
235                    A::new_dst(<#last_ty as #simple_dst_path::Dst>::len(#last_member), layout, |ptr| {
236                        let dest = ptr.cast::<u8>();
237
238                        <#last_ty as #simple_dst_path::CloneToUninit>::clone_to_uninit(#last_member, dest.add(offsets[#last_idx]).as_ptr());
239
240                        #(
241                            dest.add(offsets[#first_idxs]).cast::<#first_tys>().write(#first_members);
242                        )*
243                    })
244                })
245            }
246        }
247    })
248}
249
250fn add_clone_to_uninit_trait_bounds(mut generics: Generics, simple_dst_path: &Path) -> Generics {
251    for param in &mut generics.params {
252        if let GenericParam::Type(type_param) = param {
253            let bound = if has_unsized_bound(type_param.bounds.iter()) {
254                parse_quote! { #simple_dst_path::CloneToUninit }
255            } else {
256                parse_quote! { ::core::clone::Clone }
257            };
258            type_param.bounds.push(bound);
259        }
260    }
261    generics
262}
263
264/// Derive macro for the `CloneToUninit` trait for DSTs.
265///
266/// If there are any generic types with a `?Sized` trait bound, those are assumed to be
267/// the type of the DST, so the `CloneToUninit` trait bound will be added. All other
268/// generic types will have the [`Clone`] trait bound added.
269#[proc_macro_derive(CloneToUninit, attributes(dst))]
270pub fn derive_clone_to_uninit(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
271    let input = parse_macro_input!(input as DeriveInput);
272    derive_clone_to_uninit_impl(input)
273        .unwrap_or_else(syn::Error::into_compile_error)
274        .into()
275}
276
277fn derive_clone_to_uninit_impl(input: DeriveInput) -> syn::Result<TokenStream> {
278    let name = input.ident;
279
280    let DstAttrs {
281        simple_dst_path, ..
282    } = get_dst_attrs(&input.attrs)?;
283
284    let generics = add_clone_to_uninit_trait_bounds(input.generics, &simple_dst_path);
285    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
286
287    let (members, tys, n_fields) = get_fields(&input.data)?;
288    if n_fields == 0 {
289        return Err(Error::new_spanned(
290            name,
291            "type must have at least one field",
292        ));
293    }
294
295    let last_idx = n_fields - 1;
296
297    // let first_idxs: Vec<_> = (0..n_fields - 1).collect();
298
299    let first_members: Vec<_> = members.clone().take(n_fields - 1).collect();
300    let last_member = members.clone().nth(last_idx);
301
302    let first_tys: Vec<_> = tys.clone().take(n_fields - 1).collect();
303    let last_ty = tys.clone().nth(last_idx);
304
305    Ok(quote! {
306        #[automatically_derived]
307        unsafe impl #impl_generics #simple_dst_path::CloneToUninit for #name #ty_generics #where_clause {
308            unsafe fn clone_to_uninit(&self, dest: *mut u8) {
309                // SAFETY:
310                // * `&self.slice` >= `self` because `slice` is a field in `self`, and Self is
311                //   `#[repr(C)]`.
312                // * both pointers must be from the same allocation since they are within the
313                //   same object, and thus the memory range between them is also in bounds of
314                //   the object.
315                // * TODO: how can the distance be an exact multiple of [i128]?
316                let last_offset = unsafe { (&raw const self.#last_member).byte_offset_from_unsigned(self) };
317
318                #(
319                    let #first_members = <#first_tys as ::core::clone::Clone>::clone(&self.#first_members);
320                )*
321
322                unsafe {
323                    <#last_ty as #simple_dst_path::CloneToUninit>::clone_to_uninit(&self.#last_member, dest.add(last_offset));
324
325                    #(
326                        dest.add(::core::mem::offset_of!(Self, #first_members)).cast::<#first_tys>().write(#first_members);
327                    )*
328                }
329            }
330        }
331    })
332}
333
334struct ToOwnedAttrs {
335    alloc_path: Path,
336    owned: Type,
337}
338
339fn get_to_owned_attrs(attrs: &[Attribute], name: &Ident) -> syn::Result<ToOwnedAttrs> {
340    let mut alloc_path: Option<Path> = None;
341    let mut owned: Option<Type> = None;
342    for attr in attrs {
343        if !attr.path().is_ident("to_owned") {
344            continue;
345        }
346
347        attr.parse_nested_meta(|meta| {
348            if meta.path.is_ident("alloc_path") {
349                if alloc_path.is_some() {
350                    return Err(meta.error("only one #[to_owned(alloc_path = ...)] is allowed"));
351                }
352                alloc_path = Some({
353                    meta.input.parse::<Token![=]>()?;
354                    meta.input.parse()?
355                });
356            } else if meta.path.is_ident("owned") {
357                if owned.is_some() {
358                    return Err(meta.error("only one #[to_owned(owned = ...)] is allowed"));
359                }
360                owned = Some({
361                    meta.input.parse::<Token![=]>()?;
362                    meta.input.parse()?
363                });
364            } else {
365                return Err(meta.error("unrecognised #[to_owned(...)] argument"));
366            }
367            Ok(())
368        })?;
369    }
370
371    let alloc_path = alloc_path.unwrap_or_else(|| parse_quote! { ::std });
372    let to_owned_attrs = ToOwnedAttrs {
373        alloc_path: alloc_path.clone(),
374        owned: owned.unwrap_or_else(|| parse_quote! { #alloc_path::boxed::Box<#name> }),
375    };
376    Ok(to_owned_attrs)
377}
378
379#[proc_macro_derive(ToOwned, attributes(dst, to_owned))]
380pub fn derive_to_owned(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
381    let input = parse_macro_input!(input as DeriveInput);
382    derive_to_owned_impl(input)
383        .unwrap_or_else(syn::Error::into_compile_error)
384        .into()
385}
386
387fn derive_to_owned_impl(input: DeriveInput) -> syn::Result<TokenStream> {
388    let name = input.ident;
389
390    let DstAttrs {
391        simple_dst_path, ..
392    } = get_dst_attrs(&input.attrs)?;
393    let ToOwnedAttrs { alloc_path, owned } = get_to_owned_attrs(&input.attrs, &name)?;
394
395    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
396
397    Ok(quote! {
398        #[automatically_derived]
399        impl #impl_generics #alloc_path::borrow::ToOwned for #name #ty_generics #where_clause {
400            type Owned = #owned;
401
402            fn to_owned(&self) -> Self::Owned {
403                let layout = ::core::alloc::Layout::for_value(self);
404
405                unsafe {
406                    <#owned as #simple_dst_path::AllocDst<#name>>::new_dst(
407                        <#name as #simple_dst_path::Dst>::len(self),
408                        layout,
409                        |ptr| {
410                            let dest = ptr.cast::<u8>();
411
412                            <#name as #simple_dst_path::CloneToUninit>::clone_to_uninit(self, dest.as_ptr())
413                        },
414                    )
415                }
416            }
417        }
418    })
419}