safe_builder_derive/
lib.rs

1#![allow(dead_code)]
2
3extern crate proc_macro;
4extern crate syn;
5extern crate itertools;
6
7#[macro_use]
8extern crate quote;
9
10#[cfg(test)]
11mod test;
12
13use proc_macro::TokenStream;
14use itertools::Itertools;
15
16#[proc_macro_derive(SafeBuilder)]
17pub fn safe_builder(input: TokenStream) -> TokenStream
18{
19    let s = input.to_string();
20
21    let ast = syn::parse_macro_input(&s).unwrap();
22
23    let target = TargetStruct::new(&ast);
24
25    target.build().parse().unwrap()
26}
27
28use syn::{Body, VariantData, Ty};
29use std::collections::HashMap;
30
31#[derive(Debug, Clone, PartialEq)]
32struct TargetStruct
33{
34    name: String,
35    fields: HashMap<String, Ty>,
36    partials: Partials,
37}
38
39impl TargetStruct
40{
41    pub fn new(input: &syn::MacroInput) -> TargetStruct
42    {
43        assert!(input.generics.lifetimes.len() == 0, "safe-builder-derive does not support lifetimes");
44        assert!(input.generics.ty_params.len() == 0, "safe-builder-derive does not support generic types");
45
46        let name = input.ident.to_string();
47
48        if let Body::Struct(VariantData::Struct(ref fields)) = input.body
49        {
50            match fields.first()
51            {
52                None => TargetStruct
53                {
54                    name: name,
55                    fields: HashMap::new(),
56                    partials: Partials::new(Vec::new())
57                },
58                Some(ref first) => match first.ident
59                {
60                    Some(_) =>
61                    {
62                        let mut map = HashMap::new();
63                        let mut field_names = Vec::new();
64
65                        for field in fields.iter()
66                        {
67                            let name = field.ident.clone().unwrap().to_string();
68
69                            field_names.push(name.clone());
70                            map.insert(name, field.ty.clone());
71                        }
72
73                        let field_combinations = (0..fields.len())
74                            .flat_map(|n| field_names.iter()
75                                .combinations(n)
76                                .map(|v| v.into_iter()
77                                    .map(|s| s.to_owned())
78                                    .collect::<Vec<_>>()))
79                            .map(|mut v| {v.sort(); v});
80                        
81                        let mut names = Vec::new();
82                        let mut partials = Vec::new();
83
84                        for combo in field_combinations
85                        {
86                            let mut name = format!("{}BuilderWith{}", name,
87                                combo.iter().fold(String::new(), |mut a, b|
88                                {
89                                    a.push_str(b);
90                                    a
91                                }));
92                            
93                            while names.contains(&name)
94                            {
95                                name.push('_'); // TODO: better way to make names unique
96                            }
97
98                            names.push(name.clone());
99
100                            partials.push(PartialStruct::new(name, combo));
101                        }
102
103                        TargetStruct
104                        {
105                            name: name,
106                            fields: map,
107                            partials: Partials::new(partials)
108                        }
109                    }
110                    None => panic!("safe-builder-derive does not support tuple struct")
111                }
112            }
113        }
114        else
115        {
116            panic!("safe-builder-derive does not support enums");
117        }
118    }
119
120    pub fn build(&self) -> quote::Tokens
121    {
122        let target_id = quote::Ident::from(self.name.as_ref());
123
124        if self.fields.len() == 0
125        {
126            quote!
127            {
128                impl ::safe_builder::PartialBuilder for #target_id { }
129
130                impl ::safe_builder::SafeBuilder for #target_id
131                {
132                    fn build() -> #target_id
133                    {
134                        #target_id { }
135                    }
136                }
137            }
138        }
139        else
140        {
141            let init_struct_id = quote::Ident::from(self.partials.at_order(0).unwrap()[0].name.as_str());
142
143            let target_impl = quote!
144            {
145                impl ::safe_builder::SafeBuilder<#init_struct_id> for #target_id
146                {
147                    fn build() -> #init_struct_id
148                    {
149                        #init_struct_id{ }
150                    }
151                }
152            };
153
154            let other_impls = self.partials.all().into_iter()
155                .map(|partial| partial.build(&self))
156                .collect::<Vec<_>>();
157            
158            quote!
159            {
160                #target_impl
161
162                #(#other_impls)*
163            }
164        }
165    }
166}
167
168#[derive(Debug, Clone, PartialEq)]
169struct Partials(HashMap<usize, Vec<PartialStruct>>);
170
171impl Partials
172{
173    pub fn new(partials: Vec<PartialStruct>) -> Partials
174    {
175        let mut map: HashMap<usize, Vec<PartialStruct>> = HashMap::new();
176
177        for partial in partials.into_iter()
178        {
179            let o = partial.order();
180
181            if let Some(_) = map.get_mut(&o)
182            {
183                map.get_mut(&o).unwrap().push(partial);
184            }
185            else
186            {
187                map.insert(o, vec![partial]);
188            }
189        }
190
191        Partials(map)
192    }
193
194    pub fn at_order<'a>(&'a self, order: usize) -> Option<&'a [PartialStruct]>
195    {
196        match self.0.get(&order)
197        {
198            Some(vec) => Some(&vec),
199            None => None
200        }
201    }
202
203    pub fn all<'a>(&'a self) -> Vec<&'a PartialStruct>
204    {
205        self.0.values()
206            .flat_map(|order| order.into_iter())
207            .collect::<Vec<_>>()
208    }
209}
210
211#[derive(Debug, Clone, PartialEq)]
212struct PartialStruct
213{
214    name: String,
215    fields: Vec<String>,
216}
217
218impl PartialStruct
219{
220    pub fn new(name: String, fields: Vec<String>) -> PartialStruct
221    {
222        PartialStruct
223        {
224            name: name,
225            fields: fields
226        }
227    }
228
229    pub fn order(&self) -> usize
230    {
231        self.fields.len()
232    }
233
234    pub fn step<'a>(&self, other: &'a PartialStruct) -> Option<String>
235    {
236        if self.order() == other.order() - 1 && other.order() != 0
237        {
238            let mut s = String::new();
239            for field in other.fields.iter()
240            {
241                if !self.fields.contains(field)
242                {
243                    s = field.to_owned();
244                }
245            }
246
247            if s == String::new()
248            {
249                panic!("partial of order n - 1 can't find last field in target!");
250            }
251            else
252            {
253                Some(s)
254            }
255        }
256        else
257        {
258            None
259        }
260    }
261
262    pub fn build(&self, target: &TargetStruct) -> quote::Tokens
263    {
264        let self_id = quote::Ident::from(self.name.as_str());
265        let partial_struct =
266        {
267            let fields = self.fields.iter()
268                .map(|name|
269                {
270                    let id = quote::Ident::from(name.as_str());
271                    let ty = target.fields.get(name).unwrap();
272
273                    quote!
274                    {
275                        #id: #ty
276                    }
277                })
278                .collect::<Vec<_>>();
279
280            quote!
281            {
282                pub struct #self_id
283                {
284                    #(#fields),*
285                }
286
287                impl ::safe_builder::PartialBuilder for #self_id { }
288            }
289        };
290
291        let partial_steps = if self.fields.len() < target.fields.len() - 1
292        {
293            let steps = target.partials.at_order(self.order() + 1).unwrap().iter()
294                .filter(|partial| self.fields.iter()
295                    .all(|field| partial.fields.contains(field)))
296                .map(|partial|
297                {
298                    let step = self.step(partial).unwrap().clone();
299
300                    let step_id = quote::Ident::from(step.as_str());
301                    let step_ty = target.fields.get(&step).unwrap();
302
303                    let step_struct = quote::Ident::from(partial.name.as_str());
304
305                    let step_field = quote!
306                    {
307                        #step_id: #step_id
308                    };
309
310                    let fields = self.fields.iter()
311                        .map(|name|
312                        {
313                            let id = quote::Ident::from(name.as_str());
314
315                            quote!
316                            {
317                                #id: self.#id
318                            }
319                        });
320
321                    if fields.len() == 0
322                    {
323                        quote!
324                        {
325                            fn #step_id(self, #step_id: #step_ty) -> #step_struct
326                            {
327                                #step_struct
328                                {
329                                    #step_field
330                                }
331                            }
332                        }
333                    }
334                    else
335                    {
336                        quote!
337                        {
338                            fn #step_id(self, #step_id: #step_ty) -> #step_struct
339                            {
340                                #step_struct
341                                {
342                                    #(#fields),*,
343                                    #step_field
344                                }
345                            }
346                        }
347                    }
348                });
349
350            quote!
351            {
352                impl #self_id
353                {
354                    #(#steps)*
355                }
356            }
357        }
358        else
359        {
360            let target_id = quote::Ident::from(target.name.as_ref());
361
362            let missing =
363            {
364                let mut s = String::new();
365                for field in target.fields.keys()
366                {
367                    if !self.fields.contains(field)
368                    {
369                        s = field.to_owned();
370                    }
371                }
372
373                if s == String::new()
374                {
375                    panic!("partial of order n - 1 can't find last field in target!");
376                }
377                else
378                {
379                    s
380                }
381            };
382
383            let missing_id = quote::Ident::from(missing.as_str());
384            let missing_ty = target.fields.get(&missing);
385
386            let fields = self.fields.iter()
387                .map(|name|
388                {
389                    let id = quote::Ident::from(name.as_str());
390
391                    quote!
392                    {
393                        #id: self.#id
394                    }
395                });
396
397            if fields.len() == 0
398            {
399                quote!
400                {
401                    impl #self_id
402                    {
403                        fn #missing_id(self, #missing_id: #missing_ty) -> #target_id
404                        {
405                            #missing_id: #missing_id
406                        }
407                    }
408                }
409            }
410            else
411            {
412                quote!
413                {
414                    impl #self_id
415                    {
416                        fn #missing_id(self, #missing_id: #missing_ty) -> #target_id
417                        {
418                            #target_id
419                            {
420                                #(#fields),*,
421                                #missing_id: #missing_id
422                            }
423                        }
424                    }
425                }
426            }
427        };
428
429        quote!
430        {
431            #partial_struct
432
433            #partial_steps
434        }
435    }
436}