1extern crate proc_macro;
2use proc_macro2::{Ident, Span, TokenStream, TokenTree};
3use quote::quote;
4
5#[proc_macro_derive(BitFlags)]
6pub fn derive_sawp_flags(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
7 let ast: syn::DeriveInput = syn::parse(input).unwrap();
8 impl_sawp_flags(&ast).into()
9}
10
11fn impl_sawp_flags(ast: &syn::DeriveInput) -> TokenStream {
12 let name = &ast.ident;
13 let repr = if let Some(repr) = get_repr(ast) {
14 repr
15 } else {
16 panic!("BitFlags enum must have a `repr` attribute with numeric argument");
17 };
18 match &ast.data {
19 syn::Data::Enum(data) => impl_enum_traits(name, &repr, data),
20 _ => panic!("Bitflags is only supported on enums"),
21 }
22}
23
24fn get_repr(ast: &syn::DeriveInput) -> Option<Ident> {
25 ast.attrs.iter().find_map(|attr| {
26 if let Some(path) = attr.path.get_ident() {
27 if path == "repr" {
28 if let Some(tree) = attr.tokens.clone().into_iter().next() {
29 match tree {
30 TokenTree::Group(group) => {
31 if let Some(ident) = group.stream().into_iter().next() {
32 match ident {
33 TokenTree::Ident(ident) => Some(ident),
34 _ => None,
35 }
36 } else {
37 None
38 }
39 }
40 _ => None,
41 }
42 } else {
43 None
44 }
45 } else {
46 None
47 }
48 } else {
49 None
50 }
51 })
52}
53
54fn impl_enum_traits(name: &syn::Ident, repr: &Ident, data: &syn::DataEnum) -> TokenStream {
55 let list_items = data.variants.iter().map(|variant| &variant.ident);
57 let list_all = list_items.clone();
58 let display_items = list_items.clone();
59 let from_str_items = list_items.clone();
60 let from_str_items_str = list_items.clone().map(|variant| {
61 Ident::new(
62 variant.to_string().to_lowercase().as_str(),
63 Span::call_site(),
64 )
65 });
66
67 quote! {
68 impl Flag for #name {
69 type Primitive = #repr;
70
71 const ITEMS: &'static [Self] = &[#(#name::#list_items),*];
72
73 fn bits(self) -> Self::Primitive {
74 self as #repr
75 }
76
77 fn none() -> Flags<Self> {
78 Flags::from_bits(0)
79 }
80
81 fn all() -> Flags<Self> {
82 Flags::from_bits(#(#name::#list_all as Self::Primitive)|*)
83 }
84 }
85
86 impl std::ops::BitOr for #name {
87 type Output = Flags<#name>;
88
89 fn bitor(self, other: Self) -> Self::Output {
90 Flags::from_bits(self.bits() | other.bits())
91 }
92 }
93
94 impl std::ops::BitAnd for #name {
95 type Output = Flags<#name>;
96
97 fn bitand(self, other: Self) -> Self::Output {
98 Flags::from_bits(self.bits() & other.bits())
99 }
100 }
101
102 impl std::ops::BitXor for #name {
103 type Output = Flags<#name>;
104
105 fn bitxor(self, other: Self) -> Self::Output {
106 Flags::from_bits(self.bits() ^ other.bits())
107 }
108 }
109
110 impl std::ops::Not for #name {
111 type Output = Flags<#name>;
112
113 fn not(self) -> Self::Output {
114 Flags::from_bits(!self.bits())
115 }
116 }
117
118 impl std::fmt::Display for #name {
119 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
120 let empty = self.bits() == Self::none().bits();
121 let mut first = true;
122 #(
123 if self.bits() & #name::#display_items.bits() == #name::#display_items.bits() {
124 write!(f, "{}{}", if first { "" } else { " | " }, stringify!(#display_items))?;
125 first = false;
126
127 if empty {
128 return Ok(());
129 }
130 }
131 )*
132
133 if empty {
134 write!(f, "NONE")?;
135 }
136
137 Ok(())
138 }
139 }
140
141 impl std::str::FromStr for #name {
142 type Err = ();
143 fn from_str(val: &str) -> std::result::Result<#name, Self::Err> {
144 match val.to_lowercase().as_str() {
145 #(stringify!(#from_str_items_str) => Ok(#name::#from_str_items),)*
146 _ => Err(()),
147 }
148 }
149 }
150
151 impl PartialEq<Flags<Self>> for #name {
152 fn eq(&self, other: &Flags<Self>) -> bool {
153 self.bits() == other.bits()
154 }
155 }
156
157 impl std::fmt::Binary for #name {
158 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
159 std::fmt::Binary::fmt(&self.bits(), f)
160 }
161 }
162 }
163}
164
165#[cfg(test)]
170mod tests {
171 use super::*;
172
173 #[test]
174 fn test_macro_enum() {
175 let input = r#"
176 #[repr(u8)]
177 enum Test {
178 A = 0b0000,
179 B = 0b0001,
180 C = 0b0010,
181 D = 0b0100,
182 }
183 "#;
184 let parsed: syn::DeriveInput = syn::parse_str(input).unwrap();
185 impl_sawp_flags(&parsed);
186 }
187
188 #[test]
189 #[should_panic(expected = "BitFlags enum must have a `repr` attribute")]
190 fn test_macro_repr_panic() {
191 let input = r#"
192 enum Test {
193 A = 0b0000,
194 B = 0b0001,
195 C = 0b0010,
196 D = 0b0100,
197 }
198 "#;
199 let parsed: syn::DeriveInput = syn::parse_str(input).unwrap();
200 impl_sawp_flags(&parsed);
201 }
202
203 #[test]
204 #[should_panic(expected = "Bitflags is only supported on enums")]
205 fn test_macro_not_enum_panic() {
206 let input = r#"
207 #[repr(u8)]
208 struct Test {
209 }
210 "#;
211 let parsed: syn::DeriveInput = syn::parse_str(input).unwrap();
212 impl_sawp_flags(&parsed);
213 }
214}