safety_guard/lib.rs
1//! Provides a `#[safety]` attribute for generating a corresponding section in the documentation
2//! and, if provided, checks for a constraint in debug builds.
3
4extern crate proc_macro;
5
6use proc_macro::TokenStream;
7use quote::quote;
8use syn::{
9 fold::Fold,
10 parenthesized,
11 parse::{Parse, ParseBuffer},
12 parse_macro_input,
13 parse_quote,
14 Attribute,
15 Error,
16 Expr,
17 Ident,
18 ItemFn,
19 LitStr,
20 Result,
21 Stmt,
22 Token,
23};
24
25/// Adds a `# Safety` section to the documentation of the function and tests the given constraints
26/// in debug builds.
27///
28/// The attribute has four different forms:
29/// - `#[safety("Description")]`: Only the description is added to the documentation
30/// - `#[safety(assert(constraint), "Description")]`: `constraint` must evaluate to `true`.
31/// - `#[safety(eq(lhs, rhs), "Description")]`: `lhs` and `rhs` needs to be equal
32/// - `#[safety(ne(lhs, rhs), "Description")]`: `lhs` and `rhs` must not be equal
33///
34/// A function with a `#[safety]` attribute must be marked as `unsafe`. Otherwise a compile error
35/// is generated.
36///
37/// If `# Safety` already exists in the documentation, the heading is not added.
38///
39/// # Examples
40///
41/// ```
42/// use safety_guard::safety;
43///
44/// #[safety(assert(lhs.checked_add(rhs).is_some()), "`lhs` + `rhs` must not overflow")]
45/// unsafe fn add_unchecked(lhs: usize, rhs: usize) -> usize {
46/// lhs + rhs
47/// }
48/// ```
49///
50/// generates
51///
52/// ```
53/// /// # Safety
54/// /// - `lhs` + `rhs` must not overflow
55/// unsafe fn add_unchecked(lhs: usize, rhs: usize) -> usize {
56/// debug_assert!(lhs.checked_add(rhs).is_some(), "`lhs` + `rhs` must not overflow");
57/// lhs + rhs
58/// }
59/// ```
60///
61/// Without a constraint, only the documentation is added:
62///
63/// ```
64/// use safety_guard::safety;
65///
66/// #[safety("`hash` must correspond to the `string`s hash value")]
67/// unsafe fn add_string_with_hash(string: &str, hash: u64) -> u64 {
68/// # unimplemented!()
69/// // ...
70/// }
71/// ```
72///
73/// generates
74///
75/// ```
76/// /// # Safety
77/// /// - `hash` must correspond to the `string`s hash value
78/// unsafe fn add_string_with_hash(string: &str, hash: u64) -> u64 {
79/// # unimplemented!()
80/// // ...
81/// }
82/// ```
83///
84/// It is also possible to use multiple `#[safety]` attributes:
85///
86/// ```
87/// # use core::alloc::Layout;
88/// use safety_guard::safety;
89///
90/// #[safety(eq(ptr as usize % layout.align(), 0), "`layout` must *fit* the `ptr`")]
91/// #[safety(assert(new_size > 0), "`new_size` must be greater than zero")]
92/// #[safety(
93/// "`new_size`, when rounded up to the nearest multiple of `layout.align()`, must not \
94/// overflow (i.e., the rounded value must be less than `usize::MAX`)."
95/// )]
96/// unsafe fn realloc(
97/// ptr: *mut u8,
98/// layout: Layout,
99/// new_size: usize,
100/// ) -> *mut u8 {
101/// # unimplemented!()
102/// // ...
103/// }
104/// ```
105///
106/// However, the documentation is generated in reversed order:
107///
108///
109/// ```
110/// # use core::alloc::Layout;
111/// /// # Safety
112/// /// - `new_size`, when rounded up to the nearest multiple of `layout.align()`, must not
113/// /// overflow (i.e., the rounded value must be less than `usize::MAX`).
114/// /// - `new_size` must be greater than zero
115/// /// - `layout` must *fit* the `ptr`
116/// unsafe fn realloc(
117/// ptr: *mut u8,
118/// layout: Layout,
119/// new_size: usize,
120/// ) -> *mut u8 {
121/// debug_assert!(new_size > 0, "`new_size` must be greater than zero");
122/// debug_assert_eq!(ptr as usize % layout.align(), 0, "`layout` must *fit* the `ptr`");
123/// # unimplemented!()
124/// // ...
125/// }
126/// ```
127#[proc_macro_attribute]
128pub fn safety(args: TokenStream, input: TokenStream) -> TokenStream {
129 let input = parse_macro_input!(input);
130 let mut args = parse_macro_input!(args as Args);
131 let output = args.fold_item_fn(input);
132 TokenStream::from(quote!(#output))
133}
134
135struct Binary(Expr, Expr);
136
137impl Parse for Binary {
138 fn parse(input: &ParseBuffer<'_>) -> Result<Self> {
139 let lhs = input.parse()?;
140 input.parse::<Token![,]>()?;
141 let rhs = input.parse()?;
142 Ok(Binary(lhs, rhs))
143 }
144}
145
146enum Condition {
147 Eq(Binary),
148 Ne(Binary),
149 Assert(Expr),
150}
151
152#[allow(clippy::large_enum_variant)]
153enum Args {
154 Literal(LitStr),
155 Condition(Condition, LitStr),
156}
157
158impl Condition {
159 fn stmt(&self, text: &LitStr) -> Stmt {
160 match self {
161 Condition::Eq(Binary(lhs, rhs)) => parse_quote!(debug_assert_eq!(#lhs, #rhs, #text);),
162 Condition::Ne(Binary(lhs, rhs)) => parse_quote!(debug_assert_ne!(#lhs, #rhs, #text);),
163 Condition::Assert(expr) => parse_quote!(debug_assert!(#expr, #text);),
164 }
165 }
166}
167
168impl Parse for Condition {
169 fn parse(input: &ParseBuffer<'_>) -> Result<Self> {
170 let ident: Ident = input.parse()?;
171 let parens;
172 parenthesized!(parens in input);
173 match ident {
174 ref i if i == "eq" => Ok(Condition::Eq(parens.parse()?)),
175 ref i if i == "ne" => Ok(Condition::Ne(parens.parse()?)),
176 ref i if i == "assert" => Ok(Condition::Assert(parens.parse()?)),
177 _ => {
178 let message = format!(
179 "expected string literal, `assert()`, `eq()`, or `ne()`. Found `{}`",
180 ident
181 );
182 Err(Error::new_spanned(ident, message))
183 }
184 }
185 }
186}
187
188impl Args {
189 fn attribute(&self) -> Attribute {
190 let text = match self {
191 Args::Literal(text) | Args::Condition(_, text) => text,
192 };
193 let text = format!("- {}", text.value());
194 parse_quote!(#[doc = #text])
195 }
196
197 fn stmt(&self) -> Option<Stmt> {
198 if let Args::Condition(condition, text) = self {
199 Some(condition.stmt(text))
200 } else {
201 None
202 }
203 }
204}
205
206impl Parse for Args {
207 fn parse(input: &ParseBuffer<'_>) -> Result<Self> {
208 if input.peek(LitStr) {
209 Ok(Args::Literal(input.parse()?))
210 } else {
211 let condition = input.parse()?;
212 input.parse::<Token![,]>()?;
213 let literal = input.parse()?;
214 Ok(Args::Condition(condition, literal))
215 }
216 }
217}
218
219impl Fold for Args {
220 fn fold_item_fn(&mut self, mut item: ItemFn) -> ItemFn {
221 let safety_attr = parse_quote!(#[doc = " # Safety"]);
222 let mut heading_inserted = None;
223 for (i, attr) in item.attrs.iter().enumerate() {
224 if *attr == safety_attr {
225 heading_inserted = Some(i + 1);
226 break;
227 }
228 }
229
230 if item.unsafety.is_none() {
231 let error = Error::new_spanned(
232 &item.ident,
233 "A function with safety attributes has to be marked as `unsafe`",
234 )
235 .to_compile_error();
236 item.block.stmts.insert(0, parse_quote!(#error;));
237 };
238
239 if let Some(i) = heading_inserted {
240 item.attrs.insert(i, self.attribute());
241 } else {
242 item.attrs.push(safety_attr);
243 item.attrs.push(self.attribute());
244 }
245
246 if let Some(stmt) = self.stmt() {
247 item.block.stmts.insert(0, parse_quote!(#stmt));
248 };
249
250 item
251 }
252}