tidy_builder/lib.rs
1//! The [Builder](`crate::Builder`) derive macro creates a compile-time correct builder.
2//! It means that it only allows you to build the given struct as long as you provide a
3//! value for all of its required fields.
4//!
5//! A field is interpreted as required if it's not wrapped in an `Option`.
6//! Any field inside of an `Option` is not considered required in order to
7//! build the given struct. For example in:
8//! ```rust
9//! pub struct MyStruct {
10//! foo: String,
11//! bar: Option<usize>,
12//! }
13//! ```
14//! The `foo` field is required and `bar` is optional. **Note** that although
15//! `std::option::Option` also referes to the same type, for now this macro doesn't
16//! recongnize anything other than `Option`.
17//!
18//! The builder generated using the [Builder](`crate::Builder`) macro guarantees correctness
19//! by encoding the initialized set using const generics. An example makes it clear. Let's assume
20//! we have a struct that has two required fields and an optional one:
21//! ```rust
22//! pub struct MyStruct {
23//! req1: String,
24//! req2: String,
25//! opt1: Option<String>
26//! }
27//! ```
28//! The generated builder will be:
29//! ```rust
30//! pub struct MyStructBuilder<const P0: bool, const P1: bool> {
31//! req1: Option<String>,
32//! req2: Option<String>,
33//! opt1: Option<String>,
34//! }
35//! ```
36//! The `P0` indicates whether the first required parameter is initialized or not. And similarly,
37//! the `P1` does the same thing for the second required parameter. The initial state of the
38//! builder will be `MyStructBuilder<false, false>` and the first time a required field is
39//! initialized, its corresponding const generic parameter will be set to true which indicates a
40//! different state. Setting an optional value does not change the state and consequently keeps the
41//! same const generic parameters. When the builder reaches the `MyStructBuilder<true, true>` and
42//! only then you can call the `build` function on the builder.
43//!
44//! So the complete generated code for the given example struct is:
45//! ```rust
46//! pub struct MyStruct {
47//! req1: String,
48//! req2: String,
49//! opt1: Option<String>
50//! }
51//!
52//! pub struct MyStructBuilder<const P0: bool, const P1: bool> {
53//! req1: Option<String>,
54//! req2: Option<String>,
55//! opt1: Option<String>,
56//! }
57//!
58//! impl MyStruct {
59//! pub fn builder() -> MyStructBuilder<false, false> {
60//! MyStructBuilder {
61//! req1: None,
62//! req2: None,
63//! opt1: None,
64//! }
65//! }
66//! }
67//!
68//! impl<const P0: bool, const P1: bool> MyStructBuilder<P0, P1> {
69//! pub fn req1(self, req1: String) -> MyStructBuilder<true, P1> {
70//! MyStructBuilder {
71//! req1: Some(req1),
72//! req2: self.req2,
73//! opt1: self.opt1,
74//! }
75//! }
76//!
77//! pub fn req2(self, req2: String) -> MyStructBuilder<P0, true> {
78//! MyStructBuilder {
79//! req1: self.req1,
80//! req2: Some(req2),
81//! opt1: self.opt1,
82//! }
83//! }
84//!
85//! pub fn opt1(self, opt1: String) -> MyStructBuilder<P0, P1> {
86//! MyStructBuilder {
87//! req1: self.req1,
88//! req2: self.req2,
89//! opt1: Some(opt1),
90//! }
91//! }
92//! }
93//!
94//! impl MyStructBuilder<true, true> {
95//! pub fn build(self) -> MyStruct {
96//! unsafe {
97//! MyStruct {
98//! req1: self.req1.unwrap_unchecked(),
99//! req2: self.req2.unwrap_unchecked(),
100//! opt1: self.opt1,
101//! }
102//! }
103//! }
104//! }
105//! ```
106
107mod error;
108
109use error::BuilderError::*;
110
111use proc_macro2::TokenStream;
112use quote::{quote, ToTokens};
113use syn::spanned::Spanned;
114use syn::*;
115
116// Only `Type::Path` are supported here. These types have the form: segment0::segment1::segment2.
117// Currently this method only detects whether the type is an `Option` if it's written as `Option<_>`.
118//
119// TODO: We could also support:
120// * ::std::option::Option
121// * std::option::Option
122//
123// # Arguments
124// * `ty`: The type to check whether it's an `Option` or not.
125//
126// # Returns
127// * `Some`: Containing the type inside `Option`. For example calling this function
128// on `Option<T>` returns `Some(T)`.
129// * `None`: If the type is not option.
130#[rustfmt::skip]
131fn is_option(ty: &Type) -> Option<Type> {
132 // If `ty` is a `Type::Path`, it will contain one or more segments.
133 // For example:
134 // std::option::Option
135 // --- ------ ------
136 // s0 s1 s2
137 // has three segments.
138 if let Type::Path(TypePath { path: Path { segments, .. }, .. }) = ty {
139 // Becuase we only look for a type like `Option<_>`, we only check the first segment.
140 if segments[0].ident == "Option" {
141 // A type can have zero or more arguments. In case of `Option<_>`, we expect
142 // to see `AngleBracketed` arguments. So anything else cannot be an `Option`.
143 return match &segments[0].arguments {
144 PathArguments::None => None,
145 PathArguments::Parenthesized(_) => None,
146 PathArguments::AngleBracketed(AngleBracketedGenericArguments { args, .. }) => {
147 // We expect the argument to be a type. For example in `Option<String>`,
148 // The argument is a type and its `String`.
149 if let GenericArgument::Type(inner_ty) = &args[0] {
150 Some(inner_ty.clone())
151 } else {
152 None
153 }
154 }
155 };
156 }
157 }
158
159 None
160}
161
162// Sometimes we only need the name of a generic parameter.
163// For example in `T: std::fmt::Display`, the whole thing is
164// a generic parameter but we want to extract the `T` from it.
165// Since we have three types of generic parameters, we need to
166// distinguish between their names too.
167// * A `Type` is like `T: std::fmt::Display` from which we want the `T` which is the `Ident`.
168// * A `Lifetime` is like `'a: 'b` from which we want the `'a` which is the `Lifetime`.
169// * A `Const` is like `const N: usize` from which we want the `N` which is the `Ident`.
170#[derive(Clone)]
171enum GenericParamName {
172 Type(Ident),
173 Lifetime(Lifetime),
174 Const(Ident),
175}
176
177// We need this trait to be able to interpolate on a vector of `GenericParamName`.
178impl ToTokens for GenericParamName {
179 fn to_tokens(&self, tokens: &mut TokenStream) {
180 match self {
181 GenericParamName::Type(ty) => ty.to_tokens(tokens),
182 GenericParamName::Lifetime(lt) => lt.to_tokens(tokens),
183 GenericParamName::Const(ct) => ct.to_tokens(tokens),
184 }
185 }
186}
187
188// Extracts the name of each generic parameter in `generics`.
189fn param_to_name(generics: &Generics) -> Vec<GenericParamName> {
190 generics
191 .params
192 .iter()
193 .map(|param| match param {
194 GenericParam::Type(ty) => GenericParamName::Type(ty.ident.clone()),
195 GenericParam::Lifetime(lt) => GenericParamName::Lifetime(lt.lifetime.clone()),
196 GenericParam::Const(c) => GenericParamName::Const(c.ident.clone()),
197 })
198 .collect()
199}
200
201// Splits the generic parameter names into three categories.
202fn split_param_names(
203 param_names: Vec<GenericParamName>,
204) -> (
205 Vec<GenericParamName>, // Lifetime generic parameters
206 Vec<GenericParamName>, // Const generic parameters
207 Vec<GenericParamName>, // Type generic parameters
208) {
209 let mut lifetimes = vec![];
210 let mut consts = vec![];
211 let mut types = vec![];
212
213 for param_name in param_names {
214 match param_name {
215 GenericParamName::Lifetime(_) => lifetimes.push(param_name.clone()),
216 GenericParamName::Const(_) => consts.push(param_name.clone()),
217 GenericParamName::Type(_) => types.push(param_name.clone()),
218 }
219 }
220
221 (lifetimes, consts, types)
222}
223
224// Splits generic parameters into three categories.
225fn split_params(
226 params: Vec<GenericParam>,
227) -> (
228 Vec<GenericParam>, // Lifetime generic parameters
229 Vec<GenericParam>, // Const generic parameters
230 Vec<GenericParam>, // Type generic parameters
231) {
232 let mut lifetimes = vec![];
233 let mut consts = vec![];
234 let mut types = vec![];
235
236 for param in params {
237 match param {
238 GenericParam::Lifetime(_) => lifetimes.push(param.clone()),
239 GenericParam::Const(_) => consts.push(param.clone()),
240 GenericParam::Type(_) => types.push(param.clone()),
241 }
242 }
243
244 (lifetimes, consts, types)
245}
246
247#[proc_macro_derive(Builder)]
248pub fn builder(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
249 let ast = parse_macro_input!(input as DeriveInput);
250
251 match ast.data {
252 Data::Struct(struct_t) => match struct_t.fields {
253 Fields::Named(FieldsNamed { named, .. }) => {
254 let fields = named;
255 let struct_ident = ast.ident.clone();
256
257 // In the definition below, the boundary of each value is depicted.
258 //
259 // impl<T: std::fmt::Debug> Foo<T> where T: std::fmt::Display
260 // -------------------- --- --------------------------
261 // 0 1 2
262 //
263 // 0: impl_generics
264 // 1: ty_generics
265 // 2: where_clause
266 let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
267
268 let builder_ident =
269 Ident::new(&format!("{struct_ident}Builder"), struct_ident.span());
270
271 //--- Struct generic Parameters ---//
272 let st_param_names = param_to_name(&ast.generics);
273 // st_lt_pn: struct lifetime param names
274 // st_ct_pn: struct const param names
275 // st_ty_pn: struct type param names
276 let (st_lt_pn, st_ct_pn, st_ty_pn) = split_param_names(st_param_names);
277
278 let st_params: Vec<_> = ast.generics.params.iter().cloned().collect();
279 // st_lt_p: struct lifetime params
280 // st_ct_p: struct const params
281 // st_ty_p: struct type params
282 let (st_lt_p, st_ct_p, st_ty_p) = split_params(st_params);
283
284 //--- Builder generic parameters ---//
285 let (optional_fields, required_fields): (Vec<_>, Vec<_>) = fields
286 .iter()
287 .partition(|field| is_option(&field.ty).is_some());
288
289 // Contains all the builder parameters as `false`.
290 // So it helps to create:
291 // `Builder<false, false, false>`.
292 let mut all_false = vec![];
293
294 // Contains all the builder parameters as `true`.
295 // So it helps to create:
296 // `Builder<true, true, true>`.
297 let mut all_true = vec![];
298
299 // Contains the names of all builder parameters
300 // So it helps to create:
301 // `Builder<P0, P1, P2>`.
302 let mut b_ct_pn = vec![];
303
304 // Contains all builder parameters
305 // So it helps to create:
306 // `Builder<const P0: bool, const P1: bool, const P2: bool>`.
307 let mut b_ct_p = vec![];
308
309 // Contains all the fields of the builder.
310 // For example if the struct is:
311 // struct MyStruct {
312 // foo: Option<String>,
313 // bar: usize
314 // }
315 // The fields of the builder gonna be:
316 // struct MyStructBuilder {
317 // foo: Option<String>,
318 // bar: Option<usize>
319 // }
320 let mut b_fields = vec![];
321
322 // Contains all the initializers of the builder struct.
323 // For example for the builder on the comment above it's going to be:
324 // MyStructBuilder {
325 // foo: None,
326 // bar: None
327 // }
328 let mut b_inits = vec![];
329
330 // When we set the value of a required field, we must create the next state in the
331 // state machine. For that matter, we need to move the fields from the previous state to the new one.
332 // This field contains the moves of required fields.
333 let mut req_moves = vec![];
334
335 // When we reach the final state of the state machine and want to build the struct,
336 // we will call `unwrap` on the required fields because we know they are not `None`.
337 // For example:
338 // fn builder(self) -> MyStruct {
339 // MyStruct {
340 // foo: self.foo,
341 // bar: self.bar.unwrap()
342 // }
343 // }
344 // This variable contains the unwraps of required fields.
345 let mut req_unwraps = vec![];
346
347 for (index, field) in required_fields.iter().enumerate() {
348 let field_ident = &field.ident;
349 let field_ty = &field.ty;
350 let ct_param_ident = Ident::new(&format!("P{}", index), field.span());
351
352 b_fields.push(quote! { #field_ident: ::std::option::Option<#field_ty> });
353 b_inits.push(quote! { #field_ident: None });
354
355 req_moves.push(quote! { #field_ident: self.#field_ident });
356 req_unwraps.push(quote! { #field_ident: self.#field_ident.unwrap_unchecked() });
357
358 all_false.push(quote! { false });
359 all_true.push(quote! { true });
360 b_ct_pn.push(quote! { #ct_param_ident });
361 b_ct_p.push(quote! { const #ct_param_ident: bool });
362 }
363
364 // When we set the value of an optional field, we must create the current state in the
365 // state machine but set the optional field. For that matter,
366 // we need to move the fields from the previous state to the new one.
367 // This field contains the moves of optional fields.
368 let mut opt_moves = vec![];
369
370 for opt_field in &optional_fields {
371 let field_ident = &opt_field.ident;
372 let field_ty = &opt_field.ty;
373
374 opt_moves.push(quote! { #field_ident: self.#field_ident });
375
376 b_fields.push(quote! { #field_ident: #field_ty });
377 b_inits.push(quote! { #field_ident: None });
378 }
379
380 //--- State machine actions: Setters ---//
381
382 // Setting the value of an optional field:
383 let mut opt_setters = vec![];
384 for opt_field in &optional_fields {
385 let field_ident = &opt_field.ident;
386 let field_ty = &opt_field.ty;
387 let inner_ty = is_option(field_ty).unwrap();
388
389 // When we set an optional field, we stay in the same state.
390 // Therefore, we just need to set the value of the optional field.
391 opt_setters.push(
392 quote! {
393 pub fn #field_ident(mut self, #field_ident: #inner_ty) ->
394 #builder_ident<#(#st_lt_pn,)* #(#st_ct_pn,)* #(#b_ct_pn,)* #(#st_ty_pn,)*>
395 {
396 self.#field_ident = Some(#field_ident);
397 self
398 }
399 }
400 );
401 }
402
403 // Setting the value of a required field.
404 let mut req_setters = vec![];
405 for (index, req_field) in required_fields.iter().enumerate() {
406 let field_ident = &req_field.ident;
407 let field_ty = &req_field.ty;
408
409 // When setting a required field, we need to move the other required fields
410 // into the new state. So we pick the moves before and after this field.
411 let before_req_moves = &req_moves[..index];
412 let after_req_moves = &req_moves[index + 1..];
413
414 // When setting a parameter to `true`, we need to copy the other parameter
415 // names. So we pick the parameter names before and after the parameter that
416 // corresponds to this required field.
417 let before_pn = &b_ct_pn[..index];
418 let after_pn = &b_ct_pn[index + 1..];
419
420 // When we set the value of a required field, we must change to a state in
421 // which the parameter corresponding to that field is set to `true`.
422 req_setters.push(
423 quote! {
424 pub fn #field_ident(self, #field_ident: #field_ty) ->
425 #builder_ident<#(#st_lt_pn,)* #(#st_ct_pn,)* #(#before_pn,)* true, #(#after_pn,)* #(#st_ty_pn,)*>
426 {
427 #builder_ident {
428 #(#before_req_moves,)*
429 #field_ident: Some(#field_ident),
430 #(#after_req_moves,)*
431 #(#opt_moves,)*
432 }
433 }
434 }
435 );
436 }
437
438 //--- Generating the builder ---//
439 quote! {
440 // Definition of the builder struct.
441 pub struct #builder_ident<#(#st_lt_p,)* #(#st_ct_p,)* #(#b_ct_p,)* #(#st_ty_p,)*> #where_clause {
442 #(#b_fields),*
443 }
444
445 // An impl on the given struct to add the `builder` method to initialize the
446 // builder.
447 impl #impl_generics #struct_ident #ty_generics #where_clause {
448 pub fn builder() -> #builder_ident<#(#st_lt_pn,)* #(#st_ct_pn,)* #(#all_false,)* #(#st_ty_pn,)*> {
449 #builder_ident {
450 #(#b_inits),*
451 }
452 }
453 }
454
455 // impl on the builder containing the setter methods.
456 impl<#(#st_lt_p,)* #(#st_ct_p,)* #(#b_ct_p,)* #(#st_ty_p,)*>
457 #builder_ident<#(#st_lt_pn,)* #(#st_ct_pn,)* #(#b_ct_pn,)* #(#st_ty_pn,)* >
458 #where_clause
459 {
460 #(#opt_setters)*
461 #(#req_setters)*
462 }
463
464 // impl block on a builder with all of its parameters set to true.
465 // Meaning it's in the final state and can actually build the given struct.
466 impl<#(#st_lt_p,)* #(#st_ct_p,)* #(#st_ty_p,)*>
467 #builder_ident<#(#st_lt_pn,)* #(#st_ct_pn,)* #(#all_true,)* #(#st_ty_pn,)* >
468 #where_clause
469 {
470 fn build(self) -> #struct_ident #ty_generics {
471 unsafe {
472 #struct_ident {
473 #(#opt_moves,)*
474 #(#req_unwraps,)*
475 }
476 }
477 }
478 }
479
480 }
481 .into()
482 }
483 Fields::Unnamed(_) => UnnamedFields(struct_t.fields).into(),
484 Fields::Unit => UnitStruct(struct_t.fields).into(),
485 },
486 Data::Enum(enum_t) => Enum(enum_t).into(),
487 Data::Union(union_t) => Union(union_t).into(),
488 }
489}