silly_alloc_macros/
lib.rs

1/*!
2`silly_alloc_macros` is a macro support crate for the `silly_alloc` crate. Please see the documentation there.
3*/
4use proc_macro2::{Span, TokenStream};
5use quote::{quote, spanned::Spanned};
6use syn::{
7    parse::{Parse, ParseStream},
8    *,
9};
10
11mod cast_helpers;
12use cast_helpers::*;
13
14const CRATE_NAME: &str = "silly_alloc";
15
16struct BucketAllocatorDescriptor {
17    name: Ident,
18    buckets: Vec<BucketDescriptor>,
19}
20
21impl Parse for BucketAllocatorDescriptor {
22    fn parse(input: ParseStream<'_>) -> Result<Self> {
23        let st: ItemStruct = input.parse()?;
24        let name = st.ident;
25        let buckets: Vec<Result<BucketDescriptor>> =
26            st.fields.iter().map(|field| field.try_into()).collect();
27        let buckets: Vec<BucketDescriptor> = Result::from_iter(buckets)?;
28        Ok(BucketAllocatorDescriptor { name, buckets })
29    }
30}
31
32struct BucketDescriptor {
33    _name: Ident,
34    slot_size: usize,
35    align: usize,
36    num_slots: usize,
37}
38
39impl TryFrom<&Field> for BucketDescriptor {
40    type Error = syn::Error;
41    fn try_from(field: &Field) -> Result<Self> {
42        let name = field
43            .ident
44            .as_ref()
45            .ok_or(Error::new(field.__span(), "Struct field without a name."))?;
46
47        let Type::Path(path_type) = &field.ty else { return Err(Error::new(field.__span(), "Struct field’s type must have the simple type name 'Bucket'."))} ;
48        if path_type.path.segments.len() != 1 {
49            return Err(Error::new(
50                path_type.__span(),
51                "Struct field’s type must have the simple type name 'Bucket'.",
52            ));
53        }
54        let path_seg = path_type.path.segments.iter().nth(0).unwrap();
55        if path_seg.ident.to_string() != "Bucket" {
56            return Err(Error::new(
57                path_seg.__span(),
58                "Struct field’s type must have the simple type name 'Bucket'.",
59            ));
60        }
61
62        let mut slot_size: Option<usize> = None;
63        let mut num_slots: Option<usize> = None;
64        let mut align: Option<usize> = None;
65        let PathArguments::AngleBracketed(generics) = &path_seg.arguments else { return Err(Error::new(path_seg.__span(), "Bucket is missing generic arguments")) };
66        for generic_arg in &generics.args {
67            let GenericArgument::Type(Type::Path(param_type)) = generic_arg  else { return Err(Error::new( generic_arg.__span(), "Bucket can only take type arguments."))  };
68            if param_type.path.segments.len() != 1 {
69                return Err(Error::new(
70                    param_type.__span(),
71                    "Invalid value for a Bucket property",
72                ));
73            }
74            let segment = param_type.path.segments.iter().nth(0).unwrap();
75            let param_name = &segment.ident;
76            let PathArguments::AngleBracketed(param_generic_args) = &segment.arguments else { return Err(Error::new(segment.__span(), "Bucket parameters are passed as generic arguments.")) };
77            if param_generic_args.args.len() != 1 {
78                return Err(Error::new(
79                    param_generic_args.__span(),
80                    "Bucket parameters take exactly one generic argument.",
81                ));
82            }
83            let param_generic_arg = param_generic_args.args.iter().nth(0).unwrap();
84            let GenericArgument::Const(expr) = param_generic_arg else {
85                return Err(Error::new(param_generic_arg.__span(), "Bucket parameters must be a const expr."))
86            };
87
88            match param_name.to_string().as_str() {
89                "SlotSize" => slot_size = Some(expr_to_usize(expr)?),
90                "NumSlots" => num_slots = Some(expr_to_usize(expr)?),
91                "Align" => align = Some(expr_to_usize(expr)?),
92                _ => {
93                    return Err(Error::new(
94                        name.__span(),
95                        format!("Unknown bucket parameter: {}", param_name.to_string()),
96                    ))
97                }
98            };
99        }
100
101        Ok(BucketDescriptor {
102            _name: name.clone(),
103            slot_size: slot_size
104                .ok_or(Error::new(generics.__span(), "SlotSlize was not specified"))?,
105            num_slots: num_slots
106                .ok_or(Error::new(generics.__span(), "NumSlots was not specified"))?,
107            align: align.unwrap_or(1),
108        })
109    }
110}
111
112fn expr_to_usize(expr: &Expr) -> Result<usize> {
113    expr.try_to_int_literal()
114        .ok_or_else(|| Error::new(expr.__span(), "Bucket parameter must be an integer"))?
115        .parse::<usize>()
116        .map_err(|err| Error::new(expr.__span(), format!("{}", err)))
117}
118
119// This function exists because sometimes the macro needs to emit `crate::bucket::BucketImpl` and sometimes just `silly_alloc::bucket::BucketImpl`. In the in-crate tests, `crate::`... is needed, but for the doc tests and any other external package, `silly_alloc::` is needed. To distinguish which to emit, we inspect the `CARGO_CRATE_NAME` env variable. If it’s "silly_alloc", someone is doing development on the crate itself and running the tests, so `crate::` is used. The only exception are the doc tests, where annoyingly `CARGO_CRATE_NAME` is set to "silly_alloc", but the doc tests are compiled like an external piece of code that is linked against the `silly_alloc` crate. For a lack of a better solution, an additional env variable `SILLY_ALLOC_DOC_TESTS` is checked to override that behavior.
120fn crate_path() -> Ident {
121    fn crate_name_option() -> Option<Ident> {
122        if std::env::var("SILLY_ALLOC_DOC_TESTS").is_ok() {
123            return None;
124        }
125        let pkg_name = std::env::var("CARGO_CRATE_NAME").ok()?;
126        if pkg_name == CRATE_NAME {
127            return Some(Ident::new("crate", Span::call_site()));
128        }
129        None
130    }
131    crate_name_option().unwrap_or_else(|| Ident::new(CRATE_NAME, Span::call_site()))
132}
133
134impl BucketDescriptor {
135    fn num_segments(&self) -> usize {
136        ((self.num_slots as f32) / 32.0).ceil() as usize
137    }
138
139    fn as_init_values(&self) -> TokenStream {
140        let crate_path = crate_path();
141        quote! {
142            ::core::cell::UnsafeCell::new(#crate_path::bucket::BucketImpl::new())
143        }
144        .into()
145    }
146
147    fn as_struct_fields(&self) -> TokenStream {
148        let BucketDescriptor {
149            slot_size, align, ..
150        } = self;
151        let num_segments = self.num_segments();
152        let slot_type_ident = Ident::new(&format!("SlotWithAlign{}", align), align.__span());
153        let crate_path = crate_path();
154        quote! {
155            ::core::cell::UnsafeCell<#crate_path::bucket::BucketImpl<#crate_path::bucket::#slot_type_ident<#slot_size>, #num_segments>>
156        }
157        .into()
158    }
159
160    fn as_alloc_bucket_selectors(&self, idx: usize) -> TokenStream {
161        let BucketDescriptor {
162            slot_size, align, ..
163        } = self;
164        let idx_key = Index::from(idx);
165        quote! {
166            {
167                let bucket = self.#idx_key.get().as_mut().unwrap();
168                bucket.ensure_init();
169                if size <= #slot_size && align <= #align {
170                    if let Some(ptr) = bucket.claim_first_available_slot() {
171                        return ptr as *mut u8;
172                    }
173                }
174            }
175        }
176        .into()
177    }
178
179    fn as_dealloc_bucket_selectors(&self, idx: usize) -> TokenStream {
180        let idx_key = Index::from(idx);
181        quote! {
182            {
183                let bucket = self.#idx_key.get().as_mut().unwrap();
184                bucket.ensure_init();
185                if let Some(slot_idx) = bucket.slot_idx_for_ptr(ptr) {
186                    bucket.unset_slot(slot_idx);
187                }
188            }
189        }
190        .into()
191    }
192}
193
194struct BucketAllocatorOptions {
195    sort_buckets: bool,
196}
197
198impl Parse for BucketAllocatorOptions {
199    fn parse(input: ParseStream) -> Result<Self> {
200        let mut result = Self::default();
201        while !input.is_empty() {
202            let opt_name = Ident::parse(input)?.to_string();
203            match opt_name.as_str() {
204                "sort_buckets" => {
205                    <Token![=]>::parse(input)?;
206                    result.sort_buckets = LitBool::parse(input)?.value;
207                }
208                _ => return Err(Error::new(input.span(), "Unsupported options")),
209            }
210        }
211        Ok(result)
212    }
213}
214
215impl Default for BucketAllocatorOptions {
216    fn default() -> Self {
217        BucketAllocatorOptions {
218            sort_buckets: false,
219        }
220    }
221}
222
223/// Macro to turn a struct into an allocator.
224///
225/// `bucket_allocator` is an attribute macro that builds a `GlobalAlloc`-compatible data type from a given struct. Please see the module-level documentation for details and examples.
226///
227/// The macro supports the following options:
228/// - `sort_buckets = <true|false>`: Sort buckets by item size, then alignment
229#[proc_macro_attribute]
230pub fn bucket_allocator(
231    attr: proc_macro::TokenStream,
232    input: proc_macro::TokenStream,
233) -> proc_macro::TokenStream {
234    let BucketAllocatorOptions { sort_buckets, .. } = parse_macro_input!(attr);
235    let BucketAllocatorDescriptor { name, mut buckets } = parse_macro_input!(input);
236
237    if sort_buckets {
238        buckets.sort_by(|a, b| {
239            let cmp = a.slot_size.cmp(&b.slot_size);
240            if cmp == std::cmp::Ordering::Equal {
241                a.align.cmp(&b.align)
242            } else {
243                cmp
244            }
245        });
246    }
247
248    let bucket_field_decls: Vec<TokenStream> = buckets
249        .iter()
250        .map(|bucket| bucket.as_struct_fields())
251        .collect();
252
253    let bucket_field_inits: Vec<TokenStream> = buckets
254        .iter()
255        .map(|bucket| bucket.as_init_values())
256        .collect();
257
258    let alloc_bucket_selectors: Vec<TokenStream> = buckets
259        .iter()
260        .enumerate()
261        .map(|(idx, bucket)| bucket.as_alloc_bucket_selectors(idx))
262        .collect();
263
264    let dealloc_bucket_selectors: Vec<TokenStream> = buckets
265        .iter()
266        .enumerate()
267        .map(|(idx, bucket)| bucket.as_dealloc_bucket_selectors(idx))
268        .collect();
269
270    quote! {
271            #[derive(Default, Debug)]
272            struct #name(
273                #(#bucket_field_decls),*
274            );
275
276            impl #name {
277                const fn new() -> Self {
278                    #name (
279                        #(#bucket_field_inits),*
280                    )
281                }
282            }
283
284            unsafe impl ::core::marker::Sync for #name {}
285
286            unsafe impl ::bytemuck::Zeroable for #name {}
287
288            unsafe impl ::core::alloc::GlobalAlloc for #name {
289                unsafe fn alloc(&self, layout: ::core::alloc::Layout) -> *mut u8 {
290                    let size = layout.size();
291                    let align = layout.align();
292                    #(#alloc_bucket_selectors)*
293                    core::ptr::null_mut()
294                }
295
296                unsafe fn dealloc(&self, ptr: *mut u8, layout: ::core::alloc::Layout) {
297                    let size = layout.size();
298                    let align = layout.align();
299                    #(#dealloc_bucket_selectors)*
300                }
301
302            }
303    }
304    .into()
305}