1use proc_macro2::TokenStream;
2use quote::ToTokens;
3use syn::{Attribute, Meta, Path, Result, parse::Parse};
4
5mod splitter;
6use splitter::CommaSplitter;
7
8#[derive(Clone)]
10pub enum ExpandedAttr {
11 Direct(Attribute),
13 Nested {
17 attr: Meta,
18 condition: TokenStream,
19 original: Box<Attribute>,
21 },
22}
23
24impl ExpandedAttr {
25 pub fn parse_args<T: Parse>(&self) -> Result<T> {
26 match self {
27 ExpandedAttr::Direct(attr) => attr.parse_args(),
28 ExpandedAttr::Nested { attr, .. } => {
29 match attr {
30 Meta::List(list) => list.parse_args(),
31 Meta::NameValue(nv) => {
32 syn::parse2(nv.value.to_token_stream())
35 },
36 Meta::Path(_) => Err(syn::Error::new_spanned(
37 attr,
38 "Attribute path has no arguments",
39 )),
40 }
41 },
42 }
43 }
44
45 pub fn path(&self) -> &Path {
46 match self {
47 ExpandedAttr::Direct(attr) => attr.path(),
48 ExpandedAttr::Nested { attr, .. } => attr.path(),
49 }
50 }
51
52 pub fn is_ident(&self, ident: &str) -> bool {
53 self.path().is_ident(ident)
54 }
55}
56
57pub trait AttributeHelpers {
59 fn flattened_attributes(&self) -> Vec<ExpandedAttr>;
64
65 fn find_attribute(&self, ident: &str) -> Vec<ExpandedAttr>;
70}
71
72impl AttributeHelpers for Vec<Attribute> {
73 fn flattened_attributes(&self) -> Vec<ExpandedAttr> {
74 let mut results = Vec::new();
75 for attr in self {
76 flatten_attr_recursive(attr, &mut results, None);
77 }
78 results
79 }
80
81 fn find_attribute(&self, ident: &str) -> Vec<ExpandedAttr> {
82 self.flattened_attributes()
83 .into_iter()
84 .filter(|attr| attr.is_ident(ident))
85 .collect()
86 }
87}
88
89fn flatten_attr_recursive(
90 attr: &Attribute,
91 results: &mut Vec<ExpandedAttr>,
92 _inherited_condition: Option<&TokenStream>,
93) {
94 if attr.path().is_ident("cfg_attr") {
95 let tokens = match &attr.meta {
96 Meta::List(list) => &list.tokens,
97 _ => return,
98 };
99
100 let mut splitter = CommaSplitter::new(tokens.clone());
101
102 if let Some(condition_stream) = splitter.next() {
103 for inner_tokens in splitter {
108 if let Ok(nested_meta) = syn::parse2::<Meta>(inner_tokens.clone()) {
109 if nested_meta.path().is_ident("cfg_attr") {
110 let synthetic_attr = Attribute {
111 pound_token: Default::default(),
112 style: syn::AttrStyle::Outer,
113 bracket_token: Default::default(),
114 meta: nested_meta,
115 };
116 flatten_attr_recursive(&synthetic_attr, results, Some(&condition_stream));
117 } else {
118 results.push(ExpandedAttr::Nested {
119 attr: nested_meta,
120 condition: condition_stream.clone(),
121 original: Box::new(attr.clone()),
122 });
123 }
124 }
125 }
126 }
127 } else {
128 results.push(ExpandedAttr::Direct(attr.clone()));
129 }
130}
131
132#[cfg(test)]
133mod tests {
134 use super::*;
135 use syn::parse_quote;
136
137 #[test]
138 fn test_flatten_basic() {
139 let attrs: Vec<Attribute> = vec![parse_quote!(#[foo]), parse_quote!(#[bar(x)])];
140 let flattened = attrs.flattened_attributes();
141 assert_eq!(flattened.len(), 2);
142 assert!(flattened[0].is_ident("foo"));
143 assert!(flattened[1].is_ident("bar"));
144 }
145
146 #[test]
147 fn test_flatten_cfg_attr() {
148 let attrs: Vec<Attribute> = vec![parse_quote!(#[cfg_attr(all(), foo, bar(y))])];
149 let flattened = attrs.flattened_attributes();
150 assert_eq!(flattened.len(), 2);
151 assert!(flattened[0].is_ident("foo"));
152 assert!(flattened[1].is_ident("bar"));
153
154 match &flattened[0] {
155 ExpandedAttr::Nested { condition, .. } => {
156 assert_eq!(condition.to_string(), "all ()");
157 },
158 _ => panic!("Expected Nested"),
159 }
160 }
161
162 #[test]
163 fn test_flatten_recursive_cfg_attr() {
164 let attrs: Vec<Attribute> = vec![parse_quote!(#[cfg_attr(a, cfg_attr(b, foo))])];
165 let flattened = attrs.flattened_attributes();
166 assert_eq!(flattened.len(), 1);
167 assert!(flattened[0].is_ident("foo"));
168
169 match &flattened[0] {
170 ExpandedAttr::Nested { condition, .. } => {
171 assert_eq!(condition.to_string(), "b");
172 },
173 _ => panic!("Expected Nested"),
174 }
175 }
176
177 #[test]
178 fn test_find_attribute() {
179 let attrs: Vec<Attribute> = vec![
180 parse_quote!(#[foo]),
181 parse_quote!(#[cfg_attr(all(), foo)]),
182 parse_quote!(#[bar]),
183 parse_quote!(#[cfg_attr(any(), bar)]),
184 ];
185 let foos = attrs.find_attribute("foo");
186 assert_eq!(foos.len(), 2);
187
188 let bars = attrs.find_attribute("bar");
189 assert_eq!(bars.len(), 2);
190 }
191
192 #[test]
193 fn test_cfg_attr_multiple_attrs() {
194 let attrs: Vec<Attribute> = vec![parse_quote!(#[cfg_attr(my_cond, a, b(val), c)])];
195 let flattened = attrs.flattened_attributes();
196 assert_eq!(flattened.len(), 3);
197 assert!(flattened[0].is_ident("a"));
198 assert!(flattened[1].is_ident("b"));
199 assert!(flattened[2].is_ident("c"));
200
201 for attr in flattened {
203 if let ExpandedAttr::Nested { condition, .. } = attr {
204 assert_eq!(condition.to_string(), "my_cond");
205 } else {
206 panic!("Expected Nested layout");
207 }
208 }
209 }
210
211 #[test]
212 fn test_complex_condition() {
213 let attrs: Vec<Attribute> =
214 vec![parse_quote!(#[cfg_attr(any(target_os="linux", feature="flag"), foo)])];
215 let flattened = attrs.flattened_attributes();
216 assert_eq!(flattened.len(), 1);
217 if let ExpandedAttr::Nested { condition, .. } = &flattened[0] {
218 let s = condition.to_string();
220 assert!(s.contains("any"));
221 assert!(s.contains("target_os"));
222 assert!(s.contains("linux"));
223 assert!(s.contains("feature"));
224 } else {
225 panic!("Expected Nested");
226 }
227 }
228
229 #[test]
230 fn test_deep_mixed_nesting() {
231 let attrs: Vec<Attribute> =
233 vec![parse_quote!(#[cfg_attr(cond_a, cfg_attr(cond_b, x, y), z)])];
234 let flattened = attrs.flattened_attributes();
235 assert_eq!(flattened.len(), 3);
236
237 let z = flattened.iter().find(|a| a.is_ident("z")).unwrap();
238 let x = flattened.iter().find(|a| a.is_ident("x")).unwrap();
239 let y = flattened.iter().find(|a| a.is_ident("y")).unwrap();
240
241 if let ExpandedAttr::Nested { condition, .. } = z {
242 assert_eq!(condition.to_string(), "cond_a");
243 }
244
245 if let ExpandedAttr::Nested { condition, .. } = x {
246 assert_eq!(condition.to_string(), "cond_b");
247 }
248
249 if let ExpandedAttr::Nested { condition, .. } = y {
250 assert_eq!(condition.to_string(), "cond_b");
251 }
252 }
253
254 #[test]
255 fn test_parse_args_variants() {
256 use syn::LitInt;
257
258 let attr1: Attribute = parse_quote!(#[foo(1)]);
260 let exp1 = ExpandedAttr::Direct(attr1);
261 assert!(exp1.parse_args::<LitInt>().is_ok());
262
263 let attr2: Attribute = parse_quote!(#[cfg_attr(c, foo(1))]);
265 let flattened = vec![attr2].flattened_attributes();
266 let exp2 = &flattened[0];
267 assert!(exp2.parse_args::<LitInt>().is_ok());
268
269 let attr3: Attribute = parse_quote!(#[cfg_attr(c, foo = "bar")]);
273 let flattened3 = vec![attr3].flattened_attributes();
274 let exp3 = &flattened3[0];
275 assert!(exp3.parse_args::<syn::LitStr>().is_ok());
277 }
278}