shiv_macro_impl/
system_param.rs1use proc_macro2::TokenStream;
2use quote::quote;
3use syn::{
4 parse_quote, punctuated::Punctuated, token::Comma, Data, DeriveInput, GenericParam, Generics,
5 Ident, Index, Path, Type,
6};
7
8pub fn derive_system_param(input: DeriveInput, shiv: Path) -> proc_macro2::TokenStream {
9 validate_lifetimes(&input.generics);
10
11 let fields = fields(&input.data);
12 let field_idents = field_idents(&input.data);
13
14 let state_generics = state_generics(&input.generics, &shiv);
15 let fetch_generics = fetch_generics(&input.generics);
16 let read_only_generics = read_only_generics(&input.generics, &shiv);
17
18 let (state_impl_generics, state_ty_generics, state_where_clause) =
19 state_generics.split_for_impl();
20 let (fetch_impl_generics, _, _) = fetch_generics.split_for_impl();
21 let (_, _, read_only_where_clause) = read_only_generics.split_for_impl();
22
23 let marker_generics = marker_generics(&input.generics);
24 let fetch_ty_generics = fetch_ty_generics(&input.generics, &fields, &shiv);
25
26 let indices = (0..fields.len()).map(|i| Index::from(i));
27
28 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
29
30 let vis = input.vis;
31 let name = input.ident;
32
33 quote! {
34 const _: () = {
35 #[automatically_derived]
36 impl #impl_generics #shiv::system::SystemParam for #name #ty_generics #where_clause {
37 type Fetch = FetchState<#fetch_ty_generics>;
38 }
39
40 #vis struct FetchState #state_ty_generics #state_where_clause {
41 state: __TSystemParamState,
42 marker: ::std::marker::PhantomData<fn() -> (#marker_generics)>,
43 }
44
45 #[automatically_derived]
46 unsafe impl #state_impl_generics #shiv::system::ReadOnlySystemParamFetch for
47 FetchState #state_ty_generics #read_only_where_clause
48 {
49 }
50
51 #[automatically_derived]
52 unsafe impl #state_impl_generics #shiv::system::SystemParamState for FetchState
53 #state_ty_generics #state_where_clause
54 {
55 #[inline]
56 fn init(
57 world: &mut #shiv::world::World,
58 meta: &mut #shiv::system::SystemMeta,
59 ) -> Self {
60 Self {
61 state: __TSystemParamState::init(world, meta),
62 marker: ::std::marker::PhantomData,
63 }
64 }
65
66 #[inline]
67 fn apply(&mut self, world: &mut #shiv::world::World) {
68 self.state.apply(world);
69 }
70 }
71
72 #[automatically_derived]
73 impl #fetch_impl_generics #shiv::system::SystemParamFetch<'w, 's> for
74 FetchState<#fetch_ty_generics> #state_where_clause
75 {
76 type Item = #name #ty_generics;
77
78 #[inline]
79 #[allow(dead_code)]
80 unsafe fn get_param(
81 &'s mut self,
82 meta: &#shiv::system::SystemMeta,
83 world: &'w #shiv::world::World,
84 change_ticks: ::std::primitive::u32,
85 ) -> Self::Item {
86 let param = #shiv::system::SystemParamFetch::get_param(
87 &mut self.state,
88 meta,
89 world,
90 change_ticks
91 );
92
93 #name {#(#field_idents: param.#indices,)*}
94 }
95 }
96 };
97 }
98}
99
100fn validate_lifetimes(generics: &Generics) {
101 for lifetime in generics.lifetimes() {
102 let ident = &lifetime.lifetime.ident;
103
104 if !(ident == "w" || ident == "s") {
105 panic!(
106 "Invalid lifetime: {}, only valid lifetimes are 'w and 's",
107 ident
108 );
109 }
110 }
111}
112
113fn has_lifetime(generics: &Generics, lifetime: &str) -> bool {
114 for lt in generics.lifetimes() {
115 if lt.lifetime.ident == lifetime {
116 return true;
117 }
118 }
119
120 false
121}
122
123fn fetch_generics(generics: &Generics) -> Generics {
124 let mut generics = generics.clone();
125
126 if !has_lifetime(&generics, "w") {
127 generics.params.push(parse_quote!('w));
128 }
129
130 if !has_lifetime(&generics, "s") {
131 generics.params.push(parse_quote!('s));
132 }
133
134 generics
135}
136
137fn state_generics(generics: &Generics, shiv: &Path) -> Generics {
138 let mut generics = generics.clone();
139
140 generics.params = generics
141 .params
142 .clone()
143 .into_pairs()
144 .filter(|param| match param.value() {
145 syn::GenericParam::Lifetime(_) => false,
146 _ => true,
147 })
148 .collect();
149
150 generics.params.push(parse_quote!(
151 __TSystemParamState: #shiv::system::SystemParamState
152 ));
153
154 generics.make_where_clause().predicates.push(parse_quote!(
155 Self: ::std::marker::Send + ::std::marker::Sync + 'static
156 ));
157
158 generics
159}
160
161fn read_only_generics(generics: &Generics, shiv: &Path) -> Generics {
162 let mut generics = generics.clone();
163
164 let where_clause = generics.make_where_clause();
165 where_clause.predicates.push(parse_quote!(
166 __TSystemParamState: #shiv::system::ReadOnlySystemParamFetch
167 ));
168 where_clause.predicates.push(parse_quote!(
169 Self: for<'w, 's> #shiv::system::SystemParamFetch<'w, 's>
170 ));
171
172 generics
173}
174
175fn marker_generics(generics: &Generics) -> Punctuated<TokenStream, Comma> {
176 let mut marker_generics = Punctuated::<TokenStream, Comma>::new();
177 for generic in generics.params.iter() {
178 if let GenericParam::Type(ty) = generic {
179 let ident = &ty.ident;
180 marker_generics.push(parse_quote!(#ident));
181 }
182 }
183
184 marker_generics
185}
186
187fn fetch_ty_generics(
188 generics: &Generics,
189 fields: &[Type],
190 shiv: &Path,
191) -> Punctuated<TokenStream, Comma> {
192 let mut fetch_ty_generics = Punctuated::<TokenStream, Comma>::new();
193 for generic in generics.params.iter() {
194 if let GenericParam::Type(ty) = generic {
195 let ident = &ty.ident;
196 fetch_ty_generics.push(parse_quote!(#ident));
197 }
198 }
199
200 fetch_ty_generics.push(quote!((#(<#fields as #shiv::system::SystemParam>::Fetch,)*)));
201
202 fetch_ty_generics
203}
204
205fn fields(data: &Data) -> Vec<Type> {
206 match data {
207 Data::Struct(s) => match &s.fields {
208 syn::Fields::Named(fields) => {
209 fields.named.iter().map(|field| field.ty.clone()).collect()
210 }
211 syn::Fields::Unnamed(_) => unimplemented!("Unnamed fields are not supported"),
212 syn::Fields::Unit => Vec::new(),
213 },
214 _ => unimplemented!("Only structs are supported"),
215 }
216}
217
218fn field_idents(data: &Data) -> Vec<Ident> {
219 match data {
220 Data::Struct(s) => match &s.fields {
221 syn::Fields::Named(fields) => fields
222 .named
223 .iter()
224 .map(|field| field.ident.clone().unwrap())
225 .collect(),
226 syn::Fields::Unnamed(_) => unimplemented!("Unnamed fields are not supported"),
227 syn::Fields::Unit => Vec::new(),
228 },
229 _ => unimplemented!("Only structs are supported"),
230 }
231}