Skip to main content

ryft_core/
parameters.rs

1use std::{fmt::Debug, marker::PhantomData};
2
3use half::{bf16, f16};
4use paste::paste;
5
6use crate::errors::Error;
7
8// TODO(eaplatanios): Add thorough documentation for [Parameterized].
9//  - [Parameter]s are the leafs and all `P: Parameter` are [Parameterized].
10//  - `PhantomData<P: Parameterized>` is [Parameterized].
11//  - `(P: Parameterized)`, `(P: Parameterized, P: Parameterized)`, ..., up to sized 12, are all [Parameterized].
12//  - `[P: Parameterized; N]` is [Parameterized].
13//  - `Vec<P: Parameterized>` is [Parameterized].
14//  - TODO(eaplatanios): [HashMap]s.
15//  - TODO(eaplatanios): [Box]s.
16//  - `#[derive(Parameterized)]` provides support for custom structs and enums, which also support nested tuples
17//    that mix [Parameterized] and non-[Parameterized] fields. However, they can only be nested within other tuples.
18//    If, for example, they appear in e.g., `Vec<(P, usize)>`, then those tuples are not supported.
19//  - Only the `: Parameter` bound is supported by the derive macro. No additional bounds are supported for `P`.
20//
21// TODO(eaplatanios): Add tests for each of the [Parameterized] implementations included in this file.
22
23// TODO(eaplatanios): Add support for `named_parameters` which pairs each parameter with a path.
24// TODO(eaplatanios): Support something like a `broadcast` operation (e.g., I want to use the same learning rate
25//  for every sub-node from a specific point in the data structure). This is along the lines of what are called
26//  PyTree prefixes in JAX. Related: https://jax.readthedocs.io/en/latest/pytrees.html#applying-optional-parameters-to-pytrees.
27// TODO(eaplatanios): Borrow some of Equinox's tree manipulation capabilities.
28//  Reference: https://docs.kidger.site/equinox/api/manipulation.
29
30// For reference, in JAX, to register custom types as trees, we only need to implement these two functions:
31// - flatten(tree) -> (children, aux_data)
32// - unflatten(aux_data, children) -> tree
33
34// TODO(eaplatanios): Document that this this an empty parameter acting as a placeholder for when we want to manipulate
35//  parameter structures without having to worry about specific parameter types.
36// TODO(eaplatanios): Document that this is a marker trait for parameter types (i.e., leaf nodes).
37//  Furthermore, explain why we need this. Provide `Vec<P>` as a motivating example along with an explanation
38//  for why something like specialization would need to be stable for us to support this.
39// TODO(eaplatanios): Should `Parameter`s always have a static `Rank` and a static `DataType`?
40pub trait Parameter {}
41
42impl Parameter for bool {}
43impl Parameter for i8 {}
44impl Parameter for i16 {}
45impl Parameter for i32 {}
46impl Parameter for i64 {}
47impl Parameter for i128 {}
48impl Parameter for u8 {}
49impl Parameter for u16 {}
50impl Parameter for u32 {}
51impl Parameter for u64 {}
52impl Parameter for u128 {}
53impl Parameter for bf16 {}
54impl Parameter for f16 {}
55impl Parameter for f32 {}
56impl Parameter for f64 {}
57impl Parameter for usize {}
58
59#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
60pub struct Placeholder;
61
62impl Debug for Placeholder {
63    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64        write!(f, "<Parameter>")
65    }
66}
67
68impl Parameter for Placeholder {}
69
70// TODO(eaplatanios): `Vec<(P, non-P)>` is not supported.
71// TODO(eaplatanios): Unit structs should be impossible.
72// TODO(eaplatanios): Talk about the derive macro we have for this trait:
73//  - We also provide a `#[derive(Parameter)]` macro for convenience.
74//  - Supports both structs and enums already.
75//  - The parameter type must be a generic type parameter bounded by [Parameter].
76//  - There must be only one such generic type parameter. Not zero and not more than one.
77//  - All fields that reference / depend on the parameter type are considered parameter fields.
78//  - Attributes of generic parameters are not visited/transformed and they are always carried around as they are.
79//  - We need a recursive helper in order to properly handle tuple types. Tuples are not [Parameterized]
80//    themselves (that is done in order to avoid issues with blanket implementations since we only instantiate
81//    [Parameterized] implementations using prespecified parameter types), but they are supported when nested
82//    within other types, for which we are deriving [Parameterized] implementations.
83//  - Configurable `macro_param_lifetime` and `macro_param_type`.
84// TODO(eaplatanios): Document the following:
85//  - Vec<P> is not Parameterized<P>. Vec<T: Parameterized<P>> is Parameterized<P>.
86//  - HashMap<K, P> is not Parameterized<P>. HashMap<K, V: Parameterized<P>> is Parameterized<V>.
87//  - Same goes for arrays and other collection types.
88pub trait Parameterized<P: Parameter>: Sized {
89    // TODO(eaplatanios): We need to prove that `Self::To<P> = Self`.
90    // TODO(eaplatanios): What if `P` has additional trait bounds?
91    type To<T: Parameter>: Parameterized<T, To<P> = Self> + Parameterized<T, To<Placeholder> = Self::To<Placeholder>>;
92    // + Parameterized<T, To<JvpTracer<P>> = Self::To<JvpTracer<P>>>;
93
94    // #![feature(associated_type_defaults)]
95    // type ParamStructure = Self::To<ParamPlaceholder>;
96
97    // TODO(eaplatanios): Explain that we use associated types instead of `RPITIT` in order to support
98    //  deriving [Parameterized] for enums without the need to do any boxing. Though, is that really true?
99    //  I mean the wrapping enum would have to box anyway...hmm...maybe enums should always use `Box<dyn Iterator>`.
100    type ParamIterator<'t, T: 't + Parameter>: 't + Iterator<Item = &'t T>
101    where
102        // TODO(eaplatanios): Configure rustfmt to put these in the same line when possible.
103        Self: 't;
104
105    type ParamIteratorMut<'t, T: 't + Parameter>: 't + Iterator<Item = &'t mut T>
106    where
107        Self: 't;
108
109    type ParamIntoIterator<T: Parameter>: Iterator<Item = T>;
110
111    /// Returns the number of parameters in this [Parameterized] instance.
112    fn param_count(&self) -> usize;
113
114    fn param_structure(&self) -> Self::To<Placeholder>;
115
116    fn params(&self) -> Self::ParamIterator<'_, P>;
117    fn params_mut(&mut self) -> Self::ParamIteratorMut<'_, P>;
118    fn into_params(self) -> Self::ParamIntoIterator<P>;
119
120    fn from_params_with_remainder<I: Iterator<Item = P>>(
121        structure: Self::To<Placeholder>,
122        params: &mut I,
123    ) -> Result<Self, Error>;
124
125    fn from_params<I: IntoIterator<Item = P>>(structure: Self::To<Placeholder>, params: I) -> Result<Self, Error> {
126        let mut params = params.into_iter();
127        let parameterized = Self::from_params_with_remainder(structure, &mut params)?;
128        params.next().map(|_| Err(Error::UnusedParams)).unwrap_or_else(|| Ok(parameterized))
129    }
130
131    // TODO(eaplatanios): Document that this maps the parameters in this type.
132    fn map_params<T: Parameter, F: FnMut(P) -> T>(self, map_fn: F) -> Result<Self::To<T>, Error> {
133        Self::To::<T>::from_params(self.param_structure(), self.into_params().map(map_fn))
134    }
135}
136
137impl<P: Parameter> Parameterized<P> for P {
138    type To<T: Parameter> = T;
139
140    type ParamIterator<'t, T: 't + Parameter>
141        = std::iter::Once<&'t T>
142    where
143        Self: 't;
144    type ParamIteratorMut<'t, T: 't + Parameter>
145        = std::iter::Once<&'t mut T>
146    where
147        Self: 't;
148    type ParamIntoIterator<T: Parameter> = std::iter::Once<T>;
149
150    fn param_count(&self) -> usize {
151        1
152    }
153
154    fn param_structure(&self) -> Self::To<Placeholder> {
155        Placeholder
156    }
157
158    fn params(&self) -> Self::ParamIterator<'_, P> {
159        std::iter::once(self)
160    }
161
162    fn params_mut(&mut self) -> Self::ParamIteratorMut<'_, P> {
163        std::iter::once(self)
164    }
165
166    fn into_params(self) -> Self::ParamIntoIterator<P> {
167        std::iter::once(self)
168    }
169
170    fn from_params_with_remainder<I: Iterator<Item = P>>(
171        _structure: Self::To<Placeholder>,
172        params: &mut I,
173    ) -> Result<Self, Error> {
174        params.next().ok_or(Error::InsufficientParams { expected_count: 1 })
175    }
176}
177
178impl<P: Parameter> Parameterized<P> for PhantomData<P> {
179    type To<T: Parameter> = PhantomData<T>;
180
181    type ParamIterator<'t, T: 't + Parameter>
182        = std::iter::Empty<&'t T>
183    where
184        Self: 't;
185    type ParamIteratorMut<'t, T: 't + Parameter>
186        = std::iter::Empty<&'t mut T>
187    where
188        Self: 't;
189    type ParamIntoIterator<T: Parameter> = std::iter::Empty<T>;
190
191    fn param_count(&self) -> usize {
192        0
193    }
194
195    fn param_structure(&self) -> Self::To<Placeholder> {
196        PhantomData
197    }
198
199    fn params(&self) -> Self::ParamIterator<'_, P> {
200        std::iter::empty()
201    }
202
203    fn params_mut(&mut self) -> Self::ParamIteratorMut<'_, P> {
204        std::iter::empty()
205    }
206
207    fn into_params(self) -> Self::ParamIntoIterator<P> {
208        std::iter::empty()
209    }
210
211    fn from_params_with_remainder<I: Iterator<Item = P>>(
212        _structure: Self::To<Placeholder>,
213        _params: &mut I,
214    ) -> Result<Self, Error> {
215        Ok(PhantomData)
216    }
217}
218
219// TODO(eaplatanios): Add implementation for [Box].
220
221// Use declarative macros to provide implementations for tuples of [Parameterized] items. Note that if a tuple contains
222// a mix of [Parameterized] and non-[Parameterized] items, then the generated implementations here will not cover it.
223// Instead, such tuples are supported when nested within `struct`s or `enum`s tagged with `#[derive(Parameterized)]`
224// as the `derive` macro for [Parameterized] provides special treatment for them.
225
226macro_rules! tuple_parameterized_impl {
227    ($($T:ident),*) => {
228        paste! {
229            impl<P: Parameter$(, $T: Parameterized<P>)*> Parameterized<P> for ($($T,)*) {
230                type To<T: Parameter> = ($($T::To<T>,)*);
231
232                type ParamIterator<'t, T: 't + Parameter> =
233                    tuple_param_iterator_ty!('t, T, ($($T,)*))
234                where Self: 't;
235
236                type ParamIteratorMut<'t, T: 't + Parameter> =
237                    tuple_param_iterator_mut_ty!('t, T, ($($T,)*))
238                where Self: 't;
239
240                type ParamIntoIterator<T: Parameter> = tuple_param_into_iterator_ty!(T, ($($T,)*));
241
242                fn param_count(&self) -> usize {
243                    let ($([<$T:lower>],)*) = &self;
244                    $([<$T:lower>].param_count()+)* 0usize
245                }
246
247                fn param_structure(&self) -> Self::To<Placeholder> {
248                    let ($([<$T:lower>],)*) = &self;
249                    ($([<$T:lower>].param_structure(),)*)
250                }
251
252                fn params(&self) -> Self::ParamIterator<'_, P> {
253                    let ($([<$T:lower>],)*) = self;
254                    tuple_param_iterator!(P, ($([<$T:lower>],)*))
255                }
256
257                fn params_mut(&mut self) -> Self::ParamIteratorMut<'_, P> {
258                    let ($([<$T:lower>],)*) = self;
259                    tuple_param_iterator_mut!(P, ($([<$T:lower>],)*))
260                }
261
262                fn into_params(self) -> Self::ParamIntoIterator<P> {
263                    let ($([<$T:lower>],)*) = self;
264                    tuple_param_into_iterator!(P, ($([<$T:lower>],)*))
265                }
266
267                fn from_params_with_remainder<I: Iterator<Item = P>>(
268                    structure: Self::To<Placeholder>,
269                    params: &mut I,
270                ) -> Result<Self, Error> {
271                    let ($([<$T:lower _field>],)*) = structure;
272                    $(let [<$T:lower>] = $T::from_params_with_remainder([<$T:lower _field>], params)?;)*
273                    Ok(($([<$T:lower>],)*))
274                }
275            }
276        }
277    };
278}
279
280macro_rules! tuple_param_iterator_ty {
281    ($t:lifetime, $T:ty, ()) => {
282        std::iter::Empty<&$t $T>
283    };
284
285    ($t:lifetime, $T:ty, ($head:ident, $($tail:ident,)*)) => {
286        std::iter::Chain<$head::ParamIterator<$t, $T>, tuple_param_iterator_ty!($t, $T, ($($tail,)*))>
287    };
288}
289
290macro_rules! tuple_param_iterator_mut_ty {
291    ($t:lifetime, $T:ty, ()) => {
292        std::iter::Empty<&$t mut $T>
293    };
294
295    ($t:lifetime, $T:ty, ($head:ident, $($tail:ident,)*)) => {
296        std::iter::Chain<$head::ParamIteratorMut<$t, $T>, tuple_param_iterator_mut_ty!($t, $T, ($($tail,)*))>
297    };
298}
299
300macro_rules! tuple_param_into_iterator_ty {
301    ($T:ty, ()) => {
302        std::iter::Empty<$T>
303    };
304
305    ($T:ty, ($head:ident, $($tail:ident,)*)) => {
306        std::iter::Chain<$head::ParamIntoIterator<$T>, tuple_param_into_iterator_ty!($T, ($($tail,)*))>
307    };
308}
309
310macro_rules! tuple_param_iterator {
311    ($T:tt, ()) => {
312        std::iter::empty::<&'_ $T>()
313    };
314
315    ($T:tt, ($head:ident, $($tail:ident,)*)) => {
316        $head.params().chain(tuple_param_iterator!($T, ($($tail,)*)))
317    };
318}
319
320macro_rules! tuple_param_iterator_mut {
321    ($T:tt, ()) => {
322        std::iter::empty::<&'_ mut $T>()
323    };
324
325    ($T:tt, ($head:ident, $($tail:ident,)*)) => {
326        $head.params_mut().chain(tuple_param_iterator_mut!($T, ($($tail,)*)))
327    };
328}
329
330macro_rules! tuple_param_into_iterator {
331    ($T:tt, ()) => {
332        std::iter::empty::<$T>()
333    };
334
335    ($T:tt, ($head:ident, $($tail:ident,)*)) => {
336        $head.into_params().chain(tuple_param_into_iterator!($T, ($($tail,)*)))
337    };
338}
339
340tuple_parameterized_impl!(T0);
341tuple_parameterized_impl!(T0, T1);
342tuple_parameterized_impl!(T0, T1, T2);
343tuple_parameterized_impl!(T0, T1, T2, T3);
344tuple_parameterized_impl!(T0, T1, T2, T3, T4);
345tuple_parameterized_impl!(T0, T1, T2, T3, T4, T5);
346tuple_parameterized_impl!(T0, T1, T2, T3, T4, T5, T6);
347tuple_parameterized_impl!(T0, T1, T2, T3, T4, T5, T6, T7);
348tuple_parameterized_impl!(T0, T1, T2, T3, T4, T5, T6, T7, T8);
349tuple_parameterized_impl!(T0, T1, T2, T3, T4, T5, T6, T7, T8, T9);
350tuple_parameterized_impl!(T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10);
351tuple_parameterized_impl!(T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11);
352
353impl<P: Parameter, V: Parameterized<P>, const N: usize> Parameterized<P> for [V; N] {
354    type To<T: Parameter> = [V::To<T>; N];
355
356    type ParamIterator<'t, T: 't + Parameter>
357        = std::iter::FlatMap<
358        std::slice::Iter<'t, V>,
359        <V as Parameterized<P>>::ParamIterator<'t, T>,
360        fn(&'t V) -> <V as Parameterized<P>>::ParamIterator<'t, T>,
361    >
362    where
363        Self: 't;
364
365    type ParamIteratorMut<'t, T: 't + Parameter>
366        = std::iter::FlatMap<
367        std::slice::IterMut<'t, V>,
368        <V as Parameterized<P>>::ParamIteratorMut<'t, T>,
369        fn(&'t mut V) -> <V as Parameterized<P>>::ParamIteratorMut<'t, T>,
370    >
371    where
372        Self: 't;
373
374    type ParamIntoIterator<T: Parameter> = std::iter::FlatMap<
375        std::array::IntoIter<V, N>,
376        <V as Parameterized<P>>::ParamIntoIterator<T>,
377        fn(V) -> <V as Parameterized<P>>::ParamIntoIterator<T>,
378    >;
379
380    fn param_count(&self) -> usize {
381        self.iter().map(|value| value.param_count()).sum()
382    }
383
384    fn param_structure(&self) -> Self::To<Placeholder> {
385        std::array::from_fn(|i| self[i].param_structure())
386    }
387
388    fn params(&self) -> Self::ParamIterator<'_, P> {
389        self.iter().flat_map(V::params)
390    }
391
392    fn params_mut(&mut self) -> Self::ParamIteratorMut<'_, P> {
393        self.iter_mut().flat_map(V::params_mut)
394    }
395
396    fn into_params(self) -> Self::ParamIntoIterator<P> {
397        self.into_iter().flat_map(V::into_params)
398    }
399
400    fn from_params_with_remainder<I: Iterator<Item = P>>(
401        structure: Self::To<Placeholder>,
402        params: &mut I,
403    ) -> Result<Self, Error> {
404        // Make this more efficient by using [std::array::try_from_fn] once it becomes stable.
405        // Tracking issue: https://github.com/rust-lang/rust/issues/89379.
406        let values = structure
407            .into_iter()
408            .map(|value_structure| V::from_params_with_remainder(value_structure, params))
409            .collect::<Result<Vec<V>, _>>()?;
410        Ok(unsafe { values.try_into().unwrap_unchecked() })
411    }
412}
413
414impl<P: Parameter, V: Parameterized<P>> Parameterized<P> for Vec<V> {
415    type To<T: Parameter> = Vec<V::To<T>>;
416
417    type ParamIterator<'t, T: 't + Parameter>
418        = std::iter::FlatMap<
419        std::slice::Iter<'t, V>,
420        <V as Parameterized<P>>::ParamIterator<'t, T>,
421        fn(&'t V) -> <V as Parameterized<P>>::ParamIterator<'t, T>,
422    >
423    where
424        Self: 't;
425
426    type ParamIteratorMut<'t, T: 't + Parameter>
427        = std::iter::FlatMap<
428        std::slice::IterMut<'t, V>,
429        <V as Parameterized<P>>::ParamIteratorMut<'t, T>,
430        fn(&'t mut V) -> <V as Parameterized<P>>::ParamIteratorMut<'t, T>,
431    >
432    where
433        Self: 't;
434
435    type ParamIntoIterator<T: Parameter> = std::iter::FlatMap<
436        std::vec::IntoIter<V>,
437        <V as Parameterized<P>>::ParamIntoIterator<T>,
438        fn(V) -> <V as Parameterized<P>>::ParamIntoIterator<T>,
439    >;
440
441    fn param_count(&self) -> usize {
442        self.iter().map(|value| value.param_count()).sum()
443    }
444
445    fn param_structure(&self) -> Self::To<Placeholder> {
446        self.iter().map(|value| value.param_structure()).collect()
447    }
448
449    fn params(&self) -> Self::ParamIterator<'_, P> {
450        self.iter().flat_map(|value| value.params())
451    }
452
453    fn params_mut(&mut self) -> Self::ParamIteratorMut<'_, P> {
454        self.iter_mut().flat_map(|value| value.params_mut())
455    }
456
457    fn into_params(self) -> Self::ParamIntoIterator<P> {
458        self.into_iter().flat_map(|value| value.into_params())
459    }
460
461    fn from_params_with_remainder<I: Iterator<Item = P>>(
462        structure: Self::To<Placeholder>,
463        params: &mut I,
464    ) -> Result<Self, Error> {
465        let expected_count = structure.len();
466        let mut values = Vec::new();
467        values.reserve_exact(expected_count);
468        for value_structure in structure {
469            values.push(V::from_params_with_remainder(value_structure, params).map_err(|error| match error {
470                Error::InsufficientParams { .. } => Error::InsufficientParams { expected_count },
471                error => error,
472            })?);
473        }
474        Ok(values)
475    }
476}
477
478// TODO(eaplatanios): Implement this for arrays, HashMap<K, _>, etc.