1use std::{fmt::Debug, marker::PhantomData};
2
3use half::{bf16, f16};
4use paste::paste;
5
6use crate::errors::Error;
7
8pub trait Parameter {}
41
42impl Parameter for bool {}
43impl Parameter for i8 {}
44impl Parameter for i16 {}
45impl Parameter for i32 {}
46impl Parameter for i64 {}
47impl Parameter for i128 {}
48impl Parameter for u8 {}
49impl Parameter for u16 {}
50impl Parameter for u32 {}
51impl Parameter for u64 {}
52impl Parameter for u128 {}
53impl Parameter for bf16 {}
54impl Parameter for f16 {}
55impl Parameter for f32 {}
56impl Parameter for f64 {}
57impl Parameter for usize {}
58
59#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
60pub struct Placeholder;
61
62impl Debug for Placeholder {
63 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64 write!(f, "<Parameter>")
65 }
66}
67
68impl Parameter for Placeholder {}
69
70pub trait Parameterized<P: Parameter>: Sized {
89 type To<T: Parameter>: Parameterized<T, To<P> = Self> + Parameterized<T, To<Placeholder> = Self::To<Placeholder>>;
92 type ParamIterator<'t, T: 't + Parameter>: 't + Iterator<Item = &'t T>
101 where
102 Self: 't;
104
105 type ParamIteratorMut<'t, T: 't + Parameter>: 't + Iterator<Item = &'t mut T>
106 where
107 Self: 't;
108
109 type ParamIntoIterator<T: Parameter>: Iterator<Item = T>;
110
111 fn param_count(&self) -> usize;
113
114 fn param_structure(&self) -> Self::To<Placeholder>;
115
116 fn params(&self) -> Self::ParamIterator<'_, P>;
117 fn params_mut(&mut self) -> Self::ParamIteratorMut<'_, P>;
118 fn into_params(self) -> Self::ParamIntoIterator<P>;
119
120 fn from_params_with_remainder<I: Iterator<Item = P>>(
121 structure: Self::To<Placeholder>,
122 params: &mut I,
123 ) -> Result<Self, Error>;
124
125 fn from_params<I: IntoIterator<Item = P>>(structure: Self::To<Placeholder>, params: I) -> Result<Self, Error> {
126 let mut params = params.into_iter();
127 let parameterized = Self::from_params_with_remainder(structure, &mut params)?;
128 params.next().map(|_| Err(Error::UnusedParams)).unwrap_or_else(|| Ok(parameterized))
129 }
130
131 fn map_params<T: Parameter, F: FnMut(P) -> T>(self, map_fn: F) -> Result<Self::To<T>, Error> {
133 Self::To::<T>::from_params(self.param_structure(), self.into_params().map(map_fn))
134 }
135}
136
137impl<P: Parameter> Parameterized<P> for P {
138 type To<T: Parameter> = T;
139
140 type ParamIterator<'t, T: 't + Parameter>
141 = std::iter::Once<&'t T>
142 where
143 Self: 't;
144 type ParamIteratorMut<'t, T: 't + Parameter>
145 = std::iter::Once<&'t mut T>
146 where
147 Self: 't;
148 type ParamIntoIterator<T: Parameter> = std::iter::Once<T>;
149
150 fn param_count(&self) -> usize {
151 1
152 }
153
154 fn param_structure(&self) -> Self::To<Placeholder> {
155 Placeholder
156 }
157
158 fn params(&self) -> Self::ParamIterator<'_, P> {
159 std::iter::once(self)
160 }
161
162 fn params_mut(&mut self) -> Self::ParamIteratorMut<'_, P> {
163 std::iter::once(self)
164 }
165
166 fn into_params(self) -> Self::ParamIntoIterator<P> {
167 std::iter::once(self)
168 }
169
170 fn from_params_with_remainder<I: Iterator<Item = P>>(
171 _structure: Self::To<Placeholder>,
172 params: &mut I,
173 ) -> Result<Self, Error> {
174 params.next().ok_or(Error::InsufficientParams { expected_count: 1 })
175 }
176}
177
178impl<P: Parameter> Parameterized<P> for PhantomData<P> {
179 type To<T: Parameter> = PhantomData<T>;
180
181 type ParamIterator<'t, T: 't + Parameter>
182 = std::iter::Empty<&'t T>
183 where
184 Self: 't;
185 type ParamIteratorMut<'t, T: 't + Parameter>
186 = std::iter::Empty<&'t mut T>
187 where
188 Self: 't;
189 type ParamIntoIterator<T: Parameter> = std::iter::Empty<T>;
190
191 fn param_count(&self) -> usize {
192 0
193 }
194
195 fn param_structure(&self) -> Self::To<Placeholder> {
196 PhantomData
197 }
198
199 fn params(&self) -> Self::ParamIterator<'_, P> {
200 std::iter::empty()
201 }
202
203 fn params_mut(&mut self) -> Self::ParamIteratorMut<'_, P> {
204 std::iter::empty()
205 }
206
207 fn into_params(self) -> Self::ParamIntoIterator<P> {
208 std::iter::empty()
209 }
210
211 fn from_params_with_remainder<I: Iterator<Item = P>>(
212 _structure: Self::To<Placeholder>,
213 _params: &mut I,
214 ) -> Result<Self, Error> {
215 Ok(PhantomData)
216 }
217}
218
219macro_rules! tuple_parameterized_impl {
227 ($($T:ident),*) => {
228 paste! {
229 impl<P: Parameter$(, $T: Parameterized<P>)*> Parameterized<P> for ($($T,)*) {
230 type To<T: Parameter> = ($($T::To<T>,)*);
231
232 type ParamIterator<'t, T: 't + Parameter> =
233 tuple_param_iterator_ty!('t, T, ($($T,)*))
234 where Self: 't;
235
236 type ParamIteratorMut<'t, T: 't + Parameter> =
237 tuple_param_iterator_mut_ty!('t, T, ($($T,)*))
238 where Self: 't;
239
240 type ParamIntoIterator<T: Parameter> = tuple_param_into_iterator_ty!(T, ($($T,)*));
241
242 fn param_count(&self) -> usize {
243 let ($([<$T:lower>],)*) = &self;
244 $([<$T:lower>].param_count()+)* 0usize
245 }
246
247 fn param_structure(&self) -> Self::To<Placeholder> {
248 let ($([<$T:lower>],)*) = &self;
249 ($([<$T:lower>].param_structure(),)*)
250 }
251
252 fn params(&self) -> Self::ParamIterator<'_, P> {
253 let ($([<$T:lower>],)*) = self;
254 tuple_param_iterator!(P, ($([<$T:lower>],)*))
255 }
256
257 fn params_mut(&mut self) -> Self::ParamIteratorMut<'_, P> {
258 let ($([<$T:lower>],)*) = self;
259 tuple_param_iterator_mut!(P, ($([<$T:lower>],)*))
260 }
261
262 fn into_params(self) -> Self::ParamIntoIterator<P> {
263 let ($([<$T:lower>],)*) = self;
264 tuple_param_into_iterator!(P, ($([<$T:lower>],)*))
265 }
266
267 fn from_params_with_remainder<I: Iterator<Item = P>>(
268 structure: Self::To<Placeholder>,
269 params: &mut I,
270 ) -> Result<Self, Error> {
271 let ($([<$T:lower _field>],)*) = structure;
272 $(let [<$T:lower>] = $T::from_params_with_remainder([<$T:lower _field>], params)?;)*
273 Ok(($([<$T:lower>],)*))
274 }
275 }
276 }
277 };
278}
279
280macro_rules! tuple_param_iterator_ty {
281 ($t:lifetime, $T:ty, ()) => {
282 std::iter::Empty<&$t $T>
283 };
284
285 ($t:lifetime, $T:ty, ($head:ident, $($tail:ident,)*)) => {
286 std::iter::Chain<$head::ParamIterator<$t, $T>, tuple_param_iterator_ty!($t, $T, ($($tail,)*))>
287 };
288}
289
290macro_rules! tuple_param_iterator_mut_ty {
291 ($t:lifetime, $T:ty, ()) => {
292 std::iter::Empty<&$t mut $T>
293 };
294
295 ($t:lifetime, $T:ty, ($head:ident, $($tail:ident,)*)) => {
296 std::iter::Chain<$head::ParamIteratorMut<$t, $T>, tuple_param_iterator_mut_ty!($t, $T, ($($tail,)*))>
297 };
298}
299
300macro_rules! tuple_param_into_iterator_ty {
301 ($T:ty, ()) => {
302 std::iter::Empty<$T>
303 };
304
305 ($T:ty, ($head:ident, $($tail:ident,)*)) => {
306 std::iter::Chain<$head::ParamIntoIterator<$T>, tuple_param_into_iterator_ty!($T, ($($tail,)*))>
307 };
308}
309
310macro_rules! tuple_param_iterator {
311 ($T:tt, ()) => {
312 std::iter::empty::<&'_ $T>()
313 };
314
315 ($T:tt, ($head:ident, $($tail:ident,)*)) => {
316 $head.params().chain(tuple_param_iterator!($T, ($($tail,)*)))
317 };
318}
319
320macro_rules! tuple_param_iterator_mut {
321 ($T:tt, ()) => {
322 std::iter::empty::<&'_ mut $T>()
323 };
324
325 ($T:tt, ($head:ident, $($tail:ident,)*)) => {
326 $head.params_mut().chain(tuple_param_iterator_mut!($T, ($($tail,)*)))
327 };
328}
329
330macro_rules! tuple_param_into_iterator {
331 ($T:tt, ()) => {
332 std::iter::empty::<$T>()
333 };
334
335 ($T:tt, ($head:ident, $($tail:ident,)*)) => {
336 $head.into_params().chain(tuple_param_into_iterator!($T, ($($tail,)*)))
337 };
338}
339
340tuple_parameterized_impl!(T0);
341tuple_parameterized_impl!(T0, T1);
342tuple_parameterized_impl!(T0, T1, T2);
343tuple_parameterized_impl!(T0, T1, T2, T3);
344tuple_parameterized_impl!(T0, T1, T2, T3, T4);
345tuple_parameterized_impl!(T0, T1, T2, T3, T4, T5);
346tuple_parameterized_impl!(T0, T1, T2, T3, T4, T5, T6);
347tuple_parameterized_impl!(T0, T1, T2, T3, T4, T5, T6, T7);
348tuple_parameterized_impl!(T0, T1, T2, T3, T4, T5, T6, T7, T8);
349tuple_parameterized_impl!(T0, T1, T2, T3, T4, T5, T6, T7, T8, T9);
350tuple_parameterized_impl!(T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10);
351tuple_parameterized_impl!(T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11);
352
353impl<P: Parameter, V: Parameterized<P>, const N: usize> Parameterized<P> for [V; N] {
354 type To<T: Parameter> = [V::To<T>; N];
355
356 type ParamIterator<'t, T: 't + Parameter>
357 = std::iter::FlatMap<
358 std::slice::Iter<'t, V>,
359 <V as Parameterized<P>>::ParamIterator<'t, T>,
360 fn(&'t V) -> <V as Parameterized<P>>::ParamIterator<'t, T>,
361 >
362 where
363 Self: 't;
364
365 type ParamIteratorMut<'t, T: 't + Parameter>
366 = std::iter::FlatMap<
367 std::slice::IterMut<'t, V>,
368 <V as Parameterized<P>>::ParamIteratorMut<'t, T>,
369 fn(&'t mut V) -> <V as Parameterized<P>>::ParamIteratorMut<'t, T>,
370 >
371 where
372 Self: 't;
373
374 type ParamIntoIterator<T: Parameter> = std::iter::FlatMap<
375 std::array::IntoIter<V, N>,
376 <V as Parameterized<P>>::ParamIntoIterator<T>,
377 fn(V) -> <V as Parameterized<P>>::ParamIntoIterator<T>,
378 >;
379
380 fn param_count(&self) -> usize {
381 self.iter().map(|value| value.param_count()).sum()
382 }
383
384 fn param_structure(&self) -> Self::To<Placeholder> {
385 std::array::from_fn(|i| self[i].param_structure())
386 }
387
388 fn params(&self) -> Self::ParamIterator<'_, P> {
389 self.iter().flat_map(V::params)
390 }
391
392 fn params_mut(&mut self) -> Self::ParamIteratorMut<'_, P> {
393 self.iter_mut().flat_map(V::params_mut)
394 }
395
396 fn into_params(self) -> Self::ParamIntoIterator<P> {
397 self.into_iter().flat_map(V::into_params)
398 }
399
400 fn from_params_with_remainder<I: Iterator<Item = P>>(
401 structure: Self::To<Placeholder>,
402 params: &mut I,
403 ) -> Result<Self, Error> {
404 let values = structure
407 .into_iter()
408 .map(|value_structure| V::from_params_with_remainder(value_structure, params))
409 .collect::<Result<Vec<V>, _>>()?;
410 Ok(unsafe { values.try_into().unwrap_unchecked() })
411 }
412}
413
414impl<P: Parameter, V: Parameterized<P>> Parameterized<P> for Vec<V> {
415 type To<T: Parameter> = Vec<V::To<T>>;
416
417 type ParamIterator<'t, T: 't + Parameter>
418 = std::iter::FlatMap<
419 std::slice::Iter<'t, V>,
420 <V as Parameterized<P>>::ParamIterator<'t, T>,
421 fn(&'t V) -> <V as Parameterized<P>>::ParamIterator<'t, T>,
422 >
423 where
424 Self: 't;
425
426 type ParamIteratorMut<'t, T: 't + Parameter>
427 = std::iter::FlatMap<
428 std::slice::IterMut<'t, V>,
429 <V as Parameterized<P>>::ParamIteratorMut<'t, T>,
430 fn(&'t mut V) -> <V as Parameterized<P>>::ParamIteratorMut<'t, T>,
431 >
432 where
433 Self: 't;
434
435 type ParamIntoIterator<T: Parameter> = std::iter::FlatMap<
436 std::vec::IntoIter<V>,
437 <V as Parameterized<P>>::ParamIntoIterator<T>,
438 fn(V) -> <V as Parameterized<P>>::ParamIntoIterator<T>,
439 >;
440
441 fn param_count(&self) -> usize {
442 self.iter().map(|value| value.param_count()).sum()
443 }
444
445 fn param_structure(&self) -> Self::To<Placeholder> {
446 self.iter().map(|value| value.param_structure()).collect()
447 }
448
449 fn params(&self) -> Self::ParamIterator<'_, P> {
450 self.iter().flat_map(|value| value.params())
451 }
452
453 fn params_mut(&mut self) -> Self::ParamIteratorMut<'_, P> {
454 self.iter_mut().flat_map(|value| value.params_mut())
455 }
456
457 fn into_params(self) -> Self::ParamIntoIterator<P> {
458 self.into_iter().flat_map(|value| value.into_params())
459 }
460
461 fn from_params_with_remainder<I: Iterator<Item = P>>(
462 structure: Self::To<Placeholder>,
463 params: &mut I,
464 ) -> Result<Self, Error> {
465 let expected_count = structure.len();
466 let mut values = Vec::new();
467 values.reserve_exact(expected_count);
468 for value_structure in structure {
469 values.push(V::from_params_with_remainder(value_structure, params).map_err(|error| match error {
470 Error::InsufficientParams { .. } => Error::InsufficientParams { expected_count },
471 error => error,
472 })?);
473 }
474 Ok(values)
475 }
476}
477
478