tracker_macros/
lib.rs

1//! Macros for the `tracker` crate.
2
3#![warn(
4    missing_debug_implementations,
5    missing_docs,
6    rust_2018_idioms,
7    unreachable_pub,
8    clippy::cargo,
9    clippy::must_use_candidate,
10    clippy::cargo
11)]
12
13use proc_macro::{self, Span, TokenStream};
14use proc_macro2::{Span as Span2, TokenStream as TokenStream2};
15use quote::{quote, quote_spanned, ToTokens};
16use syn::{
17    parse_macro_input, Attribute, Error, Field, Fields, GenericParam, Ident, ItemStruct, Type,
18};
19
20const NO_EQ: &str = "no_eq";
21const DO_NOT_TRACK: &str = "do_not_track";
22
23/// Implements tracker methods for structs.
24#[proc_macro_attribute]
25pub fn track(attr: TokenStream, item: TokenStream) -> TokenStream {
26    if !attr.is_empty() {
27        return Error::new(
28            attr.into_iter().next().unwrap().span().into(),
29            "This macro doesn't handle attributes",
30        )
31        .into_compile_error()
32        .into();
33    }
34
35    let mut data: ItemStruct = parse_macro_input!(item);
36    let ident = data.ident.clone();
37    let tracker_ty;
38    let struct_vis = &data.vis;
39    let where_clause = &data.generics.where_clause;
40
41    // Remove default type parameters (like <Type=DefaultType>).
42    let mut generics = data.generics.clone();
43    for param in generics.params.iter_mut() {
44        if let GenericParam::Type(ty) = param {
45            ty.eq_token = None;
46            ty.default = None;
47        }
48    }
49
50    let mut generics_iter = data.generics.params.iter();
51    let mut generic_idents = TokenStream2::new();
52
53    if let Some(first) = generics_iter.next() {
54        impl_struct_generics(first, &mut generic_idents);
55        for generic_param in generics_iter {
56            generic_idents.extend(quote! {,});
57            impl_struct_generics(generic_param, &mut generic_idents);
58        }
59    }
60
61    let mut field_list = Vec::new();
62    if let Fields::Named(named_fields) = &mut data.fields {
63        for field in &mut named_fields.named {
64            let (do_not_track, no_eq) = parse_field_attrs(&mut field.attrs);
65            if !do_not_track {
66                let ident = field.ident.clone().expect("Field has no identifier");
67                let ty: Type = field.ty.clone();
68                field_list.push((ident, ty, no_eq, field.vis.clone()));
69            }
70        }
71
72        tracker_ty = tracker_type(field_list.len());
73        let change_field = Field {
74            attrs: Vec::new(),
75            vis: syn::Visibility::Inherited,
76            mutability: syn::FieldMutability::None,
77            ident: Some(Ident::new("tracker", Span::call_site().into())),
78            colon_token: None,
79            ty: Type::Verbatim(tracker_ty.clone()),
80        };
81
82        named_fields.named.push(change_field);
83    } else {
84        panic!("No named fields");
85    }
86
87    let mut output = data.to_token_stream();
88
89    let mut methods = proc_macro2::TokenStream::new();
90    for (num, (id, ty, no_eq, vis)) in field_list.iter().enumerate() {
91        let id_span: Span2 = id.span().unwrap().into();
92
93        let get_id = Ident::new(&format!("get_{}", id), id_span);
94        let get_mut_id = Ident::new(&format!("get_mut_{}", id), id_span);
95        let update_id = Ident::new(&format!("update_{}", id), id_span);
96        let changed_id = Ident::new(&format!("changed_{}", id), id_span);
97        let set_id = Ident::new(&format!("set_{}", id), id_span);
98
99        let get_doc = format!("Get an immutable reference to the {id} field.");
100        let get_mut_doc =
101            format!("Get a mutable reference to the {id} field and mark the field as changed.");
102        let update_doc =
103            format!("Use a closure to update the {id} field and mark the field as changed.");
104        let changed_doc =
105            format!("Check if value of {id} field has changed.");
106        let bit_mask_doc = format!("Get a bit mask to look for changes on the {id} field.");
107
108        methods.extend(quote_spanned! { id_span =>
109            #[allow(dead_code, non_snake_case)]
110            #[must_use]
111            #[doc = #get_doc]
112            #vis fn #get_id(&self) -> &#ty {
113                &self.#id
114            }
115
116            #[allow(dead_code, non_snake_case)]
117            #[must_use]
118            #[doc = #get_mut_doc]
119            #vis fn #get_mut_id(&mut self) -> &mut #ty {
120                self.tracker |= Self::#id();
121                &mut self.#id
122            }
123
124            #[allow(dead_code, non_snake_case)]
125            #[doc = #update_doc]
126            #vis fn #update_id<F: FnOnce(&mut #ty)>(&mut self, f: F) {
127                self.tracker |= Self::#id();
128                f(&mut self.#id);
129            }
130
131            #[allow(dead_code, non_snake_case)]
132            #[doc = #changed_doc]
133            #vis fn #changed_id(&self) -> bool {
134                self.changed(Self::#id())
135            }
136
137            #[allow(dead_code, non_snake_case)]
138            #[must_use]
139            #[doc = #bit_mask_doc]
140            #vis fn #id() -> #tracker_ty {
141                1 << #num
142            }
143        });
144
145        if *no_eq {
146            let set_doc = format!("Set the value of field {id} and mark the field as changed.");
147            methods.extend(quote_spanned! { id_span =>
148                #[allow(dead_code, non_snake_case)]
149                #[doc = #set_doc]
150                #vis fn #set_id(&mut self, value: #ty) {
151                    self.tracker |= Self::#id();
152                    self.#id = value;
153                }
154            });
155        } else {
156            let set_doc = format!("Set the value of field {id} and mark the field as changed if it's not equal to the previous value.");
157            methods.extend(quote_spanned! { id_span =>
158                #[allow(dead_code, non_snake_case)]
159                #[doc = #set_doc]
160                #vis fn #set_id(&mut self, value: #ty) {
161                    if self.#id != value {
162                        self.tracker |= Self::#id();
163                    }
164                    self.#id = value;
165                }
166            });
167        }
168    }
169
170    output.extend(quote_spanned! { ident.span() =>
171        impl #generics #ident < #generic_idents > #where_clause {
172            #methods
173            #[allow(dead_code)]
174            #[must_use]
175            /// Get a bit mask to look for changes on all fields.
176            #struct_vis fn track_all() -> #tracker_ty {
177                #tracker_ty::MAX
178            }
179
180            #[allow(dead_code)]
181            /// Mark all fields of the struct as changed.
182            #struct_vis fn mark_all_changed(&mut self) {
183                self.tracker = #tracker_ty::MAX;
184            }
185
186            /// Check for changes made to this struct with a given bitmask.
187            ///
188            /// To receive the bitmask, simply call `Type::#field_name()`
189            /// or `Type::#track_all()`.
190            #[warn(dead_code)]
191            #[must_use]
192            #struct_vis fn changed(&self, mask: #tracker_ty) -> bool {
193                self.tracker & mask != 0
194            }
195
196            /// Check for any changes made to this struct.
197            #[allow(dead_code)]
198            #[must_use]
199            #struct_vis fn changed_any(&self) -> bool {
200                self.tracker != 0
201            }
202
203            /// Resets the tracker value of this struct to mark all fields
204            /// as unchanged again.
205            #[warn(dead_code)]
206            #struct_vis fn reset(&mut self) {
207                self.tracker = 0;
208            }
209        }
210    });
211
212    output.into()
213}
214
215fn impl_struct_generics(param: &GenericParam, stream: &mut TokenStream2) {
216    match param {
217        GenericParam::Type(ty) => ty.ident.to_tokens(stream),
218        GenericParam::Const(cnst) => cnst.to_tokens(stream),
219        GenericParam::Lifetime(lifetime) => lifetime.to_tokens(stream),
220    }
221}
222
223/// Look for no_eq and do_not_track attributes and remove
224/// them from the tokens.
225fn parse_field_attrs(attrs: &mut Vec<Attribute>) -> (bool, bool) {
226    let mut do_not_track = false;
227    let mut no_eq = false;
228    let attrs_clone = attrs.clone();
229
230    for (index, attr) in attrs_clone.iter().enumerate() {
231        let segs = &attr.path().segments;
232        match segs.len() {
233            1 => {
234                let first = &segs.first().unwrap().ident;
235                if first == NO_EQ {
236                    attrs.remove(index);
237                    no_eq = true;
238                } else if first == DO_NOT_TRACK {
239                    attrs.remove(index);
240                    do_not_track = true;
241                }
242            }
243            2 => {
244                let mut iter = segs.iter();
245                let first = &iter.next().unwrap().ident;
246                if first == "tracker" {
247                    let second = &iter.next().unwrap().ident;
248                    if second == NO_EQ {
249                        attrs.remove(index);
250                        no_eq = true;
251                    } else if second == DO_NOT_TRACK {
252                        attrs.remove(index);
253                        do_not_track = true;
254                    }
255                }
256            }
257            _ => {}
258        }
259    }
260
261    (do_not_track, no_eq)
262}
263
264fn tracker_type(len: usize) -> proc_macro2::TokenStream {
265    match len {
266        0..=8 => {
267            quote! {u8}
268        }
269        9..=16 => {
270            quote! {u16}
271        }
272        17..=32 => {
273            quote! {u32}
274        }
275        33..=64 => {
276            quote! {u64}
277        }
278        65..=128 => {
279            quote! {u128}
280        }
281        _ => {
282            panic!("You can only track up to 128 values")
283        }
284    }
285}