strict_typing/lib.rs
1//! A macro to enforce strict typing on the fields in Rust.
2//!
3//! Please refer to the documentation of the macro for more details:
4//! [`macro@strict_types`].
5
6use proc_macro::TokenStream;
7use quote::{ToTokens, quote};
8use syn::{
9 Fields, Ident, Item, Path, ReturnType, Token, Type,
10 parse::{Parse, ParseStream},
11 parse_macro_input, parse_quote,
12 punctuated::Punctuated,
13};
14
15#[derive(Default, Clone)]
16enum Mode {
17 #[default]
18 Default,
19 Allow(Vec<Path>),
20 Disallow(Vec<Path>),
21}
22
23#[derive(Default)]
24struct StrictTypesArgs {
25 disallow: Vec<Path>,
26 mode: Mode,
27}
28
29impl Parse for StrictTypesArgs {
30 fn parse(input: ParseStream) -> syn::Result<Self> {
31 if input.is_empty() {
32 return Ok(Self::default());
33 }
34
35 let key: Ident = input.parse()?;
36 let content;
37 let _ = syn::parenthesized!(content in input);
38
39 let paths: Punctuated<Path, Token![,]> =
40 content.parse_terminated(Path::parse, Token![,])?;
41 let paths_vec: Vec<Path> = paths.into_iter().collect();
42
43 let mode;
44 let mut disallow = default_disallowed_types();
45 let disallow = match key.to_string().as_str() {
46 "disallow" => {
47 // let new_paths: Vec<Path> = paths_vec
48 // .iter()
49 // .filter(|path| !disallow.contains(path))
50 // .cloned()
51 // .collect();
52 // mode = Mode::Disallow(new_paths);
53 mode = Mode::Disallow(paths_vec.clone());
54 disallow.extend(paths_vec);
55 disallow
56 }
57 "allow" => {
58 mode = Mode::Allow(paths_vec.clone());
59 disallow.retain(|path| !paths_vec.contains(path));
60 disallow
61 }
62 _ => {
63 return Err(syn::Error::new_spanned(
64 key,
65 "expected `disallow(...)` or `allow(...)`",
66 ));
67 }
68 };
69
70 Ok(Self { disallow, mode })
71 }
72}
73
74fn default_disallowed_types() -> Vec<Path> {
75 vec![
76 parse_quote!(u8),
77 parse_quote!(u16),
78 parse_quote!(u32),
79 parse_quote!(u64),
80 parse_quote!(u128),
81 parse_quote!(usize),
82 parse_quote!(i8),
83 parse_quote!(i16),
84 parse_quote!(i32),
85 parse_quote!(i64),
86 parse_quote!(i128),
87 parse_quote!(isize),
88 parse_quote!(f32),
89 parse_quote!(f64),
90 parse_quote!(bool),
91 parse_quote!(char),
92 ]
93}
94
95fn contains_forbidden_type(ty: &Type, disallowed: &[Path]) -> bool {
96 match ty {
97 Type::Path(type_path) => {
98 if disallowed.contains(&type_path.path) {
99 return true;
100 }
101
102 for segment in &type_path.path.segments {
103 if let syn::PathArguments::AngleBracketed(generic_args) = &segment.arguments {
104 for arg in &generic_args.args {
105 if let syn::GenericArgument::Type(inner_ty) = arg {
106 if contains_forbidden_type(inner_ty, disallowed) {
107 return true;
108 }
109 }
110 }
111 }
112 }
113
114 false
115 }
116
117 Type::Tuple(tuple) => tuple
118 .elems
119 .iter()
120 .any(|elem| contains_forbidden_type(elem, disallowed)),
121
122 Type::Group(group) => contains_forbidden_type(&group.elem, disallowed),
123 Type::Paren(paren) => contains_forbidden_type(&paren.elem, disallowed),
124
125 _ => false, // you can expand this for more complex cases like references, impl traits, etc.
126 }
127}
128
129fn doc_lines(attrs: &[syn::Attribute]) -> Vec<String> {
130 attrs
131 .iter()
132 .filter_map(|attr| {
133 if attr.path().is_ident("doc") {
134 if let Ok(nv) = attr.meta.clone().require_name_value() {
135 if let syn::Expr::Lit(syn::ExprLit {
136 lit: syn::Lit::Str(s),
137 ..
138 }) = &nv.value
139 {
140 return Some(s.value().trim().to_string());
141 }
142 }
143 }
144 None
145 })
146 .collect()
147}
148
149fn verify_docs(mode: Mode, docs: &[String], input: &Item) -> Vec<syn::Error> {
150 let mut errors = Vec::new();
151
152 if let Mode::Allow(paths) | Mode::Disallow(paths) = &mode {
153 let mut strict_section_found = false;
154 let mut documented_types = Vec::new();
155
156 for line in docs {
157 if line.trim() == "# Strictness" {
158 strict_section_found = true;
159 continue;
160 }
161
162 if strict_section_found {
163 if let Some(rest) = line.trim().strip_prefix("- [") {
164 if let Some(end_idx) = rest.find(']') {
165 let type_str = &rest[..end_idx];
166 documented_types.push(type_str.to_string());
167 }
168 }
169 }
170 }
171
172 for path in paths {
173 let ty_str = quote!(#path).to_string();
174 if !documented_types.iter().any(|doc| doc == &ty_str) {
175 errors.push(syn::Error::new_spanned(
176 path,
177 format!(
178 "Missing `/// - [{ty_str}] justification` in `/// # Strictness` section"
179 ),
180 ));
181 }
182 }
183
184 if errors.is_empty() && !strict_section_found {
185 errors.push(syn::Error::new_spanned(
186 input,
187 "Missing `/// # Strictness` section for `allow(...)` or `disallow(...)` override",
188 ));
189 }
190 }
191
192 errors
193}
194
195/// A macro to enforce strict typing on struct and enum fields.
196/// It checks if any field uses a primitive type and generates a
197/// compile-time error if it does. The idea is to encourage the use of
198/// newtype wrappers for primitive types to ensure type safety and
199/// clarity in the codebase.
200///
201/// The motivation behind this macro is to prevent the use of primitive
202/// types directly in structs, which can lead to confusion and bugs.
203/// The primitive types are often too generic and have a too wide range
204/// of values, can be misused in different contexts, and do not
205/// convey the intent of the data being represented, especially meaning
206/// having useful names for the types and intentions behind them.
207///
208/// Also, often, the primitive types are not only checked for the width
209/// of the allowed range of values, but must also contain some values
210/// that are not allowed from within the allowed range. For example,
211/// a `u8` type can be used to represent a percentage, but it can also
212/// be used to represent a count of items, which is a different
213/// concept. In this case, the `u8` type does not convey the intent of
214/// the data being represented, and it is better to use a newtype wrapper
215/// to make the intent clear. There might be at least two "Percentage"
216/// types in the codebase, one is limited to the range of `0-100`, and
217/// another type which can go beyond 100 (but still not less than zero),
218/// to express the surpassing of the 100% mark. Not to mention that
219/// sometimes, in certain contexts, the percentage can be negative
220/// (e.g. when calculating the difference between two values).
221/// This macro is a way to enforce the use of newtype wrappers for
222/// primitive types in structs, which can help to avoid confusion and
223/// bugs in the codebase. It is a compile-time check that will generate
224/// an error if any field in a struct uses a primitive type directly.
225///
226/// # Example usage:
227///
228/// ```rust
229/// use strict_typing::strict_types;
230///
231/// #[repr(transparent)]
232/// struct MyNewTypeWrapper<T>(T);
233///
234/// #[strict_types]
235/// struct MyStruct {
236/// // This will generate a compile-time error
237/// // because `u8` is a primitive type.
238/// // my_field: u8,
239/// // But this not:
240/// my_field: MyNewTypeWrapper<u8>,
241/// }
242/// ```
243///
244/// Yes, this is a very simple macro, but it is intended to be used
245/// as a way to enforce strict typing in the codebase, and to encourage
246/// the use of newtype wrappers for primitive types in structs.
247///
248/// /// # Example with `disallow` which **adds** types to the disallowed
249/// list:
250///
251/// ```rust,ignore,no_run
252/// use strict_typing::strict_types;
253///
254/// #[strict_types(disallow(String))]
255/// struct MyStruct {
256/// // This will generate a compile-time error
257/// // because `String` is now also a forbidden type.
258/// my_field: String,
259/// }
260/// ```
261///
262/// When a type is added to the disallowed list or removed from it,
263/// the macro requires the user to document the reason for
264/// the change in the `/// # Strictness` section of the documentation.
265/// The documentation should be in the form of a list of items,
266/// where each item is a type that is allowed or disallowed, example:
267///
268/// ```rust,ignore,no_run
269/// use strict_typing::strict_types;
270///
271/// /// # Strictness
272/// ///
273/// /// - [String] - this is a disallowed type, because it is too bad.
274/// #[strict_types(disallow(String))]
275/// struct MyStruct {
276/// my_field: String,
277/// }
278/// ```
279///
280/// To remove from the default disallow list, you can use the
281/// `allow` directive:
282/// ```rust,ignore,no_run
283/// use strict_typing::strict_types;
284/// /// # Strictness
285/// ///
286/// /// - [u8] - this is an allowed type, because it is used for
287/// /// representing a small number of items.
288/// #[strict_types(allow(u8))]
289/// struct MyStruct {
290/// my_field: u8,
291/// }
292/// ```
293///
294/// The macro also supports working directly on the whole `impl` and
295/// `trait` items, analysing the function signatures and their
296/// return types; however, annotating a trait method or an impl method
297/// is yet impossible due to Rust limitations.
298#[proc_macro_attribute]
299pub fn strict_types(attr: TokenStream, item: TokenStream) -> TokenStream {
300 let args = parse_macro_input!(attr as StrictTypesArgs);
301 let item_clone = item.clone();
302 let input = parse_macro_input!(item as Item);
303
304 let disallowed: Vec<Path> = if args.disallow.is_empty() {
305 default_disallowed_types()
306 } else {
307 args.disallow
308 };
309
310 let mut errors = Vec::new();
311
312 let attrs = match &input {
313 Item::Struct(struct_item) => {
314 for field in &struct_item.fields {
315 if let Type::Path(tp) = &field.ty {
316 if contains_forbidden_type(&field.ty, &disallowed) {
317 let fname = field
318 .ident
319 .as_ref()
320 .map(|i| i.to_string())
321 .unwrap_or("<unnamed>".into());
322 errors.push(syn::Error::new_spanned(
323 &field.ty,
324 format!("field `{}` uses disallowed type `{}`", fname, quote!(#tp)),
325 ));
326 }
327 }
328 }
329 &struct_item.attrs
330 }
331
332 Item::Enum(enum_item) => {
333 for variant in &enum_item.variants {
334 match &variant.fields {
335 Fields::Unit => {}
336 Fields::Named(fields) => {
337 for field in &fields.named {
338 if let Type::Path(tp) = &field.ty {
339 if contains_forbidden_type(&field.ty, &disallowed) {
340 errors.push(syn::Error::new_spanned(
341 &field.ty,
342 format!(
343 "variant `{}` has field with disallowed type `{}`",
344 variant.ident,
345 quote!(#tp)
346 ),
347 ));
348 }
349 }
350 }
351 }
352 Fields::Unnamed(fields) => {
353 for field in &fields.unnamed {
354 if let Type::Path(tp) = &field.ty {
355 if contains_forbidden_type(&field.ty, &disallowed) {
356 errors.push(syn::Error::new_spanned(
357 &field.ty,
358 format!(
359 "variant `{}` has field with disallowed type `{}`",
360 variant.ident,
361 quote!(#tp)
362 ),
363 ));
364 }
365 }
366 }
367 }
368 }
369 }
370
371 &enum_item.attrs
372 }
373
374 Item::Fn(fn_item) => {
375 let sig = &fn_item.sig;
376
377 for arg in &sig.inputs {
378 if let syn::FnArg::Typed(pat_type) = arg {
379 if let Type::Path(tp) = &*pat_type.ty {
380 if contains_forbidden_type(&pat_type.ty, &disallowed) {
381 let path = &tp.path;
382 let arg_str = quote!(#path).to_string();
383 errors.push(syn::Error::new_spanned(
384 &pat_type.ty,
385 format!("function parameter uses disallowed type `{arg_str}`"),
386 ));
387 }
388 }
389 }
390 }
391
392 if let ReturnType::Type(_, ty) = &fn_item.sig.output {
393 if let Type::Path(tp) = ty.as_ref() {
394 if contains_forbidden_type(ty, &disallowed) {
395 errors.push(syn::Error::new_spanned(
396 tp,
397 format!(
398 "function return type is disallowed: `{}`",
399 tp.path.to_token_stream()
400 ),
401 ));
402 }
403 }
404 }
405
406 errors.extend(verify_docs(args.mode, &doc_lines(&fn_item.attrs), &input));
407
408 let diagnostics = errors.into_iter().map(|e| e.to_compile_error());
409 let output = quote! {
410 #fn_item
411 #(#diagnostics)*
412 };
413
414 return output.into();
415 }
416
417 Item::Trait(item_trait) => {
418 for item in &item_trait.items {
419 if let syn::TraitItem::Fn(method) = item {
420 if let ReturnType::Type(_, ty) = &method.sig.output {
421 if let Type::Path(tp) = ty.as_ref() {
422 if contains_forbidden_type(ty, &disallowed) {
423 errors.push(syn::Error::new_spanned(
424 tp,
425 format!(
426 "trait method return type is disallowed: `{}`",
427 tp.path.to_token_stream()
428 ),
429 ));
430 }
431 }
432 }
433
434 for arg in &method.sig.inputs {
435 if let syn::FnArg::Typed(pat_type) = arg {
436 if let Type::Path(tp) = &*pat_type.ty {
437 if contains_forbidden_type(&pat_type.ty, &disallowed) {
438 let path = &tp.path;
439 let arg_str = quote!(#path).to_string();
440 errors.push(syn::Error::new_spanned(
441 &pat_type.ty,
442 format!("trait method parameter uses disallowed type `{arg_str}`"),
443 ));
444 }
445 }
446 }
447 }
448 }
449 }
450
451 &item_trait.attrs
452 }
453
454 Item::Impl(item_impl) => {
455 for item in &item_impl.items {
456 if let syn::ImplItem::Fn(method) = item {
457 if let ReturnType::Type(_, ty) = &method.sig.output {
458 if let Type::Path(tp) = ty.as_ref() {
459 if contains_forbidden_type(ty, &disallowed) {
460 errors.push(syn::Error::new_spanned(
461 tp,
462 format!(
463 "impl method return type is disallowed: `{}`",
464 tp.path.to_token_stream()
465 ),
466 ));
467 }
468 }
469 }
470
471 for arg in &method.sig.inputs {
472 if let syn::FnArg::Typed(pat_type) = arg {
473 if let Type::Path(tp) = &*pat_type.ty {
474 if contains_forbidden_type(&pat_type.ty, &disallowed) {
475 let path = &tp.path;
476 let arg_str = quote!(#path).to_string();
477 errors.push(syn::Error::new_spanned(
478 &pat_type.ty,
479 format!(
480 "impl method parameter uses disallowed type `{arg_str}`"
481 ),
482 ));
483 }
484 }
485 }
486 }
487 }
488 }
489
490 &item_impl.attrs
491 }
492
493 _ => {
494 errors.push(syn::Error::new_spanned(
495 &input,
496 "#[strict_types] only works on structs, enums, functions, impls and traits",
497 ));
498
499 let original = proc_macro2::TokenStream::from(item_clone);
500 let diagnostics = errors.into_iter().map(|e| e.to_compile_error());
501
502 return quote! {
503 #original
504 #(#diagnostics)*
505 }
506 .into();
507 }
508 };
509
510 errors.extend(verify_docs(args.mode, &doc_lines(attrs), &input));
511
512 let original = proc_macro2::TokenStream::from(item_clone);
513 let diagnostics = errors.into_iter().map(|e| e.to_compile_error());
514
515 quote! {
516 #original
517 #(#diagnostics)*
518 }
519 .into()
520}
521
522// #[proc_macro_attribute]
523// pub fn strict_types(attr: TokenStream, item: TokenStream) -> TokenStream {
524// let forbidden = {
525// let parsed = parse_macro_input!(attr as StrictTypesArgs);
526
527// if parsed.disallow.is_empty() {
528// default_forbidden_types()
529// } else {
530// parsed.disallow
531// }
532// };
533
534// let input = parse_macro_input!(item as DeriveInput);
535// let ident = &input.ident;
536
537// let error_tokens = if let Data::Struct(data_struct) = &input.data {
538// let mut errors = Vec::new();
539
540// for field in data_struct.fields.iter() {
541// if let Type::Path(type_path) = &field.ty {
542// if let Some(ident) = type_path.path.get_ident() {
543// if forbidden.contains(&type_path.path) {
544// let field_name = field
545// .ident
546// .as_ref()
547// .map(|i| i.to_string())
548// .unwrap_or("<unnamed>".into());
549
550// errors.push(syn::Error::new_spanned(
551// &field.ty,
552// format!(
553// "field `{field_name}` uses forbidden primitive type `{ty_str}` — use a newtype wrapper"
554// ),
555// ));
556// }
557// }
558// }
559// }
560
561// if errors.is_empty() {
562// quote! {}
563// } else {
564// let combined = errors.iter().map(syn::Error::to_compile_error);
565// quote! { #(#combined)* }
566// }
567// } else {
568// syn::Error::new_spanned(ident, "#[enforce_strict_types] only works on structs")
569// .to_compile_error()
570// };
571
572// let output = quote! {
573// #input
574// #error_tokens
575// };
576
577// output.into()
578// }