Skip to main content

subtr_actor/collector/ndarray/
collector.rs

1use super::builtins::*;
2use super::traits::*;
3use crate::collector::{Collector, TimeAdvance};
4use crate::stats::analysis_graph::{AnalysisDependency, AnalysisGraph};
5use crate::stats::calculators::{FrameInput, ReplayFrameInputBuilder};
6use crate::*;
7use ::ndarray;
8use boxcars;
9use serde::Serialize;
10
11/// Column headers for the frame matrix emitted by [`NDArrayCollector`].
12#[derive(Debug, Clone, PartialEq, Serialize)]
13pub struct NDArrayColumnHeaders {
14    /// Column names emitted once per frame, independent of player ordering.
15    pub global_headers: Vec<String>,
16    /// Column names repeated once for each player in replay order.
17    pub player_headers: Vec<String>,
18}
19
20impl NDArrayColumnHeaders {
21    /// Builds a header set from global and per-player column names.
22    pub fn new(global_headers: Vec<String>, player_headers: Vec<String>) -> Self {
23        Self {
24            global_headers,
25            player_headers,
26        }
27    }
28}
29
30/// Replay metadata bundled with the ndarray column layout used to produce it.
31#[derive(Debug, Clone, PartialEq, Serialize)]
32pub struct ReplayMetaWithHeaders {
33    /// Replay metadata describing the teams and player ordering.
34    pub replay_meta: ReplayMeta,
35    /// Column headers associated with the emitted ndarray rows.
36    pub column_headers: NDArrayColumnHeaders,
37}
38
39impl ReplayMetaWithHeaders {
40    /// Flattens the global and per-player headers using a default player prefix.
41    pub fn headers_vec(&self) -> Vec<String> {
42        self.headers_vec_from(|_, _info, index| format!("Player {index} - "))
43    }
44
45    /// Flattens the global and per-player headers with a custom player prefix.
46    pub fn headers_vec_from<F>(&self, player_prefix_getter: F) -> Vec<String>
47    where
48        F: Fn(&Self, &PlayerInfo, usize) -> String,
49    {
50        self.column_headers
51            .global_headers
52            .iter()
53            .cloned()
54            .chain(self.replay_meta.player_order().enumerate().flat_map(
55                move |(player_index, info)| {
56                    let player_prefix = player_prefix_getter(self, info, player_index);
57                    self.column_headers
58                        .player_headers
59                        .iter()
60                        .map(move |header| format!("{player_prefix}{header}"))
61                },
62            ))
63            .collect()
64    }
65}
66
67/// Collects replay frames into a dense 2D feature matrix.
68pub struct NDArrayCollector<F> {
69    feature_adders: NDArrayFeatureAdders<F>,
70    player_feature_adders: NDArrayPlayerFeatureAdders<F>,
71    analysis_runtime: Option<NDArrayAnalysisRuntime>,
72    data: Vec<F>,
73    replay_meta: Option<ReplayMeta>,
74    frames_added: usize,
75}
76
77struct NDArrayAnalysisRuntime {
78    graph: AnalysisGraph,
79    dependencies: Vec<AnalysisDependency>,
80    frame_input_builder: ReplayFrameInputBuilder,
81    last_sample_time: Option<f32>,
82    last_replay_meta_player_count: Option<usize>,
83}
84
85impl NDArrayAnalysisRuntime {
86    fn new(dependencies: Vec<AnalysisDependency>) -> Self {
87        let mut graph = AnalysisGraph::new();
88        graph.register_input_state::<FrameInput>();
89        Self {
90            graph,
91            dependencies,
92            frame_input_builder: ReplayFrameInputBuilder::default(),
93            last_sample_time: None,
94            last_replay_meta_player_count: None,
95        }
96    }
97
98    fn process_frame(
99        &mut self,
100        processor: &dyn ProcessorView,
101        frame_number: usize,
102        current_time: f32,
103    ) -> SubtrActorResult<()> {
104        let player_count = processor.player_count();
105        if self.last_replay_meta_player_count != Some(player_count) {
106            self.graph
107                .ensure_dependencies(self.dependencies.iter().copied())?;
108            self.graph.on_replay_meta(&processor.get_replay_meta()?)?;
109            self.last_replay_meta_player_count = Some(player_count);
110        }
111
112        let dt = self
113            .last_sample_time
114            .map(|last_time| (current_time - last_time).max(0.0))
115            .unwrap_or(0.0);
116        let frame_input =
117            self.frame_input_builder
118                .aggregate(processor, frame_number, current_time, dt);
119        self.graph.evaluate_with_state(&frame_input)?;
120        self.last_sample_time = Some(current_time);
121        Ok(())
122    }
123
124    fn context(&self) -> AnalysisFeatureContext<'_> {
125        AnalysisFeatureContext::new(&self.graph)
126    }
127
128    fn finish_replay(&mut self) -> SubtrActorResult<()> {
129        self.graph.finish()
130    }
131}
132
133impl<F> NDArrayCollector<F> {
134    /// Creates a collector from ordered global and per-player feature-adder specs.
135    pub fn new(
136        feature_adders: NDArrayFeatureAdders<F>,
137        player_feature_adders: NDArrayPlayerFeatureAdders<F>,
138    ) -> Self {
139        let analysis_dependencies = feature_adders
140            .iter()
141            .flat_map(NDArrayFeatureAdder::analysis_dependencies)
142            .chain(
143                player_feature_adders
144                    .iter()
145                    .flat_map(NDArrayPlayerFeatureAdder::analysis_dependencies),
146            )
147            .collect();
148        let uses_analysis = feature_adders
149            .iter()
150            .any(NDArrayFeatureAdder::is_analysis_backed)
151            || player_feature_adders
152                .iter()
153                .any(NDArrayPlayerFeatureAdder::is_analysis_backed);
154        Self {
155            feature_adders,
156            player_feature_adders,
157            analysis_runtime: uses_analysis
158                .then(|| NDArrayAnalysisRuntime::new(analysis_dependencies)),
159            data: Vec::new(),
160            replay_meta: None,
161            frames_added: 0,
162        }
163    }
164
165    /// Returns the column headers implied by the configured feature adders.
166    pub fn get_column_headers(&self) -> NDArrayColumnHeaders {
167        let global_headers = self
168            .feature_adders
169            .iter()
170            .flat_map(move |fa| {
171                fa.get_column_headers()
172                    .iter()
173                    .map(move |column_name| column_name.to_string())
174            })
175            .collect();
176        let player_headers = self
177            .player_feature_adders
178            .iter()
179            .flat_map(move |pfa| {
180                pfa.get_column_headers()
181                    .iter()
182                    .map(move |base_name| base_name.to_string())
183            })
184            .collect();
185        NDArrayColumnHeaders::new(global_headers, player_headers)
186    }
187
188    /// Finalizes collection and returns only the ndarray payload.
189    pub fn get_ndarray(self) -> SubtrActorResult<ndarray::Array2<F>> {
190        self.get_meta_and_ndarray().map(|a| a.1)
191    }
192
193    /// Finalizes collection and returns replay metadata alongside the ndarray.
194    pub fn get_meta_and_ndarray(
195        self,
196    ) -> SubtrActorResult<(ReplayMetaWithHeaders, ndarray::Array2<F>)> {
197        let features_per_row = self.try_get_frame_feature_count()?;
198        let expected_length = features_per_row * self.frames_added;
199        assert!(self.data.len() == expected_length);
200        let column_headers = self.get_column_headers();
201        Ok((
202            ReplayMetaWithHeaders {
203                replay_meta: self.replay_meta.ok_or_else(|| {
204                    SubtrActorError::new(SubtrActorErrorVariant::CouldNotBuildReplayMeta)
205                })?,
206                column_headers,
207            },
208            ndarray::Array2::from_shape_vec((self.frames_added, features_per_row), self.data)
209                .map_err(SubtrActorErrorVariant::NDArrayShapeError)
210                .map_err(SubtrActorError::new)?,
211        ))
212    }
213
214    /// Processes enough of a replay to determine metadata and column headers.
215    pub fn process_and_get_meta_and_headers(
216        &mut self,
217        replay: &boxcars::Replay,
218    ) -> SubtrActorResult<ReplayMetaWithHeaders> {
219        let mut processor = ReplayProcessor::new(replay)?;
220        processor.process_long_enough_to_get_actor_ids()?;
221        self.maybe_set_replay_meta(&processor)?;
222        Ok(ReplayMetaWithHeaders {
223            replay_meta: self
224                .replay_meta
225                .as_ref()
226                .ok_or_else(|| {
227                    SubtrActorError::new(SubtrActorErrorVariant::CouldNotBuildReplayMeta)
228                })?
229                .clone(),
230            column_headers: self.get_column_headers(),
231        })
232    }
233
234    fn try_get_frame_feature_count(&self) -> SubtrActorResult<usize> {
235        let player_count = self
236            .replay_meta
237            .as_ref()
238            .ok_or_else(|| SubtrActorError::new(SubtrActorErrorVariant::CouldNotBuildReplayMeta))?
239            .player_count();
240        let global_feature_count: usize = self
241            .feature_adders
242            .iter()
243            .map(|fa| fa.features_added())
244            .sum();
245        let player_feature_count: usize = self
246            .player_feature_adders
247            .iter()
248            .map(|pfa| pfa.features_added() * player_count)
249            .sum();
250        Ok(global_feature_count + player_feature_count)
251    }
252
253    fn maybe_set_replay_meta(&mut self, processor: &dyn ProcessorView) -> SubtrActorResult<()> {
254        if self.replay_meta.is_none() {
255            self.replay_meta = Some(processor.get_replay_meta()?);
256        }
257        Ok(())
258    }
259}
260
261impl<F> Collector for NDArrayCollector<F> {
262    fn process_frame(
263        &mut self,
264        processor: &dyn ProcessorView,
265        frame: &boxcars::Frame,
266        frame_number: usize,
267        current_time: f32,
268    ) -> SubtrActorResult<TimeAdvance> {
269        self.maybe_set_replay_meta(processor)?;
270
271        if let Some(analysis_runtime) = self.analysis_runtime.as_mut() {
272            analysis_runtime.process_frame(processor, frame_number, current_time)?;
273        }
274        let analysis_context = self
275            .analysis_runtime
276            .as_ref()
277            .map(NDArrayAnalysisRuntime::context);
278
279        for feature_adder in &self.feature_adders {
280            match feature_adder {
281                NDArrayFeatureAdder::Plain(adder) => adder.add_features(
282                    processor,
283                    frame,
284                    frame_number,
285                    current_time,
286                    &mut self.data,
287                )?,
288                NDArrayFeatureAdder::Analysis(adder) => adder.add_features(
289                    analysis_context
290                        .as_ref()
291                        .expect("analysis runtime exists for analysis feature adders"),
292                    processor,
293                    frame,
294                    frame_number,
295                    current_time,
296                    &mut self.data,
297                )?,
298            }
299        }
300
301        for player_id in processor.iter_player_ids_in_order() {
302            for player_feature_adder in &self.player_feature_adders {
303                match player_feature_adder {
304                    NDArrayPlayerFeatureAdder::Plain(adder) => adder.add_features(
305                        player_id,
306                        processor,
307                        frame,
308                        frame_number,
309                        current_time,
310                        &mut self.data,
311                    )?,
312                    NDArrayPlayerFeatureAdder::Analysis(adder) => {
313                        let context = analysis_context
314                            .as_ref()
315                            .expect("analysis runtime exists for analysis feature adders");
316                        adder.add_features(
317                            AnalysisPlayerFeatureInput {
318                                context,
319                                player_id,
320                                processor,
321                                frame,
322                                frame_count: frame_number,
323                                current_time,
324                            },
325                            &mut self.data,
326                        )?
327                    }
328                }
329            }
330        }
331
332        self.frames_added += 1;
333
334        Ok(TimeAdvance::NextFrame)
335    }
336
337    fn finish_replay(&mut self, _processor: &dyn ProcessorView) -> SubtrActorResult<()> {
338        if let Some(analysis_runtime) = self.analysis_runtime.as_mut() {
339            analysis_runtime.finish_replay()?;
340        }
341        Ok(())
342    }
343}
344
345fn global_feature_adder_from_name<F>(name: &str) -> Option<NDArrayFeatureAdder<F>>
346where
347    F: TryFrom<f32> + Send + Sync + 'static,
348    <F as TryFrom<f32>>::Error: std::fmt::Debug,
349{
350    match name {
351        "BallRigidBody" => Some(NDArrayFeatureAdder::plain(BallRigidBody::<F>::arc_new())),
352        "BallRigidBodyNoVelocities" => Some(NDArrayFeatureAdder::plain(
353            BallRigidBodyNoVelocities::<F>::arc_new(),
354        )),
355        "BallRigidBodyQuaternions" => Some(NDArrayFeatureAdder::plain(
356            BallRigidBodyQuaternions::<F>::arc_new(),
357        )),
358        "BallRigidBodyQuaternionVelocities" => Some(NDArrayFeatureAdder::plain(
359            BallRigidBodyQuaternionVelocities::<F>::arc_new(),
360        )),
361        "BallRigidBodyBasis" => Some(NDArrayFeatureAdder::plain(
362            BallRigidBodyBasis::<F>::arc_new(),
363        )),
364        "VelocityAddedBallRigidBodyNoVelocities" => Some(NDArrayFeatureAdder::plain(
365            VelocityAddedBallRigidBodyNoVelocities::<F>::arc_new(),
366        )),
367        "InterpolatedBallRigidBodyNoVelocities" => Some(NDArrayFeatureAdder::plain(
368            InterpolatedBallRigidBodyNoVelocities::<F>::arc_new(0.0),
369        )),
370        "SecondsRemaining" => Some(NDArrayFeatureAdder::plain(SecondsRemaining::<F>::arc_new())),
371        "CurrentTime" => Some(NDArrayFeatureAdder::plain(CurrentTime::<F>::arc_new())),
372        "FrameTime" => Some(NDArrayFeatureAdder::plain(FrameTime::<F>::arc_new())),
373        "ReplicatedStateName" => Some(NDArrayFeatureAdder::plain(
374            ReplicatedStateName::<F>::arc_new(),
375        )),
376        "ReplicatedGameStateTimeRemaining" => Some(NDArrayFeatureAdder::plain(
377            ReplicatedGameStateTimeRemaining::<F>::arc_new(),
378        )),
379        "BallHasBeenHit" => Some(NDArrayFeatureAdder::plain(BallHasBeenHit::<F>::arc_new())),
380        _ => None,
381    }
382}
383
384fn player_feature_adder_from_name<F>(name: &str) -> Option<NDArrayPlayerFeatureAdder<F>>
385where
386    F: TryFrom<f32> + Send + Sync + 'static,
387    <F as TryFrom<f32>>::Error: std::fmt::Debug,
388{
389    match name {
390        "PlayerRigidBody" => Some(NDArrayPlayerFeatureAdder::plain(
391            PlayerRigidBody::<F>::arc_new(),
392        )),
393        "PlayerRigidBodyNoVelocities" => Some(NDArrayPlayerFeatureAdder::plain(
394            PlayerRigidBodyNoVelocities::<F>::arc_new(),
395        )),
396        "PlayerRigidBodyQuaternions" => Some(NDArrayPlayerFeatureAdder::plain(
397            PlayerRigidBodyQuaternions::<F>::arc_new(),
398        )),
399        "PlayerRigidBodyQuaternionVelocities" => Some(NDArrayPlayerFeatureAdder::plain(
400            PlayerRigidBodyQuaternionVelocities::<F>::arc_new(),
401        )),
402        "PlayerRigidBodyBasis" => Some(NDArrayPlayerFeatureAdder::plain(
403            PlayerRigidBodyBasis::<F>::arc_new(),
404        )),
405        "PlayerRelativeBallPosition" => Some(NDArrayPlayerFeatureAdder::plain(
406            PlayerRelativeBallPosition::<F>::arc_new(),
407        )),
408        "PlayerRelativeBallVelocity" => Some(NDArrayPlayerFeatureAdder::plain(
409            PlayerRelativeBallVelocity::<F>::arc_new(),
410        )),
411        "PlayerLocalRelativeBallPosition" => Some(NDArrayPlayerFeatureAdder::plain(
412            PlayerLocalRelativeBallPosition::<F>::arc_new(),
413        )),
414        "PlayerLocalRelativeBallVelocity" => Some(NDArrayPlayerFeatureAdder::plain(
415            PlayerLocalRelativeBallVelocity::<F>::arc_new(),
416        )),
417        "VelocityAddedPlayerRigidBodyNoVelocities" => Some(NDArrayPlayerFeatureAdder::plain(
418            VelocityAddedPlayerRigidBodyNoVelocities::<F>::arc_new(),
419        )),
420        "InterpolatedPlayerRigidBodyNoVelocities" => Some(NDArrayPlayerFeatureAdder::plain(
421            InterpolatedPlayerRigidBodyNoVelocities::<F>::arc_new(0.003),
422        )),
423        "PlayerBallDistance" | "PlayerDistanceToBall" => Some(NDArrayPlayerFeatureAdder::plain(
424            PlayerBallDistance::<F>::arc_new(),
425        )),
426        "PlayerBoost" => Some(NDArrayPlayerFeatureAdder::plain(PlayerBoost::<F>::arc_new())),
427        "PlayerJump" => Some(NDArrayPlayerFeatureAdder::plain(PlayerJump::<F>::arc_new())),
428        "PlayerAnyJump" => Some(NDArrayPlayerFeatureAdder::plain(
429            PlayerAnyJump::<F>::arc_new(),
430        )),
431        "PlayerDodgeRefreshed" => Some(NDArrayPlayerFeatureAdder::plain(
432            PlayerDodgeRefreshed::<F>::arc_new(),
433        )),
434        "PlayerDemolishedBy" => Some(NDArrayPlayerFeatureAdder::plain(
435            PlayerDemolishedBy::<F>::arc_new(),
436        )),
437        _ => analysis_player_event_feature_adder_from_name(name),
438    }
439}
440
441impl<F> NDArrayCollector<F>
442where
443    F: TryFrom<f32> + Send + Sync + 'static,
444    <F as TryFrom<f32>>::Error: std::fmt::Debug,
445{
446    /// Builds a collector from the registered string names of feature adders.
447    pub fn from_strings_typed(fa_names: &[&str], pfa_names: &[&str]) -> SubtrActorResult<Self> {
448        let feature_adders: NDArrayFeatureAdders<F> = fa_names
449            .iter()
450            .map(|name| {
451                global_feature_adder_from_name(name).ok_or_else(|| {
452                    SubtrActorError::new(SubtrActorErrorVariant::UnknownFeatureAdderName(
453                        name.to_string(),
454                    ))
455                })
456            })
457            .collect::<SubtrActorResult<Vec<_>>>()?;
458        let player_feature_adders: NDArrayPlayerFeatureAdders<F> = pfa_names
459            .iter()
460            .map(|name| {
461                player_feature_adder_from_name(name).ok_or_else(|| {
462                    SubtrActorError::new(SubtrActorErrorVariant::UnknownFeatureAdderName(
463                        name.to_string(),
464                    ))
465                })
466            })
467            .collect::<SubtrActorResult<Vec<_>>>()?;
468        Ok(Self::new(feature_adders, player_feature_adders))
469    }
470}
471
472impl NDArrayCollector<f32> {
473    /// Builds an `f32` collector from the registered string names of feature adders.
474    pub fn from_strings(fa_names: &[&str], pfa_names: &[&str]) -> SubtrActorResult<Self> {
475        Self::from_strings_typed(fa_names, pfa_names)
476    }
477}
478
479impl<F: TryFrom<f32> + Send + Sync + 'static> Default for NDArrayCollector<F>
480where
481    <F as TryFrom<f32>>::Error: std::fmt::Debug,
482{
483    fn default() -> Self {
484        NDArrayCollector::new(
485            vec![NDArrayFeatureAdder::plain(BallRigidBody::arc_new())],
486            vec![
487                NDArrayPlayerFeatureAdder::plain(PlayerRigidBody::arc_new()),
488                NDArrayPlayerFeatureAdder::plain(PlayerBoost::arc_new()),
489                NDArrayPlayerFeatureAdder::plain(PlayerAnyJump::arc_new()),
490            ],
491        )
492    }
493}
494
495#[cfg(test)]
496#[path = "collector_tests.rs"]
497mod tests;