1#![warn(
4 missing_debug_implementations,
5 missing_docs,
6 rust_2018_idioms,
7 unreachable_pub,
8 clippy::cargo,
9 clippy::must_use_candidate,
10 clippy::cargo
11)]
12
13use proc_macro::{self, Span, TokenStream};
14use proc_macro2::{Span as Span2, TokenStream as TokenStream2};
15use quote::{quote, quote_spanned, ToTokens};
16use syn::{
17 parse_macro_input, Attribute, Error, Field, Fields, GenericParam, Ident, ItemStruct, Type,
18};
19
20const NO_EQ: &str = "no_eq";
21const DO_NOT_TRACK: &str = "do_not_track";
22
23#[proc_macro_attribute]
25pub fn track(attr: TokenStream, item: TokenStream) -> TokenStream {
26 if !attr.is_empty() {
27 return Error::new(
28 attr.into_iter().next().unwrap().span().into(),
29 "This macro doesn't handle attributes",
30 )
31 .into_compile_error()
32 .into();
33 }
34
35 let mut data: ItemStruct = parse_macro_input!(item);
36 let ident = data.ident.clone();
37 let tracker_ty;
38 let struct_vis = &data.vis;
39 let where_clause = &data.generics.where_clause;
40
41 let mut generics = data.generics.clone();
43 for param in generics.params.iter_mut() {
44 if let GenericParam::Type(ty) = param {
45 ty.eq_token = None;
46 ty.default = None;
47 }
48 }
49
50 let mut generics_iter = data.generics.params.iter();
51 let mut generic_idents = TokenStream2::new();
52
53 if let Some(first) = generics_iter.next() {
54 impl_struct_generics(first, &mut generic_idents);
55 for generic_param in generics_iter {
56 generic_idents.extend(quote! {,});
57 impl_struct_generics(generic_param, &mut generic_idents);
58 }
59 }
60
61 let mut field_list = Vec::new();
62 if let Fields::Named(named_fields) = &mut data.fields {
63 for field in &mut named_fields.named {
64 let (do_not_track, no_eq) = parse_field_attrs(&mut field.attrs);
65 if !do_not_track {
66 let ident = field.ident.clone().expect("Field has no identifier");
67 let ty: Type = field.ty.clone();
68 field_list.push((ident, ty, no_eq, field.vis.clone()));
69 }
70 }
71
72 tracker_ty = tracker_type(field_list.len());
73 let change_field = Field {
74 attrs: Vec::new(),
75 vis: syn::Visibility::Inherited,
76 mutability: syn::FieldMutability::None,
77 ident: Some(Ident::new("tracker", Span::call_site().into())),
78 colon_token: None,
79 ty: Type::Verbatim(tracker_ty.clone()),
80 };
81
82 named_fields.named.push(change_field);
83 } else {
84 panic!("No named fields");
85 }
86
87 let mut output = data.to_token_stream();
88
89 let mut methods = proc_macro2::TokenStream::new();
90 for (num, (id, ty, no_eq, vis)) in field_list.iter().enumerate() {
91 let id_span: Span2 = id.span().unwrap().into();
92
93 let get_id = Ident::new(&format!("get_{}", id), id_span);
94 let get_mut_id = Ident::new(&format!("get_mut_{}", id), id_span);
95 let update_id = Ident::new(&format!("update_{}", id), id_span);
96 let changed_id = Ident::new(&format!("changed_{}", id), id_span);
97 let set_id = Ident::new(&format!("set_{}", id), id_span);
98
99 let get_doc = format!("Get an immutable reference to the {id} field.");
100 let get_mut_doc =
101 format!("Get a mutable reference to the {id} field and mark the field as changed.");
102 let update_doc =
103 format!("Use a closure to update the {id} field and mark the field as changed.");
104 let changed_doc =
105 format!("Check if value of {id} field has changed.");
106 let bit_mask_doc = format!("Get a bit mask to look for changes on the {id} field.");
107
108 methods.extend(quote_spanned! { id_span =>
109 #[allow(dead_code, non_snake_case)]
110 #[must_use]
111 #[doc = #get_doc]
112 #vis fn #get_id(&self) -> &#ty {
113 &self.#id
114 }
115
116 #[allow(dead_code, non_snake_case)]
117 #[must_use]
118 #[doc = #get_mut_doc]
119 #vis fn #get_mut_id(&mut self) -> &mut #ty {
120 self.tracker |= Self::#id();
121 &mut self.#id
122 }
123
124 #[allow(dead_code, non_snake_case)]
125 #[doc = #update_doc]
126 #vis fn #update_id<F: FnOnce(&mut #ty)>(&mut self, f: F) {
127 self.tracker |= Self::#id();
128 f(&mut self.#id);
129 }
130
131 #[allow(dead_code, non_snake_case)]
132 #[doc = #changed_doc]
133 #vis fn #changed_id(&self) -> bool {
134 self.changed(Self::#id())
135 }
136
137 #[allow(dead_code, non_snake_case)]
138 #[must_use]
139 #[doc = #bit_mask_doc]
140 #vis fn #id() -> #tracker_ty {
141 1 << #num
142 }
143 });
144
145 if *no_eq {
146 let set_doc = format!("Set the value of field {id} and mark the field as changed.");
147 methods.extend(quote_spanned! { id_span =>
148 #[allow(dead_code, non_snake_case)]
149 #[doc = #set_doc]
150 #vis fn #set_id(&mut self, value: #ty) {
151 self.tracker |= Self::#id();
152 self.#id = value;
153 }
154 });
155 } else {
156 let set_doc = format!("Set the value of field {id} and mark the field as changed if it's not equal to the previous value.");
157 methods.extend(quote_spanned! { id_span =>
158 #[allow(dead_code, non_snake_case)]
159 #[doc = #set_doc]
160 #vis fn #set_id(&mut self, value: #ty) {
161 if self.#id != value {
162 self.tracker |= Self::#id();
163 }
164 self.#id = value;
165 }
166 });
167 }
168 }
169
170 output.extend(quote_spanned! { ident.span() =>
171 impl #generics #ident < #generic_idents > #where_clause {
172 #methods
173 #[allow(dead_code)]
174 #[must_use]
175 #struct_vis fn track_all() -> #tracker_ty {
177 #tracker_ty::MAX
178 }
179
180 #[allow(dead_code)]
181 #struct_vis fn mark_all_changed(&mut self) {
183 self.tracker = #tracker_ty::MAX;
184 }
185
186 #[warn(dead_code)]
191 #[must_use]
192 #struct_vis fn changed(&self, mask: #tracker_ty) -> bool {
193 self.tracker & mask != 0
194 }
195
196 #[allow(dead_code)]
198 #[must_use]
199 #struct_vis fn changed_any(&self) -> bool {
200 self.tracker != 0
201 }
202
203 #[warn(dead_code)]
206 #struct_vis fn reset(&mut self) {
207 self.tracker = 0;
208 }
209 }
210 });
211
212 output.into()
213}
214
215fn impl_struct_generics(param: &GenericParam, stream: &mut TokenStream2) {
216 match param {
217 GenericParam::Type(ty) => ty.ident.to_tokens(stream),
218 GenericParam::Const(cnst) => cnst.to_tokens(stream),
219 GenericParam::Lifetime(lifetime) => lifetime.to_tokens(stream),
220 }
221}
222
223fn parse_field_attrs(attrs: &mut Vec<Attribute>) -> (bool, bool) {
226 let mut do_not_track = false;
227 let mut no_eq = false;
228 let attrs_clone = attrs.clone();
229
230 for (index, attr) in attrs_clone.iter().enumerate() {
231 let segs = &attr.path().segments;
232 match segs.len() {
233 1 => {
234 let first = &segs.first().unwrap().ident;
235 if first == NO_EQ {
236 attrs.remove(index);
237 no_eq = true;
238 } else if first == DO_NOT_TRACK {
239 attrs.remove(index);
240 do_not_track = true;
241 }
242 }
243 2 => {
244 let mut iter = segs.iter();
245 let first = &iter.next().unwrap().ident;
246 if first == "tracker" {
247 let second = &iter.next().unwrap().ident;
248 if second == NO_EQ {
249 attrs.remove(index);
250 no_eq = true;
251 } else if second == DO_NOT_TRACK {
252 attrs.remove(index);
253 do_not_track = true;
254 }
255 }
256 }
257 _ => {}
258 }
259 }
260
261 (do_not_track, no_eq)
262}
263
264fn tracker_type(len: usize) -> proc_macro2::TokenStream {
265 match len {
266 0..=8 => {
267 quote! {u8}
268 }
269 9..=16 => {
270 quote! {u16}
271 }
272 17..=32 => {
273 quote! {u32}
274 }
275 33..=64 => {
276 quote! {u64}
277 }
278 65..=128 => {
279 quote! {u128}
280 }
281 _ => {
282 panic!("You can only track up to 128 values")
283 }
284 }
285}