silly_alloc_macros/
lib.rs1use proc_macro2::{Span, TokenStream};
5use quote::{quote, spanned::Spanned};
6use syn::{
7 parse::{Parse, ParseStream},
8 *,
9};
10
11mod cast_helpers;
12use cast_helpers::*;
13
14const CRATE_NAME: &str = "silly_alloc";
15
16struct BucketAllocatorDescriptor {
17 name: Ident,
18 buckets: Vec<BucketDescriptor>,
19}
20
21impl Parse for BucketAllocatorDescriptor {
22 fn parse(input: ParseStream<'_>) -> Result<Self> {
23 let st: ItemStruct = input.parse()?;
24 let name = st.ident;
25 let buckets: Vec<Result<BucketDescriptor>> =
26 st.fields.iter().map(|field| field.try_into()).collect();
27 let buckets: Vec<BucketDescriptor> = Result::from_iter(buckets)?;
28 Ok(BucketAllocatorDescriptor { name, buckets })
29 }
30}
31
32struct BucketDescriptor {
33 _name: Ident,
34 slot_size: usize,
35 align: usize,
36 num_slots: usize,
37}
38
39impl TryFrom<&Field> for BucketDescriptor {
40 type Error = syn::Error;
41 fn try_from(field: &Field) -> Result<Self> {
42 let name = field
43 .ident
44 .as_ref()
45 .ok_or(Error::new(field.__span(), "Struct field without a name."))?;
46
47 let Type::Path(path_type) = &field.ty else { return Err(Error::new(field.__span(), "Struct field’s type must have the simple type name 'Bucket'."))} ;
48 if path_type.path.segments.len() != 1 {
49 return Err(Error::new(
50 path_type.__span(),
51 "Struct field’s type must have the simple type name 'Bucket'.",
52 ));
53 }
54 let path_seg = path_type.path.segments.iter().nth(0).unwrap();
55 if path_seg.ident.to_string() != "Bucket" {
56 return Err(Error::new(
57 path_seg.__span(),
58 "Struct field’s type must have the simple type name 'Bucket'.",
59 ));
60 }
61
62 let mut slot_size: Option<usize> = None;
63 let mut num_slots: Option<usize> = None;
64 let mut align: Option<usize> = None;
65 let PathArguments::AngleBracketed(generics) = &path_seg.arguments else { return Err(Error::new(path_seg.__span(), "Bucket is missing generic arguments")) };
66 for generic_arg in &generics.args {
67 let GenericArgument::Type(Type::Path(param_type)) = generic_arg else { return Err(Error::new( generic_arg.__span(), "Bucket can only take type arguments.")) };
68 if param_type.path.segments.len() != 1 {
69 return Err(Error::new(
70 param_type.__span(),
71 "Invalid value for a Bucket property",
72 ));
73 }
74 let segment = param_type.path.segments.iter().nth(0).unwrap();
75 let param_name = &segment.ident;
76 let PathArguments::AngleBracketed(param_generic_args) = &segment.arguments else { return Err(Error::new(segment.__span(), "Bucket parameters are passed as generic arguments.")) };
77 if param_generic_args.args.len() != 1 {
78 return Err(Error::new(
79 param_generic_args.__span(),
80 "Bucket parameters take exactly one generic argument.",
81 ));
82 }
83 let param_generic_arg = param_generic_args.args.iter().nth(0).unwrap();
84 let GenericArgument::Const(expr) = param_generic_arg else {
85 return Err(Error::new(param_generic_arg.__span(), "Bucket parameters must be a const expr."))
86 };
87
88 match param_name.to_string().as_str() {
89 "SlotSize" => slot_size = Some(expr_to_usize(expr)?),
90 "NumSlots" => num_slots = Some(expr_to_usize(expr)?),
91 "Align" => align = Some(expr_to_usize(expr)?),
92 _ => {
93 return Err(Error::new(
94 name.__span(),
95 format!("Unknown bucket parameter: {}", param_name.to_string()),
96 ))
97 }
98 };
99 }
100
101 Ok(BucketDescriptor {
102 _name: name.clone(),
103 slot_size: slot_size
104 .ok_or(Error::new(generics.__span(), "SlotSlize was not specified"))?,
105 num_slots: num_slots
106 .ok_or(Error::new(generics.__span(), "NumSlots was not specified"))?,
107 align: align.unwrap_or(1),
108 })
109 }
110}
111
112fn expr_to_usize(expr: &Expr) -> Result<usize> {
113 expr.try_to_int_literal()
114 .ok_or_else(|| Error::new(expr.__span(), "Bucket parameter must be an integer"))?
115 .parse::<usize>()
116 .map_err(|err| Error::new(expr.__span(), format!("{}", err)))
117}
118
119fn crate_path() -> Ident {
121 fn crate_name_option() -> Option<Ident> {
122 if std::env::var("SILLY_ALLOC_DOC_TESTS").is_ok() {
123 return None;
124 }
125 let pkg_name = std::env::var("CARGO_CRATE_NAME").ok()?;
126 if pkg_name == CRATE_NAME {
127 return Some(Ident::new("crate", Span::call_site()));
128 }
129 None
130 }
131 crate_name_option().unwrap_or_else(|| Ident::new(CRATE_NAME, Span::call_site()))
132}
133
134impl BucketDescriptor {
135 fn num_segments(&self) -> usize {
136 ((self.num_slots as f32) / 32.0).ceil() as usize
137 }
138
139 fn as_init_values(&self) -> TokenStream {
140 let crate_path = crate_path();
141 quote! {
142 ::core::cell::UnsafeCell::new(#crate_path::bucket::BucketImpl::new())
143 }
144 .into()
145 }
146
147 fn as_struct_fields(&self) -> TokenStream {
148 let BucketDescriptor {
149 slot_size, align, ..
150 } = self;
151 let num_segments = self.num_segments();
152 let slot_type_ident = Ident::new(&format!("SlotWithAlign{}", align), align.__span());
153 let crate_path = crate_path();
154 quote! {
155 ::core::cell::UnsafeCell<#crate_path::bucket::BucketImpl<#crate_path::bucket::#slot_type_ident<#slot_size>, #num_segments>>
156 }
157 .into()
158 }
159
160 fn as_alloc_bucket_selectors(&self, idx: usize) -> TokenStream {
161 let BucketDescriptor {
162 slot_size, align, ..
163 } = self;
164 let idx_key = Index::from(idx);
165 quote! {
166 {
167 let bucket = self.#idx_key.get().as_mut().unwrap();
168 bucket.ensure_init();
169 if size <= #slot_size && align <= #align {
170 if let Some(ptr) = bucket.claim_first_available_slot() {
171 return ptr as *mut u8;
172 }
173 }
174 }
175 }
176 .into()
177 }
178
179 fn as_dealloc_bucket_selectors(&self, idx: usize) -> TokenStream {
180 let idx_key = Index::from(idx);
181 quote! {
182 {
183 let bucket = self.#idx_key.get().as_mut().unwrap();
184 bucket.ensure_init();
185 if let Some(slot_idx) = bucket.slot_idx_for_ptr(ptr) {
186 bucket.unset_slot(slot_idx);
187 }
188 }
189 }
190 .into()
191 }
192}
193
194struct BucketAllocatorOptions {
195 sort_buckets: bool,
196}
197
198impl Parse for BucketAllocatorOptions {
199 fn parse(input: ParseStream) -> Result<Self> {
200 let mut result = Self::default();
201 while !input.is_empty() {
202 let opt_name = Ident::parse(input)?.to_string();
203 match opt_name.as_str() {
204 "sort_buckets" => {
205 <Token![=]>::parse(input)?;
206 result.sort_buckets = LitBool::parse(input)?.value;
207 }
208 _ => return Err(Error::new(input.span(), "Unsupported options")),
209 }
210 }
211 Ok(result)
212 }
213}
214
215impl Default for BucketAllocatorOptions {
216 fn default() -> Self {
217 BucketAllocatorOptions {
218 sort_buckets: false,
219 }
220 }
221}
222
223#[proc_macro_attribute]
230pub fn bucket_allocator(
231 attr: proc_macro::TokenStream,
232 input: proc_macro::TokenStream,
233) -> proc_macro::TokenStream {
234 let BucketAllocatorOptions { sort_buckets, .. } = parse_macro_input!(attr);
235 let BucketAllocatorDescriptor { name, mut buckets } = parse_macro_input!(input);
236
237 if sort_buckets {
238 buckets.sort_by(|a, b| {
239 let cmp = a.slot_size.cmp(&b.slot_size);
240 if cmp == std::cmp::Ordering::Equal {
241 a.align.cmp(&b.align)
242 } else {
243 cmp
244 }
245 });
246 }
247
248 let bucket_field_decls: Vec<TokenStream> = buckets
249 .iter()
250 .map(|bucket| bucket.as_struct_fields())
251 .collect();
252
253 let bucket_field_inits: Vec<TokenStream> = buckets
254 .iter()
255 .map(|bucket| bucket.as_init_values())
256 .collect();
257
258 let alloc_bucket_selectors: Vec<TokenStream> = buckets
259 .iter()
260 .enumerate()
261 .map(|(idx, bucket)| bucket.as_alloc_bucket_selectors(idx))
262 .collect();
263
264 let dealloc_bucket_selectors: Vec<TokenStream> = buckets
265 .iter()
266 .enumerate()
267 .map(|(idx, bucket)| bucket.as_dealloc_bucket_selectors(idx))
268 .collect();
269
270 quote! {
271 #[derive(Default, Debug)]
272 struct #name(
273 #(#bucket_field_decls),*
274 );
275
276 impl #name {
277 const fn new() -> Self {
278 #name (
279 #(#bucket_field_inits),*
280 )
281 }
282 }
283
284 unsafe impl ::core::marker::Sync for #name {}
285
286 unsafe impl ::bytemuck::Zeroable for #name {}
287
288 unsafe impl ::core::alloc::GlobalAlloc for #name {
289 unsafe fn alloc(&self, layout: ::core::alloc::Layout) -> *mut u8 {
290 let size = layout.size();
291 let align = layout.align();
292 #(#alloc_bucket_selectors)*
293 core::ptr::null_mut()
294 }
295
296 unsafe fn dealloc(&self, ptr: *mut u8, layout: ::core::alloc::Layout) {
297 let size = layout.size();
298 let align = layout.align();
299 #(#dealloc_bucket_selectors)*
300 }
301
302 }
303 }
304 .into()
305}