Skip to main content

strict_typing/
lib.rs

1//! A macro to enforce strict typing on the fields in Rust.
2//!
3//! Please refer to the documentation of the macro for more details:
4//! [`macro@strict_types`].
5
6use proc_macro::TokenStream;
7use quote::{ToTokens, quote};
8use syn::{
9    Fields, Ident, Item, Path, ReturnType, Token, Type,
10    parse::{Parse, ParseStream},
11    parse_macro_input, parse_quote,
12    punctuated::Punctuated,
13};
14
15#[derive(Default, Clone)]
16enum Mode {
17    #[default]
18    Default,
19    Allow(Vec<Path>),
20    Disallow(Vec<Path>),
21}
22
23#[derive(Default)]
24struct StrictTypesArgs {
25    disallow: Vec<Path>,
26    mode: Mode,
27}
28
29impl Parse for StrictTypesArgs {
30    fn parse(input: ParseStream) -> syn::Result<Self> {
31        if input.is_empty() {
32            return Ok(Self::default());
33        }
34
35        let key: Ident = input.parse()?;
36        let content;
37        let _ = syn::parenthesized!(content in input);
38
39        let paths: Punctuated<Path, Token![,]> =
40            content.parse_terminated(Path::parse, Token![,])?;
41        let paths_vec: Vec<Path> = paths.into_iter().collect();
42
43        let mode;
44        let mut disallow = default_disallowed_types();
45        let disallow = match key.to_string().as_str() {
46            "disallow" => {
47                // let new_paths: Vec<Path> = paths_vec
48                //     .iter()
49                //     .filter(|path| !disallow.contains(path))
50                //     .cloned()
51                //     .collect();
52                // mode = Mode::Disallow(new_paths);
53                mode = Mode::Disallow(paths_vec.clone());
54                disallow.extend(paths_vec);
55                disallow
56            }
57            "allow" => {
58                mode = Mode::Allow(paths_vec.clone());
59                disallow.retain(|path| !paths_vec.contains(path));
60                disallow
61            }
62            _ => {
63                return Err(syn::Error::new_spanned(
64                    key,
65                    "expected `disallow(...)` or `allow(...)`",
66                ));
67            }
68        };
69
70        Ok(Self { disallow, mode })
71    }
72}
73
74fn default_disallowed_types() -> Vec<Path> {
75    vec![
76        parse_quote!(u8),
77        parse_quote!(u16),
78        parse_quote!(u32),
79        parse_quote!(u64),
80        parse_quote!(u128),
81        parse_quote!(usize),
82        parse_quote!(i8),
83        parse_quote!(i16),
84        parse_quote!(i32),
85        parse_quote!(i64),
86        parse_quote!(i128),
87        parse_quote!(isize),
88        parse_quote!(f32),
89        parse_quote!(f64),
90        parse_quote!(bool),
91        parse_quote!(char),
92    ]
93}
94
95fn contains_forbidden_type(ty: &Type, disallowed: &[Path]) -> bool {
96    match ty {
97        Type::Path(type_path) => {
98            if disallowed.contains(&type_path.path) {
99                return true;
100            }
101
102            for segment in &type_path.path.segments {
103                if let syn::PathArguments::AngleBracketed(generic_args) = &segment.arguments {
104                    for arg in &generic_args.args {
105                        if let syn::GenericArgument::Type(inner_ty) = arg {
106                            if contains_forbidden_type(inner_ty, disallowed) {
107                                return true;
108                            }
109                        }
110                    }
111                }
112            }
113
114            false
115        }
116
117        Type::Tuple(tuple) => tuple
118            .elems
119            .iter()
120            .any(|elem| contains_forbidden_type(elem, disallowed)),
121
122        Type::Group(group) => contains_forbidden_type(&group.elem, disallowed),
123        Type::Paren(paren) => contains_forbidden_type(&paren.elem, disallowed),
124
125        _ => false, // you can expand this for more complex cases like references, impl traits, etc.
126    }
127}
128
129fn doc_lines(attrs: &[syn::Attribute]) -> Vec<String> {
130    attrs
131        .iter()
132        .filter_map(|attr| {
133            if attr.path().is_ident("doc") {
134                if let Ok(nv) = attr.meta.clone().require_name_value() {
135                    if let syn::Expr::Lit(syn::ExprLit {
136                        lit: syn::Lit::Str(s),
137                        ..
138                    }) = &nv.value
139                    {
140                        return Some(s.value().trim().to_string());
141                    }
142                }
143            }
144            None
145        })
146        .collect()
147}
148
149fn verify_docs(mode: Mode, docs: &[String], input: &Item) -> Vec<syn::Error> {
150    let mut errors = Vec::new();
151
152    if let Mode::Allow(paths) | Mode::Disallow(paths) = &mode {
153        let mut strict_section_found = false;
154        let mut documented_types = Vec::new();
155
156        for line in docs {
157            if line.trim() == "# Strictness" {
158                strict_section_found = true;
159                continue;
160            }
161
162            if strict_section_found {
163                if let Some(rest) = line.trim().strip_prefix("- [") {
164                    if let Some(end_idx) = rest.find(']') {
165                        let type_str = &rest[..end_idx];
166                        documented_types.push(type_str.to_string());
167                    }
168                }
169            }
170        }
171
172        for path in paths {
173            let ty_str = quote!(#path).to_string();
174            if !documented_types.iter().any(|doc| doc == &ty_str) {
175                errors.push(syn::Error::new_spanned(
176                    path,
177                    format!(
178                        "Missing `/// - [{ty_str}] justification` in `/// # Strictness` section"
179                    ),
180                ));
181            }
182        }
183
184        if errors.is_empty() && !strict_section_found {
185            errors.push(syn::Error::new_spanned(
186                input,
187                "Missing `/// # Strictness` section for `allow(...)` or `disallow(...)` override",
188            ));
189        }
190    }
191
192    errors
193}
194
195/// A macro to enforce strict typing on struct and enum fields.
196/// It checks if any field uses a primitive type and generates a
197/// compile-time error if it does. The idea is to encourage the use of
198/// newtype wrappers for primitive types to ensure type safety and
199/// clarity in the codebase.
200///
201/// The motivation behind this macro is to prevent the use of primitive
202/// types directly in structs, which can lead to confusion and bugs.
203/// The primitive types are often too generic and have a too wide range
204/// of values, can be misused in different contexts, and do not
205/// convey the intent of the data being represented, especially meaning
206/// having useful names for the types and intentions behind them.
207///
208/// Also, often, the primitive types are not only checked for the width
209/// of the allowed range of values, but must also contain some values
210/// that are not allowed from within the allowed range. For example,
211/// a `u8` type can be used to represent a percentage, but it can also
212/// be used to represent a count of items, which is a different
213/// concept. In this case, the `u8` type does not convey the intent of
214/// the data being represented, and it is better to use a newtype wrapper
215/// to make the intent clear. There might be at least two "Percentage"
216/// types in the codebase, one is limited to the range of `0-100`, and
217/// another type which can go beyond 100 (but still not less than zero),
218/// to express the surpassing of the 100% mark. Not to mention that
219/// sometimes, in certain contexts, the percentage can be negative
220/// (e.g. when calculating the difference between two values).
221/// This macro is a way to enforce the use of newtype wrappers for
222/// primitive types in structs, which can help to avoid confusion and
223/// bugs in the codebase. It is a compile-time check that will generate
224/// an error if any field in a struct uses a primitive type directly.
225///
226/// # Example usage:
227///
228/// ```rust
229/// use strict_typing::strict_types;
230///
231/// #[repr(transparent)]
232/// struct MyNewTypeWrapper<T>(T);
233///
234/// #[strict_types]
235/// struct MyStruct {
236///     // This will generate a compile-time error
237///     // because `u8` is a primitive type.
238///     // my_field: u8,
239///     // But this not:
240///     my_field: MyNewTypeWrapper<u8>,
241/// }
242/// ```
243///
244/// Yes, this is a very simple macro, but it is intended to be used
245/// as a way to enforce strict typing in the codebase, and to encourage
246/// the use of newtype wrappers for primitive types in structs.
247///
248/// /// # Example with `disallow` which **adds** types to the disallowed
249/// list:
250///
251/// ```rust,ignore,no_run
252/// use strict_typing::strict_types;
253///
254/// #[strict_types(disallow(String))]
255/// struct MyStruct {
256///    // This will generate a compile-time error
257///    // because `String` is now also a forbidden type.
258///    my_field: String,
259/// }
260/// ```
261///
262/// When a type is added to the disallowed list or removed from it,
263/// the macro requires the user to document the reason for
264/// the change in the `/// # Strictness` section of the documentation.
265/// The documentation should be in the form of a list of items,
266/// where each item is a type that is allowed or disallowed, example:
267///
268/// ```rust,ignore,no_run
269/// use strict_typing::strict_types;
270///
271/// /// # Strictness
272/// ///
273/// /// - [String] - this is a disallowed type, because it is too bad.
274/// #[strict_types(disallow(String))]
275/// struct MyStruct {
276///     my_field: String,
277/// }
278/// ```
279///
280/// To remove from the default disallow list, you can use the
281/// `allow` directive:
282/// ```rust,ignore,no_run
283/// use strict_typing::strict_types;
284/// /// # Strictness
285/// ///
286/// /// - [u8] - this is an allowed type, because it is used for
287/// ///   representing a small number of items.
288/// #[strict_types(allow(u8))]
289/// struct MyStruct {
290///     my_field: u8,
291/// }
292/// ```
293///
294/// The macro also supports working directly on the whole `impl` and
295/// `trait` items, analysing the function signatures and their
296/// return types; however, annotating a trait method or an impl method
297/// is yet impossible due to Rust limitations.
298#[proc_macro_attribute]
299pub fn strict_types(attr: TokenStream, item: TokenStream) -> TokenStream {
300    let args = parse_macro_input!(attr as StrictTypesArgs);
301    let item_clone = item.clone();
302    let input = parse_macro_input!(item as Item);
303
304    let disallowed: Vec<Path> = if args.disallow.is_empty() {
305        default_disallowed_types()
306    } else {
307        args.disallow
308    };
309
310    let mut errors = Vec::new();
311
312    let attrs = match &input {
313        Item::Struct(struct_item) => {
314            for field in &struct_item.fields {
315                if let Type::Path(tp) = &field.ty {
316                    if contains_forbidden_type(&field.ty, &disallowed) {
317                        let fname = field
318                            .ident
319                            .as_ref()
320                            .map(|i| i.to_string())
321                            .unwrap_or("<unnamed>".into());
322                        errors.push(syn::Error::new_spanned(
323                            &field.ty,
324                            format!("field `{}` uses disallowed type `{}`", fname, quote!(#tp)),
325                        ));
326                    }
327                }
328            }
329            &struct_item.attrs
330        }
331
332        Item::Enum(enum_item) => {
333            for variant in &enum_item.variants {
334                match &variant.fields {
335                    Fields::Unit => {}
336                    Fields::Named(fields) => {
337                        for field in &fields.named {
338                            if let Type::Path(tp) = &field.ty {
339                                if contains_forbidden_type(&field.ty, &disallowed) {
340                                    errors.push(syn::Error::new_spanned(
341                                        &field.ty,
342                                        format!(
343                                            "variant `{}` has field with disallowed type `{}`",
344                                            variant.ident,
345                                            quote!(#tp)
346                                        ),
347                                    ));
348                                }
349                            }
350                        }
351                    }
352                    Fields::Unnamed(fields) => {
353                        for field in &fields.unnamed {
354                            if let Type::Path(tp) = &field.ty {
355                                if contains_forbidden_type(&field.ty, &disallowed) {
356                                    errors.push(syn::Error::new_spanned(
357                                        &field.ty,
358                                        format!(
359                                            "variant `{}` has field with disallowed type `{}`",
360                                            variant.ident,
361                                            quote!(#tp)
362                                        ),
363                                    ));
364                                }
365                            }
366                        }
367                    }
368                }
369            }
370
371            &enum_item.attrs
372        }
373
374        Item::Fn(fn_item) => {
375            let sig = &fn_item.sig;
376
377            for arg in &sig.inputs {
378                if let syn::FnArg::Typed(pat_type) = arg {
379                    if let Type::Path(tp) = &*pat_type.ty {
380                        if contains_forbidden_type(&pat_type.ty, &disallowed) {
381                            let path = &tp.path;
382                            let arg_str = quote!(#path).to_string();
383                            errors.push(syn::Error::new_spanned(
384                                &pat_type.ty,
385                                format!("function parameter uses disallowed type `{arg_str}`"),
386                            ));
387                        }
388                    }
389                }
390            }
391
392            if let ReturnType::Type(_, ty) = &fn_item.sig.output {
393                if let Type::Path(tp) = ty.as_ref() {
394                    if contains_forbidden_type(ty, &disallowed) {
395                        errors.push(syn::Error::new_spanned(
396                            tp,
397                            format!(
398                                "function return type is disallowed: `{}`",
399                                tp.path.to_token_stream()
400                            ),
401                        ));
402                    }
403                }
404            }
405
406            errors.extend(verify_docs(args.mode, &doc_lines(&fn_item.attrs), &input));
407
408            let diagnostics = errors.into_iter().map(|e| e.to_compile_error());
409            let output = quote! {
410                #fn_item
411                #(#diagnostics)*
412            };
413
414            return output.into();
415        }
416
417        Item::Trait(item_trait) => {
418            for item in &item_trait.items {
419                if let syn::TraitItem::Fn(method) = item {
420                    if let ReturnType::Type(_, ty) = &method.sig.output {
421                        if let Type::Path(tp) = ty.as_ref() {
422                            if contains_forbidden_type(ty, &disallowed) {
423                                errors.push(syn::Error::new_spanned(
424                                    tp,
425                                    format!(
426                                        "trait method return type is disallowed: `{}`",
427                                        tp.path.to_token_stream()
428                                    ),
429                                ));
430                            }
431                        }
432                    }
433
434                    for arg in &method.sig.inputs {
435                        if let syn::FnArg::Typed(pat_type) = arg {
436                            if let Type::Path(tp) = &*pat_type.ty {
437                                if contains_forbidden_type(&pat_type.ty, &disallowed) {
438                                    let path = &tp.path;
439                                    let arg_str = quote!(#path).to_string();
440                                    errors.push(syn::Error::new_spanned(
441                                        &pat_type.ty,
442                                        format!("trait method parameter uses disallowed type `{arg_str}`"),
443                                    ));
444                                }
445                            }
446                        }
447                    }
448                }
449            }
450
451            &item_trait.attrs
452        }
453
454        Item::Impl(item_impl) => {
455            for item in &item_impl.items {
456                if let syn::ImplItem::Fn(method) = item {
457                    if let ReturnType::Type(_, ty) = &method.sig.output {
458                        if let Type::Path(tp) = ty.as_ref() {
459                            if contains_forbidden_type(ty, &disallowed) {
460                                errors.push(syn::Error::new_spanned(
461                                    tp,
462                                    format!(
463                                        "impl method return type is disallowed: `{}`",
464                                        tp.path.to_token_stream()
465                                    ),
466                                ));
467                            }
468                        }
469                    }
470
471                    for arg in &method.sig.inputs {
472                        if let syn::FnArg::Typed(pat_type) = arg {
473                            if let Type::Path(tp) = &*pat_type.ty {
474                                if contains_forbidden_type(&pat_type.ty, &disallowed) {
475                                    let path = &tp.path;
476                                    let arg_str = quote!(#path).to_string();
477                                    errors.push(syn::Error::new_spanned(
478                                        &pat_type.ty,
479                                        format!(
480                                            "impl method parameter uses disallowed type `{arg_str}`"
481                                        ),
482                                    ));
483                                }
484                            }
485                        }
486                    }
487                }
488            }
489
490            &item_impl.attrs
491        }
492
493        _ => {
494            errors.push(syn::Error::new_spanned(
495                &input,
496                "#[strict_types] only works on structs, enums, functions, impls and traits",
497            ));
498
499            let original = proc_macro2::TokenStream::from(item_clone);
500            let diagnostics = errors.into_iter().map(|e| e.to_compile_error());
501
502            return quote! {
503                #original
504                #(#diagnostics)*
505            }
506            .into();
507        }
508    };
509
510    errors.extend(verify_docs(args.mode, &doc_lines(attrs), &input));
511
512    let original = proc_macro2::TokenStream::from(item_clone);
513    let diagnostics = errors.into_iter().map(|e| e.to_compile_error());
514
515    quote! {
516        #original
517        #(#diagnostics)*
518    }
519    .into()
520}
521
522// #[proc_macro_attribute]
523// pub fn strict_types(attr: TokenStream, item: TokenStream) -> TokenStream {
524//     let forbidden = {
525//         let parsed = parse_macro_input!(attr as StrictTypesArgs);
526
527//         if parsed.disallow.is_empty() {
528//             default_forbidden_types()
529//         } else {
530//             parsed.disallow
531//         }
532//     };
533
534//     let input = parse_macro_input!(item as DeriveInput);
535//     let ident = &input.ident;
536
537//     let error_tokens = if let Data::Struct(data_struct) = &input.data {
538//         let mut errors = Vec::new();
539
540//         for field in data_struct.fields.iter() {
541//             if let Type::Path(type_path) = &field.ty {
542//                 if let Some(ident) = type_path.path.get_ident() {
543//                     if forbidden.contains(&type_path.path) {
544//                         let field_name = field
545//                             .ident
546//                             .as_ref()
547//                             .map(|i| i.to_string())
548//                             .unwrap_or("<unnamed>".into());
549
550//                         errors.push(syn::Error::new_spanned(
551//                             &field.ty,
552//                             format!(
553//                                 "field `{field_name}` uses forbidden primitive type `{ty_str}` — use a newtype wrapper"
554//                             ),
555//                         ));
556//                     }
557//                 }
558//             }
559//         }
560
561//         if errors.is_empty() {
562//             quote! {}
563//         } else {
564//             let combined = errors.iter().map(syn::Error::to_compile_error);
565//             quote! { #(#combined)* }
566//         }
567//     } else {
568//         syn::Error::new_spanned(ident, "#[enforce_strict_types] only works on structs")
569//             .to_compile_error()
570//     };
571
572//     let output = quote! {
573//         #input
574//         #error_tokens
575//     };
576
577//     output.into()
578// }