struct_arithmetic/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2::Span;
3use quote::quote;
4use syn::{
5    self, parse_macro_input, punctuated::Punctuated, token::Comma, Data, DataStruct, DeriveInput,
6    Field, Fields, Ident, Path, Type, TypePath,
7};
8
9#[proc_macro_derive(StructArithmetic, attributes(helper))]
10pub fn struct_arithmetic(tokens: TokenStream) -> TokenStream {
11    let input = parse_macro_input!(tokens as DeriveInput);
12    let name = input.ident;
13
14    let fields: Punctuated<Field, Comma> = match input.data {
15        Data::Struct(DataStruct {
16            fields: Fields::Named(fields),
17            ..
18        }) => fields.named,
19        _ => panic!("Only structs with named fields can be annotated with ToUrl"),
20    };
21
22    let field_type = match &fields.first().unwrap().ty {
23        Type::Path(TypePath {
24            path: Path { segments, .. },
25            ..
26        }) => {
27            if let Some(path_seg) = segments.first() {
28                let ident: &proc_macro2::Ident = &path_seg.ident;
29                Some(ident.to_string())
30            } else {
31                None
32            }
33        }
34        Type::Array(arr) => match &(*arr.elem) {
35            Type::Path(TypePath {
36                path: Path { segments, .. },
37                ..
38            }) => {
39                if let Some(path_seg) = segments.first() {
40                    let ident: &proc_macro2::Ident = &path_seg.ident;
41                    Some(ident.to_string())
42                } else {
43                    None
44                }
45            }
46            _ => None,
47        },
48        _ => None,
49    }
50    .unwrap();
51
52    let factor = Ident::new("factor", Span::call_site());
53    let numerator = Ident::new("numerator", Span::call_site());
54    let denominator = Ident::new("denominator", Span::call_site());
55    let fields_type = Ident::new(&field_type, Span::call_site());
56
57    let (addition, addition_array) = generate_add(&fields);
58    let addition_assign = generate_add_assign(&fields);
59    let (subtraction, subtraction_array) = generate_sub(&fields);
60    let subtraction_assign = generate_sub_assign(&fields);
61    let (multiplication, multiplication_array) = generate_mul(&fields);
62    let (division, division_array) = generate_div(&fields);
63    let (division_scalar, division_scalar_array) = generate_div_scalar(&fields, factor.clone());
64    let (multiplication_scalar, multiplication_scalar_array) =
65        generate_mul_scalar(&fields, factor.clone());
66    let (multiplication_fraction, multiplication_fraction_array) =
67        generate_mul_fraction(&fields, numerator, denominator);
68
69    let (new_constructor_args, new_constructor_struct) = generate_new(&fields);
70    let is_zero = generate_is_zero(&fields);
71    // let token_amount = generate_token_amount(&fields);
72
73    let modified = quote! {
74        impl #name {
75            pub fn new(#(#new_constructor_args)*) -> #name {
76                #name {
77                #(#new_constructor_struct)*
78                }
79            }
80
81            pub fn is_zero(&self) -> bool {
82                #(#is_zero)*
83                return true;
84            }
85
86            pub fn add(&self, other: &#name) -> Option<#name> {
87                #(#addition_array)*
88                Some(#name::new(
89                    #(#addition)*
90                ))
91            }
92
93            pub fn add_assign(&mut self, other: &#name) -> Option<()> {
94                #(#addition_assign)*
95
96                Some(())
97            }
98
99            pub fn sub(&self, other: &#name) -> Option<#name> {
100                #(#subtraction_array)*
101                Some(#name::new(
102                    #(#subtraction)*
103                ))
104            }
105
106            pub fn sub_assign(&mut self, other: &#name) -> Option<()> {
107                #(#subtraction_assign)*
108
109                Some(())
110            }
111
112            pub fn div(&self, other: &#name) -> Option<#name> {
113                #(#division_array)*
114                Some(#name::new(
115                    #(#division)*
116                ))
117            }
118
119            pub fn div_scalar(&self, factor: #fields_type) -> Option<#name> {
120                #(#division_scalar_array)*
121                Some(#name::new(
122                    #(#division_scalar)*
123                ))
124            }
125
126            pub fn mul(&self, other: &#name) -> Option<#name> {
127                #(#multiplication_array)*
128                Some(#name::new(
129                    #(#multiplication)*
130                ))
131            }
132
133            pub fn mul_scalar(&self, factor: #fields_type) -> Option<#name> {
134                #(#multiplication_scalar_array)*
135                Some(#name::new(
136                    #(#multiplication_scalar)*
137                ))
138            }
139
140            pub fn mul_fraction(&self, numerator: #fields_type, denominator: #fields_type) -> Option<#name> {
141                #(#multiplication_fraction_array)*
142                Some(#name::new(
143                    #(#multiplication_fraction)*
144                ))
145            }
146
147            pub fn mul_bps(&self, factor: u16) -> Option<#name> {
148                self.mul_fraction(factor as #fields_type, 10_000)
149            }
150
151            pub fn mul_bps_u64(&self, factor: u64) -> Option<#name> {
152                self.mul_fraction(factor as #fields_type, 10_000)
153            }
154
155            pub fn mul_percent(&self, factor: u16) -> Option<#name> {
156                self.mul_fraction(factor as #fields_type, 100)
157            }
158
159        }
160    };
161    TokenStream::from(modified)
162}
163
164//     args_code
165// }
166
167fn generate_is_zero<'a>(
168    fields: &'a Punctuated<Field, Comma>,
169) -> impl Iterator<Item = proc_macro2::TokenStream> + 'a {
170    let args_code_array = fields
171        .into_iter()
172        .enumerate()
173        .filter(|(_i, field)| match &field.ty {
174            Type::Array(_arr) => !field
175                .ident
176                .as_ref()
177                .unwrap()
178                .to_string()
179                .starts_with("_reserved"),
180            _ => false,
181        })
182        .map(move |(_i, field)| {
183            let field_ident = field.ident.as_ref().unwrap();
184            quote! {
185                for i in 0..self.#field_ident.len() {
186                    if self.#field_ident[i] != 0 {
187                        return false;
188                    }
189                }
190            }
191        });
192    let args_code = fields
193        .into_iter()
194        .filter(|field| match &field.ty {
195            Type::Array(_arr) => false,
196            _ => true,
197        })
198        .map(move |field| {
199            let field_ident = field.ident.as_ref().unwrap();
200            quote! {
201                if self.#field_ident != 0 {
202                    return false;
203                }
204            }
205        });
206
207    args_code.chain(args_code_array)
208}
209
210fn generate_new(
211    fields: &Punctuated<Field, Comma>,
212) -> (
213    impl Iterator<Item = proc_macro2::TokenStream> + '_,
214    impl Iterator<Item = proc_macro2::TokenStream> + '_,
215) {
216    let args_code = fields.iter().enumerate().map(move |(i, field)| {
217        let field_ident = field.ident.as_ref().unwrap();
218        let field_type = &field.ty;
219        if field_ident.to_string().starts_with("_reserved") {
220            return quote! {};
221        }
222        if i < fields.len() - 1 {
223            quote! { #field_ident: #field_type, }
224        } else {
225            quote! { #field_ident: #field_type }
226        }
227    });
228    let struct_code = fields.iter().map(move |field| {
229        let field_ident = field.ident.as_ref().unwrap();
230        if field_ident.to_string().starts_with("_reserved") {
231            let reserved_len = match &field.ty {
232                Type::Array(arr) => &arr.len,
233                _ => panic!("_reserved can only be an array"),
234            };
235            return quote! { #field_ident: [0; #reserved_len], };
236        }
237        return quote! { #field_ident, };
238    });
239    (args_code, struct_code)
240}
241
242fn generate_add(
243    fields: &Punctuated<Field, Comma>,
244) -> (
245    impl Iterator<Item = proc_macro2::TokenStream> + '_,
246    impl Iterator<Item = proc_macro2::TokenStream> + '_,
247) {
248    let code_array = fields
249        .into_iter()
250        .enumerate()
251        .filter(|(_i, field)| match &field.ty {
252            Type::Array(_arr) => !field
253                .ident
254                .as_ref()
255                .unwrap()
256                .to_string()
257                .starts_with("_reserved"),
258            _ => false,
259        })
260        .map(move |(_i, field)| {
261            let field_ident = field.ident.as_ref().unwrap();
262            let (field_size, field_type) = match &field.ty {
263                Type::Array(arr) => (&arr.len, &arr.elem),
264                _ => panic!("Only arrays are accepted"),
265            };
266            quote! {
267                let mut #field_ident = [#field_type::default(); #field_size];
268                for i in 0..self.#field_ident.len() {
269                    #field_ident[i] = self.#field_ident[i].checked_add(other.#field_ident[i])?;
270                }
271            }
272        });
273    let code = fields.into_iter().map(move |field| {
274        let field_ident = field.ident.as_ref().unwrap();
275        if field_ident.to_string().starts_with("_reserved") {
276            return quote! {};
277        }
278        match &field.ty {
279            Type::Array(_arr) => quote! { #field_ident, },
280            _ => quote! { self.#field_ident.checked_add(other.#field_ident)?, },
281        }
282    });
283
284    (code, code_array)
285}
286
287fn generate_add_assign(
288    fields: &Punctuated<Field, Comma>,
289) -> impl Iterator<Item = proc_macro2::TokenStream> + '_ {
290    let code = fields.iter().map(|field| {
291        let field_ident = field.ident.as_ref().unwrap();
292        if field_ident.to_string().starts_with("_reserved") {
293            return quote! {};
294        }
295        match &field.ty {
296            Type::Array(_arr) => quote! {
297                for i in 0..self.#field_ident.len() {
298                    self.#field_ident[i] = self.#field_ident[i].checked_add(other.#field_ident[i])?;
299                }
300            },
301            _ => quote! { self.#field_ident = self.#field_ident.checked_add(other.#field_ident)?; },
302        }
303    });
304    code
305}
306
307fn generate_sub(
308    fields: &Punctuated<Field, Comma>,
309) -> (
310    impl Iterator<Item = proc_macro2::TokenStream> + '_,
311    impl Iterator<Item = proc_macro2::TokenStream> + '_,
312) {
313    let code_array = fields
314        .into_iter()
315        .enumerate()
316        .filter(|(_i, field)| match &field.ty {
317            Type::Array(_arr) => !field
318                .ident
319                .as_ref()
320                .unwrap()
321                .to_string()
322                .starts_with("_reserved"),
323            _ => false,
324        })
325        .map(move |(_i, field)| {
326            let field_ident = field.ident.as_ref().unwrap();
327            let (field_size, field_type) = match &field.ty {
328                Type::Array(arr) => (&arr.len, &arr.elem),
329                _ => panic!("Only arrays are accepted"),
330            };
331            quote! {
332                let mut #field_ident = [#field_type::default(); #field_size];
333                for i in 0..self.#field_ident.len() {
334                    #field_ident[i] = self.#field_ident[i].checked_sub(other.#field_ident[i])?;
335                }
336            }
337        });
338    let code = fields.into_iter().map(move |field| {
339        let field_ident = field.ident.as_ref().unwrap();
340        if field_ident.to_string().starts_with("_reserved") {
341            return quote! {};
342        }
343        match &field.ty {
344            Type::Array(_arr) => quote! { #field_ident, },
345            _ => quote! { self.#field_ident.checked_sub(other.#field_ident)?, },
346        }
347    });
348
349    (code, code_array)
350}
351
352fn generate_sub_assign(
353    fields: &Punctuated<Field, Comma>,
354) -> impl Iterator<Item = proc_macro2::TokenStream> + '_ {
355    let code = fields.iter().map(|field| {
356        let field_ident = field.ident.as_ref().unwrap();
357        if field_ident.to_string().starts_with("_reserved") {
358            return quote! {};
359        }
360        match &field.ty {
361            Type::Array(_arr) => quote! {
362                for i in 0..self.#field_ident.len() {
363                    self.#field_ident[i] = self.#field_ident[i].checked_sub(other.#field_ident[i])?;
364                }
365            },
366            _ => quote! { self.#field_ident = self.#field_ident.checked_sub(other.#field_ident)?; },
367        }
368    });
369    code
370}
371
372fn generate_mul(
373    fields: &Punctuated<Field, Comma>,
374) -> (
375    impl Iterator<Item = proc_macro2::TokenStream> + '_,
376    impl Iterator<Item = proc_macro2::TokenStream> + '_,
377) {
378    let code_array = fields
379        .into_iter()
380        .enumerate()
381        .filter(|(_i, field)| match &field.ty {
382            Type::Array(_arr) => !field
383                .ident
384                .as_ref()
385                .unwrap()
386                .to_string()
387                .starts_with("_reserved"),
388            _ => false,
389        })
390        .map(move |(_i, field)| {
391            let field_ident = field.ident.as_ref().unwrap();
392            let (field_size, field_type) = match &field.ty {
393                Type::Array(arr) => (&arr.len, &arr.elem),
394                _ => panic!("Only arrays are accepted"),
395            };
396            quote! {
397                let mut #field_ident = [#field_type::default(); #field_size];
398                for i in 0..self.#field_ident.len() {
399                    #field_ident[i] = self.#field_ident[i].checked_mul(other.#field_ident[i])?;
400                }
401            }
402        });
403    let code = fields.into_iter().map(move |field| {
404        let field_ident = field.ident.as_ref().unwrap();
405        if field_ident.to_string().starts_with("_reserved") {
406            return quote! {};
407        }
408        match &field.ty {
409            Type::Array(_arr) => quote! { #field_ident, },
410            _ => quote! { self.#field_ident.checked_mul(other.#field_ident)?, },
411        }
412    });
413
414    (code, code_array)
415}
416
417fn generate_div(
418    fields: &Punctuated<Field, Comma>,
419) -> (
420    impl Iterator<Item = proc_macro2::TokenStream> + '_,
421    impl Iterator<Item = proc_macro2::TokenStream> + '_,
422) {
423    let code_array = fields
424        .into_iter()
425        .enumerate()
426        .filter(|(_i, field)| match &field.ty {
427            Type::Array(_arr) => !field
428                .ident
429                .as_ref()
430                .unwrap()
431                .to_string()
432                .starts_with("_reserved"),
433            _ => false,
434        })
435        .map(move |(_i, field)| {
436            let field_ident = field.ident.as_ref().unwrap();
437            let (field_size, field_type) = match &field.ty {
438                Type::Array(arr) => (&arr.len, &arr.elem),
439                _ => panic!("Only arrays are accepted"),
440            };
441            quote! {
442                let mut #field_ident = [#field_type::default(); #field_size];
443                for i in 0..self.#field_ident.len() {
444                    #field_ident[i] = self.#field_ident[i].checked_div(other.#field_ident[i])?;
445                }
446            }
447        });
448    let code = fields.into_iter().map(move |field| {
449        let field_ident = field.ident.as_ref().unwrap();
450        if field_ident.to_string().starts_with("_reserved") {
451            return quote! {};
452        }
453        match &field.ty {
454            Type::Array(_arr) => quote! { #field_ident, },
455            _ => quote! { self.#field_ident.checked_div(other.#field_ident)?, },
456        }
457    });
458
459    (code, code_array)
460}
461
462fn generate_div_scalar(
463    fields: &Punctuated<Field, Comma>,
464    factor: Ident,
465) -> (
466    impl Iterator<Item = proc_macro2::TokenStream> + '_,
467    impl Iterator<Item = proc_macro2::TokenStream> + '_,
468) {
469    let factor2 = factor.clone();
470    let code_array = fields
471        .into_iter()
472        .enumerate()
473        .filter(|(_i, field)| match &field.ty {
474            Type::Array(_arr) => !field
475                .ident
476                .as_ref()
477                .unwrap()
478                .to_string()
479                .starts_with("_reserved"),
480            _ => false,
481        })
482        .map(move |(_i, field)| {
483            let field_ident = field.ident.as_ref().unwrap();
484            let (field_size, field_type) = match &field.ty {
485                Type::Array(arr) => (&arr.len, &arr.elem),
486                _ => panic!("Only arrays are accepted"),
487            };
488            quote! {
489                let mut #field_ident = [#field_type::default(); #field_size];
490                for i in 0..self.#field_ident.len() {
491                    #field_ident[i] = self.#field_ident[i].checked_div(#factor.into())?;
492                }
493            }
494        });
495    let code = fields.into_iter().map(move |field| {
496        let field_ident = field.ident.as_ref().unwrap();
497        if field_ident.to_string().starts_with("_reserved") {
498            return quote! {};
499        }
500        match &field.ty {
501            Type::Array(_arr) => quote! { #field_ident, },
502            _ => quote! { self.#field_ident.checked_div(#factor2)?, },
503        }
504    });
505
506    (code, code_array)
507}
508
509fn generate_mul_scalar(
510    fields: &Punctuated<Field, Comma>,
511    factor: Ident,
512) -> (
513    impl Iterator<Item = proc_macro2::TokenStream> + '_,
514    impl Iterator<Item = proc_macro2::TokenStream> + '_,
515) {
516    let factor2 = factor.clone();
517    let code_array = fields
518        .into_iter()
519        .enumerate()
520        .filter(|(_i, field)| match &field.ty {
521            Type::Array(_arr) => !field
522                .ident
523                .as_ref()
524                .unwrap()
525                .to_string()
526                .starts_with("_reserved"),
527            _ => false,
528        })
529        .map(move |(_i, field)| {
530            let field_ident = field.ident.as_ref().unwrap();
531            let (field_size, field_type) = match &field.ty {
532                Type::Array(arr) => (&arr.len, &arr.elem),
533                _ => panic!("Only arrays are accepted"),
534            };
535            quote! {
536                let mut #field_ident = [#field_type::default(); #field_size];
537                for i in 0..self.#field_ident.len() {
538                    #field_ident[i] = self.#field_ident[i].checked_mul(#factor.into())?;
539                }
540            }
541        });
542    let code = fields.into_iter().map(move |field| {
543        let field_ident = field.ident.as_ref().unwrap();
544        if field_ident.to_string().starts_with("_reserved") {
545            return quote! {};
546        }
547        match &field.ty {
548            Type::Array(_arr) => quote! { #field_ident, },
549            _ => quote! { self.#field_ident.checked_mul(#factor2)?, },
550        }
551    });
552
553    (code, code_array)
554}
555
556fn generate_mul_fraction(
557    fields: &Punctuated<Field, Comma>,
558    numerator: Ident,
559    denominator: Ident,
560) -> (
561    impl Iterator<Item = proc_macro2::TokenStream> + '_,
562    impl Iterator<Item = proc_macro2::TokenStream> + '_,
563) {
564    let numerator2 = numerator.clone();
565    let denominator2 = denominator.clone();
566    let code_array = fields
567        .into_iter()
568        .enumerate()
569        .filter(|(_i, field)| match &field.ty {
570            Type::Array(_arr) =>!field
571                .ident
572                .as_ref()
573                .unwrap()
574                .to_string()
575                .starts_with("_reserved"),
576            _ => false,
577        })
578        .map(move |(_i, field)| {
579            let field_ident = field.ident.as_ref().unwrap();
580            let (field_size, field_type) = match &field.ty {
581                Type::Array(arr) => (&arr.len, &arr.elem),
582                _ => panic!("Only arrays are accepted"),
583            };
584            quote! {
585                let mut #field_ident = [#field_type::default(); #field_size];
586                for i in 0..self.#field_ident.len() {
587                    #field_ident[i] = ((self.#field_ident[i] as u128).checked_mul(#numerator as u128)?.checked_div(#denominator as u128)?) as #field_type;
588                }
589            }
590        });
591    let code = fields.into_iter().map(move |field| {
592        let field_ident = field.ident.as_ref().unwrap();
593        if field_ident.to_string().starts_with("_reserved") {
594            return quote! {};
595        }
596        match &field.ty {
597            Type::Array(_arr) => quote! { #field_ident, },
598            _ => {
599                let s = match &field.ty {
600                    Type::Path(TypePath {
601                        path: Path { segments, .. },
602                        ..
603                    }) => {
604                        if let Some(path_seg) = segments.first() {
605                            let ident: &proc_macro2::Ident = &path_seg.ident;
606                            Some(ident.to_string())
607                        } else {
608                            None
609                        }
610                    }
611                    _ => None,
612                }.unwrap();
613                let field_type = Ident::new(&s, Span::call_site());
614                quote! {
615                ((self.#field_ident as u128).checked_mul(#numerator2 as u128)?.checked_div(#denominator2 as u128)?) as #field_type, }
616            },
617        }
618    });
619
620    (code, code_array)
621}