safety_guard/
lib.rs

1//! Provides a `#[safety]` attribute for generating a corresponding section in the documentation
2//! and, if provided, checks for a constraint in debug builds.
3
4extern crate proc_macro;
5
6use proc_macro::TokenStream;
7use quote::quote;
8use syn::{
9	fold::Fold,
10	parenthesized,
11	parse::{Parse, ParseBuffer},
12	parse_macro_input,
13	parse_quote,
14	Attribute,
15	Error,
16	Expr,
17	Ident,
18	ItemFn,
19	LitStr,
20	Result,
21	Stmt,
22	Token,
23};
24
25/// Adds a `# Safety` section to the documentation of the function and tests the given constraints
26/// in debug builds.
27///
28/// The attribute has four different forms:
29/// - `#[safety("Description")]`: Only the description is added to the documentation
30/// - `#[safety(assert(constraint), "Description")]`: `constraint` must evaluate to `true`.
31/// - `#[safety(eq(lhs, rhs), "Description")]`: `lhs` and `rhs` needs to be equal
32/// - `#[safety(ne(lhs, rhs), "Description")]`: `lhs` and `rhs` must not be equal
33///
34/// A function with a `#[safety]` attribute must be marked as `unsafe`. Otherwise a compile error
35/// is generated.
36///
37/// If `# Safety` already exists in the documentation, the heading is not added.
38///
39/// # Examples
40///
41/// ```
42/// use safety_guard::safety;
43///
44/// #[safety(assert(lhs.checked_add(rhs).is_some()), "`lhs` + `rhs` must not overflow")]
45/// unsafe fn add_unchecked(lhs: usize, rhs: usize) -> usize {
46/// 	lhs + rhs
47/// }
48/// ```
49///
50/// generates
51///
52/// ```
53/// /// # Safety
54/// /// - `lhs` + `rhs` must not overflow
55/// unsafe fn add_unchecked(lhs: usize, rhs: usize) -> usize {
56/// 	debug_assert!(lhs.checked_add(rhs).is_some(), "`lhs` + `rhs` must not overflow");
57/// 	lhs + rhs
58/// }
59/// ```
60///
61/// Without a constraint, only the documentation is added:
62///
63/// ```
64/// use safety_guard::safety;
65///
66/// #[safety("`hash` must correspond to the `string`s hash value")]
67/// unsafe fn add_string_with_hash(string: &str, hash: u64) -> u64 {
68/// 	# unimplemented!()
69/// 	// ...
70/// }
71/// ```
72///
73/// generates
74///
75/// ```
76/// /// # Safety
77/// /// - `hash` must correspond to the `string`s hash value
78/// unsafe fn add_string_with_hash(string: &str, hash: u64) -> u64 {
79/// 	# unimplemented!()
80/// 	// ...
81/// }
82/// ```
83///
84/// It is also possible to use multiple `#[safety]` attributes:
85///
86/// ```
87/// # use core::alloc::Layout;
88/// use safety_guard::safety;
89///
90///	#[safety(eq(ptr as usize % layout.align(), 0), "`layout` must *fit* the `ptr`")]
91///	#[safety(assert(new_size > 0), "`new_size` must be greater than zero")]
92///	#[safety(
93///		"`new_size`, when rounded up to the nearest multiple of `layout.align()`, must not \
94///		 overflow (i.e., the rounded value must be less than `usize::MAX`)."
95///	)]
96///	unsafe fn realloc(
97///		ptr: *mut u8,
98///		layout: Layout,
99///		new_size: usize,
100///	) -> *mut u8 {
101/// 	# unimplemented!()
102/// 	// ...
103/// }
104/// ```
105///
106/// However, the documentation is generated in reversed order:
107///
108///
109/// ```
110/// # use core::alloc::Layout;
111/// /// # Safety
112/// /// - `new_size`, when rounded up to the nearest multiple of `layout.align()`, must not
113/// ///   overflow (i.e., the rounded value must be less than `usize::MAX`).
114/// /// - `new_size` must be greater than zero
115/// /// - `layout` must *fit* the `ptr`
116///	unsafe fn realloc(
117///		ptr: *mut u8,
118///		layout: Layout,
119///		new_size: usize,
120///	) -> *mut u8 {
121/// 	debug_assert!(new_size > 0, "`new_size` must be greater than zero");
122/// 	debug_assert_eq!(ptr as usize % layout.align(), 0, "`layout` must *fit* the `ptr`");
123/// 	# unimplemented!()
124/// 	// ...
125/// }
126/// ```
127#[proc_macro_attribute]
128pub fn safety(args: TokenStream, input: TokenStream) -> TokenStream {
129	let input = parse_macro_input!(input);
130	let mut args = parse_macro_input!(args as Args);
131	let output = args.fold_item_fn(input);
132	TokenStream::from(quote!(#output))
133}
134
135struct Binary(Expr, Expr);
136
137impl Parse for Binary {
138	fn parse(input: &ParseBuffer<'_>) -> Result<Self> {
139		let lhs = input.parse()?;
140		input.parse::<Token![,]>()?;
141		let rhs = input.parse()?;
142		Ok(Binary(lhs, rhs))
143	}
144}
145
146enum Condition {
147	Eq(Binary),
148	Ne(Binary),
149	Assert(Expr),
150}
151
152#[allow(clippy::large_enum_variant)]
153enum Args {
154	Literal(LitStr),
155	Condition(Condition, LitStr),
156}
157
158impl Condition {
159	fn stmt(&self, text: &LitStr) -> Stmt {
160		match self {
161			Condition::Eq(Binary(lhs, rhs)) => parse_quote!(debug_assert_eq!(#lhs, #rhs, #text);),
162			Condition::Ne(Binary(lhs, rhs)) => parse_quote!(debug_assert_ne!(#lhs, #rhs, #text);),
163			Condition::Assert(expr) => parse_quote!(debug_assert!(#expr, #text);),
164		}
165	}
166}
167
168impl Parse for Condition {
169	fn parse(input: &ParseBuffer<'_>) -> Result<Self> {
170		let ident: Ident = input.parse()?;
171		let parens;
172		parenthesized!(parens in input);
173		match ident {
174			ref i if i == "eq" => Ok(Condition::Eq(parens.parse()?)),
175			ref i if i == "ne" => Ok(Condition::Ne(parens.parse()?)),
176			ref i if i == "assert" => Ok(Condition::Assert(parens.parse()?)),
177			_ => {
178				let message = format!(
179					"expected string literal, `assert()`, `eq()`, or `ne()`. Found `{}`",
180					ident
181				);
182				Err(Error::new_spanned(ident, message))
183			}
184		}
185	}
186}
187
188impl Args {
189	fn attribute(&self) -> Attribute {
190		let text = match self {
191			Args::Literal(text) | Args::Condition(_, text) => text,
192		};
193		let text = format!("- {}", text.value());
194		parse_quote!(#[doc = #text])
195	}
196
197	fn stmt(&self) -> Option<Stmt> {
198		if let Args::Condition(condition, text) = self {
199			Some(condition.stmt(text))
200		} else {
201			None
202		}
203	}
204}
205
206impl Parse for Args {
207	fn parse(input: &ParseBuffer<'_>) -> Result<Self> {
208		if input.peek(LitStr) {
209			Ok(Args::Literal(input.parse()?))
210		} else {
211			let condition = input.parse()?;
212			input.parse::<Token![,]>()?;
213			let literal = input.parse()?;
214			Ok(Args::Condition(condition, literal))
215		}
216	}
217}
218
219impl Fold for Args {
220	fn fold_item_fn(&mut self, mut item: ItemFn) -> ItemFn {
221		let safety_attr = parse_quote!(#[doc = " # Safety"]);
222		let mut heading_inserted = None;
223		for (i, attr) in item.attrs.iter().enumerate() {
224			if *attr == safety_attr {
225				heading_inserted = Some(i + 1);
226				break;
227			}
228		}
229
230		if item.unsafety.is_none() {
231			let error = Error::new_spanned(
232				&item.ident,
233				"A function with safety attributes has to be marked as `unsafe`",
234			)
235			.to_compile_error();
236			item.block.stmts.insert(0, parse_quote!(#error;));
237		};
238
239		if let Some(i) = heading_inserted {
240			item.attrs.insert(i, self.attribute());
241		} else {
242			item.attrs.push(safety_attr);
243			item.attrs.push(self.attribute());
244		}
245
246		if let Some(stmt) = self.stmt() {
247			item.block.stmts.insert(0, parse_quote!(#stmt));
248		};
249
250		item
251	}
252}