shiv_macro_impl/
system_param.rs

1use proc_macro2::TokenStream;
2use quote::quote;
3use syn::{
4    parse_quote, punctuated::Punctuated, token::Comma, Data, DeriveInput, GenericParam, Generics,
5    Ident, Index, Path, Type,
6};
7
8pub fn derive_system_param(input: DeriveInput, shiv: Path) -> proc_macro2::TokenStream {
9    validate_lifetimes(&input.generics);
10
11    let fields = fields(&input.data);
12    let field_idents = field_idents(&input.data);
13
14    let state_generics = state_generics(&input.generics, &shiv);
15    let fetch_generics = fetch_generics(&input.generics);
16    let read_only_generics = read_only_generics(&input.generics, &shiv);
17
18    let (state_impl_generics, state_ty_generics, state_where_clause) =
19        state_generics.split_for_impl();
20    let (fetch_impl_generics, _, _) = fetch_generics.split_for_impl();
21    let (_, _, read_only_where_clause) = read_only_generics.split_for_impl();
22
23    let marker_generics = marker_generics(&input.generics);
24    let fetch_ty_generics = fetch_ty_generics(&input.generics, &fields, &shiv);
25
26    let indices = (0..fields.len()).map(|i| Index::from(i));
27
28    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
29
30    let vis = input.vis;
31    let name = input.ident;
32
33    quote! {
34        const _: () = {
35            #[automatically_derived]
36            impl #impl_generics #shiv::system::SystemParam for #name #ty_generics #where_clause {
37                type Fetch = FetchState<#fetch_ty_generics>;
38            }
39
40            #vis struct FetchState #state_ty_generics #state_where_clause {
41                state: __TSystemParamState,
42                marker: ::std::marker::PhantomData<fn() -> (#marker_generics)>,
43            }
44
45            #[automatically_derived]
46            unsafe impl #state_impl_generics #shiv::system::ReadOnlySystemParamFetch for
47                FetchState #state_ty_generics #read_only_where_clause
48            {
49            }
50
51            #[automatically_derived]
52            unsafe impl #state_impl_generics #shiv::system::SystemParamState for FetchState
53                #state_ty_generics #state_where_clause
54            {
55                #[inline]
56                fn init(
57                    world: &mut #shiv::world::World,
58                    meta: &mut #shiv::system::SystemMeta,
59                ) -> Self {
60                    Self {
61                        state: __TSystemParamState::init(world, meta),
62                        marker: ::std::marker::PhantomData,
63                    }
64                }
65
66                #[inline]
67                fn apply(&mut self, world: &mut #shiv::world::World) {
68                    self.state.apply(world);
69                }
70            }
71
72            #[automatically_derived]
73            impl #fetch_impl_generics #shiv::system::SystemParamFetch<'w, 's> for
74                FetchState<#fetch_ty_generics> #state_where_clause
75            {
76                type Item = #name #ty_generics;
77
78                #[inline]
79                #[allow(dead_code)]
80                unsafe fn get_param(
81                    &'s mut self,
82                    meta: &#shiv::system::SystemMeta,
83                    world: &'w #shiv::world::World,
84                    change_ticks: ::std::primitive::u32,
85                ) -> Self::Item {
86                    let param = #shiv::system::SystemParamFetch::get_param(
87                        &mut self.state,
88                        meta,
89                        world,
90                        change_ticks
91                    );
92
93                    #name {#(#field_idents: param.#indices,)*}
94                }
95            }
96        };
97    }
98}
99
100fn validate_lifetimes(generics: &Generics) {
101    for lifetime in generics.lifetimes() {
102        let ident = &lifetime.lifetime.ident;
103
104        if !(ident == "w" || ident == "s") {
105            panic!(
106                "Invalid lifetime: {}, only valid lifetimes are 'w and 's",
107                ident
108            );
109        }
110    }
111}
112
113fn has_lifetime(generics: &Generics, lifetime: &str) -> bool {
114    for lt in generics.lifetimes() {
115        if lt.lifetime.ident == lifetime {
116            return true;
117        }
118    }
119
120    false
121}
122
123fn fetch_generics(generics: &Generics) -> Generics {
124    let mut generics = generics.clone();
125
126    if !has_lifetime(&generics, "w") {
127        generics.params.push(parse_quote!('w));
128    }
129
130    if !has_lifetime(&generics, "s") {
131        generics.params.push(parse_quote!('s));
132    }
133
134    generics
135}
136
137fn state_generics(generics: &Generics, shiv: &Path) -> Generics {
138    let mut generics = generics.clone();
139
140    generics.params = generics
141        .params
142        .clone()
143        .into_pairs()
144        .filter(|param| match param.value() {
145            syn::GenericParam::Lifetime(_) => false,
146            _ => true,
147        })
148        .collect();
149
150    generics.params.push(parse_quote!(
151        __TSystemParamState: #shiv::system::SystemParamState
152    ));
153
154    generics.make_where_clause().predicates.push(parse_quote!(
155        Self: ::std::marker::Send + ::std::marker::Sync + 'static
156    ));
157
158    generics
159}
160
161fn read_only_generics(generics: &Generics, shiv: &Path) -> Generics {
162    let mut generics = generics.clone();
163
164    let where_clause = generics.make_where_clause();
165    where_clause.predicates.push(parse_quote!(
166        __TSystemParamState: #shiv::system::ReadOnlySystemParamFetch
167    ));
168    where_clause.predicates.push(parse_quote!(
169        Self: for<'w, 's> #shiv::system::SystemParamFetch<'w, 's>
170    ));
171
172    generics
173}
174
175fn marker_generics(generics: &Generics) -> Punctuated<TokenStream, Comma> {
176    let mut marker_generics = Punctuated::<TokenStream, Comma>::new();
177    for generic in generics.params.iter() {
178        if let GenericParam::Type(ty) = generic {
179            let ident = &ty.ident;
180            marker_generics.push(parse_quote!(#ident));
181        }
182    }
183
184    marker_generics
185}
186
187fn fetch_ty_generics(
188    generics: &Generics,
189    fields: &[Type],
190    shiv: &Path,
191) -> Punctuated<TokenStream, Comma> {
192    let mut fetch_ty_generics = Punctuated::<TokenStream, Comma>::new();
193    for generic in generics.params.iter() {
194        if let GenericParam::Type(ty) = generic {
195            let ident = &ty.ident;
196            fetch_ty_generics.push(parse_quote!(#ident));
197        }
198    }
199
200    fetch_ty_generics.push(quote!((#(<#fields as #shiv::system::SystemParam>::Fetch,)*)));
201
202    fetch_ty_generics
203}
204
205fn fields(data: &Data) -> Vec<Type> {
206    match data {
207        Data::Struct(s) => match &s.fields {
208            syn::Fields::Named(fields) => {
209                fields.named.iter().map(|field| field.ty.clone()).collect()
210            }
211            syn::Fields::Unnamed(_) => unimplemented!("Unnamed fields are not supported"),
212            syn::Fields::Unit => Vec::new(),
213        },
214        _ => unimplemented!("Only structs are supported"),
215    }
216}
217
218fn field_idents(data: &Data) -> Vec<Ident> {
219    match data {
220        Data::Struct(s) => match &s.fields {
221            syn::Fields::Named(fields) => fields
222                .named
223                .iter()
224                .map(|field| field.ident.clone().unwrap())
225                .collect(),
226            syn::Fields::Unnamed(_) => unimplemented!("Unnamed fields are not supported"),
227            syn::Fields::Unit => Vec::new(),
228        },
229        _ => unimplemented!("Only structs are supported"),
230    }
231}