subtr_actor_spec/collector/
ndarray.rs

1use crate::*;
2use ::ndarray;
3use boxcars;
4pub use derive_new;
5use lazy_static::lazy_static;
6use ndarray::Axis;
7pub use paste;
8use serde::Serialize;
9use std::sync::Arc;
10
11/// Represents the column headers in the collected data of an [`NDArrayCollector`].
12///
13/// # Fields
14///
15/// * `global_headers`: A list of strings that represent the global,
16/// player-independent features' column headers.
17/// * `player_headers`: A list of strings that represent the player-specific
18/// features' column headers.
19///
20/// Use [`Self::new`] to construct an instance of this struct.
21#[derive(Debug, Clone, PartialEq, Serialize)]
22pub struct NDArrayColumnHeaders {
23    pub global_headers: Vec<String>,
24    pub player_headers: Vec<String>,
25}
26
27impl NDArrayColumnHeaders {
28    pub fn new(global_headers: Vec<String>, player_headers: Vec<String>) -> Self {
29        Self {
30            global_headers,
31            player_headers,
32        }
33    }
34}
35
36/// A struct that contains both the metadata of a replay and the associated
37/// column headers.
38///
39/// # Fields
40///
41/// * `replay_meta`: Contains metadata about a [`boxcars::Replay`].
42/// * `column_headers`: The [`NDArrayColumnHeaders`] associated with the data
43/// collected from the replay.
44#[derive(Debug, Clone, PartialEq, Serialize)]
45pub struct ReplayMetaWithHeaders {
46    pub replay_meta: ReplayMeta,
47    pub column_headers: NDArrayColumnHeaders,
48}
49
50impl ReplayMetaWithHeaders {
51    pub fn headers_vec(&self) -> Vec<String> {
52        self.headers_vec_from(|_, _info, index| format!("Player {} - ", index))
53    }
54
55    pub fn headers_vec_from<F>(&self, player_prefix_getter: F) -> Vec<String>
56    where
57        F: Fn(&Self, &PlayerInfo, usize) -> String,
58    {
59        self.column_headers
60            .global_headers
61            .iter()
62            .cloned()
63            .chain(self.replay_meta.player_order().enumerate().flat_map(
64                move |(player_index, info)| {
65                    let player_prefix = player_prefix_getter(self, info, player_index);
66                    self.column_headers
67                        .player_headers
68                        .iter()
69                        .map(move |header| format!("{}{}", player_prefix, header))
70                },
71            ))
72            .collect()
73    }
74}
75
76/// [`NDArrayCollector`] is a [`Collector`] which transforms frame-based replay
77/// data into a 2-dimensional array of type [`ndarray::Array2`], where each
78/// element is of a specified floating point type.
79///
80/// It's initialized with collections of [`FeatureAdder`] instances which
81/// extract global, player independent features for each frame, and
82/// [`PlayerFeatureAdder`], which add player specific features for each frame.
83///
84/// It's main entrypoint is [`Self::get_meta_and_ndarray`], which provides
85/// [`ndarray::Array2`] along with column headers and replay metadata.
86pub struct NDArrayCollector<F> {
87    feature_adders: FeatureAdders<F>,
88    player_feature_adders: PlayerFeatureAdders<F>,
89    data: Vec<F>,
90    replay_meta: Option<ReplayMeta>,
91    frames_added: usize,
92}
93
94impl<F> NDArrayCollector<F> {
95    /// Creates a new instance of `NDArrayCollector`.
96    ///
97    /// # Arguments
98    ///
99    /// * `feature_adders` - A vector of [`Arc<dyn FeatureAdder<F>>`], each
100    /// implementing the [`FeatureAdder`] trait. These are used to add global
101    /// features to the replay data.
102    ///
103    /// * `player_feature_adders` - A vector of [`Arc<dyn PlayerFeatureAdder<F>>`],
104    /// each implementing the [`PlayerFeatureAdder`]
105    /// trait. These are used to add player-specific features to the replay
106    /// data.
107    ///
108    /// # Returns
109    ///
110    /// A new [`NDArrayCollector`] instance. This instance is initialized with
111    /// empty data, no replay metadata and zero frames added.
112    pub fn new(
113        feature_adders: FeatureAdders<F>,
114        player_feature_adders: PlayerFeatureAdders<F>,
115    ) -> Self {
116        Self {
117            feature_adders,
118            player_feature_adders,
119            data: Vec::new(),
120            replay_meta: None,
121            frames_added: 0,
122        }
123    }
124
125    /// Returns the column headers of the 2-dimensional array produced by the
126    /// [`NDArrayCollector`].
127    ///
128    /// # Returns
129    ///
130    /// An instance of [`NDArrayColumnHeaders`] representing the column headers
131    /// in the collected data.
132    pub fn get_column_headers(&self) -> NDArrayColumnHeaders {
133        let global_headers = self
134            .feature_adders
135            .iter()
136            .flat_map(move |fa| {
137                fa.get_column_headers()
138                    .iter()
139                    .map(move |column_name| format!("{}", column_name))
140            })
141            .collect();
142        let player_headers = self
143            .player_feature_adders
144            .iter()
145            .flat_map(move |pfa| {
146                pfa.get_column_headers()
147                    .iter()
148                    .map(move |base_name| format!("{}", base_name))
149            })
150            .collect();
151        NDArrayColumnHeaders::new(global_headers, player_headers)
152    }
153
154    /// This function consumes the [`NDArrayCollector`] instance and returns the
155    /// data collected as an [`ndarray::Array2`].
156    ///
157    /// # Returns
158    ///
159    /// A [`SubtrActorResult`] containing the collected data as an
160    /// [`ndarray::Array2`].
161    ///
162    /// This method is a shorthand for calling [`Self::get_meta_and_ndarray`]
163    /// and discarding the replay metadata and headers.
164    pub fn get_ndarray(self) -> SubtrActorResult<ndarray::Array2<F>> {
165        self.get_meta_and_ndarray().map(|a| a.1)
166    }
167
168    pub fn get_shots_and_array(&self,
169        data: &[u8],
170    ) -> SubtrActorResult<(Vec<ShotMetadata>, Vec<Vec<f32>>, Vec<String>)> {
171        
172        let parsing = boxcars::ParserBuilder::new(&data[..])
173            .always_check_crc()
174            .must_parse_network_data()
175            .parse();
176        
177        let replay = parsing.unwrap();
178
179        let mut collector = NDArrayCollector::<f32>::from_strings(
180            &["InterpolatedBallRigidBodyNoVelocities"],
181            &[
182                "InterpolatedPlayerRigidBodyNoVelocities",
183                "PlayerBoost",
184                "PlayerAnyJump",
185                "PlayerDemolishedBy",
186            ],
187        )
188        .unwrap();
189
190        let mut collector2 = NDArrayCollector::<f32>::from_strings(
191            &["InterpolatedBallRigidBodyNoVelocities"],
192            &[
193                "InterpolatedPlayerRigidBodyNoVelocities",
194                "PlayerBoost",
195                "PlayerAnyJump",
196                "PlayerDemolishedBy",
197            ],
198        )
199        .unwrap();
200
201        FrameRateDecorator::new_from_fps(10.0, &mut collector)
202        .process_replay(&replay)
203        .unwrap();
204
205        let (meta, array) = collector.get_meta_and_ndarray().unwrap();
206        
207        let result = collector2.process_and_get_meta_and_headers(&replay).unwrap();
208
209        // Extract the shots metadata
210        let shots = result.replay_meta.shots.clone();  
211        let json_array: Vec<Vec<f32>> = array
212        .axis_iter(Axis(0))
213        .map(|row| row.to_vec())
214        .collect();
215
216
217        
218        Ok((shots, json_array, meta.headers_vec()))
219    }
220    
221
222    /// Consumes the [`NDArrayCollector`] and returns the collected features as a
223    /// 2D ndarray, along with replay metadata and headers.
224    ///
225    /// # Returns
226    ///
227    /// A [`SubtrActorResult`] containing a tuple:
228    /// - [`ReplayMetaWithHeaders`]: The replay metadata along with the headers
229    /// for each column in the ndarray.
230    /// - [`ndarray::Array2<F>`]: The collected features as a 2D ndarray.
231    pub fn get_meta_and_ndarray(
232        self,
233    ) -> SubtrActorResult<(ReplayMetaWithHeaders, ndarray::Array2<F>)> {
234        let features_per_row = self.try_get_frame_feature_count()?;
235        let expected_length = features_per_row * self.frames_added;
236        assert!(self.data.len() == expected_length);
237        let column_headers = self.get_column_headers();
238        Ok((
239            ReplayMetaWithHeaders {
240                replay_meta: self.replay_meta.ok_or(SubtrActorError::new(
241                    SubtrActorErrorVariant::CouldNotBuildReplayMeta,
242                ))?,
243                column_headers,
244            },
245            ndarray::Array2::from_shape_vec((self.frames_added, features_per_row), self.data)
246                .map_err(SubtrActorErrorVariant::NDArrayShapeError)
247                .map_err(SubtrActorError::new)?,
248        ))
249    }
250
251    /// Processes a [`boxcars::Replay`] and returns its metadata along with column headers.
252    ///
253    /// This method first processes the replay using a [`ReplayProcessor`]. It
254    /// then updates the `replay_meta` field if it's not already set, and
255    /// returns a clone of the `replay_meta` field along with column headers of
256    /// the data.
257    ///
258    /// # Arguments
259    ///
260    /// * `replay`: A reference to the [`boxcars::Replay`] to process.
261    ///
262    /// # Returns
263    ///
264    /// A [`SubtrActorResult`] containing a [`ReplayMetaWithHeaders`] that
265    /// includes the metadata of the replay and column headers.
266    pub fn process_and_get_meta_and_headers(
267        &mut self,
268        replay: &boxcars::Replay,
269    ) -> SubtrActorResult<ReplayMetaWithHeaders> {
270        let mut processor: ReplayProcessor = ReplayProcessor::new(replay)?;
271        processor.process_long_enough_to_get_actor_ids()?;
272        self.maybe_set_replay_meta(&processor)?;
273        Ok(ReplayMetaWithHeaders {
274            replay_meta: self
275                .replay_meta
276                .as_ref()
277                .ok_or(SubtrActorError::new(
278                    SubtrActorErrorVariant::CouldNotBuildReplayMeta,
279                ))?
280                .clone(),
281            column_headers: self.get_column_headers(),
282        })
283    }
284
285    fn try_get_frame_feature_count(&self) -> SubtrActorResult<usize> {
286        let player_count = self
287            .replay_meta
288            .as_ref()
289            .ok_or(SubtrActorError::new(
290                SubtrActorErrorVariant::CouldNotBuildReplayMeta,
291            ))?
292            .player_count();
293        let global_feature_count: usize = self
294            .feature_adders
295            .iter()
296            .map(|fa| fa.features_added())
297            .sum();
298        let player_feature_count: usize = self
299            .player_feature_adders
300            .iter() // iterate
301            .map(|pfa| pfa.features_added() * player_count)
302            .sum();
303        Ok(global_feature_count + player_feature_count)
304    }
305
306    fn maybe_set_replay_meta(&mut self, processor: &ReplayProcessor) -> SubtrActorResult<()> {
307        if let None = self.replay_meta {
308            self.replay_meta = Some(processor.get_replay_meta()?);
309        }
310        Ok(())
311    }
312}
313
314impl<F> Collector for NDArrayCollector<F> {
315    fn process_frame(
316        &mut self,
317        processor: &ReplayProcessor,
318        frame: &boxcars::Frame,
319        frame_number: usize,
320        current_time: f32,
321    ) -> SubtrActorResult<collector::TimeAdvance> {
322        self.maybe_set_replay_meta(processor)?;
323
324        if !processor.ball_rigid_body_exists()? {
325            return Ok(collector::TimeAdvance::NextFrame);
326        }
327
328        for feature_adder in self.feature_adders.iter() {
329            feature_adder.add_features(
330                processor,
331                frame,
332                frame_number,
333                current_time,
334                &mut self.data,
335            )?;
336        }
337
338        for player_id in processor.iter_player_ids_in_order() {
339            for player_feature_adder in self.player_feature_adders.iter() {
340                player_feature_adder.add_features(
341                    player_id,
342                    processor,
343                    frame,
344                    frame_number,
345                    current_time,
346                    &mut self.data,
347                )?;
348            }
349        }
350
351        self.frames_added += 1;
352
353        Ok(collector::TimeAdvance::NextFrame)
354    }
355}
356
357impl NDArrayCollector<f32> {
358    pub fn from_strings(fa_names: &[&str], pfa_names: &[&str]) -> SubtrActorResult<Self> {
359        let feature_adders: Vec<Arc<dyn FeatureAdder<f32> + Send + Sync>> = fa_names
360            .iter()
361            .map(|name| {
362                Ok(NAME_TO_GLOBAL_FEATURE_ADDER
363                    .get(name)
364                    .ok_or_else(|| {
365                        SubtrActorError::new(SubtrActorErrorVariant::UnknownFeatureAdderName(
366                            name.to_string(),
367                        ))
368                    })?
369                    .clone())
370            })
371            .collect::<SubtrActorResult<Vec<_>>>()?;
372        let player_feature_adders: Vec<Arc<dyn PlayerFeatureAdder<f32> + Send + Sync>> = pfa_names
373            .iter()
374            .map(|name| {
375                Ok(NAME_TO_PLAYER_FEATURE_ADDER
376                    .get(name)
377                    .ok_or_else(|| {
378                        SubtrActorError::new(SubtrActorErrorVariant::UnknownFeatureAdderName(
379                            name.to_string(),
380                        ))
381                    })?
382                    .clone())
383            })
384            .collect::<SubtrActorResult<Vec<_>>>()?;
385        Ok(Self::new(feature_adders, player_feature_adders))
386    }
387}
388
389impl<F: TryFrom<f32> + Send + Sync + 'static> Default for NDArrayCollector<F>
390where
391    <F as TryFrom<f32>>::Error: std::fmt::Debug,
392{
393    fn default() -> Self {
394        NDArrayCollector::new(
395            vec![BallRigidBody::arc_new()],
396            vec![
397                PlayerRigidBody::arc_new(),
398                PlayerBoost::arc_new(),
399                PlayerAnyJump::arc_new(),
400            ],
401        )
402    }
403}
404
405/// This trait acts as an abstraction over a feature adder, and is primarily
406/// used to allow for heterogeneous collections of feature adders in the
407/// [`NDArrayCollector`]. While it provides methods for adding features and
408/// retrieving column headers, it is generally recommended to implement the
409/// [`LengthCheckedFeatureAdder`] trait instead, which provides compile-time
410/// guarantees about the number of features returned.
411pub trait FeatureAdder<F> {
412    fn features_added(&self) -> usize {
413        self.get_column_headers().len()
414    }
415
416    fn get_column_headers(&self) -> &[&str];
417
418    fn add_features(
419        &self,
420        processor: &ReplayProcessor,
421        frame: &boxcars::Frame,
422        frame_count: usize,
423        current_time: f32,
424        vector: &mut Vec<F>,
425    ) -> SubtrActorResult<()>;
426}
427
428pub type FeatureAdders<F> = Vec<Arc<dyn FeatureAdder<F> + Send + Sync>>;
429
430/// This trait is stricter version of the [`FeatureAdder`] trait, enforcing at
431/// compile time that the number of features added is equal to the number of
432/// column headers provided. Implementations of this trait can be automatically
433/// adapted to the [`FeatureAdder`] trait using the [`impl_feature_adder!`]
434/// macro.
435pub trait LengthCheckedFeatureAdder<F, const N: usize> {
436    fn get_column_headers_array(&self) -> &[&str; N];
437
438    fn get_features(
439        &self,
440        processor: &ReplayProcessor,
441        frame: &boxcars::Frame,
442        frame_count: usize,
443        current_time: f32,
444    ) -> SubtrActorResult<[F; N]>;
445}
446
447/// A macro to provide an automatic implementation of the [`FeatureAdder`] trait
448/// for types that implement [`LengthCheckedFeatureAdder`]. This allows you to
449/// take advantage of the compile-time guarantees provided by
450/// [`LengthCheckedFeatureAdder`], while still being able to use your type in
451/// contexts that require a [`FeatureAdder`] object. This macro is used to
452/// bridge the gap between the two traits, as Rust's type system does not
453/// currently provide a way to prove to the compiler that there will always be
454/// exactly one implementation of [`LengthCheckedFeatureAdder`] for each type.
455#[macro_export]
456macro_rules! impl_feature_adder {
457    ($struct_name:ident) => {
458        impl<F: TryFrom<f32>> FeatureAdder<F> for $struct_name<F>
459        where
460            <F as TryFrom<f32>>::Error: std::fmt::Debug,
461        {
462            fn add_features(
463                &self,
464                processor: &ReplayProcessor,
465                frame: &boxcars::Frame,
466                frame_count: usize,
467                current_time: f32,
468                vector: &mut Vec<F>,
469            ) -> SubtrActorResult<()> {
470                Ok(
471                    vector.extend(self.get_features(
472                        processor,
473                        frame,
474                        frame_count,
475                        current_time,
476                    )?),
477                )
478            }
479
480            fn get_column_headers(&self) -> &[&str] {
481                self.get_column_headers_array()
482            }
483        }
484    };
485}
486
487/// This trait acts as an abstraction over a player-specific feature adder, and
488/// is primarily used to allow for heterogeneous collections of player feature
489/// adders in the [`NDArrayCollector`]. While it provides methods for adding
490/// player-specific features and retrieving column headers, it is generally
491/// recommended to implement the [`LengthCheckedPlayerFeatureAdder`] trait
492/// instead, which provides compile-time guarantees about the number of features
493/// returned.
494pub trait PlayerFeatureAdder<F> {
495    fn features_added(&self) -> usize {
496        self.get_column_headers().len()
497    }
498
499    fn get_column_headers(&self) -> &[&str];
500
501    fn add_features(
502        &self,
503        player_id: &PlayerId,
504        processor: &ReplayProcessor,
505        frame: &boxcars::Frame,
506        frame_count: usize,
507        current_time: f32,
508        vector: &mut Vec<F>,
509    ) -> SubtrActorResult<()>;
510}
511
512pub type PlayerFeatureAdders<F> = Vec<Arc<dyn PlayerFeatureAdder<F> + Send + Sync>>;
513
514/// This trait is a more strict version of the [`PlayerFeatureAdder`] trait,
515/// enforcing at compile time that the number of player-specific features added
516/// is equal to the number of column headers provided. Implementations of this
517/// trait can be automatically adapted to the [`PlayerFeatureAdder`] trait using
518/// the [`impl_player_feature_adder!`] macro.
519pub trait LengthCheckedPlayerFeatureAdder<F, const N: usize> {
520    fn get_column_headers_array(&self) -> &[&str; N];
521
522    fn get_features(
523        &self,
524        player_id: &PlayerId,
525        processor: &ReplayProcessor,
526        frame: &boxcars::Frame,
527        frame_count: usize,
528        current_time: f32,
529    ) -> SubtrActorResult<[F; N]>;
530}
531
532/// A macro to provide an automatic implementation of the [`PlayerFeatureAdder`]
533/// trait for types that implement [`LengthCheckedPlayerFeatureAdder`]. This
534/// allows you to take advantage of the compile-time guarantees provided by
535/// [`LengthCheckedPlayerFeatureAdder`], while still being able to use your type
536/// in contexts that require a [`PlayerFeatureAdder`] object. This macro is used
537/// to bridge the gap between the two traits, as Rust's type system does not
538/// currently provide a way to prove to the compiler that there will always be
539/// exactly one implementation of [`LengthCheckedPlayerFeatureAdder`] for each
540/// type.
541#[macro_export]
542macro_rules! impl_player_feature_adder {
543    ($struct_name:ident) => {
544        impl<F: TryFrom<f32>> PlayerFeatureAdder<F> for $struct_name<F>
545        where
546            <F as TryFrom<f32>>::Error: std::fmt::Debug,
547        {
548            fn add_features(
549                &self,
550                player_id: &PlayerId,
551                processor: &ReplayProcessor,
552                frame: &boxcars::Frame,
553                frame_count: usize,
554                current_time: f32,
555                vector: &mut Vec<F>,
556            ) -> SubtrActorResult<()> {
557                Ok(vector.extend(self.get_features(
558                    player_id,
559                    processor,
560                    frame,
561                    frame_count,
562                    current_time,
563                )?))
564            }
565
566            fn get_column_headers(&self) -> &[&str] {
567                self.get_column_headers_array()
568            }
569        }
570    };
571}
572
573impl<G, F, const N: usize> FeatureAdder<F> for (G, &[&str; N])
574where
575    G: Fn(&ReplayProcessor, &boxcars::Frame, usize, f32) -> SubtrActorResult<[F; N]>,
576{
577    fn add_features(
578        &self,
579        processor: &ReplayProcessor,
580        frame: &boxcars::Frame,
581        frame_count: usize,
582        current_time: f32,
583        vector: &mut Vec<F>,
584    ) -> SubtrActorResult<()> {
585        Ok(vector.extend(self.0(processor, frame, frame_count, current_time)?))
586    }
587
588    fn get_column_headers(&self) -> &[&str] {
589        &self.1.as_slice()
590    }
591}
592
593impl<G, F, const N: usize> PlayerFeatureAdder<F> for (G, &[&str; N])
594where
595    G: Fn(&PlayerId, &ReplayProcessor, &boxcars::Frame, usize, f32) -> SubtrActorResult<[F; N]>,
596{
597    fn add_features(
598        &self,
599        player_id: &PlayerId,
600        processor: &ReplayProcessor,
601        frame: &boxcars::Frame,
602        frame_count: usize,
603        current_time: f32,
604        vector: &mut Vec<F>,
605    ) -> SubtrActorResult<()> {
606        Ok(vector.extend(self.0(
607            player_id,
608            processor,
609            frame,
610            frame_count,
611            current_time,
612        )?))
613    }
614
615    fn get_column_headers(&self) -> &[&str] {
616        &self.1.as_slice()
617    }
618}
619
620/// This macro creates a global [`FeatureAdder`] struct and implements the
621/// necessary traits to add the calculated features to the data matrix. The
622/// macro exports a struct with the same name as passed in the parameter. The
623/// number of column names and the length of the feature array returned by
624/// `$prop_getter` are checked at compile time to ensure they match, in line
625/// with the [`LengthCheckedFeatureAdder`] trait. The output struct also
626/// provides an implementation of the [`FeatureAdder`] trait via the
627/// [`impl_feature_adder!`] macro, allowing it to be used in contexts where a
628/// [`FeatureAdder`] object is required.
629///
630/// # Parameters
631///
632/// * `$struct_name`: The name of the struct to be created.
633/// * `$prop_getter`: The function or closure used to calculate the features.
634/// * `$( $column_names:expr ),*`: A comma-separated list of column names as strings.
635///
636/// # Example
637///
638/// ```
639/// use subtr_actor_spec::*;
640///
641/// build_global_feature_adder!(
642///     SecondsRemainingExample,
643///     |_, processor: &ReplayProcessor, _frame, _index, _current_time| {
644///         convert_all_floats!(processor.get_seconds_remaining()?.clone() as f32)
645///     },
646///     "seconds remaining"
647/// );
648/// ```
649///
650/// This will create a struct named `SecondsRemaining` and implement necessary
651/// traits to calculate features using the provided closure. The feature will be
652/// added under the column name "seconds remaining". Note, however, that it is
653/// possible to add more than one feature with each feature adder
654#[macro_export]
655macro_rules! build_global_feature_adder {
656    ($struct_name:ident, $prop_getter:expr, $( $column_names:expr ),* $(,)?) => {
657
658        #[derive(derive_new::new)]
659        pub struct $struct_name<F> {
660            _zero: std::marker::PhantomData<F>,
661        }
662
663        impl<F: Sync + Send + TryFrom<f32> + 'static> $struct_name<F> where
664            <F as TryFrom<f32>>::Error: std::fmt::Debug,
665        {
666            pub fn arc_new() -> std::sync::Arc<dyn FeatureAdder<F> + Send + Sync + 'static> {
667                std::sync::Arc::new(Self::new())
668            }
669        }
670
671        global_feature_adder!(
672            $struct_name,
673            $prop_getter,
674            $( $column_names ),*
675        );
676    }
677}
678
679/// This macro is used to implement necessary traits for an existing struct to
680/// add the calculated features to the data matrix. This macro is particularly
681/// useful when the feature adder needs to be instantiated with specific
682/// parameters. The number of column names and the length of the feature array
683/// returned by `$prop_getter` are checked at compile time to ensure they match.
684///
685/// # Parameters
686///
687/// * `$struct_name`: The name of the existing struct.
688/// * `$prop_getter`: The function or closure used to calculate the features.
689/// * `$( $column_names:expr ),*`: A comma-separated list of column names as strings.
690#[macro_export]
691macro_rules! global_feature_adder {
692    ($struct_name:ident, $prop_getter:expr, $( $column_names:expr ),* $(,)?) => {
693        macro_rules! _global_feature_adder {
694            ($count:ident) => {
695                impl<F: TryFrom<f32>> LengthCheckedFeatureAdder<F, $count> for $struct_name<F>
696                where
697                    <F as TryFrom<f32>>::Error: std::fmt::Debug,
698                {
699                    fn get_column_headers_array(&self) -> &[&str; $count] {
700                        &[$( $column_names ),*]
701                    }
702
703                    fn get_features(
704                        &self,
705                        processor: &ReplayProcessor,
706                        frame: &boxcars::Frame,
707                        frame_count: usize,
708                        current_time: f32,
709                    ) -> SubtrActorResult<[F; $count]> {
710                        $prop_getter(self, processor, frame, frame_count, current_time)
711                    }
712                }
713
714                impl_feature_adder!($struct_name);
715            };
716        }
717        paste::paste! {
718            const [<$struct_name:snake:upper _LENGTH>]: usize = [$($column_names),*].len();
719            _global_feature_adder!([<$struct_name:snake:upper _LENGTH>]);
720        }
721    }
722}
723
724/// This macro creates a player feature adder struct and implements the
725/// necessary traits to add the calculated player-specific features to the data
726/// matrix. The macro exports a struct with the same name as passed in the
727/// parameter. The number of column names and the length of the feature array
728/// returned by `$prop_getter` are checked at compile time to ensure they match,
729/// in line with the [`LengthCheckedPlayerFeatureAdder`] trait. The output
730/// struct also provides an implementation of the [`PlayerFeatureAdder`] trait
731/// via the [`impl_player_feature_adder!`] macro, allowing it to be used in
732/// contexts where a [`PlayerFeatureAdder`] object is required.
733///
734/// # Parameters
735///
736/// * `$struct_name`: The name of the struct to be created.
737/// * `$prop_getter`: The function or closure used to calculate the features.
738/// * `$( $column_names:expr ),*`: A comma-separated list of column names as strings.
739///
740/// # Example
741///
742/// ```
743/// use subtr_actor_spec::*;
744///
745/// fn u8_get_f32(v: u8) -> SubtrActorResult<f32> {
746///    v.try_into().map_err(convert_float_conversion_error)
747/// }
748///
749/// build_player_feature_adder!(
750///     PlayerJump,
751///     |_,
752///      player_id: &PlayerId,
753///      processor: &ReplayProcessor,
754///      _frame,
755///      _frame_number,
756///      _current_time: f32| {
757///         convert_all_floats!(
758///             processor
759///                 .get_dodge_active(player_id)
760///                 .and_then(u8_get_f32)
761///                 .unwrap_or(0.0),
762///             processor
763///                 .get_jump_active(player_id)
764///                 .and_then(u8_get_f32)
765///                 .unwrap_or(0.0),
766///             processor
767///                 .get_double_jump_active(player_id)
768///                 .and_then(u8_get_f32)
769///                 .unwrap_or(0.0),
770///         )
771///     },
772///     "dodge active",
773///     "jump active",
774///     "double jump active"
775/// );
776/// ```
777///
778/// This will create a struct named `PlayerJump` and implement necessary
779/// traits to calculate features using the provided closure. The player-specific
780/// features will be added under the column names "dodge active",
781/// "jump active", and "double jump active" respectively.
782#[macro_export]
783macro_rules! build_player_feature_adder {
784    ($struct_name:ident, $prop_getter:expr, $( $column_names:expr ),* $(,)?) => {
785        #[derive(derive_new::new)]
786        pub struct $struct_name<F> {
787            _zero: std::marker::PhantomData<F>,
788        }
789
790        impl<F: Sync + Send + TryFrom<f32> + 'static> $struct_name<F> where
791            <F as TryFrom<f32>>::Error: std::fmt::Debug,
792        {
793            pub fn arc_new() -> std::sync::Arc<dyn PlayerFeatureAdder<F> + Send + Sync + 'static> {
794                std::sync::Arc::new(Self::new())
795            }
796        }
797
798        player_feature_adder!(
799            $struct_name,
800            $prop_getter,
801            $( $column_names ),*
802        );
803    }
804}
805
806/// This macro is used to implement necessary traits for an existing struct to
807/// add the calculated player-specific features to the data matrix. This macro
808/// is particularly useful when the feature adder needs to be instantiated with
809/// specific parameters. The number of column names and the length of the
810/// feature array returned by `$prop_getter` are checked at compile time to
811/// ensure they match.
812///
813/// # Parameters
814///
815/// * `$struct_name`: The name of the existing struct.
816/// * `$prop_getter`: The function or closure used to calculate the features.
817/// * `$( $column_names:expr ),*`: A comma-separated list of column names as strings.
818#[macro_export]
819macro_rules! player_feature_adder {
820    ($struct_name:ident, $prop_getter:expr, $( $column_names:expr ),* $(,)?) => {
821        macro_rules! _player_feature_adder {
822            ($count:ident) => {
823                impl<F: TryFrom<f32>> LengthCheckedPlayerFeatureAdder<F, $count> for $struct_name<F>
824                where
825                    <F as TryFrom<f32>>::Error: std::fmt::Debug,
826                {
827                    fn get_column_headers_array(&self) -> &[&str; $count] {
828                        &[$( $column_names ),*]
829                    }
830
831                    fn get_features(
832                        &self,
833                        player_id: &PlayerId,
834                        processor: &ReplayProcessor,
835                        frame: &boxcars::Frame,
836                        frame_count: usize,
837                        current_time: f32,
838                    ) -> SubtrActorResult<[F; $count]> {
839                        $prop_getter(self, player_id, processor, frame, frame_count, current_time)
840                    }
841                }
842
843                impl_player_feature_adder!($struct_name);
844            };
845        }
846        paste::paste! {
847            const [<$struct_name:snake:upper _LENGTH>]: usize = [$($column_names),*].len();
848            _player_feature_adder!([<$struct_name:snake:upper _LENGTH>]);
849        }
850    }
851}
852
853/// Unconditionally convert any error into a [`SubtrActorError`] of with the
854/// [`SubtrActorErrorVariant::FloatConversionError`] variant.
855pub fn convert_float_conversion_error<T>(_: T) -> SubtrActorError {
856    SubtrActorError::new(SubtrActorErrorVariant::FloatConversionError)
857}
858
859/// A macro that tries to convert each provided item into a type. If any of the
860/// conversions fail, it short-circuits and returns the error.
861///
862/// The first argument `$err` is a closure that accepts an error and returns a
863/// [`SubtrActorResult`]. It is used to map any conversion errors into a
864/// [`SubtrActorResult`].
865///
866/// Subsequent arguments should be expressions that implement the [`TryInto`]
867/// trait, with the type they're being converted into being the one used in the
868/// `Ok` variant of the return value.
869#[macro_export]
870macro_rules! convert_all {
871    ($err:expr, $( $item:expr ),* $(,)?) => {{
872		Ok([
873			$( $item.try_into().map_err($err)? ),*
874		])
875	}};
876}
877
878/// A convenience macro that uses the [`convert_all`] macro with the
879/// [`convert_float_conversion_error`] function for error handling.
880///
881/// Each item provided is attempted to be converted into a floating point
882/// number. If any of the conversions fail, it short-circuits and returns the
883/// error. This macro must be used in the context of a function that returns a
884/// [`Result`] because it uses the ? operator. It is primarily useful for
885/// defining function like the one shown in the example below that are generic
886/// in some parameter that can implements [`TryFrom`].
887///
888/// # Example
889///
890/// ```
891/// use subtr_actor_spec::*;
892///
893/// pub fn some_constant_function<F: TryFrom<f32>>(
894///     rigid_body: &boxcars::RigidBody,
895/// ) -> SubtrActorResult<[F; 3]> {
896///     convert_all_floats!(42.0, 0.0, 1.234)
897/// }
898/// ```
899#[macro_export]
900macro_rules! convert_all_floats {
901    ($( $item:expr ),* $(,)?) => {{
902        convert_all!(convert_float_conversion_error, $( $item ),*)
903    }};
904}
905
906fn or_zero_boxcars_3f() -> boxcars::Vector3f {
907    boxcars::Vector3f {
908        x: 0.0,
909        y: 0.0,
910        z: 0.0,
911    }
912}
913
914type RigidBodyArrayResult<F> = SubtrActorResult<[F; 12]>;
915
916/// Extracts the location, rotation, linear velocity and angular velocity from a
917/// [`boxcars::RigidBody`] and converts them to a type implementing [`TryFrom<f32>`].
918///
919/// If any of the components of the rigid body are not set (`None`), they are
920/// treated as zero.
921///
922/// The returned array contains twelve elements in the following order: x, y, z
923/// location, x, y, z rotation (as Euler angles), x, y, z linear velocity, x, y,
924/// z angular velocity.
925pub fn get_rigid_body_properties<F: TryFrom<f32>>(
926    rigid_body: &boxcars::RigidBody,
927) -> RigidBodyArrayResult<F>
928where
929    <F as TryFrom<f32>>::Error: std::fmt::Debug,
930{
931    let linear_velocity = rigid_body
932        .linear_velocity
933        .unwrap_or_else(or_zero_boxcars_3f);
934    let angular_velocity = rigid_body
935        .angular_velocity
936        .unwrap_or_else(or_zero_boxcars_3f);
937    let rotation = rigid_body.rotation;
938    let location = rigid_body.location;
939    let (rx, ry, rz) =
940        glam::quat(rotation.x, rotation.y, rotation.z, rotation.w).to_euler(glam::EulerRot::XYZ);
941    convert_all_floats!(
942        location.x,
943        location.y,
944        location.z,
945        rx,
946        ry,
947        rz,
948        linear_velocity.x,
949        linear_velocity.y,
950        linear_velocity.z,
951        angular_velocity.x,
952        angular_velocity.y,
953        angular_velocity.z,
954    )
955}
956
957/// Extracts the location and rotation from a [`boxcars::RigidBody`] and
958/// converts them to a type implementing [`TryFrom<f32>`].
959///
960/// If any of the components of the rigid body are not set (`None`), they are
961/// treated as zero.
962///
963/// The returned array contains seven elements in the following order: x, y, z
964/// location, x, y, z, w rotation.
965pub fn get_rigid_body_properties_no_velocities<F: TryFrom<f32>>(
966    rigid_body: &boxcars::RigidBody,
967) -> SubtrActorResult<[F; 7]>
968where
969    <F as TryFrom<f32>>::Error: std::fmt::Debug,
970{
971    let rotation = rigid_body.rotation;
972    let location = rigid_body.location;
973    convert_all_floats!(
974        location.x, location.y, location.z, rotation.x, rotation.y, rotation.z, rotation.w
975    )
976}
977
978fn default_rb_state<F: TryFrom<f32>>() -> RigidBodyArrayResult<F>
979where
980    <F as TryFrom<f32>>::Error: std::fmt::Debug,
981{
982    convert_all!(
983        convert_float_conversion_error,
984        // We use huge values for location instead of 0s so that hopefully any
985        // model built on this data can understand that the player is not
986        // actually on the field.
987        0.0,
988        0.0,
989        0.0,
990        0.0,
991        0.0,
992        0.0,
993        0.0,
994        0.0,
995        0.0,
996        0.0,
997        0.0,
998        0.0,
999    )
1000}
1001
1002fn default_rb_state_no_velocities<F: TryFrom<f32>>() -> SubtrActorResult<[F; 7]>
1003where
1004    <F as TryFrom<f32>>::Error: std::fmt::Debug,
1005{
1006    convert_all_floats!(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,)
1007}
1008
1009build_global_feature_adder!(
1010    SecondsRemaining,
1011    |_, processor: &ReplayProcessor, _frame, _index, _current_time| {
1012        convert_all_floats!(processor.get_seconds_remaining()?.clone() as f32)
1013    },
1014    "seconds remaining"
1015);
1016
1017build_global_feature_adder!(
1018    CurrentTime,
1019    |_, _processor, _frame, _index, current_time: f32| { convert_all_floats!(current_time) },
1020    "current time"
1021);
1022
1023build_global_feature_adder!(
1024    FrameTime,
1025    |_, _processor, frame: &boxcars::Frame, _index, _current_time| {
1026        convert_all_floats!(frame.time)
1027    },
1028    "frame time"
1029);
1030
1031build_global_feature_adder!(
1032    BallRigidBody,
1033    |_, processor: &ReplayProcessor, _frame, _index, _current_time| {
1034        get_rigid_body_properties(processor.get_ball_rigid_body()?)
1035    },
1036    "Ball - position x",
1037    "Ball - position y",
1038    "Ball - position z",
1039    "Ball - rotation x",
1040    "Ball - rotation y",
1041    "Ball - rotation z",
1042    "Ball - linear velocity x",
1043    "Ball - linear velocity y",
1044    "Ball - linear velocity z",
1045    "Ball - angular velocity x",
1046    "Ball - angular velocity y",
1047    "Ball - angular velocity z",
1048);
1049
1050build_global_feature_adder!(
1051    BallRigidBodyNoVelocities,
1052    |_, processor: &ReplayProcessor, _frame, _index, _current_time| {
1053        get_rigid_body_properties_no_velocities(processor.get_ball_rigid_body()?)
1054    },
1055    "Ball - position x",
1056    "Ball - position y",
1057    "Ball - position z",
1058    "Ball - rotation x",
1059    "Ball - rotation y",
1060    "Ball - rotation z",
1061    "Ball - rotation w",
1062);
1063
1064// XXX: This approach seems to give some unexpected results with rotation
1065// changes. There may be a unit mismatch or some other type of issue.
1066build_global_feature_adder!(
1067    VelocityAddedBallRigidBodyNoVelocities,
1068    |_, processor: &ReplayProcessor, _frame, _index, current_time: f32| {
1069        get_rigid_body_properties_no_velocities(
1070            &processor.get_velocity_applied_ball_rigid_body(current_time)?,
1071        )
1072    },
1073    "Ball - position x",
1074    "Ball - position y",
1075    "Ball - position z",
1076    "Ball - rotation x",
1077    "Ball - rotation y",
1078    "Ball - rotation z",
1079    "Ball - rotation w",
1080);
1081
1082#[derive(derive_new::new)]
1083pub struct InterpolatedBallRigidBodyNoVelocities<F> {
1084    close_enough_to_frame_time: f32,
1085    _zero: std::marker::PhantomData<F>,
1086}
1087
1088impl<F> InterpolatedBallRigidBodyNoVelocities<F> {
1089    pub fn arc_new(close_enough_to_frame_time: f32) -> Arc<Self> {
1090        Arc::new(Self::new(close_enough_to_frame_time))
1091    }
1092}
1093
1094global_feature_adder!(
1095    InterpolatedBallRigidBodyNoVelocities,
1096    |s: &InterpolatedBallRigidBodyNoVelocities<F>,
1097     processor: &ReplayProcessor,
1098     _frame: &boxcars::Frame,
1099     _index,
1100     current_time: f32| {
1101        processor
1102            .get_interpolated_ball_rigid_body(current_time, s.close_enough_to_frame_time)
1103            .map(|v| get_rigid_body_properties_no_velocities(&v))
1104            .unwrap_or_else(|_| default_rb_state_no_velocities())
1105    },
1106    "Ball - position x",
1107    "Ball - position y",
1108    "Ball - position z",
1109    "Ball - rotation x",
1110    "Ball - rotation y",
1111    "Ball - rotation z",
1112    "Ball - rotation w",
1113);
1114
1115build_player_feature_adder!(
1116    PlayerRigidBody,
1117    |_, player_id: &PlayerId, processor: &ReplayProcessor, _frame, _index, _current_time: f32| {
1118        if let Ok(rb) = processor.get_player_rigid_body(player_id) {
1119            get_rigid_body_properties(rb)
1120        } else {
1121            default_rb_state()
1122        }
1123    },
1124    "position x",
1125    "position y",
1126    "position z",
1127    "rotation x",
1128    "rotation y",
1129    "rotation z",
1130    "linear velocity x",
1131    "linear velocity y",
1132    "linear velocity z",
1133    "angular velocity x",
1134    "angular velocity y",
1135    "angular velocity z",
1136);
1137
1138build_player_feature_adder!(
1139    PlayerRigidBodyNoVelocities,
1140    |_, player_id: &PlayerId, processor: &ReplayProcessor, _frame, _index, _current_time: f32| {
1141        if let Ok(rb) = processor.get_player_rigid_body(player_id) {
1142            get_rigid_body_properties_no_velocities(rb)
1143        } else {
1144            default_rb_state_no_velocities()
1145        }
1146    },
1147    "position x",
1148    "position y",
1149    "position z",
1150    "rotation x",
1151    "rotation y",
1152    "rotation z",
1153    "rotation w"
1154);
1155
1156// XXX: This approach seems to give some unexpected results with rotation
1157// changes. There may be a unit mismatch or some other type of issue.
1158build_player_feature_adder!(
1159    VelocityAddedPlayerRigidBodyNoVelocities,
1160    |_, player_id: &PlayerId, processor: &ReplayProcessor, _frame, _index, current_time: f32| {
1161        if let Ok(rb) = processor.get_velocity_applied_player_rigid_body(player_id, current_time) {
1162            get_rigid_body_properties_no_velocities(&rb)
1163        } else {
1164            default_rb_state_no_velocities()
1165        }
1166    },
1167    "position x",
1168    "position y",
1169    "position z",
1170    "rotation x",
1171    "rotation y",
1172    "rotation z",
1173    "rotation w"
1174);
1175
1176#[derive(derive_new::new)]
1177pub struct InterpolatedPlayerRigidBodyNoVelocities<F> {
1178    close_enough_to_frame_time: f32,
1179    _zero: std::marker::PhantomData<F>,
1180}
1181
1182impl<F> InterpolatedPlayerRigidBodyNoVelocities<F> {
1183    pub fn arc_new(close_enough_to_frame_time: f32) -> Arc<Self> {
1184        Arc::new(Self::new(close_enough_to_frame_time))
1185    }
1186}
1187
1188player_feature_adder!(
1189    InterpolatedPlayerRigidBodyNoVelocities,
1190    |s: &InterpolatedPlayerRigidBodyNoVelocities<F>,
1191     player_id: &PlayerId,
1192     processor: &ReplayProcessor,
1193     _frame: &boxcars::Frame,
1194     _index,
1195     current_time: f32| {
1196        processor
1197            .get_interpolated_player_rigid_body(
1198                player_id,
1199                current_time,
1200                s.close_enough_to_frame_time,
1201            )
1202            .map(|v| get_rigid_body_properties_no_velocities(&v))
1203            .unwrap_or_else(|_| default_rb_state_no_velocities())
1204    },
1205    "i position x",
1206    "i position y",
1207    "i position z",
1208    "i rotation x",
1209    "i rotation y",
1210    "i rotation z",
1211    "i rotation w"
1212);
1213
1214build_player_feature_adder!(
1215    PlayerBoost,
1216    |_, player_id: &PlayerId, processor: &ReplayProcessor, _frame, _index, _current_time: f32| {
1217        convert_all_floats!(processor.get_player_boost_level(player_id).unwrap_or(0.0))
1218    },
1219    "boost level"
1220);
1221
1222fn u8_get_f32(v: u8) -> SubtrActorResult<f32> {
1223    v.try_into().map_err(convert_float_conversion_error)
1224}
1225
1226build_player_feature_adder!(
1227    PlayerJump,
1228    |_,
1229     player_id: &PlayerId,
1230     processor: &ReplayProcessor,
1231     _frame,
1232     _frame_number,
1233     _current_time: f32| {
1234        convert_all_floats!(
1235            processor
1236                .get_dodge_active(player_id)
1237                .and_then(u8_get_f32)
1238                .unwrap_or(0.0),
1239            processor
1240                .get_jump_active(player_id)
1241                .and_then(u8_get_f32)
1242                .unwrap_or(0.0),
1243            processor
1244                .get_double_jump_active(player_id)
1245                .and_then(u8_get_f32)
1246                .unwrap_or(0.0),
1247        )
1248    },
1249    "dodge active",
1250    "jump active",
1251    "double jump active"
1252);
1253
1254build_player_feature_adder!(
1255    PlayerAnyJump,
1256    |_,
1257     player_id: &PlayerId,
1258     processor: &ReplayProcessor,
1259     _frame,
1260     _frame_number,
1261     _current_time: f32| {
1262        let dodge_is_active = processor.get_dodge_active(player_id).unwrap_or(0) % 2;
1263        let jump_is_active = processor.get_jump_active(player_id).unwrap_or(0) % 2;
1264        let double_jump_is_active = processor.get_double_jump_active(player_id).unwrap_or(0) % 2;
1265        let value: f32 = [dodge_is_active, jump_is_active, double_jump_is_active]
1266            .into_iter()
1267            .enumerate()
1268            .map(|(index, is_active)| (1 << index) * is_active)
1269            .sum::<u8>() as f32;
1270        convert_all_floats!(value)
1271    },
1272    "any_jump_active"
1273);
1274
1275const DEMOLISH_APPEARANCE_FRAME_COUNT: usize = 30;
1276
1277build_player_feature_adder!(
1278    PlayerDemolishedBy,
1279    |_,
1280     player_id: &PlayerId,
1281     processor: &ReplayProcessor,
1282     _frame,
1283     frame_number,
1284     _current_time: f32| {
1285        let demolisher_index = processor
1286            .demolishes
1287            .iter()
1288            .find(|demolish_info| {
1289                &demolish_info.victim == player_id
1290                    && frame_number - demolish_info.frame < DEMOLISH_APPEARANCE_FRAME_COUNT
1291            })
1292            .map(|demolish_info| {
1293                processor
1294                    .iter_player_ids_in_order()
1295                    .position(|player_id| player_id == &demolish_info.attacker)
1296                    .unwrap_or_else(|| processor.iter_player_ids_in_order().count())
1297            })
1298            .and_then(|v| i32::try_from(v).ok())
1299            .unwrap_or(-1);
1300        convert_all_floats!(demolisher_index as f32)
1301    },
1302    "player demolished by"
1303);
1304
1305lazy_static! {
1306    static ref NAME_TO_GLOBAL_FEATURE_ADDER: std::collections::HashMap<&'static str, Arc<dyn FeatureAdder<f32> + Send + Sync + 'static>> = {
1307        let mut m: std::collections::HashMap<
1308            &'static str,
1309            Arc<dyn FeatureAdder<f32> + Send + Sync + 'static>,
1310        > = std::collections::HashMap::new();
1311        macro_rules! insert_adder {
1312            ($adder_name:ident, $( $arguments:expr ),*) => {
1313                m.insert(stringify!($adder_name), $adder_name::<f32>::arc_new($ ( $arguments ),*));
1314            };
1315            ($adder_name:ident) => {
1316                insert_adder!($adder_name,)
1317            }
1318        }
1319        insert_adder!(BallRigidBody);
1320        insert_adder!(BallRigidBodyNoVelocities);
1321        insert_adder!(VelocityAddedBallRigidBodyNoVelocities);
1322        insert_adder!(InterpolatedBallRigidBodyNoVelocities, 0.0);
1323        insert_adder!(SecondsRemaining);
1324        insert_adder!(CurrentTime);
1325        insert_adder!(FrameTime);
1326        m
1327    };
1328    static ref NAME_TO_PLAYER_FEATURE_ADDER: std::collections::HashMap<
1329        &'static str,
1330        Arc<dyn PlayerFeatureAdder<f32> + Send + Sync + 'static>,
1331    > = {
1332        let mut m: std::collections::HashMap<
1333            &'static str,
1334            Arc<dyn PlayerFeatureAdder<f32> + Send + Sync + 'static>,
1335        > = std::collections::HashMap::new();
1336        macro_rules! insert_adder {
1337            ($adder_name:ident, $( $arguments:expr ),*) => {
1338                m.insert(stringify!($adder_name), $adder_name::<f32>::arc_new($ ( $arguments ),*));
1339            };
1340            ($adder_name:ident) => {
1341                insert_adder!($adder_name,)
1342            };
1343        }
1344        insert_adder!(PlayerRigidBody);
1345        insert_adder!(PlayerRigidBodyNoVelocities);
1346        insert_adder!(VelocityAddedPlayerRigidBodyNoVelocities);
1347        insert_adder!(InterpolatedPlayerRigidBodyNoVelocities, 0.003);
1348        insert_adder!(PlayerBoost);
1349        insert_adder!(PlayerJump);
1350        insert_adder!(PlayerAnyJump);
1351        insert_adder!(PlayerDemolishedBy);
1352        m
1353    };
1354}