subtr_actor/collector/
ndarray.rs

1use crate::*;
2use ::ndarray;
3use boxcars;
4pub use derive_new;
5use lazy_static::lazy_static;
6pub use paste;
7use serde::Serialize;
8use std::sync::Arc;
9
10/// Represents the column headers in the collected data of an [`NDArrayCollector`].
11///
12/// # Fields
13///
14/// * `global_headers`: A list of strings that represent the global,
15///   player-independent features' column headers.
16/// * `player_headers`: A list of strings that represent the player-specific
17///   features' column headers.
18///
19/// Use [`Self::new`] to construct an instance of this struct.
20#[derive(Debug, Clone, PartialEq, Serialize)]
21pub struct NDArrayColumnHeaders {
22    pub global_headers: Vec<String>,
23    pub player_headers: Vec<String>,
24}
25
26impl NDArrayColumnHeaders {
27    pub fn new(global_headers: Vec<String>, player_headers: Vec<String>) -> Self {
28        Self {
29            global_headers,
30            player_headers,
31        }
32    }
33}
34
35/// A struct that contains both the metadata of a replay and the associated
36/// column headers.
37///
38/// # Fields
39///
40/// * `replay_meta`: Contains metadata about a [`boxcars::Replay`].
41/// * `column_headers`: The [`NDArrayColumnHeaders`] associated with the data
42///   collected from the replay.
43#[derive(Debug, Clone, PartialEq, Serialize)]
44pub struct ReplayMetaWithHeaders {
45    pub replay_meta: ReplayMeta,
46    pub column_headers: NDArrayColumnHeaders,
47}
48
49impl ReplayMetaWithHeaders {
50    pub fn headers_vec(&self) -> Vec<String> {
51        self.headers_vec_from(|_, _info, index| format!("Player {index} - "))
52    }
53
54    pub fn headers_vec_from<F>(&self, player_prefix_getter: F) -> Vec<String>
55    where
56        F: Fn(&Self, &PlayerInfo, usize) -> String,
57    {
58        self.column_headers
59            .global_headers
60            .iter()
61            .cloned()
62            .chain(self.replay_meta.player_order().enumerate().flat_map(
63                move |(player_index, info)| {
64                    let player_prefix = player_prefix_getter(self, info, player_index);
65                    self.column_headers
66                        .player_headers
67                        .iter()
68                        .map(move |header| format!("{player_prefix}{header}"))
69                },
70            ))
71            .collect()
72    }
73}
74
75/// [`NDArrayCollector`] is a [`Collector`] which transforms frame-based replay
76/// data into a 2-dimensional array of type [`ndarray::Array2`], where each
77/// element is of a specified floating point type.
78///
79/// It's initialized with collections of [`FeatureAdder`] instances which
80/// extract global, player independent features for each frame, and
81/// [`PlayerFeatureAdder`], which add player specific features for each frame.
82///
83/// It's main entrypoint is [`Self::get_meta_and_ndarray`], which provides
84/// [`ndarray::Array2`] along with column headers and replay metadata.
85pub struct NDArrayCollector<F> {
86    feature_adders: FeatureAdders<F>,
87    player_feature_adders: PlayerFeatureAdders<F>,
88    data: Vec<F>,
89    replay_meta: Option<ReplayMeta>,
90    frames_added: usize,
91}
92
93impl<F> NDArrayCollector<F> {
94    /// Creates a new instance of `NDArrayCollector`.
95    ///
96    /// # Arguments
97    ///
98    /// * `feature_adders` - A vector of [`Arc<dyn FeatureAdder<F>>`], each
99    ///   implementing the [`FeatureAdder`] trait. These are used to add global
100    ///   features to the replay data.
101    ///
102    /// * `player_feature_adders` - A vector of [`Arc<dyn PlayerFeatureAdder<F>>`],
103    ///   each implementing the [`PlayerFeatureAdder`]
104    ///   trait. These are used to add player-specific features to the replay
105    ///   data.
106    ///
107    /// # Returns
108    ///
109    /// A new [`NDArrayCollector`] instance. This instance is initialized with
110    /// empty data, no replay metadata and zero frames added.
111    pub fn new(
112        feature_adders: FeatureAdders<F>,
113        player_feature_adders: PlayerFeatureAdders<F>,
114    ) -> Self {
115        Self {
116            feature_adders,
117            player_feature_adders,
118            data: Vec::new(),
119            replay_meta: None,
120            frames_added: 0,
121        }
122    }
123
124    /// Returns the column headers of the 2-dimensional array produced by the
125    /// [`NDArrayCollector`].
126    ///
127    /// # Returns
128    ///
129    /// An instance of [`NDArrayColumnHeaders`] representing the column headers
130    /// in the collected data.
131    pub fn get_column_headers(&self) -> NDArrayColumnHeaders {
132        let global_headers = self
133            .feature_adders
134            .iter()
135            .flat_map(move |fa| {
136                fa.get_column_headers()
137                    .iter()
138                    .map(move |column_name| column_name.to_string())
139            })
140            .collect();
141        let player_headers = self
142            .player_feature_adders
143            .iter()
144            .flat_map(move |pfa| {
145                pfa.get_column_headers()
146                    .iter()
147                    .map(move |base_name| base_name.to_string())
148            })
149            .collect();
150        NDArrayColumnHeaders::new(global_headers, player_headers)
151    }
152
153    /// This function consumes the [`NDArrayCollector`] instance and returns the
154    /// data collected as an [`ndarray::Array2`].
155    ///
156    /// # Returns
157    ///
158    /// A [`SubtrActorResult`] containing the collected data as an
159    /// [`ndarray::Array2`].
160    ///
161    /// This method is a shorthand for calling [`Self::get_meta_and_ndarray`]
162    /// and discarding the replay metadata and headers.
163    pub fn get_ndarray(self) -> SubtrActorResult<ndarray::Array2<F>> {
164        self.get_meta_and_ndarray().map(|a| a.1)
165    }
166
167    /// Consumes the [`NDArrayCollector`] and returns the collected features as a
168    /// 2D ndarray, along with replay metadata and headers.
169    ///
170    /// # Returns
171    ///
172    /// A [`SubtrActorResult`] containing a tuple:
173    /// - [`ReplayMetaWithHeaders`]: The replay metadata along with the headers
174    ///   for each column in the ndarray.
175    /// - [`ndarray::Array2<F>`]: The collected features as a 2D ndarray.
176    pub fn get_meta_and_ndarray(
177        self,
178    ) -> SubtrActorResult<(ReplayMetaWithHeaders, ndarray::Array2<F>)> {
179        let features_per_row = self.try_get_frame_feature_count()?;
180        let expected_length = features_per_row * self.frames_added;
181        assert!(self.data.len() == expected_length);
182        let column_headers = self.get_column_headers();
183        Ok((
184            ReplayMetaWithHeaders {
185                replay_meta: self.replay_meta.ok_or(SubtrActorError::new(
186                    SubtrActorErrorVariant::CouldNotBuildReplayMeta,
187                ))?,
188                column_headers,
189            },
190            ndarray::Array2::from_shape_vec((self.frames_added, features_per_row), self.data)
191                .map_err(SubtrActorErrorVariant::NDArrayShapeError)
192                .map_err(SubtrActorError::new)?,
193        ))
194    }
195
196    /// Processes a [`boxcars::Replay`] and returns its metadata along with column headers.
197    ///
198    /// This method first processes the replay using a [`ReplayProcessor`]. It
199    /// then updates the `replay_meta` field if it's not already set, and
200    /// returns a clone of the `replay_meta` field along with column headers of
201    /// the data.
202    ///
203    /// # Arguments
204    ///
205    /// * `replay`: A reference to the [`boxcars::Replay`] to process.
206    ///
207    /// # Returns
208    ///
209    /// A [`SubtrActorResult`] containing a [`ReplayMetaWithHeaders`] that
210    /// includes the metadata of the replay and column headers.
211    pub fn process_and_get_meta_and_headers(
212        &mut self,
213        replay: &boxcars::Replay,
214    ) -> SubtrActorResult<ReplayMetaWithHeaders> {
215        let mut processor = ReplayProcessor::new(replay)?;
216        processor.process_long_enough_to_get_actor_ids()?;
217        self.maybe_set_replay_meta(&processor)?;
218        Ok(ReplayMetaWithHeaders {
219            replay_meta: self
220                .replay_meta
221                .as_ref()
222                .ok_or(SubtrActorError::new(
223                    SubtrActorErrorVariant::CouldNotBuildReplayMeta,
224                ))?
225                .clone(),
226            column_headers: self.get_column_headers(),
227        })
228    }
229
230    fn try_get_frame_feature_count(&self) -> SubtrActorResult<usize> {
231        let player_count = self
232            .replay_meta
233            .as_ref()
234            .ok_or(SubtrActorError::new(
235                SubtrActorErrorVariant::CouldNotBuildReplayMeta,
236            ))?
237            .player_count();
238        let global_feature_count: usize = self
239            .feature_adders
240            .iter()
241            .map(|fa| fa.features_added())
242            .sum();
243        let player_feature_count: usize = self
244            .player_feature_adders
245            .iter() // iterate
246            .map(|pfa| pfa.features_added() * player_count)
247            .sum();
248        Ok(global_feature_count + player_feature_count)
249    }
250
251    fn maybe_set_replay_meta(&mut self, processor: &ReplayProcessor) -> SubtrActorResult<()> {
252        if self.replay_meta.is_none() {
253            self.replay_meta = Some(processor.get_replay_meta()?);
254        }
255        Ok(())
256    }
257}
258
259impl<F> Collector for NDArrayCollector<F> {
260    fn process_frame(
261        &mut self,
262        processor: &ReplayProcessor,
263        frame: &boxcars::Frame,
264        frame_number: usize,
265        current_time: f32,
266    ) -> SubtrActorResult<collector::TimeAdvance> {
267        self.maybe_set_replay_meta(processor)?;
268
269        if !processor.ball_rigid_body_exists()? {
270            return Ok(collector::TimeAdvance::NextFrame);
271        }
272
273        for feature_adder in self.feature_adders.iter() {
274            feature_adder.add_features(
275                processor,
276                frame,
277                frame_number,
278                current_time,
279                &mut self.data,
280            )?;
281        }
282
283        for player_id in processor.iter_player_ids_in_order() {
284            for player_feature_adder in self.player_feature_adders.iter() {
285                player_feature_adder.add_features(
286                    player_id,
287                    processor,
288                    frame,
289                    frame_number,
290                    current_time,
291                    &mut self.data,
292                )?;
293            }
294        }
295
296        self.frames_added += 1;
297
298        Ok(collector::TimeAdvance::NextFrame)
299    }
300}
301
302impl NDArrayCollector<f32> {
303    pub fn from_strings(fa_names: &[&str], pfa_names: &[&str]) -> SubtrActorResult<Self> {
304        let feature_adders: Vec<Arc<dyn FeatureAdder<f32> + Send + Sync>> = fa_names
305            .iter()
306            .map(|name| {
307                Ok(NAME_TO_GLOBAL_FEATURE_ADDER
308                    .get(name)
309                    .ok_or_else(|| {
310                        SubtrActorError::new(SubtrActorErrorVariant::UnknownFeatureAdderName(
311                            name.to_string(),
312                        ))
313                    })?
314                    .clone())
315            })
316            .collect::<SubtrActorResult<Vec<_>>>()?;
317        let player_feature_adders: Vec<Arc<dyn PlayerFeatureAdder<f32> + Send + Sync>> = pfa_names
318            .iter()
319            .map(|name| {
320                Ok(NAME_TO_PLAYER_FEATURE_ADDER
321                    .get(name)
322                    .ok_or_else(|| {
323                        SubtrActorError::new(SubtrActorErrorVariant::UnknownFeatureAdderName(
324                            name.to_string(),
325                        ))
326                    })?
327                    .clone())
328            })
329            .collect::<SubtrActorResult<Vec<_>>>()?;
330        Ok(Self::new(feature_adders, player_feature_adders))
331    }
332}
333
334impl<F: TryFrom<f32> + Send + Sync + 'static> Default for NDArrayCollector<F>
335where
336    <F as TryFrom<f32>>::Error: std::fmt::Debug,
337{
338    fn default() -> Self {
339        NDArrayCollector::new(
340            vec![BallRigidBody::arc_new()],
341            vec![
342                PlayerRigidBody::arc_new(),
343                PlayerBoost::arc_new(),
344                PlayerAnyJump::arc_new(),
345            ],
346        )
347    }
348}
349
350/// This trait acts as an abstraction over a feature adder, and is primarily
351/// used to allow for heterogeneous collections of feature adders in the
352/// [`NDArrayCollector`]. While it provides methods for adding features and
353/// retrieving column headers, it is generally recommended to implement the
354/// [`LengthCheckedFeatureAdder`] trait instead, which provides compile-time
355/// guarantees about the number of features returned.
356pub trait FeatureAdder<F> {
357    fn features_added(&self) -> usize {
358        self.get_column_headers().len()
359    }
360
361    fn get_column_headers(&self) -> &[&str];
362
363    fn add_features(
364        &self,
365        processor: &ReplayProcessor,
366        frame: &boxcars::Frame,
367        frame_count: usize,
368        current_time: f32,
369        vector: &mut Vec<F>,
370    ) -> SubtrActorResult<()>;
371}
372
373pub type FeatureAdders<F> = Vec<Arc<dyn FeatureAdder<F> + Send + Sync>>;
374
375/// This trait is stricter version of the [`FeatureAdder`] trait, enforcing at
376/// compile time that the number of features added is equal to the number of
377/// column headers provided. Implementations of this trait can be automatically
378/// adapted to the [`FeatureAdder`] trait using the [`impl_feature_adder!`]
379/// macro.
380pub trait LengthCheckedFeatureAdder<F, const N: usize> {
381    fn get_column_headers_array(&self) -> &[&str; N];
382
383    fn get_features(
384        &self,
385        processor: &ReplayProcessor,
386        frame: &boxcars::Frame,
387        frame_count: usize,
388        current_time: f32,
389    ) -> SubtrActorResult<[F; N]>;
390}
391
392/// A macro to provide an automatic implementation of the [`FeatureAdder`] trait
393/// for types that implement [`LengthCheckedFeatureAdder`]. This allows you to
394/// take advantage of the compile-time guarantees provided by
395/// [`LengthCheckedFeatureAdder`], while still being able to use your type in
396/// contexts that require a [`FeatureAdder`] object. This macro is used to
397/// bridge the gap between the two traits, as Rust's type system does not
398/// currently provide a way to prove to the compiler that there will always be
399/// exactly one implementation of [`LengthCheckedFeatureAdder`] for each type.
400#[macro_export]
401macro_rules! impl_feature_adder {
402    ($struct_name:ident) => {
403        impl<F: TryFrom<f32>> FeatureAdder<F> for $struct_name<F>
404        where
405            <F as TryFrom<f32>>::Error: std::fmt::Debug,
406        {
407            fn add_features(
408                &self,
409                processor: &ReplayProcessor,
410                frame: &boxcars::Frame,
411                frame_count: usize,
412                current_time: f32,
413                vector: &mut Vec<F>,
414            ) -> SubtrActorResult<()> {
415                Ok(
416                    vector.extend(self.get_features(
417                        processor,
418                        frame,
419                        frame_count,
420                        current_time,
421                    )?),
422                )
423            }
424
425            fn get_column_headers(&self) -> &[&str] {
426                self.get_column_headers_array()
427            }
428        }
429    };
430}
431
432/// This trait acts as an abstraction over a player-specific feature adder, and
433/// is primarily used to allow for heterogeneous collections of player feature
434/// adders in the [`NDArrayCollector`]. While it provides methods for adding
435/// player-specific features and retrieving column headers, it is generally
436/// recommended to implement the [`LengthCheckedPlayerFeatureAdder`] trait
437/// instead, which provides compile-time guarantees about the number of features
438/// returned.
439pub trait PlayerFeatureAdder<F> {
440    fn features_added(&self) -> usize {
441        self.get_column_headers().len()
442    }
443
444    fn get_column_headers(&self) -> &[&str];
445
446    fn add_features(
447        &self,
448        player_id: &PlayerId,
449        processor: &ReplayProcessor,
450        frame: &boxcars::Frame,
451        frame_count: usize,
452        current_time: f32,
453        vector: &mut Vec<F>,
454    ) -> SubtrActorResult<()>;
455}
456
457pub type PlayerFeatureAdders<F> = Vec<Arc<dyn PlayerFeatureAdder<F> + Send + Sync>>;
458
459/// This trait is a more strict version of the [`PlayerFeatureAdder`] trait,
460/// enforcing at compile time that the number of player-specific features added
461/// is equal to the number of column headers provided. Implementations of this
462/// trait can be automatically adapted to the [`PlayerFeatureAdder`] trait using
463/// the [`impl_player_feature_adder!`] macro.
464pub trait LengthCheckedPlayerFeatureAdder<F, const N: usize> {
465    fn get_column_headers_array(&self) -> &[&str; N];
466
467    fn get_features(
468        &self,
469        player_id: &PlayerId,
470        processor: &ReplayProcessor,
471        frame: &boxcars::Frame,
472        frame_count: usize,
473        current_time: f32,
474    ) -> SubtrActorResult<[F; N]>;
475}
476
477/// A macro to provide an automatic implementation of the [`PlayerFeatureAdder`]
478/// trait for types that implement [`LengthCheckedPlayerFeatureAdder`]. This
479/// allows you to take advantage of the compile-time guarantees provided by
480/// [`LengthCheckedPlayerFeatureAdder`], while still being able to use your type
481/// in contexts that require a [`PlayerFeatureAdder`] object. This macro is used
482/// to bridge the gap between the two traits, as Rust's type system does not
483/// currently provide a way to prove to the compiler that there will always be
484/// exactly one implementation of [`LengthCheckedPlayerFeatureAdder`] for each
485/// type.
486#[macro_export]
487macro_rules! impl_player_feature_adder {
488    ($struct_name:ident) => {
489        impl<F: TryFrom<f32>> PlayerFeatureAdder<F> for $struct_name<F>
490        where
491            <F as TryFrom<f32>>::Error: std::fmt::Debug,
492        {
493            fn add_features(
494                &self,
495                player_id: &PlayerId,
496                processor: &ReplayProcessor,
497                frame: &boxcars::Frame,
498                frame_count: usize,
499                current_time: f32,
500                vector: &mut Vec<F>,
501            ) -> SubtrActorResult<()> {
502                Ok(vector.extend(self.get_features(
503                    player_id,
504                    processor,
505                    frame,
506                    frame_count,
507                    current_time,
508                )?))
509            }
510
511            fn get_column_headers(&self) -> &[&str] {
512                self.get_column_headers_array()
513            }
514        }
515    };
516}
517
518impl<G, F, const N: usize> FeatureAdder<F> for (G, &[&str; N])
519where
520    G: Fn(&ReplayProcessor, &boxcars::Frame, usize, f32) -> SubtrActorResult<[F; N]>,
521{
522    fn add_features(
523        &self,
524        processor: &ReplayProcessor,
525        frame: &boxcars::Frame,
526        frame_count: usize,
527        current_time: f32,
528        vector: &mut Vec<F>,
529    ) -> SubtrActorResult<()> {
530        vector.extend(self.0(processor, frame, frame_count, current_time)?);
531        Ok(())
532    }
533
534    fn get_column_headers(&self) -> &[&str] {
535        self.1.as_slice()
536    }
537}
538
539impl<G, F, const N: usize> PlayerFeatureAdder<F> for (G, &[&str; N])
540where
541    G: Fn(&PlayerId, &ReplayProcessor, &boxcars::Frame, usize, f32) -> SubtrActorResult<[F; N]>,
542{
543    fn add_features(
544        &self,
545        player_id: &PlayerId,
546        processor: &ReplayProcessor,
547        frame: &boxcars::Frame,
548        frame_count: usize,
549        current_time: f32,
550        vector: &mut Vec<F>,
551    ) -> SubtrActorResult<()> {
552        vector.extend(self.0(
553            player_id,
554            processor,
555            frame,
556            frame_count,
557            current_time,
558        )?);
559        Ok(())
560    }
561
562    fn get_column_headers(&self) -> &[&str] {
563        self.1.as_slice()
564    }
565}
566
567/// This macro creates a global [`FeatureAdder`] struct and implements the
568/// necessary traits to add the calculated features to the data matrix. The
569/// macro exports a struct with the same name as passed in the parameter. The
570/// number of column names and the length of the feature array returned by
571/// `$prop_getter` are checked at compile time to ensure they match, in line
572/// with the [`LengthCheckedFeatureAdder`] trait. The output struct also
573/// provides an implementation of the [`FeatureAdder`] trait via the
574/// [`impl_feature_adder!`] macro, allowing it to be used in contexts where a
575/// [`FeatureAdder`] object is required.
576///
577/// # Parameters
578///
579/// * `$struct_name`: The name of the struct to be created.
580/// * `$prop_getter`: The function or closure used to calculate the features.
581/// * `$( $column_names:expr ),*`: A comma-separated list of column names as strings.
582///
583/// # Example
584///
585/// ```
586/// use subtr_actor::*;
587///
588/// build_global_feature_adder!(
589///     SecondsRemainingExample,
590///     |_, processor: &ReplayProcessor, _frame, _index, _current_time| {
591///         convert_all_floats!(processor.get_seconds_remaining()?.clone() as f32)
592///     },
593///     "seconds remaining"
594/// );
595/// ```
596///
597/// This will create a struct named `SecondsRemaining` and implement necessary
598/// traits to calculate features using the provided closure. The feature will be
599/// added under the column name "seconds remaining". Note, however, that it is
600/// possible to add more than one feature with each feature adder
601#[macro_export]
602macro_rules! build_global_feature_adder {
603    ($struct_name:ident, $prop_getter:expr, $( $column_names:expr ),* $(,)?) => {
604
605        #[derive(derive_new::new)]
606        pub struct $struct_name<F> {
607            _zero: std::marker::PhantomData<F>,
608        }
609
610        impl<F: Sync + Send + TryFrom<f32> + 'static> $struct_name<F> where
611            <F as TryFrom<f32>>::Error: std::fmt::Debug,
612        {
613            pub fn arc_new() -> std::sync::Arc<dyn FeatureAdder<F> + Send + Sync + 'static> {
614                std::sync::Arc::new(Self::new())
615            }
616        }
617
618        global_feature_adder!(
619            $struct_name,
620            $prop_getter,
621            $( $column_names ),*
622        );
623    }
624}
625
626/// This macro is used to implement necessary traits for an existing struct to
627/// add the calculated features to the data matrix. This macro is particularly
628/// useful when the feature adder needs to be instantiated with specific
629/// parameters. The number of column names and the length of the feature array
630/// returned by `$prop_getter` are checked at compile time to ensure they match.
631///
632/// # Parameters
633///
634/// * `$struct_name`: The name of the existing struct.
635/// * `$prop_getter`: The function or closure used to calculate the features.
636/// * `$( $column_names:expr ),*`: A comma-separated list of column names as strings.
637#[macro_export]
638macro_rules! global_feature_adder {
639    ($struct_name:ident, $prop_getter:expr, $( $column_names:expr ),* $(,)?) => {
640        macro_rules! _global_feature_adder {
641            ($count:ident) => {
642                impl<F: TryFrom<f32>> LengthCheckedFeatureAdder<F, $count> for $struct_name<F>
643                where
644                    <F as TryFrom<f32>>::Error: std::fmt::Debug,
645                {
646                    fn get_column_headers_array(&self) -> &[&str; $count] {
647                        &[$( $column_names ),*]
648                    }
649
650                    fn get_features(
651                        &self,
652                        processor: &ReplayProcessor,
653                        frame: &boxcars::Frame,
654                        frame_count: usize,
655                        current_time: f32,
656                    ) -> SubtrActorResult<[F; $count]> {
657                        $prop_getter(self, processor, frame, frame_count, current_time)
658                    }
659                }
660
661                impl_feature_adder!($struct_name);
662            };
663        }
664        paste::paste! {
665            const [<$struct_name:snake:upper _LENGTH>]: usize = [$($column_names),*].len();
666            _global_feature_adder!([<$struct_name:snake:upper _LENGTH>]);
667        }
668    }
669}
670
671/// This macro creates a player feature adder struct and implements the
672/// necessary traits to add the calculated player-specific features to the data
673/// matrix. The macro exports a struct with the same name as passed in the
674/// parameter. The number of column names and the length of the feature array
675/// returned by `$prop_getter` are checked at compile time to ensure they match,
676/// in line with the [`LengthCheckedPlayerFeatureAdder`] trait. The output
677/// struct also provides an implementation of the [`PlayerFeatureAdder`] trait
678/// via the [`impl_player_feature_adder!`] macro, allowing it to be used in
679/// contexts where a [`PlayerFeatureAdder`] object is required.
680///
681/// # Parameters
682///
683/// * `$struct_name`: The name of the struct to be created.
684/// * `$prop_getter`: The function or closure used to calculate the features.
685/// * `$( $column_names:expr ),*`: A comma-separated list of column names as strings.
686///
687/// # Example
688///
689/// ```
690/// use subtr_actor::*;
691///
692/// fn u8_get_f32(v: u8) -> SubtrActorResult<f32> {
693///    v.try_into().map_err(convert_float_conversion_error)
694/// }
695///
696/// build_player_feature_adder!(
697///     PlayerJump,
698///     |_,
699///      player_id: &PlayerId,
700///      processor: &ReplayProcessor,
701///      _frame,
702///      _frame_number,
703///      _current_time: f32| {
704///         convert_all_floats!(
705///             processor
706///                 .get_dodge_active(player_id)
707///                 .and_then(u8_get_f32)
708///                 .unwrap_or(0.0),
709///             processor
710///                 .get_jump_active(player_id)
711///                 .and_then(u8_get_f32)
712///                 .unwrap_or(0.0),
713///             processor
714///                 .get_double_jump_active(player_id)
715///                 .and_then(u8_get_f32)
716///                 .unwrap_or(0.0),
717///         )
718///     },
719///     "dodge active",
720///     "jump active",
721///     "double jump active"
722/// );
723/// ```
724///
725/// This will create a struct named `PlayerJump` and implement necessary
726/// traits to calculate features using the provided closure. The player-specific
727/// features will be added under the column names "dodge active",
728/// "jump active", and "double jump active" respectively.
729#[macro_export]
730macro_rules! build_player_feature_adder {
731    ($struct_name:ident, $prop_getter:expr, $( $column_names:expr ),* $(,)?) => {
732        #[derive(derive_new::new)]
733        pub struct $struct_name<F> {
734            _zero: std::marker::PhantomData<F>,
735        }
736
737        impl<F: Sync + Send + TryFrom<f32> + 'static> $struct_name<F> where
738            <F as TryFrom<f32>>::Error: std::fmt::Debug,
739        {
740            pub fn arc_new() -> std::sync::Arc<dyn PlayerFeatureAdder<F> + Send + Sync + 'static> {
741                std::sync::Arc::new(Self::new())
742            }
743        }
744
745        player_feature_adder!(
746            $struct_name,
747            $prop_getter,
748            $( $column_names ),*
749        );
750    }
751}
752
753/// This macro is used to implement necessary traits for an existing struct to
754/// add the calculated player-specific features to the data matrix. This macro
755/// is particularly useful when the feature adder needs to be instantiated with
756/// specific parameters. The number of column names and the length of the
757/// feature array returned by `$prop_getter` are checked at compile time to
758/// ensure they match.
759///
760/// # Parameters
761///
762/// * `$struct_name`: The name of the existing struct.
763/// * `$prop_getter`: The function or closure used to calculate the features.
764/// * `$( $column_names:expr ),*`: A comma-separated list of column names as strings.
765#[macro_export]
766macro_rules! player_feature_adder {
767    ($struct_name:ident, $prop_getter:expr, $( $column_names:expr ),* $(,)?) => {
768        macro_rules! _player_feature_adder {
769            ($count:ident) => {
770                impl<F: TryFrom<f32>> LengthCheckedPlayerFeatureAdder<F, $count> for $struct_name<F>
771                where
772                    <F as TryFrom<f32>>::Error: std::fmt::Debug,
773                {
774                    fn get_column_headers_array(&self) -> &[&str; $count] {
775                        &[$( $column_names ),*]
776                    }
777
778                    fn get_features(
779                        &self,
780                        player_id: &PlayerId,
781                        processor: &ReplayProcessor,
782                        frame: &boxcars::Frame,
783                        frame_count: usize,
784                        current_time: f32,
785                    ) -> SubtrActorResult<[F; $count]> {
786                        $prop_getter(self, player_id, processor, frame, frame_count, current_time)
787                    }
788                }
789
790                impl_player_feature_adder!($struct_name);
791            };
792        }
793        paste::paste! {
794            const [<$struct_name:snake:upper _LENGTH>]: usize = [$($column_names),*].len();
795            _player_feature_adder!([<$struct_name:snake:upper _LENGTH>]);
796        }
797    }
798}
799
800/// Unconditionally convert any error into a [`SubtrActorError`] of with the
801/// [`SubtrActorErrorVariant::FloatConversionError`] variant.
802pub fn convert_float_conversion_error<T>(_: T) -> SubtrActorError {
803    SubtrActorError::new(SubtrActorErrorVariant::FloatConversionError)
804}
805
806/// A macro that tries to convert each provided item into a type. If any of the
807/// conversions fail, it short-circuits and returns the error.
808///
809/// The first argument `$err` is a closure that accepts an error and returns a
810/// [`SubtrActorResult`]. It is used to map any conversion errors into a
811/// [`SubtrActorResult`].
812///
813/// Subsequent arguments should be expressions that implement the [`TryInto`]
814/// trait, with the type they're being converted into being the one used in the
815/// `Ok` variant of the return value.
816#[macro_export]
817macro_rules! convert_all {
818    ($err:expr, $( $item:expr ),* $(,)?) => {{
819		Ok([
820			$( $item.try_into().map_err($err)? ),*
821		])
822	}};
823}
824
825/// A convenience macro that uses the [`convert_all`] macro with the
826/// [`convert_float_conversion_error`] function for error handling.
827///
828/// Each item provided is attempted to be converted into a floating point
829/// number. If any of the conversions fail, it short-circuits and returns the
830/// error. This macro must be used in the context of a function that returns a
831/// [`Result`] because it uses the ? operator. It is primarily useful for
832/// defining function like the one shown in the example below that are generic
833/// in some parameter that can implements [`TryFrom`].
834///
835/// # Example
836///
837/// ```
838/// use subtr_actor::*;
839///
840/// pub fn some_constant_function<F: TryFrom<f32>>(
841///     rigid_body: &boxcars::RigidBody,
842/// ) -> SubtrActorResult<[F; 3]> {
843///     convert_all_floats!(42.0, 0.0, 1.234)
844/// }
845/// ```
846#[macro_export]
847macro_rules! convert_all_floats {
848    ($( $item:expr ),* $(,)?) => {{
849        convert_all!(convert_float_conversion_error, $( $item ),*)
850    }};
851}
852
853fn or_zero_boxcars_3f() -> boxcars::Vector3f {
854    boxcars::Vector3f {
855        x: 0.0,
856        y: 0.0,
857        z: 0.0,
858    }
859}
860
861type RigidBodyArrayResult<F> = SubtrActorResult<[F; 12]>;
862
863/// Extracts the location, rotation, linear velocity and angular velocity from a
864/// [`boxcars::RigidBody`] and converts them to a type implementing [`TryFrom<f32>`].
865///
866/// If any of the components of the rigid body are not set (`None`), they are
867/// treated as zero.
868///
869/// The returned array contains twelve elements in the following order: x, y, z
870/// location, x, y, z rotation (as Euler angles), x, y, z linear velocity, x, y,
871/// z angular velocity.
872pub fn get_rigid_body_properties<F: TryFrom<f32>>(
873    rigid_body: &boxcars::RigidBody,
874) -> RigidBodyArrayResult<F>
875where
876    <F as TryFrom<f32>>::Error: std::fmt::Debug,
877{
878    let linear_velocity = rigid_body
879        .linear_velocity
880        .unwrap_or_else(or_zero_boxcars_3f);
881    let angular_velocity = rigid_body
882        .angular_velocity
883        .unwrap_or_else(or_zero_boxcars_3f);
884    let rotation = rigid_body.rotation;
885    let location = rigid_body.location;
886    let (rx, ry, rz) =
887        glam::quat(rotation.x, rotation.y, rotation.z, rotation.w).to_euler(glam::EulerRot::XYZ);
888    convert_all_floats!(
889        location.x,
890        location.y,
891        location.z,
892        rx,
893        ry,
894        rz,
895        linear_velocity.x,
896        linear_velocity.y,
897        linear_velocity.z,
898        angular_velocity.x,
899        angular_velocity.y,
900        angular_velocity.z,
901    )
902}
903
904/// Extracts the location and rotation from a [`boxcars::RigidBody`] and
905/// converts them to a type implementing [`TryFrom<f32>`].
906///
907/// If any of the components of the rigid body are not set (`None`), they are
908/// treated as zero.
909///
910/// The returned array contains seven elements in the following order: x, y, z
911/// location, x, y, z, w rotation.
912pub fn get_rigid_body_properties_no_velocities<F: TryFrom<f32>>(
913    rigid_body: &boxcars::RigidBody,
914) -> SubtrActorResult<[F; 7]>
915where
916    <F as TryFrom<f32>>::Error: std::fmt::Debug,
917{
918    let rotation = rigid_body.rotation;
919    let location = rigid_body.location;
920    convert_all_floats!(
921        location.x, location.y, location.z, rotation.x, rotation.y, rotation.z, rotation.w
922    )
923}
924
925fn default_rb_state<F: TryFrom<f32>>() -> RigidBodyArrayResult<F>
926where
927    <F as TryFrom<f32>>::Error: std::fmt::Debug,
928{
929    convert_all!(
930        convert_float_conversion_error,
931        // We use huge values for location instead of 0s so that hopefully any
932        // model built on this data can understand that the player is not
933        // actually on the field.
934        0.0,
935        0.0,
936        0.0,
937        0.0,
938        0.0,
939        0.0,
940        0.0,
941        0.0,
942        0.0,
943        0.0,
944        0.0,
945        0.0,
946    )
947}
948
949fn default_rb_state_no_velocities<F: TryFrom<f32>>() -> SubtrActorResult<[F; 7]>
950where
951    <F as TryFrom<f32>>::Error: std::fmt::Debug,
952{
953    convert_all_floats!(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,)
954}
955
956build_global_feature_adder!(
957    SecondsRemaining,
958    |_, processor: &ReplayProcessor, _frame, _index, _current_time| {
959        convert_all_floats!(processor.get_seconds_remaining()?.clone() as f32)
960    },
961    "seconds remaining"
962);
963
964build_global_feature_adder!(
965    CurrentTime,
966    |_, _processor, _frame, _index, current_time: f32| { convert_all_floats!(current_time) },
967    "current time"
968);
969
970build_global_feature_adder!(
971    FrameTime,
972    |_, _processor, frame: &boxcars::Frame, _index, _current_time| {
973        convert_all_floats!(frame.time)
974    },
975    "frame time"
976);
977
978build_global_feature_adder!(
979    BallRigidBody,
980    |_, processor: &ReplayProcessor, _frame, _index, _current_time| {
981        get_rigid_body_properties(processor.get_ball_rigid_body()?)
982    },
983    "Ball - position x",
984    "Ball - position y",
985    "Ball - position z",
986    "Ball - rotation x",
987    "Ball - rotation y",
988    "Ball - rotation z",
989    "Ball - linear velocity x",
990    "Ball - linear velocity y",
991    "Ball - linear velocity z",
992    "Ball - angular velocity x",
993    "Ball - angular velocity y",
994    "Ball - angular velocity z",
995);
996
997build_global_feature_adder!(
998    BallRigidBodyNoVelocities,
999    |_, processor: &ReplayProcessor, _frame, _index, _current_time| {
1000        get_rigid_body_properties_no_velocities(processor.get_ball_rigid_body()?)
1001    },
1002    "Ball - position x",
1003    "Ball - position y",
1004    "Ball - position z",
1005    "Ball - rotation x",
1006    "Ball - rotation y",
1007    "Ball - rotation z",
1008    "Ball - rotation w",
1009);
1010
1011// XXX: This approach seems to give some unexpected results with rotation
1012// changes. There may be a unit mismatch or some other type of issue.
1013build_global_feature_adder!(
1014    VelocityAddedBallRigidBodyNoVelocities,
1015    |_, processor: &ReplayProcessor, _frame, _index, current_time: f32| {
1016        get_rigid_body_properties_no_velocities(
1017            &processor.get_velocity_applied_ball_rigid_body(current_time)?,
1018        )
1019    },
1020    "Ball - position x",
1021    "Ball - position y",
1022    "Ball - position z",
1023    "Ball - rotation x",
1024    "Ball - rotation y",
1025    "Ball - rotation z",
1026    "Ball - rotation w",
1027);
1028
1029#[derive(derive_new::new)]
1030pub struct InterpolatedBallRigidBodyNoVelocities<F> {
1031    close_enough_to_frame_time: f32,
1032    _zero: std::marker::PhantomData<F>,
1033}
1034
1035impl<F> InterpolatedBallRigidBodyNoVelocities<F> {
1036    pub fn arc_new(close_enough_to_frame_time: f32) -> Arc<Self> {
1037        Arc::new(Self::new(close_enough_to_frame_time))
1038    }
1039}
1040
1041global_feature_adder!(
1042    InterpolatedBallRigidBodyNoVelocities,
1043    |s: &InterpolatedBallRigidBodyNoVelocities<F>,
1044     processor: &ReplayProcessor,
1045     _frame: &boxcars::Frame,
1046     _index,
1047     current_time: f32| {
1048        processor
1049            .get_interpolated_ball_rigid_body(current_time, s.close_enough_to_frame_time)
1050            .map(|v| get_rigid_body_properties_no_velocities(&v))
1051            .unwrap_or_else(|_| default_rb_state_no_velocities())
1052    },
1053    "Ball - position x",
1054    "Ball - position y",
1055    "Ball - position z",
1056    "Ball - rotation x",
1057    "Ball - rotation y",
1058    "Ball - rotation z",
1059    "Ball - rotation w",
1060);
1061
1062build_player_feature_adder!(
1063    PlayerRigidBody,
1064    |_, player_id: &PlayerId, processor: &ReplayProcessor, _frame, _index, _current_time: f32| {
1065        if let Ok(rb) = processor.get_player_rigid_body(player_id) {
1066            get_rigid_body_properties(rb)
1067        } else {
1068            default_rb_state()
1069        }
1070    },
1071    "position x",
1072    "position y",
1073    "position z",
1074    "rotation x",
1075    "rotation y",
1076    "rotation z",
1077    "linear velocity x",
1078    "linear velocity y",
1079    "linear velocity z",
1080    "angular velocity x",
1081    "angular velocity y",
1082    "angular velocity z",
1083);
1084
1085build_player_feature_adder!(
1086    PlayerRigidBodyNoVelocities,
1087    |_, player_id: &PlayerId, processor: &ReplayProcessor, _frame, _index, _current_time: f32| {
1088        if let Ok(rb) = processor.get_player_rigid_body(player_id) {
1089            get_rigid_body_properties_no_velocities(rb)
1090        } else {
1091            default_rb_state_no_velocities()
1092        }
1093    },
1094    "position x",
1095    "position y",
1096    "position z",
1097    "rotation x",
1098    "rotation y",
1099    "rotation z",
1100    "rotation w"
1101);
1102
1103// XXX: This approach seems to give some unexpected results with rotation
1104// changes. There may be a unit mismatch or some other type of issue.
1105build_player_feature_adder!(
1106    VelocityAddedPlayerRigidBodyNoVelocities,
1107    |_, player_id: &PlayerId, processor: &ReplayProcessor, _frame, _index, current_time: f32| {
1108        if let Ok(rb) = processor.get_velocity_applied_player_rigid_body(player_id, current_time) {
1109            get_rigid_body_properties_no_velocities(&rb)
1110        } else {
1111            default_rb_state_no_velocities()
1112        }
1113    },
1114    "position x",
1115    "position y",
1116    "position z",
1117    "rotation x",
1118    "rotation y",
1119    "rotation z",
1120    "rotation w"
1121);
1122
1123#[derive(derive_new::new)]
1124pub struct InterpolatedPlayerRigidBodyNoVelocities<F> {
1125    close_enough_to_frame_time: f32,
1126    _zero: std::marker::PhantomData<F>,
1127}
1128
1129impl<F> InterpolatedPlayerRigidBodyNoVelocities<F> {
1130    pub fn arc_new(close_enough_to_frame_time: f32) -> Arc<Self> {
1131        Arc::new(Self::new(close_enough_to_frame_time))
1132    }
1133}
1134
1135player_feature_adder!(
1136    InterpolatedPlayerRigidBodyNoVelocities,
1137    |s: &InterpolatedPlayerRigidBodyNoVelocities<F>,
1138     player_id: &PlayerId,
1139     processor: &ReplayProcessor,
1140     _frame: &boxcars::Frame,
1141     _index,
1142     current_time: f32| {
1143        processor
1144            .get_interpolated_player_rigid_body(
1145                player_id,
1146                current_time,
1147                s.close_enough_to_frame_time,
1148            )
1149            .map(|v| get_rigid_body_properties_no_velocities(&v))
1150            .unwrap_or_else(|_| default_rb_state_no_velocities())
1151    },
1152    "i position x",
1153    "i position y",
1154    "i position z",
1155    "i rotation x",
1156    "i rotation y",
1157    "i rotation z",
1158    "i rotation w"
1159);
1160
1161build_player_feature_adder!(
1162    PlayerBoost,
1163    |_, player_id: &PlayerId, processor: &ReplayProcessor, _frame, _index, _current_time: f32| {
1164        convert_all_floats!(processor.get_player_boost_level(player_id).unwrap_or(0.0))
1165    },
1166    "boost level"
1167);
1168
1169fn u8_get_f32(v: u8) -> SubtrActorResult<f32> {
1170    Ok(v.into())
1171}
1172
1173build_player_feature_adder!(
1174    PlayerJump,
1175    |_,
1176     player_id: &PlayerId,
1177     processor: &ReplayProcessor,
1178     _frame,
1179     _frame_number,
1180     _current_time: f32| {
1181        convert_all_floats!(
1182            processor
1183                .get_dodge_active(player_id)
1184                .and_then(u8_get_f32)
1185                .unwrap_or(0.0),
1186            processor
1187                .get_jump_active(player_id)
1188                .and_then(u8_get_f32)
1189                .unwrap_or(0.0),
1190            processor
1191                .get_double_jump_active(player_id)
1192                .and_then(u8_get_f32)
1193                .unwrap_or(0.0),
1194        )
1195    },
1196    "dodge active",
1197    "jump active",
1198    "double jump active"
1199);
1200
1201build_player_feature_adder!(
1202    PlayerAnyJump,
1203    |_,
1204     player_id: &PlayerId,
1205     processor: &ReplayProcessor,
1206     _frame,
1207     _frame_number,
1208     _current_time: f32| {
1209        let dodge_is_active = processor.get_dodge_active(player_id).unwrap_or(0) % 2;
1210        let jump_is_active = processor.get_jump_active(player_id).unwrap_or(0) % 2;
1211        let double_jump_is_active = processor.get_double_jump_active(player_id).unwrap_or(0) % 2;
1212        let value: f32 = [dodge_is_active, jump_is_active, double_jump_is_active]
1213            .into_iter()
1214            .enumerate()
1215            .map(|(index, is_active)| (1 << index) * is_active)
1216            .sum::<u8>() as f32;
1217        convert_all_floats!(value)
1218    },
1219    "any_jump_active"
1220);
1221
1222const DEMOLISH_APPEARANCE_FRAME_COUNT: usize = 30;
1223
1224build_player_feature_adder!(
1225    PlayerDemolishedBy,
1226    |_,
1227     player_id: &PlayerId,
1228     processor: &ReplayProcessor,
1229     _frame,
1230     frame_number,
1231     _current_time: f32| {
1232        let demolisher_index = processor
1233            .demolishes
1234            .iter()
1235            .find(|demolish_info| {
1236                &demolish_info.victim == player_id
1237                    && frame_number - demolish_info.frame < DEMOLISH_APPEARANCE_FRAME_COUNT
1238            })
1239            .map(|demolish_info| {
1240                processor
1241                    .iter_player_ids_in_order()
1242                    .position(|player_id| player_id == &demolish_info.attacker)
1243                    .unwrap_or_else(|| processor.iter_player_ids_in_order().count())
1244            })
1245            .and_then(|v| i32::try_from(v).ok())
1246            .unwrap_or(-1);
1247        convert_all_floats!(demolisher_index as f32)
1248    },
1249    "player demolished by"
1250);
1251
1252build_player_feature_adder!(
1253    PlayerRigidBodyQuaternions,
1254    |_, player_id: &PlayerId, processor: &ReplayProcessor, _frame, _index, _current_time: f32| {
1255        if let Ok(rb) = processor.get_player_rigid_body(player_id) {
1256            let rotation = rb.rotation;
1257            let location = rb.location;
1258            convert_all_floats!(
1259                location.x, location.y, location.z, rotation.x, rotation.y, rotation.z, rotation.w
1260            )
1261        } else {
1262            convert_all_floats!(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0)
1263        }
1264    },
1265    "position x",
1266    "position y",
1267    "position z",
1268    "quaternion x",
1269    "quaternion y",
1270    "quaternion z",
1271    "quaternion w"
1272);
1273
1274build_global_feature_adder!(
1275    BallRigidBodyQuaternions,
1276    |_, processor: &ReplayProcessor, _frame, _index, _current_time| {
1277        let rb = processor.get_ball_rigid_body()?;
1278        let rotation = rb.rotation;
1279        let location = rb.location;
1280        convert_all_floats!(
1281            location.x, location.y, location.z, rotation.x, rotation.y, rotation.z, rotation.w
1282        )
1283    },
1284    "Ball - position x",
1285    "Ball - position y",
1286    "Ball - position z",
1287    "Ball - quaternion x",
1288    "Ball - quaternion y",
1289    "Ball - quaternion z",
1290    "Ball - quaternion w"
1291);
1292
1293lazy_static! {
1294    static ref NAME_TO_GLOBAL_FEATURE_ADDER: std::collections::HashMap<&'static str, Arc<dyn FeatureAdder<f32> + Send + Sync + 'static>> = {
1295        let mut m: std::collections::HashMap<
1296            &'static str,
1297            Arc<dyn FeatureAdder<f32> + Send + Sync + 'static>,
1298        > = std::collections::HashMap::new();
1299        macro_rules! insert_adder {
1300            ($adder_name:ident, $( $arguments:expr ),*) => {
1301                m.insert(stringify!($adder_name), $adder_name::<f32>::arc_new($ ( $arguments ),*));
1302            };
1303            ($adder_name:ident) => {
1304                insert_adder!($adder_name,)
1305            }
1306        }
1307        insert_adder!(BallRigidBody);
1308        insert_adder!(BallRigidBodyNoVelocities);
1309        insert_adder!(BallRigidBodyQuaternions);
1310        insert_adder!(VelocityAddedBallRigidBodyNoVelocities);
1311        insert_adder!(InterpolatedBallRigidBodyNoVelocities, 0.0);
1312        insert_adder!(SecondsRemaining);
1313        insert_adder!(CurrentTime);
1314        insert_adder!(FrameTime);
1315        m
1316    };
1317    static ref NAME_TO_PLAYER_FEATURE_ADDER: std::collections::HashMap<
1318        &'static str,
1319        Arc<dyn PlayerFeatureAdder<f32> + Send + Sync + 'static>,
1320    > = {
1321        let mut m: std::collections::HashMap<
1322            &'static str,
1323            Arc<dyn PlayerFeatureAdder<f32> + Send + Sync + 'static>,
1324        > = std::collections::HashMap::new();
1325        macro_rules! insert_adder {
1326            ($adder_name:ident, $( $arguments:expr ),*) => {
1327                m.insert(stringify!($adder_name), $adder_name::<f32>::arc_new($ ( $arguments ),*));
1328            };
1329            ($adder_name:ident) => {
1330                insert_adder!($adder_name,)
1331            };
1332        }
1333        insert_adder!(PlayerRigidBody);
1334        insert_adder!(PlayerRigidBodyNoVelocities);
1335        insert_adder!(PlayerRigidBodyQuaternions);
1336        insert_adder!(VelocityAddedPlayerRigidBodyNoVelocities);
1337        insert_adder!(InterpolatedPlayerRigidBodyNoVelocities, 0.003);
1338        insert_adder!(PlayerBoost);
1339        insert_adder!(PlayerJump);
1340        insert_adder!(PlayerAnyJump);
1341        insert_adder!(PlayerDemolishedBy);
1342        m
1343    };
1344}