1use proc_macro2::{Span, TokenStream};
2use quote::quote;
3use syn::{
4 Attribute, Data, DataEnum, DataStruct, DataUnion, DeriveInput, Error, GenericParam, Generics,
5 Ident, Member, Path, Token, TraitBound, Type, TypeParamBound, Visibility, parse_macro_input,
6 parse_quote,
7};
8
9fn require_repr_c(attrs: &[Attribute]) -> syn::Result<()> {
10 let mut found = false;
11 for attr in attrs {
12 if !attr.path().is_ident("repr") {
13 continue;
14 }
15
16 attr.parse_nested_meta(|meta| {
17 if meta.path.is_ident("C") {
18 Ok(())
19 } else {
20 Err(meta.error("only #[repr(C)] is supported"))
21 }
22 })?;
23 if found {
24 return Err(syn::Error::new_spanned(attr, "only one #[repr(C)] allowed"));
25 }
26 found = true;
27 }
28 if !found {
29 return Err(syn::Error::new(
30 Span::call_site(),
31 "type must be #[repr(C)]",
32 ));
33 }
34 Ok(())
35}
36
37fn get_fields(
38 data: &Data,
39) -> syn::Result<(
40 impl Iterator<Item = Member> + Clone,
41 impl Iterator<Item = &Type> + Clone,
42 usize,
43)> {
44 Ok(match data {
45 Data::Struct(DataStruct { fields, .. }) => {
46 (fields.members(), fields.iter().map(|f| &f.ty), fields.len())
47 }
48 Data::Enum(DataEnum { enum_token, .. }) => {
49 return Err(Error::new_spanned(enum_token, "only structs are supported"));
50 }
51 Data::Union(DataUnion { union_token, .. }) => {
52 return Err(Error::new_spanned(
53 union_token,
54 "only structs are supported",
55 ));
56 }
57 })
58}
59
60struct DstAttrs {
61 simple_dst_path: Path,
62 new_unchecked_vis: Visibility,
63}
64
65fn get_dst_attrs(attrs: &[Attribute]) -> syn::Result<DstAttrs> {
66 let mut simple_dst_path: Option<Path> = None;
67 let mut new_unchecked_vis: Option<Visibility> = None;
68 for attr in attrs {
69 if !attr.path().is_ident("dst") {
70 continue;
71 }
72
73 attr.parse_nested_meta(|meta| {
74 if meta.path.is_ident("simple_dst_path") {
75 if simple_dst_path.is_some() {
76 return Err(meta.error("only one #[dst(simple_dst_path = ...)] is allowed"));
77 }
78 simple_dst_path = Some({
79 meta.input.parse::<Token![=]>()?;
80 meta.input.parse()?
81 });
82 } else if meta.path.is_ident("new_unchecked_vis") {
83 if new_unchecked_vis.is_some() {
84 return Err(meta.error("only one #[dst(new_unchecked_vis = ...)] is allowed"));
85 }
86 new_unchecked_vis = Some({
87 meta.input.parse::<Token![=]>()?;
88 meta.input.parse()?
89 });
90 } else {
91 return Err(meta.error("unrecognised #[dst(...)] argument"));
92 }
93 Ok(())
94 })?;
95 }
96
97 let dst_attrs = DstAttrs {
98 simple_dst_path: simple_dst_path.unwrap_or_else(|| parse_quote! { ::simple_dst }),
99 new_unchecked_vis: new_unchecked_vis.unwrap_or(Visibility::Inherited),
100 };
101 Ok(dst_attrs)
102}
103
104fn has_unsized_bound<'a>(bounds: impl Iterator<Item = &'a TypeParamBound>) -> bool {
105 for bound in bounds {
106 if let TypeParamBound::Trait(TraitBound {
107 modifier: syn::TraitBoundModifier::Maybe(_),
108 lifetimes: None,
109 path,
110 ..
111 }) = bound
112 && path.is_ident("Sized")
113 {
114 return true;
115 }
116 }
117 false
118}
119
120fn add_dst_trait_bounds(mut generics: Generics, simple_dst_path: &Path) -> Generics {
121 for param in &mut generics.params {
122 if let GenericParam::Type(type_param) = param
123 && has_unsized_bound(type_param.bounds.iter())
124 {
125 type_param
126 .bounds
127 .push(parse_quote! { #simple_dst_path::Dst });
128 type_param
129 .bounds
130 .push(parse_quote! { #simple_dst_path::CloneToUninit });
131 }
132 }
133 generics
134}
135
136#[proc_macro_derive(Dst, attributes(dst))]
150pub fn derive_dst(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
151 let input = parse_macro_input!(input as DeriveInput);
152 derive_dst_impl(input)
153 .unwrap_or_else(syn::Error::into_compile_error)
154 .into()
155}
156
157fn derive_dst_impl(input: DeriveInput) -> syn::Result<TokenStream> {
158 require_repr_c(&input.attrs)?;
159
160 let name = input.ident;
161
162 let DstAttrs {
163 simple_dst_path,
164 new_unchecked_vis,
165 } = get_dst_attrs(&input.attrs)?;
166
167 let generics = add_dst_trait_bounds(input.generics, &simple_dst_path);
168 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
169
170 let (members, tys, n_fields) = get_fields(&input.data)?;
171 if n_fields == 0 {
172 return Err(Error::new_spanned(
173 name,
174 "type must have at least one field",
175 ));
176 }
177
178 let idxs: Vec<_> = (0..n_fields).collect();
179 let last_idx = n_fields - 1;
180 let first_idxs: Vec<_> = (0..n_fields - 1).collect();
181
182 let first_members: Vec<_> = members.clone().take(n_fields - 1).collect();
183 let last_member = members.clone().nth(last_idx);
184
185 let first_tys: Vec<_> = tys.clone().take(n_fields - 1).collect();
186 let last_ty = tys.clone().nth(last_idx);
187
188 Ok(quote! {
189 #[automatically_derived]
190 unsafe impl #impl_generics #simple_dst_path::Dst for #name #ty_generics #where_clause {
191 fn len(&self) -> usize {
192 #simple_dst_path::Dst::len(&self.#last_member)
193 }
194
195 fn layout(len: usize) -> ::core::result::Result<::core::alloc::Layout, ::core::alloc::LayoutError> {
196 let (layout, _) = Self::__dst_impl_layout_offsets(len)?;
197 ::core::result::Result::Ok(layout)
198 }
199
200 fn retype(ptr: ::core::ptr::NonNull<u8>, len: usize) -> ::core::ptr::NonNull<Self> {
201 unsafe {
205 #[allow(
206 clippy::cast_ptr_alignment,
207 reason = "the responsibility to provide a pointer with the correct alignment is on the caller"
208 )]
209 ::core::ptr::NonNull::new_unchecked(::core::ptr::slice_from_raw_parts_mut(ptr.as_ptr(), len) as *mut Self)
210 }
211 }
212 }
213
214 #[automatically_derived]
215 impl #impl_generics #name #ty_generics #where_clause {
216 #[doc(hidden)]
217 #[inline]
218 fn __dst_impl_layout_offsets(len: usize) -> ::core::result::Result<(::core::alloc::Layout, [usize; #n_fields]), ::core::alloc::LayoutError> {
219 let layouts = [#(::core::alloc::Layout::new::<#first_tys>()),*, <#last_ty as #simple_dst_path::Dst>::layout(len)?];
220 let mut offsets = [0; #n_fields];
221 let layout = ::core::alloc::Layout::from_size_align(0, 1)?;
222 #(
223 let (layout, offset) = layout.extend(layouts[#idxs])?;
224 offsets[#idxs] = offset;
225 )*
226 ::core::result::Result::Ok((layout.pad_to_align(), offsets))
227 }
228
229 #new_unchecked_vis unsafe fn new_unchecked<A: #simple_dst_path::AllocDst<Self>>(
230 #( #first_members: #first_tys, )*
231 #last_member: &#last_ty
232 ) -> ::core::result::Result<A, ::core::alloc::LayoutError> {
233 let (layout, offsets) = Self::__dst_impl_layout_offsets(#last_member.len())?;
234 Ok(unsafe {
235 A::new_dst(<#last_ty as #simple_dst_path::Dst>::len(#last_member), layout, |ptr| {
236 let dest = ptr.cast::<u8>();
237
238 <#last_ty as #simple_dst_path::CloneToUninit>::clone_to_uninit(#last_member, dest.add(offsets[#last_idx]).as_ptr());
239
240 #(
241 dest.add(offsets[#first_idxs]).cast::<#first_tys>().write(#first_members);
242 )*
243 })
244 })
245 }
246 }
247 })
248}
249
250fn add_clone_to_uninit_trait_bounds(mut generics: Generics, simple_dst_path: &Path) -> Generics {
251 for param in &mut generics.params {
252 if let GenericParam::Type(type_param) = param {
253 let bound = if has_unsized_bound(type_param.bounds.iter()) {
254 parse_quote! { #simple_dst_path::CloneToUninit }
255 } else {
256 parse_quote! { ::core::clone::Clone }
257 };
258 type_param.bounds.push(bound);
259 }
260 }
261 generics
262}
263
264#[proc_macro_derive(CloneToUninit, attributes(dst))]
270pub fn derive_clone_to_uninit(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
271 let input = parse_macro_input!(input as DeriveInput);
272 derive_clone_to_uninit_impl(input)
273 .unwrap_or_else(syn::Error::into_compile_error)
274 .into()
275}
276
277fn derive_clone_to_uninit_impl(input: DeriveInput) -> syn::Result<TokenStream> {
278 let name = input.ident;
279
280 let DstAttrs {
281 simple_dst_path, ..
282 } = get_dst_attrs(&input.attrs)?;
283
284 let generics = add_clone_to_uninit_trait_bounds(input.generics, &simple_dst_path);
285 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
286
287 let (members, tys, n_fields) = get_fields(&input.data)?;
288 if n_fields == 0 {
289 return Err(Error::new_spanned(
290 name,
291 "type must have at least one field",
292 ));
293 }
294
295 let last_idx = n_fields - 1;
296
297 let first_members: Vec<_> = members.clone().take(n_fields - 1).collect();
300 let last_member = members.clone().nth(last_idx);
301
302 let first_tys: Vec<_> = tys.clone().take(n_fields - 1).collect();
303 let last_ty = tys.clone().nth(last_idx);
304
305 Ok(quote! {
306 #[automatically_derived]
307 unsafe impl #impl_generics #simple_dst_path::CloneToUninit for #name #ty_generics #where_clause {
308 unsafe fn clone_to_uninit(&self, dest: *mut u8) {
309 let last_offset = unsafe { (&raw const self.#last_member).byte_offset_from_unsigned(self) };
317
318 #(
319 let #first_members = <#first_tys as ::core::clone::Clone>::clone(&self.#first_members);
320 )*
321
322 unsafe {
323 <#last_ty as #simple_dst_path::CloneToUninit>::clone_to_uninit(&self.#last_member, dest.add(last_offset));
324
325 #(
326 dest.add(::core::mem::offset_of!(Self, #first_members)).cast::<#first_tys>().write(#first_members);
327 )*
328 }
329 }
330 }
331 })
332}
333
334struct ToOwnedAttrs {
335 alloc_path: Path,
336 owned: Type,
337}
338
339fn get_to_owned_attrs(attrs: &[Attribute], name: &Ident) -> syn::Result<ToOwnedAttrs> {
340 let mut alloc_path: Option<Path> = None;
341 let mut owned: Option<Type> = None;
342 for attr in attrs {
343 if !attr.path().is_ident("to_owned") {
344 continue;
345 }
346
347 attr.parse_nested_meta(|meta| {
348 if meta.path.is_ident("alloc_path") {
349 if alloc_path.is_some() {
350 return Err(meta.error("only one #[to_owned(alloc_path = ...)] is allowed"));
351 }
352 alloc_path = Some({
353 meta.input.parse::<Token![=]>()?;
354 meta.input.parse()?
355 });
356 } else if meta.path.is_ident("owned") {
357 if owned.is_some() {
358 return Err(meta.error("only one #[to_owned(owned = ...)] is allowed"));
359 }
360 owned = Some({
361 meta.input.parse::<Token![=]>()?;
362 meta.input.parse()?
363 });
364 } else {
365 return Err(meta.error("unrecognised #[to_owned(...)] argument"));
366 }
367 Ok(())
368 })?;
369 }
370
371 let alloc_path = alloc_path.unwrap_or_else(|| parse_quote! { ::std });
372 let to_owned_attrs = ToOwnedAttrs {
373 alloc_path: alloc_path.clone(),
374 owned: owned.unwrap_or_else(|| parse_quote! { #alloc_path::boxed::Box<#name> }),
375 };
376 Ok(to_owned_attrs)
377}
378
379#[proc_macro_derive(ToOwned, attributes(dst, to_owned))]
380pub fn derive_to_owned(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
381 let input = parse_macro_input!(input as DeriveInput);
382 derive_to_owned_impl(input)
383 .unwrap_or_else(syn::Error::into_compile_error)
384 .into()
385}
386
387fn derive_to_owned_impl(input: DeriveInput) -> syn::Result<TokenStream> {
388 let name = input.ident;
389
390 let DstAttrs {
391 simple_dst_path, ..
392 } = get_dst_attrs(&input.attrs)?;
393 let ToOwnedAttrs { alloc_path, owned } = get_to_owned_attrs(&input.attrs, &name)?;
394
395 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
396
397 Ok(quote! {
398 #[automatically_derived]
399 impl #impl_generics #alloc_path::borrow::ToOwned for #name #ty_generics #where_clause {
400 type Owned = #owned;
401
402 fn to_owned(&self) -> Self::Owned {
403 let layout = ::core::alloc::Layout::for_value(self);
404
405 unsafe {
406 <#owned as #simple_dst_path::AllocDst<#name>>::new_dst(
407 <#name as #simple_dst_path::Dst>::len(self),
408 layout,
409 |ptr| {
410 let dest = ptr.cast::<u8>();
411
412 <#name as #simple_dst_path::CloneToUninit>::clone_to_uninit(self, dest.as_ptr())
413 },
414 )
415 }
416 }
417 }
418 })
419}