Skip to main content

subtr_actor/collector/ndarray/
collector.rs

1use super::builtins::*;
2use super::traits::*;
3use crate::collector::{Collector, TimeAdvance};
4use crate::*;
5use ::ndarray;
6use boxcars;
7use serde::Serialize;
8use std::sync::Arc;
9
10/// Column headers for the frame matrix emitted by [`NDArrayCollector`].
11#[derive(Debug, Clone, PartialEq, Serialize)]
12pub struct NDArrayColumnHeaders {
13    /// Column names emitted once per frame, independent of player ordering.
14    pub global_headers: Vec<String>,
15    /// Column names repeated once for each player in replay order.
16    pub player_headers: Vec<String>,
17}
18
19impl NDArrayColumnHeaders {
20    /// Builds a header set from global and per-player column names.
21    pub fn new(global_headers: Vec<String>, player_headers: Vec<String>) -> Self {
22        Self {
23            global_headers,
24            player_headers,
25        }
26    }
27}
28
29/// Replay metadata bundled with the ndarray column layout used to produce it.
30#[derive(Debug, Clone, PartialEq, Serialize)]
31pub struct ReplayMetaWithHeaders {
32    /// Replay metadata describing the teams and player ordering.
33    pub replay_meta: ReplayMeta,
34    /// Column headers associated with the emitted ndarray rows.
35    pub column_headers: NDArrayColumnHeaders,
36}
37
38impl ReplayMetaWithHeaders {
39    /// Flattens the global and per-player headers using a default player prefix.
40    pub fn headers_vec(&self) -> Vec<String> {
41        self.headers_vec_from(|_, _info, index| format!("Player {index} - "))
42    }
43
44    /// Flattens the global and per-player headers with a custom player prefix.
45    pub fn headers_vec_from<F>(&self, player_prefix_getter: F) -> Vec<String>
46    where
47        F: Fn(&Self, &PlayerInfo, usize) -> String,
48    {
49        self.column_headers
50            .global_headers
51            .iter()
52            .cloned()
53            .chain(self.replay_meta.player_order().enumerate().flat_map(
54                move |(player_index, info)| {
55                    let player_prefix = player_prefix_getter(self, info, player_index);
56                    self.column_headers
57                        .player_headers
58                        .iter()
59                        .map(move |header| format!("{player_prefix}{header}"))
60                },
61            ))
62            .collect()
63    }
64}
65
66/// Collects replay frames into a dense 2D feature matrix.
67pub struct NDArrayCollector<F> {
68    feature_adders: FeatureAdders<F>,
69    player_feature_adders: PlayerFeatureAdders<F>,
70    data: Vec<F>,
71    replay_meta: Option<ReplayMeta>,
72    frames_added: usize,
73}
74
75impl<F> NDArrayCollector<F> {
76    /// Creates a collector from explicit global and per-player feature adders.
77    pub fn new(
78        feature_adders: FeatureAdders<F>,
79        player_feature_adders: PlayerFeatureAdders<F>,
80    ) -> Self {
81        Self {
82            feature_adders,
83            player_feature_adders,
84            data: Vec::new(),
85            replay_meta: None,
86            frames_added: 0,
87        }
88    }
89
90    /// Returns the column headers implied by the configured feature adders.
91    pub fn get_column_headers(&self) -> NDArrayColumnHeaders {
92        let global_headers = self
93            .feature_adders
94            .iter()
95            .flat_map(move |fa| {
96                fa.get_column_headers()
97                    .iter()
98                    .map(move |column_name| column_name.to_string())
99            })
100            .collect();
101        let player_headers = self
102            .player_feature_adders
103            .iter()
104            .flat_map(move |pfa| {
105                pfa.get_column_headers()
106                    .iter()
107                    .map(move |base_name| base_name.to_string())
108            })
109            .collect();
110        NDArrayColumnHeaders::new(global_headers, player_headers)
111    }
112
113    /// Finalizes collection and returns only the ndarray payload.
114    pub fn get_ndarray(self) -> SubtrActorResult<ndarray::Array2<F>> {
115        self.get_meta_and_ndarray().map(|a| a.1)
116    }
117
118    /// Finalizes collection and returns replay metadata alongside the ndarray.
119    pub fn get_meta_and_ndarray(
120        self,
121    ) -> SubtrActorResult<(ReplayMetaWithHeaders, ndarray::Array2<F>)> {
122        let features_per_row = self.try_get_frame_feature_count()?;
123        let expected_length = features_per_row * self.frames_added;
124        assert!(self.data.len() == expected_length);
125        let column_headers = self.get_column_headers();
126        Ok((
127            ReplayMetaWithHeaders {
128                replay_meta: self.replay_meta.ok_or(SubtrActorError::new(
129                    SubtrActorErrorVariant::CouldNotBuildReplayMeta,
130                ))?,
131                column_headers,
132            },
133            ndarray::Array2::from_shape_vec((self.frames_added, features_per_row), self.data)
134                .map_err(SubtrActorErrorVariant::NDArrayShapeError)
135                .map_err(SubtrActorError::new)?,
136        ))
137    }
138
139    /// Processes enough of a replay to determine metadata and column headers.
140    pub fn process_and_get_meta_and_headers(
141        &mut self,
142        replay: &boxcars::Replay,
143    ) -> SubtrActorResult<ReplayMetaWithHeaders> {
144        let mut processor = ReplayProcessor::new(replay)?;
145        processor.process_long_enough_to_get_actor_ids()?;
146        self.maybe_set_replay_meta(&processor)?;
147        Ok(ReplayMetaWithHeaders {
148            replay_meta: self
149                .replay_meta
150                .as_ref()
151                .ok_or(SubtrActorError::new(
152                    SubtrActorErrorVariant::CouldNotBuildReplayMeta,
153                ))?
154                .clone(),
155            column_headers: self.get_column_headers(),
156        })
157    }
158
159    fn try_get_frame_feature_count(&self) -> SubtrActorResult<usize> {
160        let player_count = self
161            .replay_meta
162            .as_ref()
163            .ok_or(SubtrActorError::new(
164                SubtrActorErrorVariant::CouldNotBuildReplayMeta,
165            ))?
166            .player_count();
167        let global_feature_count: usize = self
168            .feature_adders
169            .iter()
170            .map(|fa| fa.features_added())
171            .sum();
172        let player_feature_count: usize = self
173            .player_feature_adders
174            .iter()
175            .map(|pfa| pfa.features_added() * player_count)
176            .sum();
177        Ok(global_feature_count + player_feature_count)
178    }
179
180    fn maybe_set_replay_meta(&mut self, processor: &ReplayProcessor) -> SubtrActorResult<()> {
181        if self.replay_meta.is_none() {
182            self.replay_meta = Some(processor.get_replay_meta()?);
183        }
184        Ok(())
185    }
186}
187
188impl<F> Collector for NDArrayCollector<F> {
189    fn process_frame(
190        &mut self,
191        processor: &ReplayProcessor,
192        frame: &boxcars::Frame,
193        frame_number: usize,
194        current_time: f32,
195    ) -> SubtrActorResult<TimeAdvance> {
196        self.maybe_set_replay_meta(processor)?;
197
198        for feature_adder in &self.feature_adders {
199            feature_adder.add_features(
200                processor,
201                frame,
202                frame_number,
203                current_time,
204                &mut self.data,
205            )?;
206        }
207
208        for player_id in processor.iter_player_ids_in_order() {
209            for player_feature_adder in &self.player_feature_adders {
210                player_feature_adder.add_features(
211                    player_id,
212                    processor,
213                    frame,
214                    frame_number,
215                    current_time,
216                    &mut self.data,
217                )?;
218            }
219        }
220
221        self.frames_added += 1;
222
223        Ok(TimeAdvance::NextFrame)
224    }
225}
226
227fn global_feature_adder_from_name<F>(
228    name: &str,
229) -> Option<Arc<dyn FeatureAdder<F> + Send + Sync + 'static>>
230where
231    F: TryFrom<f32> + Send + Sync + 'static,
232    <F as TryFrom<f32>>::Error: std::fmt::Debug,
233{
234    match name {
235        "BallRigidBody" => Some(BallRigidBody::<F>::arc_new()),
236        "BallRigidBodyNoVelocities" => Some(BallRigidBodyNoVelocities::<F>::arc_new()),
237        "BallRigidBodyQuaternions" => Some(BallRigidBodyQuaternions::<F>::arc_new()),
238        "BallRigidBodyQuaternionVelocities" => {
239            Some(BallRigidBodyQuaternionVelocities::<F>::arc_new())
240        }
241        "BallRigidBodyBasis" => Some(BallRigidBodyBasis::<F>::arc_new()),
242        "VelocityAddedBallRigidBodyNoVelocities" => {
243            Some(VelocityAddedBallRigidBodyNoVelocities::<F>::arc_new())
244        }
245        "InterpolatedBallRigidBodyNoVelocities" => {
246            Some(InterpolatedBallRigidBodyNoVelocities::<F>::arc_new(0.0))
247        }
248        "SecondsRemaining" => Some(SecondsRemaining::<F>::arc_new()),
249        "CurrentTime" => Some(CurrentTime::<F>::arc_new()),
250        "FrameTime" => Some(FrameTime::<F>::arc_new()),
251        "ReplicatedStateName" => Some(ReplicatedStateName::<F>::arc_new()),
252        "ReplicatedGameStateTimeRemaining" => {
253            Some(ReplicatedGameStateTimeRemaining::<F>::arc_new())
254        }
255        "BallHasBeenHit" => Some(BallHasBeenHit::<F>::arc_new()),
256        _ => None,
257    }
258}
259
260fn player_feature_adder_from_name<F>(
261    name: &str,
262) -> Option<Arc<dyn PlayerFeatureAdder<F> + Send + Sync + 'static>>
263where
264    F: TryFrom<f32> + Send + Sync + 'static,
265    <F as TryFrom<f32>>::Error: std::fmt::Debug,
266{
267    match name {
268        "PlayerRigidBody" => Some(PlayerRigidBody::<F>::arc_new()),
269        "PlayerRigidBodyNoVelocities" => Some(PlayerRigidBodyNoVelocities::<F>::arc_new()),
270        "PlayerRigidBodyQuaternions" => Some(PlayerRigidBodyQuaternions::<F>::arc_new()),
271        "PlayerRigidBodyQuaternionVelocities" => {
272            Some(PlayerRigidBodyQuaternionVelocities::<F>::arc_new())
273        }
274        "PlayerRigidBodyBasis" => Some(PlayerRigidBodyBasis::<F>::arc_new()),
275        "PlayerRelativeBallPosition" => Some(PlayerRelativeBallPosition::<F>::arc_new()),
276        "PlayerRelativeBallVelocity" => Some(PlayerRelativeBallVelocity::<F>::arc_new()),
277        "PlayerLocalRelativeBallPosition" => Some(PlayerLocalRelativeBallPosition::<F>::arc_new()),
278        "PlayerLocalRelativeBallVelocity" => Some(PlayerLocalRelativeBallVelocity::<F>::arc_new()),
279        "VelocityAddedPlayerRigidBodyNoVelocities" => {
280            Some(VelocityAddedPlayerRigidBodyNoVelocities::<F>::arc_new())
281        }
282        "InterpolatedPlayerRigidBodyNoVelocities" => {
283            Some(InterpolatedPlayerRigidBodyNoVelocities::<F>::arc_new(0.003))
284        }
285        "PlayerBallDistance" | "PlayerDistanceToBall" => Some(PlayerBallDistance::<F>::arc_new()),
286        "PlayerBoost" => Some(PlayerBoost::<F>::arc_new()),
287        "PlayerJump" => Some(PlayerJump::<F>::arc_new()),
288        "PlayerAnyJump" => Some(PlayerAnyJump::<F>::arc_new()),
289        "PlayerDodgeRefreshed" => Some(PlayerDodgeRefreshed::<F>::arc_new()),
290        "PlayerDemolishedBy" => Some(PlayerDemolishedBy::<F>::arc_new()),
291        _ => None,
292    }
293}
294
295impl<F> NDArrayCollector<F>
296where
297    F: TryFrom<f32> + Send + Sync + 'static,
298    <F as TryFrom<f32>>::Error: std::fmt::Debug,
299{
300    /// Builds a collector from the registered string names of feature adders.
301    pub fn from_strings_typed(fa_names: &[&str], pfa_names: &[&str]) -> SubtrActorResult<Self> {
302        let feature_adders: Vec<Arc<dyn FeatureAdder<F> + Send + Sync>> = fa_names
303            .iter()
304            .map(|name| {
305                global_feature_adder_from_name(name).ok_or_else(|| {
306                    SubtrActorError::new(SubtrActorErrorVariant::UnknownFeatureAdderName(
307                        name.to_string(),
308                    ))
309                })
310            })
311            .collect::<SubtrActorResult<Vec<_>>>()?;
312        let player_feature_adders: Vec<Arc<dyn PlayerFeatureAdder<F> + Send + Sync>> = pfa_names
313            .iter()
314            .map(|name| {
315                player_feature_adder_from_name(name).ok_or_else(|| {
316                    SubtrActorError::new(SubtrActorErrorVariant::UnknownFeatureAdderName(
317                        name.to_string(),
318                    ))
319                })
320            })
321            .collect::<SubtrActorResult<Vec<_>>>()?;
322        Ok(Self::new(feature_adders, player_feature_adders))
323    }
324}
325
326impl NDArrayCollector<f32> {
327    /// Builds an `f32` collector from the registered string names of feature adders.
328    pub fn from_strings(fa_names: &[&str], pfa_names: &[&str]) -> SubtrActorResult<Self> {
329        Self::from_strings_typed(fa_names, pfa_names)
330    }
331}
332
333impl<F: TryFrom<f32> + Send + Sync + 'static> Default for NDArrayCollector<F>
334where
335    <F as TryFrom<f32>>::Error: std::fmt::Debug,
336{
337    fn default() -> Self {
338        NDArrayCollector::new(
339            vec![BallRigidBody::arc_new()],
340            vec![
341                PlayerRigidBody::arc_new(),
342                PlayerBoost::arc_new(),
343                PlayerAnyJump::arc_new(),
344            ],
345        )
346    }
347}