tuple_iter/
lib.rs

1//! Generate iterator types for tuples of items implementing the same trait.
2
3use proc_macro2::TokenStream;
4use quote::{format_ident, quote, ToTokens};
5use syn::parse::{Parse, ParseStream};
6use syn::punctuated::Punctuated;
7use syn::Result;
8
9/// Creates an iterator for the shared reference to a tuple as trait objects.
10///
11/// ```
12/// trait Foo {
13///     fn bar(&self) -> i32;
14/// }
15///
16/// struct A(i32);
17/// impl Foo for A {
18///     fn bar(&self) -> i32 { self.0 }
19/// }
20///
21/// struct B(i32);
22/// impl Foo for B {
23///     fn bar(&self) -> i32 { self.0 }
24/// }
25///
26/// let my_tuple = (A(1), B(2), A(3));
27/// let iter = tuple_iter::iter!(my_tuple, (Foo + Send + Sync + 'static; 3));
28/// let vec: Vec<i32> = iter.map(|foo| foo.bar()).collect();
29/// assert_eq!(vec, vec![1, 2, 3]);
30/// ```
31#[proc_macro]
32pub fn iter(ts: proc_macro::TokenStream) -> proc_macro::TokenStream {
33    iter_impl(false, ts.into())
34        .unwrap_or_else(|err| err.to_compile_error())
35        .into()
36}
37
38/// Creates an iterator for the mutable reference to a tuple as trait objects.
39///
40/// ```
41/// trait Foo {
42///     fn bar(&mut self) -> i32;
43/// }
44///
45/// struct A(i32);
46/// impl Foo for A {
47///     fn bar(&mut self) -> i32 { self.0 }
48/// }
49///
50/// struct B(i32);
51/// impl Foo for B {
52///     fn bar(&mut self) -> i32 { self.0 }
53/// }
54///
55/// let mut my_tuple = (A(1), B(2), A(3));
56/// let iter = tuple_iter::iter_mut!(my_tuple, (Foo + Send + Sync + 'static; 3));
57/// let vec: Vec<i32> = iter.map(|foo| foo.bar()).collect();
58/// assert_eq!(vec, vec![1, 2, 3]);
59/// ```
60#[proc_macro]
61pub fn iter_mut(ts: proc_macro::TokenStream) -> proc_macro::TokenStream {
62    iter_impl(true, ts.into())
63        .unwrap_or_else(|err| err.to_compile_error())
64        .into()
65}
66
67fn iter_impl(is_mut: bool, ts: TokenStream) -> Result<TokenStream> {
68    let (is_mut, ptr) = if is_mut {
69        (quote!(mut), quote!(*mut))
70    } else {
71        (quote!(), quote!(*const))
72    };
73
74    let input = syn::parse2::<Input>(ts)?;
75    let Input {
76        expr,
77        _comma,
78        _parentheses,
79        bounds,
80        _semicolon,
81        count,
82    } = &input;
83    let count = *count;
84
85    let ordinal = 0..count;
86    let ty_params: Vec<_> = ordinal.clone().map(|i| format_ident!("Ty{}", i)).collect();
87    let ordinal = ordinal.map(|i| proc_macro2::Literal::usize_unsuffixed(i).to_token_stream());
88
89    let code = quote! {
90        {
91            struct __TupleIter<T>(T, usize);
92
93            impl<'t, #(#ty_params),*> Iterator for __TupleIter<&'t #is_mut (#(#ty_params),*)>
94                where #(#ty_params: #bounds),* {
95                    type Item = &'t #is_mut (dyn #bounds);
96
97                    fn next(&mut self) -> Option<Self::Item> {
98                        match self.1 {
99                            #(
100                                #ordinal => {
101                                    self.1 += 1;
102                                    let ptr = &#is_mut (self.0).#ordinal as #ptr #ty_params;
103                                    let ptr: &#is_mut #ty_params = unsafe { &#is_mut *ptr };
104                                    let dyn_ptr: &#is_mut (dyn #bounds) = ptr;
105                                    Some(dyn_ptr)
106                                },
107                            )*
108                            _ => None,
109                        }
110                    }
111                }
112
113            __TupleIter(&#is_mut #expr, 0)
114        }
115    };
116    Ok(code)
117}
118
119struct Input {
120    expr: syn::Expr,
121    _comma: syn::Token![,],
122    _parentheses: syn::token::Paren,
123    bounds: Punctuated<syn::TypeParamBound, syn::Token![+]>,
124    _semicolon: syn::Token![;],
125    count: usize,
126}
127
128impl Parse for Input {
129    fn parse(input: ParseStream) -> Result<Self> {
130        let expr = input.parse::<syn::Expr>()?;
131        let comma = input.parse::<syn::Token![,]>()?;
132        let inner;
133        let parentheses = syn::parenthesized!(inner in input);
134        let bounds = Punctuated::parse_separated_nonempty(&inner)?;
135        let semicolon = inner.parse::<syn::Token![;]>()?;
136        let count = inner.parse::<syn::LitInt>()?;
137        let count = count.base10_parse()?;
138
139        Ok(Self {
140            expr,
141            _comma: comma,
142            _parentheses: parentheses,
143            bounds,
144            _semicolon: semicolon,
145            count,
146        })
147    }
148}