syn_utils/
lib.rs

1use syn::{
2  spanned::Spanned, GenericArgument, ParenthesizedGenericArguments, PathArguments, PathSegment,
3  Type,
4};
5#[macro_use]
6mod macros;
7
8mod expr_trait;
9use std::{fmt::Display, str::FromStr};
10
11pub use expr_trait::*;
12use proc_macro2::{Span, TokenStream as TokenStream2};
13use quote::{quote, ToTokens};
14use syn::{
15  parse::Parse,
16  punctuated::{Iter, IterMut, Punctuated},
17  Attribute, Expr, ExprCall, ExprClosure, Field, Fields, Ident, Lit, LitInt, LitStr, Meta, Path,
18  Token, Variant,
19};
20
21pub trait GenericArgumentExt {
22  fn as_type(&self) -> syn::Result<&Type>;
23  fn as_type_mut(&mut self) -> syn::Result<&mut Type>;
24}
25
26impl GenericArgumentExt for GenericArgument {
27  fn as_type(&self) -> syn::Result<&Type> {
28    if let GenericArgument::Type(ty) = self {
29      Ok(ty)
30    } else {
31      bail!(self, "Expected this generic argument to be a type");
32    }
33  }
34
35  fn as_type_mut(&mut self) -> syn::Result<&mut Type> {
36    if let GenericArgument::Type(ty) = self {
37      Ok(ty)
38    } else {
39      let reborrow = &self;
40      bail!(reborrow, "Expected this generic argument to be a type");
41    }
42  }
43}
44
45pub trait PathSegmentExt {
46  fn generic_args(&self) -> Option<Iter<'_, GenericArgument>>;
47  fn generic_args_mut(&mut self) -> Option<IterMut<'_, GenericArgument>>;
48  fn parenthesized_args(&self) -> Option<&ParenthesizedGenericArguments>;
49  fn parenthesized_args_mut(&mut self) -> Option<&mut ParenthesizedGenericArguments>;
50}
51
52impl PathSegmentExt for PathSegment {
53  fn parenthesized_args_mut(&mut self) -> Option<&mut ParenthesizedGenericArguments> {
54    match &mut self.arguments {
55      PathArguments::None => None,
56      PathArguments::AngleBracketed(_) => None,
57      PathArguments::Parenthesized(par) => Some(par),
58    }
59  }
60
61  fn parenthesized_args(&self) -> Option<&ParenthesizedGenericArguments> {
62    match &self.arguments {
63      PathArguments::None => None,
64      PathArguments::AngleBracketed(_) => None,
65      PathArguments::Parenthesized(par) => Some(par),
66    }
67  }
68
69  fn generic_args_mut(&mut self) -> Option<IterMut<'_, GenericArgument>> {
70    match &mut self.arguments {
71      PathArguments::None => None,
72      PathArguments::AngleBracketed(ab) => Some(ab.args.iter_mut()),
73      PathArguments::Parenthesized(_) => None,
74    }
75  }
76
77  fn generic_args(&self) -> Option<Iter<'_, GenericArgument>> {
78    match &self.arguments {
79      PathArguments::None => None,
80      PathArguments::AngleBracketed(ab) => Some(ab.args.iter()),
81      PathArguments::Parenthesized(_) => None,
82    }
83  }
84}
85
86pub trait PathExt {
87  fn last_segment(&self) -> &PathSegment;
88  fn last_segment_mut(&mut self) -> &mut PathSegment;
89}
90
91impl PathExt for Path {
92  fn last_segment(&self) -> &PathSegment {
93    self.segments.last().unwrap()
94  }
95
96  fn last_segment_mut(&mut self) -> &mut PathSegment {
97    self.segments.last_mut().unwrap()
98  }
99}
100
101pub trait TypeExt {
102  fn as_path(&self) -> syn::Result<&Path>;
103  fn as_path_mut(&mut self) -> syn::Result<&mut Path>;
104}
105
106impl TypeExt for Type {
107  fn as_path(&self) -> syn::Result<&Path> {
108    if let Type::Path(path) = self {
109      Ok(&path.path)
110    } else {
111      bail!(self, "Expected a type path");
112    }
113  }
114
115  fn as_path_mut(&mut self) -> syn::Result<&mut Path> {
116    if let Type::Path(path) = self {
117      Ok(&mut path.path)
118    } else {
119      bail!(self, "Expected a type path");
120    }
121  }
122}
123
124pub trait EnumVariant {
125  fn is_single_tuple(&self) -> bool;
126  fn typ(&self) -> syn::Result<&Type>;
127  fn path(&self) -> syn::Result<&Path>;
128  fn type_mut(&mut self) -> syn::Result<&mut Type>;
129  fn path_mut(&mut self) -> syn::Result<&mut Path>;
130  fn is_unit(&self) -> bool;
131  fn named_fields(&self) -> syn::Result<Iter<'_, Field>>;
132  fn named_fields_mut(&mut self) -> syn::Result<IterMut<'_, Field>>;
133  fn unnamed_fields(&self) -> syn::Result<Iter<'_, Field>>;
134  fn unnamed_fields_mut(&mut self) -> syn::Result<IterMut<'_, Field>>;
135}
136
137impl EnumVariant for Variant {
138  fn is_single_tuple(&self) -> bool {
139    if let Fields::Unnamed(fields) = &self.fields && fields.unnamed.len() == 1 {
140      true
141    } else {
142      false
143    }
144  }
145
146  fn path_mut(&mut self) -> syn::Result<&mut Path> {
147    let span = self.span();
148
149    if let Fields::Unnamed(fields) = &mut self.fields && fields.unnamed.len() == 1 {
150      Ok(fields.unnamed.last_mut().unwrap().ty.as_path_mut()?)
151    } else {
152      bail_with_span!(span, "Expected this variant to have a single unnamed field");
153    }
154  }
155
156  fn type_mut(&mut self) -> syn::Result<&mut Type> {
157    let span = self.span();
158
159    if let Fields::Unnamed(fields) = &mut self.fields && fields.unnamed.len() == 1 {
160      Ok(&mut fields.unnamed.last_mut().unwrap().ty)
161    } else {
162      bail_with_span!(span, "Expected this variant to have a single unnamed field");
163    }
164  }
165
166  fn path(&self) -> syn::Result<&Path> {
167    if let Fields::Unnamed(fields) = &self.fields && fields.unnamed.len() == 1 {
168      Ok(fields.unnamed.last().unwrap().ty.as_path()?)
169    } else {
170      bail!(self, "Expected this variant to have a single unnamed field");
171    }
172  }
173
174  fn typ(&self) -> syn::Result<&Type> {
175    if let Fields::Unnamed(fields) = &self.fields && fields.unnamed.len() == 1 {
176      Ok(&fields.unnamed.last().unwrap().ty)
177    } else {
178      bail!(self, "Expected this variant to have a single unnamed field");
179    }
180  }
181
182  fn is_unit(&self) -> bool {
183    matches!(self.fields, Fields::Unit)
184  }
185
186  fn named_fields(&self) -> syn::Result<Iter<'_, Field>> {
187    if let Fields::Named(fields) = &self.fields {
188      Ok(fields.named.iter())
189    } else {
190      bail!(self, "Expected this variant to have named fields");
191    }
192  }
193
194  fn named_fields_mut(&mut self) -> syn::Result<IterMut<'_, Field>> {
195    let span = self.span();
196
197    if let Fields::Named(fields) = &mut self.fields {
198      Ok(fields.named.iter_mut())
199    } else {
200      bail_with_span!(span, "Expected this variant to have named fields");
201    }
202  }
203
204  fn unnamed_fields(&self) -> syn::Result<Iter<'_, Field>> {
205    if let Fields::Unnamed(fields) = &self.fields {
206      Ok(fields.unnamed.iter())
207    } else {
208      bail!(self, "Expected this variant to have unnamed fields");
209    }
210  }
211
212  fn unnamed_fields_mut(&mut self) -> syn::Result<IterMut<'_, Field>> {
213    let span = self.span();
214
215    if let Fields::Unnamed(fields) = &mut self.fields {
216      Ok(fields.unnamed.iter_mut())
217    } else {
218      bail_with_span!(span, "Expected this variant to have unnamed fields");
219    }
220  }
221}
222
223pub trait AsNamedField {
224  fn ident(&self) -> syn::Result<&Ident>;
225}
226
227impl AsNamedField for Field {
228  fn ident(&self) -> syn::Result<&Ident> {
229    self
230      .ident
231      .as_ref()
232      .ok_or(error!(self, "Expected a named field"))
233  }
234}
235
236pub fn filter_attributes(attrs: &[Attribute], allowed_idents: &[&str]) -> syn::Result<Vec<Meta>> {
237  let mut metas = Vec::new();
238
239  for attr in attrs {
240    let attr_ident = if let Some(ident) = attr.path().get_ident() {
241      ident.to_string()
242    } else {
243      continue;
244    };
245
246    if !allowed_idents.contains(&attr_ident.as_str()) {
247      continue;
248    }
249
250    let parser = Punctuated::<Meta, Token![,]>::parse_terminated;
251    let args = attr.parse_args_with(parser)?;
252
253    metas.extend(args);
254  }
255
256  Ok(metas)
257}
258
259pub fn new_ident(name: &str) -> Ident {
260  Ident::new(name, Span::call_site())
261}
262
263#[derive(Default, Clone, Debug)]
264pub struct ControlFlow {
265  pub dummy: Option<TokenStream2>,
266}
267
268impl ControlFlow {
269  pub fn new() -> Self {
270    Self { dummy: None }
271  }
272
273  pub fn with_custom_dummy(dummy: &TokenStream2) -> Self {
274    Self {
275      dummy: Some(dummy.to_token_stream()),
276    }
277  }
278}
279
280pub trait MacroResult: Sized {
281  type Output;
282
283  fn unwrap_or_dummy(self, dummy: TokenStream2) -> Result<Self::Output, TokenStream2>;
284
285  fn unwrap_or_unimplemented(self) -> Result<Self::Output, TokenStream2> {
286    self.unwrap_or_dummy(quote! { unimplemented!() })
287  }
288}
289
290impl<T> MacroResult for syn::Result<T> {
291  type Output = T;
292
293  fn unwrap_or_dummy(self, dummy: TokenStream2) -> Result<Self::Output, TokenStream2> {
294    match self {
295      Ok(o) => Ok(o),
296      Err(e) => {
297        let error = e.into_compile_error();
298
299        Err(quote! {
300          #error #dummy
301        })
302      }
303    }
304  }
305}
306
307#[derive(Debug, Clone)]
308pub enum CallOrClosure {
309  Call(ExprCall),
310  Closure(ExprClosure),
311}
312
313impl ToTokens for CallOrClosure {
314  fn to_tokens(&self, tokens: &mut TokenStream2) {
315    let output = match self {
316      CallOrClosure::Call(call) => call.to_token_stream(),
317      CallOrClosure::Closure(expr_closure) => expr_closure.to_token_stream(),
318    };
319
320    tokens.extend(output);
321  }
322}
323
324pub struct PunctuatedItems<T: Parse> {
325  pub inner: Vec<T>,
326}
327
328impl<T: Parse> Parse for PunctuatedItems<T> {
329  fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
330    let inner = Punctuated::<T, Token![,]>::parse_terminated(input)?
331      .into_iter()
332      .collect();
333
334    Ok(Self { inner })
335  }
336}
337
338pub struct StringList {
339  pub list: Vec<String>,
340}
341
342impl Parse for StringList {
343  fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
344    let items = Punctuated::<LitStr, Token![,]>::parse_terminated(input)?;
345
346    let list: Vec<String> = items
347      .into_iter()
348      .map(|lit_str| lit_str.value())
349      .collect();
350
351    Ok(Self { list })
352  }
353}
354
355pub struct NumList {
356  pub list: Vec<i32>,
357}
358
359impl Parse for NumList {
360  fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
361    let items = Punctuated::<LitInt, Token![,]>::parse_terminated(input)?;
362
363    let mut list: Vec<i32> = Vec::new();
364
365    for item in items {
366      list.push(item.base10_parse()?);
367    }
368
369    Ok(Self { list })
370  }
371}