Skip to main content

subtr_actor/collector/ndarray/
traits.rs

1use crate::*;
2use boxcars;
3/// Re-export of `derive_new` used by the public ndarray feature macros.
4pub use derive_new;
5/// Re-export of `paste` used by the public ndarray feature macros.
6pub use paste;
7use std::sync::Arc;
8
9/// Object-safe interface for frame-level feature extraction.
10pub trait FeatureAdder<F> {
11    fn features_added(&self) -> usize {
12        self.get_column_headers().len()
13    }
14
15    fn get_column_headers(&self) -> &[&str];
16
17    fn add_features(
18        &self,
19        processor: &ReplayProcessor,
20        frame: &boxcars::Frame,
21        frame_count: usize,
22        current_time: f32,
23        vector: &mut Vec<F>,
24    ) -> SubtrActorResult<()>;
25}
26
27/// Heterogeneous collection of frame-level feature adders.
28pub type FeatureAdders<F> = Vec<Arc<dyn FeatureAdder<F> + Send + Sync>>;
29
30/// Fixed-width feature extractor with compile-time column count validation.
31pub trait LengthCheckedFeatureAdder<F, const N: usize> {
32    fn get_column_headers_array(&self) -> &[&str; N];
33
34    fn get_features(
35        &self,
36        processor: &ReplayProcessor,
37        frame: &boxcars::Frame,
38        frame_count: usize,
39        current_time: f32,
40    ) -> SubtrActorResult<[F; N]>;
41}
42
43/// Implements [`FeatureAdder`] for a type that already satisfies [`LengthCheckedFeatureAdder`].
44#[macro_export]
45macro_rules! impl_feature_adder {
46    ($struct_name:ident) => {
47        impl<F: TryFrom<f32>> FeatureAdder<F> for $struct_name<F>
48        where
49            <F as TryFrom<f32>>::Error: std::fmt::Debug,
50        {
51            fn add_features(
52                &self,
53                processor: &ReplayProcessor,
54                frame: &boxcars::Frame,
55                frame_count: usize,
56                current_time: f32,
57                vector: &mut Vec<F>,
58            ) -> SubtrActorResult<()> {
59                Ok(
60                    vector.extend(self.get_features(
61                        processor,
62                        frame,
63                        frame_count,
64                        current_time,
65                    )?),
66                )
67            }
68
69            fn get_column_headers(&self) -> &[&str] {
70                self.get_column_headers_array()
71            }
72        }
73    };
74}
75
76/// Object-safe interface for per-player feature extraction.
77pub trait PlayerFeatureAdder<F> {
78    fn features_added(&self) -> usize {
79        self.get_column_headers().len()
80    }
81
82    fn get_column_headers(&self) -> &[&str];
83
84    fn add_features(
85        &self,
86        player_id: &PlayerId,
87        processor: &ReplayProcessor,
88        frame: &boxcars::Frame,
89        frame_count: usize,
90        current_time: f32,
91        vector: &mut Vec<F>,
92    ) -> SubtrActorResult<()>;
93}
94
95/// Heterogeneous collection of per-player feature adders.
96pub type PlayerFeatureAdders<F> = Vec<Arc<dyn PlayerFeatureAdder<F> + Send + Sync>>;
97
98/// Fixed-width per-player feature extractor with compile-time column count validation.
99pub trait LengthCheckedPlayerFeatureAdder<F, const N: usize> {
100    fn get_column_headers_array(&self) -> &[&str; N];
101
102    fn get_features(
103        &self,
104        player_id: &PlayerId,
105        processor: &ReplayProcessor,
106        frame: &boxcars::Frame,
107        frame_count: usize,
108        current_time: f32,
109    ) -> SubtrActorResult<[F; N]>;
110}
111
112/// Implements [`PlayerFeatureAdder`] for a type that satisfies [`LengthCheckedPlayerFeatureAdder`].
113#[macro_export]
114macro_rules! impl_player_feature_adder {
115    ($struct_name:ident) => {
116        impl<F: TryFrom<f32>> PlayerFeatureAdder<F> for $struct_name<F>
117        where
118            <F as TryFrom<f32>>::Error: std::fmt::Debug,
119        {
120            fn add_features(
121                &self,
122                player_id: &PlayerId,
123                processor: &ReplayProcessor,
124                frame: &boxcars::Frame,
125                frame_count: usize,
126                current_time: f32,
127                vector: &mut Vec<F>,
128            ) -> SubtrActorResult<()> {
129                Ok(vector.extend(self.get_features(
130                    player_id,
131                    processor,
132                    frame,
133                    frame_count,
134                    current_time,
135                )?))
136            }
137
138            fn get_column_headers(&self) -> &[&str] {
139                self.get_column_headers_array()
140            }
141        }
142    };
143}
144
145impl<G, F, const N: usize> FeatureAdder<F> for (G, &[&str; N])
146where
147    G: Fn(&ReplayProcessor, &boxcars::Frame, usize, f32) -> SubtrActorResult<[F; N]>,
148{
149    fn add_features(
150        &self,
151        processor: &ReplayProcessor,
152        frame: &boxcars::Frame,
153        frame_count: usize,
154        current_time: f32,
155        vector: &mut Vec<F>,
156    ) -> SubtrActorResult<()> {
157        vector.extend(self.0(processor, frame, frame_count, current_time)?);
158        Ok(())
159    }
160
161    fn get_column_headers(&self) -> &[&str] {
162        self.1.as_slice()
163    }
164}
165
166impl<G, F, const N: usize> PlayerFeatureAdder<F> for (G, &[&str; N])
167where
168    G: Fn(&PlayerId, &ReplayProcessor, &boxcars::Frame, usize, f32) -> SubtrActorResult<[F; N]>,
169{
170    fn add_features(
171        &self,
172        player_id: &PlayerId,
173        processor: &ReplayProcessor,
174        frame: &boxcars::Frame,
175        frame_count: usize,
176        current_time: f32,
177        vector: &mut Vec<F>,
178    ) -> SubtrActorResult<()> {
179        vector.extend(self.0(
180            player_id,
181            processor,
182            frame,
183            frame_count,
184            current_time,
185        )?);
186        Ok(())
187    }
188
189    fn get_column_headers(&self) -> &[&str] {
190        self.1.as_slice()
191    }
192}
193
194/// Declares a new global feature-adder type and wires it into the ndarray traits.
195#[macro_export]
196macro_rules! build_global_feature_adder {
197    ($struct_name:ident, $prop_getter:expr, $( $column_names:expr ),* $(,)?) => {
198
199        #[derive(derive_new::new)]
200        pub struct $struct_name<F> {
201            _zero: std::marker::PhantomData<F>,
202        }
203
204        impl<F: Sync + Send + TryFrom<f32> + 'static> $struct_name<F> where
205            <F as TryFrom<f32>>::Error: std::fmt::Debug,
206        {
207            pub fn arc_new() -> std::sync::Arc<dyn FeatureAdder<F> + Send + Sync + 'static> {
208                std::sync::Arc::new(Self::new())
209            }
210        }
211
212        global_feature_adder!(
213            $struct_name,
214            $prop_getter,
215            $( $column_names ),*
216        );
217    }
218}
219
220/// Implements the ndarray feature-adder traits for an existing global feature type.
221#[macro_export]
222macro_rules! global_feature_adder {
223    ($struct_name:ident, $prop_getter:expr, $( $column_names:expr ),* $(,)?) => {
224        macro_rules! _global_feature_adder {
225            ($count:ident) => {
226                impl<F: TryFrom<f32>> LengthCheckedFeatureAdder<F, $count> for $struct_name<F>
227                where
228                    <F as TryFrom<f32>>::Error: std::fmt::Debug,
229                {
230                    fn get_column_headers_array(&self) -> &[&str; $count] {
231                        &[$( $column_names ),*]
232                    }
233
234                    fn get_features(
235                        &self,
236                        processor: &ReplayProcessor,
237                        frame: &boxcars::Frame,
238                        frame_count: usize,
239                        current_time: f32,
240                    ) -> SubtrActorResult<[F; $count]> {
241                        $prop_getter(self, processor, frame, frame_count, current_time)
242                    }
243                }
244
245                impl_feature_adder!($struct_name);
246            };
247        }
248        paste::paste! {
249            const [<$struct_name:snake:upper _LENGTH>]: usize = [$($column_names),*].len();
250            _global_feature_adder!([<$struct_name:snake:upper _LENGTH>]);
251        }
252    }
253}
254
255/// Declares a new per-player feature-adder type and wires it into the ndarray traits.
256#[macro_export]
257macro_rules! build_player_feature_adder {
258    ($struct_name:ident, $prop_getter:expr, $( $column_names:expr ),* $(,)?) => {
259        #[derive(derive_new::new)]
260        pub struct $struct_name<F> {
261            _zero: std::marker::PhantomData<F>,
262        }
263
264        impl<F: Sync + Send + TryFrom<f32> + 'static> $struct_name<F> where
265            <F as TryFrom<f32>>::Error: std::fmt::Debug,
266        {
267            pub fn arc_new() -> std::sync::Arc<dyn PlayerFeatureAdder<F> + Send + Sync + 'static> {
268                std::sync::Arc::new(Self::new())
269            }
270        }
271
272        player_feature_adder!(
273            $struct_name,
274            $prop_getter,
275            $( $column_names ),*
276        );
277    }
278}
279
280/// Implements the ndarray feature-adder traits for an existing per-player feature type.
281#[macro_export]
282macro_rules! player_feature_adder {
283    ($struct_name:ident, $prop_getter:expr, $( $column_names:expr ),* $(,)?) => {
284        macro_rules! _player_feature_adder {
285            ($count:ident) => {
286                impl<F: TryFrom<f32>> LengthCheckedPlayerFeatureAdder<F, $count> for $struct_name<F>
287                where
288                    <F as TryFrom<f32>>::Error: std::fmt::Debug,
289                {
290                    fn get_column_headers_array(&self) -> &[&str; $count] {
291                        &[$( $column_names ),*]
292                    }
293
294                    fn get_features(
295                        &self,
296                        player_id: &PlayerId,
297                        processor: &ReplayProcessor,
298                        frame: &boxcars::Frame,
299                        frame_count: usize,
300                        current_time: f32,
301                    ) -> SubtrActorResult<[F; $count]> {
302                        $prop_getter(self, player_id, processor, frame, frame_count, current_time)
303                    }
304                }
305
306                impl_player_feature_adder!($struct_name);
307            };
308        }
309        paste::paste! {
310            const [<$struct_name:snake:upper _LENGTH>]: usize = [$($column_names),*].len();
311            _player_feature_adder!([<$struct_name:snake:upper _LENGTH>]);
312        }
313    }
314}
315
316/// Maps arbitrary conversion failures into a generic float-conversion error.
317pub fn convert_float_conversion_error<T>(_: T) -> SubtrActorError {
318    SubtrActorError::new(SubtrActorErrorVariant::FloatConversionError)
319}
320
321/// Converts a fixed list of values with a caller-supplied error mapper.
322#[macro_export]
323macro_rules! convert_all {
324    ($err:expr, $( $item:expr ),* $(,)?) => {{
325		Ok([
326			$( $item.try_into().map_err($err)? ),*
327		])
328	}};
329}
330
331/// Converts a fixed list of float-like values using [`convert_float_conversion_error`].
332#[macro_export]
333macro_rules! convert_all_floats {
334    ($( $item:expr ),* $(,)?) => {{
335        convert_all!(convert_float_conversion_error, $( $item ),*)
336    }};
337}