tidy_builder/
lib.rs

1//! The [Builder](`crate::Builder`) derive macro creates a compile-time correct builder.
2//! It means that it only allows you to build the given struct as long as you provide a
3//! value for all of its required fields.
4//!
5//! A field is interpreted as required if it's not wrapped in an `Option`.
6//! Any field inside of an `Option` is not considered required in order to
7//! build the given struct. For example in:
8//! ```rust
9//! pub struct MyStruct {
10//!     foo: String,
11//!     bar: Option<usize>,
12//! }
13//! ```
14//! The `foo` field is required and `bar` is optional. **Note** that although
15//! `std::option::Option` also referes to the same type, for now this macro doesn't
16//! recongnize anything other than `Option`.
17//!
18//! The builder generated using the [Builder](`crate::Builder`) macro guarantees correctness
19//! by encoding the initialized set using const generics. An example makes it clear. Let's assume
20//! we have a struct that has two required fields and an optional one:
21//! ```rust
22//! pub struct MyStruct {
23//!     req1: String,
24//!     req2: String,
25//!     opt1: Option<String>
26//! }
27//! ```
28//! The generated builder will be:
29//! ```rust
30//! pub struct MyStructBuilder<const P0: bool, const P1: bool> {
31//!     req1: Option<String>,
32//!     req2: Option<String>,
33//!     opt1: Option<String>,
34//! }
35//! ```
36//! The `P0` indicates whether the first required parameter is initialized or not. And similarly,
37//! the `P1` does the same thing for the second required parameter. The initial state of the
38//! builder will be `MyStructBuilder<false, false>` and the first time a required field is
39//! initialized, its corresponding const generic parameter will be set to true which indicates a
40//! different state. Setting an optional value does not change the state and consequently keeps the
41//! same const generic parameters. When the builder reaches the `MyStructBuilder<true, true>` and
42//! only then you can call the `build` function on the builder.
43//!
44//! So the complete generated code for the given example struct is:
45//! ```rust
46//! pub struct MyStruct {
47//!     req1: String,
48//!     req2: String,
49//!     opt1: Option<String>
50//! }
51//!
52//! pub struct MyStructBuilder<const P0: bool, const P1: bool> {
53//!     req1: Option<String>,
54//!     req2: Option<String>,
55//!     opt1: Option<String>,
56//! }
57//!
58//! impl MyStruct {
59//!     pub fn builder() -> MyStructBuilder<false, false> {
60//!         MyStructBuilder {
61//!             req1: None,
62//!             req2: None,
63//!             opt1: None,
64//!         }
65//!     }
66//! }
67//!
68//! impl<const P0: bool, const P1: bool> MyStructBuilder<P0, P1> {
69//!     pub fn req1(self, req1: String) -> MyStructBuilder<true, P1> {
70//!         MyStructBuilder {
71//!             req1: Some(req1),
72//!             req2: self.req2,
73//!             opt1: self.opt1,
74//!         }
75//!     }
76//!
77//!     pub fn req2(self, req2: String) -> MyStructBuilder<P0, true> {
78//!         MyStructBuilder {
79//!             req1: self.req1,
80//!             req2: Some(req2),
81//!             opt1: self.opt1,
82//!         }
83//!     }
84//!
85//!     pub fn opt1(self, opt1: String) -> MyStructBuilder<P0, P1> {
86//!         MyStructBuilder {
87//!             req1: self.req1,
88//!             req2: self.req2,
89//!             opt1: Some(opt1),
90//!         }
91//!     }
92//! }
93//!
94//! impl MyStructBuilder<true, true> {
95//!     pub fn build(self) -> MyStruct {
96//!         unsafe {
97//!             MyStruct {
98//!                 req1: self.req1.unwrap_unchecked(),
99//!                 req2: self.req2.unwrap_unchecked(),
100//!                 opt1: self.opt1,
101//!             }
102//!         }
103//!     }
104//! }
105//! ```
106
107mod error;
108
109use error::BuilderError::*;
110
111use proc_macro2::TokenStream;
112use quote::{quote, ToTokens};
113use syn::spanned::Spanned;
114use syn::*;
115
116// Only `Type::Path` are supported here. These types have the form: segment0::segment1::segment2.
117// Currently this method only detects whether the type is an `Option` if it's written as `Option<_>`.
118//
119// TODO: We could also support:
120//      * ::std::option::Option
121//      * std::option::Option
122//
123// # Arguments
124// * `ty`: The type to check whether it's an `Option` or not.
125//
126// # Returns
127// * `Some`: Containing the type inside `Option`. For example calling this function
128//           on `Option<T>` returns `Some(T)`.
129// * `None`: If the type is not option.
130#[rustfmt::skip]
131fn is_option(ty: &Type) -> Option<Type> {
132    // If `ty` is a `Type::Path`, it will contain one or more segments.
133    // For example:
134    //      std::option::Option
135    //      ---  ------  ------
136    //       s0    s1      s2
137    // has three segments.
138    if let Type::Path(TypePath { path: Path { segments, .. }, .. }) = ty {
139        // Becuase we only look for a type like `Option<_>`, we only check the first segment.
140        if segments[0].ident == "Option" {
141            // A type can have zero or more arguments. In case of `Option<_>`, we expect
142            // to see `AngleBracketed` arguments. So anything else cannot be an `Option`.
143            return match &segments[0].arguments {
144                PathArguments::None => None,
145                PathArguments::Parenthesized(_) => None,
146                PathArguments::AngleBracketed(AngleBracketedGenericArguments { args, .. }) => {
147                    // We expect the argument to be a type. For example in `Option<String>`,
148                    // The argument is a type and its `String`.
149                    if let GenericArgument::Type(inner_ty) = &args[0] {
150                        Some(inner_ty.clone())
151                    } else {
152                        None
153                    }
154                }
155            };
156        }
157    }
158
159    None
160}
161
162// Sometimes we only need the name of a generic parameter.
163// For example in `T: std::fmt::Display`, the whole thing is
164// a generic parameter but we want to extract the `T` from it.
165// Since we have three types of generic parameters, we need to
166// distinguish between their names too.
167//  * A `Type` is like `T: std::fmt::Display` from which we want the `T` which is the `Ident`.
168//  * A `Lifetime` is like `'a: 'b` from which we want the `'a` which is the `Lifetime`.
169//  * A `Const` is like `const N: usize` from which we want the `N` which is the `Ident`.
170#[derive(Clone)]
171enum GenericParamName {
172    Type(Ident),
173    Lifetime(Lifetime),
174    Const(Ident),
175}
176
177// We need this trait to be able to interpolate on a vector of `GenericParamName`.
178impl ToTokens for GenericParamName {
179    fn to_tokens(&self, tokens: &mut TokenStream) {
180        match self {
181            GenericParamName::Type(ty) => ty.to_tokens(tokens),
182            GenericParamName::Lifetime(lt) => lt.to_tokens(tokens),
183            GenericParamName::Const(ct) => ct.to_tokens(tokens),
184        }
185    }
186}
187
188// Extracts the name of each generic parameter in `generics`.
189fn param_to_name(generics: &Generics) -> Vec<GenericParamName> {
190    generics
191        .params
192        .iter()
193        .map(|param| match param {
194            GenericParam::Type(ty) => GenericParamName::Type(ty.ident.clone()),
195            GenericParam::Lifetime(lt) => GenericParamName::Lifetime(lt.lifetime.clone()),
196            GenericParam::Const(c) => GenericParamName::Const(c.ident.clone()),
197        })
198        .collect()
199}
200
201// Splits the generic parameter names into three categories.
202fn split_param_names(
203    param_names: Vec<GenericParamName>,
204) -> (
205    Vec<GenericParamName>, // Lifetime generic parameters
206    Vec<GenericParamName>, // Const generic parameters
207    Vec<GenericParamName>, // Type generic parameters
208) {
209    let mut lifetimes = vec![];
210    let mut consts = vec![];
211    let mut types = vec![];
212
213    for param_name in param_names {
214        match param_name {
215            GenericParamName::Lifetime(_) => lifetimes.push(param_name.clone()),
216            GenericParamName::Const(_) => consts.push(param_name.clone()),
217            GenericParamName::Type(_) => types.push(param_name.clone()),
218        }
219    }
220
221    (lifetimes, consts, types)
222}
223
224// Splits generic parameters into three categories.
225fn split_params(
226    params: Vec<GenericParam>,
227) -> (
228    Vec<GenericParam>, // Lifetime generic parameters
229    Vec<GenericParam>, // Const generic parameters
230    Vec<GenericParam>, // Type generic parameters
231) {
232    let mut lifetimes = vec![];
233    let mut consts = vec![];
234    let mut types = vec![];
235
236    for param in params {
237        match param {
238            GenericParam::Lifetime(_) => lifetimes.push(param.clone()),
239            GenericParam::Const(_) => consts.push(param.clone()),
240            GenericParam::Type(_) => types.push(param.clone()),
241        }
242    }
243
244    (lifetimes, consts, types)
245}
246
247#[proc_macro_derive(Builder)]
248pub fn builder(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
249    let ast = parse_macro_input!(input as DeriveInput);
250
251    match ast.data {
252        Data::Struct(struct_t) => match struct_t.fields {
253            Fields::Named(FieldsNamed { named, .. }) => {
254                let fields = named;
255                let struct_ident = ast.ident.clone();
256
257                // In the definition below, the boundary of each value is depicted.
258                //
259                // impl<T: std::fmt::Debug> Foo<T> where T: std::fmt::Display
260                //     --------------------    --- --------------------------
261                //              0               1               2
262                //
263                //  0: impl_generics
264                //  1: ty_generics
265                //  2: where_clause
266                let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
267
268                let builder_ident =
269                    Ident::new(&format!("{struct_ident}Builder"), struct_ident.span());
270
271                //--- Struct generic Parameters ---//
272                let st_param_names = param_to_name(&ast.generics);
273                // st_lt_pn: struct lifetime param names
274                // st_ct_pn: struct const param names
275                // st_ty_pn: struct type param names
276                let (st_lt_pn, st_ct_pn, st_ty_pn) = split_param_names(st_param_names);
277
278                let st_params: Vec<_> = ast.generics.params.iter().cloned().collect();
279                // st_lt_p: struct lifetime params
280                // st_ct_p: struct const params
281                // st_ty_p: struct type params
282                let (st_lt_p, st_ct_p, st_ty_p) = split_params(st_params);
283
284                //--- Builder generic parameters ---//
285                let (optional_fields, required_fields): (Vec<_>, Vec<_>) = fields
286                    .iter()
287                    .partition(|field| is_option(&field.ty).is_some());
288
289                // Contains all the builder parameters as `false`.
290                // So it helps to create:
291                //      `Builder<false, false, false>`.
292                let mut all_false = vec![];
293
294                // Contains all the builder parameters as `true`.
295                // So it helps to create:
296                //      `Builder<true, true, true>`.
297                let mut all_true = vec![];
298
299                // Contains the names of all builder parameters
300                // So it helps to create:
301                //      `Builder<P0, P1, P2>`.
302                let mut b_ct_pn = vec![];
303
304                // Contains all builder parameters
305                // So it helps to create:
306                //      `Builder<const P0: bool, const P1: bool, const P2: bool>`.
307                let mut b_ct_p = vec![];
308
309                // Contains all the fields of the builder.
310                // For example if the struct is:
311                //      struct MyStruct {
312                //          foo: Option<String>,
313                //          bar: usize
314                //      }
315                // The fields of the builder gonna be:
316                //      struct MyStructBuilder {
317                //          foo: Option<String>,
318                //          bar: Option<usize>
319                //      }
320                let mut b_fields = vec![];
321
322                // Contains all the initializers of the builder struct.
323                // For example for the builder on the comment above it's going to be:
324                //      MyStructBuilder {
325                //          foo: None,
326                //          bar: None
327                //      }
328                let mut b_inits = vec![];
329
330                // When we set the value of a required field, we must create the next state in the
331                // state machine. For that matter, we need to move the fields from the previous state to the new one.
332                // This field contains the moves of required fields.
333                let mut req_moves = vec![];
334
335                // When we reach the final state of the state machine and want to build the struct,
336                // we will call `unwrap` on the required fields because we know they are not `None`.
337                // For example:
338                //      fn builder(self) -> MyStruct {
339                //          MyStruct {
340                //              foo: self.foo,
341                //              bar: self.bar.unwrap()
342                //          }
343                //      }
344                // This variable contains the unwraps of required fields.
345                let mut req_unwraps = vec![];
346
347                for (index, field) in required_fields.iter().enumerate() {
348                    let field_ident = &field.ident;
349                    let field_ty = &field.ty;
350                    let ct_param_ident = Ident::new(&format!("P{}", index), field.span());
351
352                    b_fields.push(quote! { #field_ident: ::std::option::Option<#field_ty> });
353                    b_inits.push(quote! { #field_ident: None });
354
355                    req_moves.push(quote! { #field_ident: self.#field_ident });
356                    req_unwraps.push(quote! { #field_ident: self.#field_ident.unwrap_unchecked() });
357
358                    all_false.push(quote! { false });
359                    all_true.push(quote! { true });
360                    b_ct_pn.push(quote! { #ct_param_ident });
361                    b_ct_p.push(quote! { const #ct_param_ident: bool });
362                }
363
364                // When we set the value of an optional field, we must create the current state in the
365                // state machine but set the optional field. For that matter,
366                // we need to move the fields from the previous state to the new one.
367                // This field contains the moves of optional fields.
368                let mut opt_moves = vec![];
369                
370                for opt_field in &optional_fields {
371                    let field_ident = &opt_field.ident;
372                    let field_ty = &opt_field.ty;
373
374                    opt_moves.push(quote! { #field_ident: self.#field_ident });
375
376                    b_fields.push(quote! { #field_ident: #field_ty });
377                    b_inits.push(quote! { #field_ident: None });
378                }
379
380                //--- State machine actions: Setters ---//
381
382                // Setting the value of an optional field:
383                let mut opt_setters = vec![];
384                for opt_field in &optional_fields {
385                    let field_ident = &opt_field.ident;
386                    let field_ty = &opt_field.ty;
387                    let inner_ty = is_option(field_ty).unwrap();
388
389                    // When we set an optional field, we stay in the same state.
390                    // Therefore, we just need to set the value of the optional field.
391                    opt_setters.push(
392                        quote! {
393                            pub fn #field_ident(mut self, #field_ident: #inner_ty) ->
394                                #builder_ident<#(#st_lt_pn,)* #(#st_ct_pn,)* #(#b_ct_pn,)* #(#st_ty_pn,)*>
395                            {
396                                self.#field_ident = Some(#field_ident);
397                                self
398                            }
399                        }
400                    );
401                }
402
403                // Setting the value of a required field.
404                let mut req_setters = vec![];
405                for (index, req_field) in required_fields.iter().enumerate() {
406                    let field_ident = &req_field.ident;
407                    let field_ty = &req_field.ty;
408
409                    // When setting a required field, we need to move the other required fields
410                    // into the new state. So we pick the moves before and after this field.
411                    let before_req_moves = &req_moves[..index];
412                    let after_req_moves = &req_moves[index + 1..];
413
414                    // When setting a parameter to `true`, we need to copy the other parameter
415                    // names. So we pick the parameter names before and after the parameter that
416                    // corresponds to this required field.
417                    let before_pn = &b_ct_pn[..index];
418                    let after_pn = &b_ct_pn[index + 1..];
419
420                    // When we set the value of a required field, we must change to a state in
421                    // which the parameter corresponding to that field is set to `true`.
422                    req_setters.push(
423                        quote! {
424                            pub fn #field_ident(self, #field_ident: #field_ty) ->
425                                #builder_ident<#(#st_lt_pn,)* #(#st_ct_pn,)* #(#before_pn,)* true, #(#after_pn,)* #(#st_ty_pn,)*>
426                            {
427                                #builder_ident {
428                                    #(#before_req_moves,)*
429                                    #field_ident: Some(#field_ident),
430                                    #(#after_req_moves,)*
431                                    #(#opt_moves,)*
432                                }
433                            }
434                        }
435                    );
436                }
437
438                //--- Generating the builder ---//
439                quote! {
440                    // Definition of the builder struct.
441                    pub struct #builder_ident<#(#st_lt_p,)* #(#st_ct_p,)* #(#b_ct_p,)* #(#st_ty_p,)*> #where_clause {
442                        #(#b_fields),*
443                    }
444
445                    // An impl on the given struct to add the `builder` method to initialize the
446                    // builder.
447                    impl #impl_generics #struct_ident #ty_generics #where_clause {
448                        pub fn builder() -> #builder_ident<#(#st_lt_pn,)* #(#st_ct_pn,)* #(#all_false,)* #(#st_ty_pn,)*> {
449                            #builder_ident {
450                                #(#b_inits),*
451                            }
452                        }
453                    }
454
455                    // impl on the builder containing the setter methods.
456                    impl<#(#st_lt_p,)* #(#st_ct_p,)* #(#b_ct_p,)* #(#st_ty_p,)*>
457                        #builder_ident<#(#st_lt_pn,)* #(#st_ct_pn,)* #(#b_ct_pn,)* #(#st_ty_pn,)* >
458                        #where_clause
459                    {
460                        #(#opt_setters)*
461                        #(#req_setters)*
462                    }
463
464                    // impl block on a builder with all of its parameters set to true.
465                    // Meaning it's in the final state and can actually build the given struct.
466                    impl<#(#st_lt_p,)* #(#st_ct_p,)* #(#st_ty_p,)*>
467                        #builder_ident<#(#st_lt_pn,)* #(#st_ct_pn,)* #(#all_true,)* #(#st_ty_pn,)* >
468                        #where_clause
469                    {
470                        fn build(self) -> #struct_ident #ty_generics {
471                            unsafe {
472                                #struct_ident {
473                                    #(#opt_moves,)*
474                                    #(#req_unwraps,)*
475                                }
476                            }
477                        }
478                    }
479
480                }
481                .into()
482            }
483            Fields::Unnamed(_) => UnnamedFields(struct_t.fields).into(),
484            Fields::Unit => UnitStruct(struct_t.fields).into(),
485        },
486        Data::Enum(enum_t) => Enum(enum_t).into(),
487        Data::Union(union_t) => Union(union_t).into(),
488    }
489}