1use super::builtins::*;
2use super::traits::*;
3use crate::collector::{Collector, TimeAdvance};
4use crate::stats::analysis_graph::{AnalysisDependency, AnalysisGraph};
5use crate::stats::calculators::{FrameInput, ReplayFrameInputBuilder};
6use crate::*;
7use ::ndarray;
8use boxcars;
9use serde::Serialize;
10
11#[derive(Debug, Clone, PartialEq, Serialize)]
13pub struct NDArrayColumnHeaders {
14 pub global_headers: Vec<String>,
16 pub player_headers: Vec<String>,
18}
19
20impl NDArrayColumnHeaders {
21 pub fn new(global_headers: Vec<String>, player_headers: Vec<String>) -> Self {
23 Self {
24 global_headers,
25 player_headers,
26 }
27 }
28}
29
30#[derive(Debug, Clone, PartialEq, Serialize)]
32pub struct ReplayMetaWithHeaders {
33 pub replay_meta: ReplayMeta,
35 pub column_headers: NDArrayColumnHeaders,
37}
38
39impl ReplayMetaWithHeaders {
40 pub fn headers_vec(&self) -> Vec<String> {
42 self.headers_vec_from(|_, _info, index| format!("Player {index} - "))
43 }
44
45 pub fn headers_vec_from<F>(&self, player_prefix_getter: F) -> Vec<String>
47 where
48 F: Fn(&Self, &PlayerInfo, usize) -> String,
49 {
50 self.column_headers
51 .global_headers
52 .iter()
53 .cloned()
54 .chain(self.replay_meta.player_order().enumerate().flat_map(
55 move |(player_index, info)| {
56 let player_prefix = player_prefix_getter(self, info, player_index);
57 self.column_headers
58 .player_headers
59 .iter()
60 .map(move |header| format!("{player_prefix}{header}"))
61 },
62 ))
63 .collect()
64 }
65}
66
67pub struct NDArrayCollector<F> {
69 feature_adders: NDArrayFeatureAdders<F>,
70 player_feature_adders: NDArrayPlayerFeatureAdders<F>,
71 analysis_runtime: Option<NDArrayAnalysisRuntime>,
72 data: Vec<F>,
73 replay_meta: Option<ReplayMeta>,
74 frames_added: usize,
75}
76
77struct NDArrayAnalysisRuntime {
78 graph: AnalysisGraph,
79 dependencies: Vec<AnalysisDependency>,
80 frame_input_builder: ReplayFrameInputBuilder,
81 last_sample_time: Option<f32>,
82 last_replay_meta_player_count: Option<usize>,
83}
84
85impl NDArrayAnalysisRuntime {
86 fn new(dependencies: Vec<AnalysisDependency>) -> Self {
87 let mut graph = AnalysisGraph::new();
88 graph.register_input_state::<FrameInput>();
89 Self {
90 graph,
91 dependencies,
92 frame_input_builder: ReplayFrameInputBuilder::default(),
93 last_sample_time: None,
94 last_replay_meta_player_count: None,
95 }
96 }
97
98 fn process_frame(
99 &mut self,
100 processor: &dyn ProcessorView,
101 frame_number: usize,
102 current_time: f32,
103 ) -> SubtrActorResult<()> {
104 let player_count = processor.player_count();
105 if self.last_replay_meta_player_count != Some(player_count) {
106 self.graph
107 .ensure_dependencies(self.dependencies.iter().copied())?;
108 self.graph.on_replay_meta(&processor.get_replay_meta()?)?;
109 self.last_replay_meta_player_count = Some(player_count);
110 }
111
112 let dt = self
113 .last_sample_time
114 .map(|last_time| (current_time - last_time).max(0.0))
115 .unwrap_or(0.0);
116 let frame_input =
117 self.frame_input_builder
118 .aggregate(processor, frame_number, current_time, dt);
119 self.graph.evaluate_with_state(&frame_input)?;
120 self.last_sample_time = Some(current_time);
121 Ok(())
122 }
123
124 fn context(&self) -> AnalysisFeatureContext<'_> {
125 AnalysisFeatureContext::new(&self.graph)
126 }
127
128 fn finish_replay(&mut self) -> SubtrActorResult<()> {
129 self.graph.finish()
130 }
131}
132
133impl<F> NDArrayCollector<F> {
134 pub fn new(
136 feature_adders: NDArrayFeatureAdders<F>,
137 player_feature_adders: NDArrayPlayerFeatureAdders<F>,
138 ) -> Self {
139 let analysis_dependencies = feature_adders
140 .iter()
141 .flat_map(NDArrayFeatureAdder::analysis_dependencies)
142 .chain(
143 player_feature_adders
144 .iter()
145 .flat_map(NDArrayPlayerFeatureAdder::analysis_dependencies),
146 )
147 .collect();
148 let uses_analysis = feature_adders
149 .iter()
150 .any(NDArrayFeatureAdder::is_analysis_backed)
151 || player_feature_adders
152 .iter()
153 .any(NDArrayPlayerFeatureAdder::is_analysis_backed);
154 Self {
155 feature_adders,
156 player_feature_adders,
157 analysis_runtime: uses_analysis
158 .then(|| NDArrayAnalysisRuntime::new(analysis_dependencies)),
159 data: Vec::new(),
160 replay_meta: None,
161 frames_added: 0,
162 }
163 }
164
165 pub fn get_column_headers(&self) -> NDArrayColumnHeaders {
167 let global_headers = self
168 .feature_adders
169 .iter()
170 .flat_map(move |fa| {
171 fa.get_column_headers()
172 .iter()
173 .map(move |column_name| column_name.to_string())
174 })
175 .collect();
176 let player_headers = self
177 .player_feature_adders
178 .iter()
179 .flat_map(move |pfa| {
180 pfa.get_column_headers()
181 .iter()
182 .map(move |base_name| base_name.to_string())
183 })
184 .collect();
185 NDArrayColumnHeaders::new(global_headers, player_headers)
186 }
187
188 pub fn get_ndarray(self) -> SubtrActorResult<ndarray::Array2<F>> {
190 self.get_meta_and_ndarray().map(|a| a.1)
191 }
192
193 pub fn get_meta_and_ndarray(
195 self,
196 ) -> SubtrActorResult<(ReplayMetaWithHeaders, ndarray::Array2<F>)> {
197 let features_per_row = self.try_get_frame_feature_count()?;
198 let expected_length = features_per_row * self.frames_added;
199 assert!(self.data.len() == expected_length);
200 let column_headers = self.get_column_headers();
201 Ok((
202 ReplayMetaWithHeaders {
203 replay_meta: self.replay_meta.ok_or_else(|| {
204 SubtrActorError::new(SubtrActorErrorVariant::CouldNotBuildReplayMeta)
205 })?,
206 column_headers,
207 },
208 ndarray::Array2::from_shape_vec((self.frames_added, features_per_row), self.data)
209 .map_err(SubtrActorErrorVariant::NDArrayShapeError)
210 .map_err(SubtrActorError::new)?,
211 ))
212 }
213
214 pub fn process_and_get_meta_and_headers(
216 &mut self,
217 replay: &boxcars::Replay,
218 ) -> SubtrActorResult<ReplayMetaWithHeaders> {
219 let mut processor = ReplayProcessor::new(replay)?;
220 processor.process_long_enough_to_get_actor_ids()?;
221 self.maybe_set_replay_meta(&processor)?;
222 Ok(ReplayMetaWithHeaders {
223 replay_meta: self
224 .replay_meta
225 .as_ref()
226 .ok_or_else(|| {
227 SubtrActorError::new(SubtrActorErrorVariant::CouldNotBuildReplayMeta)
228 })?
229 .clone(),
230 column_headers: self.get_column_headers(),
231 })
232 }
233
234 fn try_get_frame_feature_count(&self) -> SubtrActorResult<usize> {
235 let player_count = self
236 .replay_meta
237 .as_ref()
238 .ok_or_else(|| SubtrActorError::new(SubtrActorErrorVariant::CouldNotBuildReplayMeta))?
239 .player_count();
240 let global_feature_count: usize = self
241 .feature_adders
242 .iter()
243 .map(|fa| fa.features_added())
244 .sum();
245 let player_feature_count: usize = self
246 .player_feature_adders
247 .iter()
248 .map(|pfa| pfa.features_added() * player_count)
249 .sum();
250 Ok(global_feature_count + player_feature_count)
251 }
252
253 fn maybe_set_replay_meta(&mut self, processor: &dyn ProcessorView) -> SubtrActorResult<()> {
254 if self.replay_meta.is_none() {
255 self.replay_meta = Some(processor.get_replay_meta()?);
256 }
257 Ok(())
258 }
259}
260
261impl<F> Collector for NDArrayCollector<F> {
262 fn process_frame(
263 &mut self,
264 processor: &dyn ProcessorView,
265 frame: &boxcars::Frame,
266 frame_number: usize,
267 current_time: f32,
268 ) -> SubtrActorResult<TimeAdvance> {
269 self.maybe_set_replay_meta(processor)?;
270
271 if let Some(analysis_runtime) = self.analysis_runtime.as_mut() {
272 analysis_runtime.process_frame(processor, frame_number, current_time)?;
273 }
274 let analysis_context = self
275 .analysis_runtime
276 .as_ref()
277 .map(NDArrayAnalysisRuntime::context);
278
279 for feature_adder in &self.feature_adders {
280 match feature_adder {
281 NDArrayFeatureAdder::Plain(adder) => adder.add_features(
282 processor,
283 frame,
284 frame_number,
285 current_time,
286 &mut self.data,
287 )?,
288 NDArrayFeatureAdder::Analysis(adder) => adder.add_features(
289 analysis_context
290 .as_ref()
291 .expect("analysis runtime exists for analysis feature adders"),
292 processor,
293 frame,
294 frame_number,
295 current_time,
296 &mut self.data,
297 )?,
298 }
299 }
300
301 for player_id in processor.iter_player_ids_in_order() {
302 for player_feature_adder in &self.player_feature_adders {
303 match player_feature_adder {
304 NDArrayPlayerFeatureAdder::Plain(adder) => adder.add_features(
305 player_id,
306 processor,
307 frame,
308 frame_number,
309 current_time,
310 &mut self.data,
311 )?,
312 NDArrayPlayerFeatureAdder::Analysis(adder) => {
313 let context = analysis_context
314 .as_ref()
315 .expect("analysis runtime exists for analysis feature adders");
316 adder.add_features(
317 AnalysisPlayerFeatureInput {
318 context,
319 player_id,
320 processor,
321 frame,
322 frame_count: frame_number,
323 current_time,
324 },
325 &mut self.data,
326 )?
327 }
328 }
329 }
330 }
331
332 self.frames_added += 1;
333
334 Ok(TimeAdvance::NextFrame)
335 }
336
337 fn finish_replay(&mut self, _processor: &dyn ProcessorView) -> SubtrActorResult<()> {
338 if let Some(analysis_runtime) = self.analysis_runtime.as_mut() {
339 analysis_runtime.finish_replay()?;
340 }
341 Ok(())
342 }
343}
344
345fn global_feature_adder_from_name<F>(name: &str) -> Option<NDArrayFeatureAdder<F>>
346where
347 F: TryFrom<f32> + Send + Sync + 'static,
348 <F as TryFrom<f32>>::Error: std::fmt::Debug,
349{
350 match name {
351 "BallRigidBody" => Some(NDArrayFeatureAdder::plain(BallRigidBody::<F>::arc_new())),
352 "BallRigidBodyNoVelocities" => Some(NDArrayFeatureAdder::plain(
353 BallRigidBodyNoVelocities::<F>::arc_new(),
354 )),
355 "BallRigidBodyQuaternions" => Some(NDArrayFeatureAdder::plain(
356 BallRigidBodyQuaternions::<F>::arc_new(),
357 )),
358 "BallRigidBodyQuaternionVelocities" => Some(NDArrayFeatureAdder::plain(
359 BallRigidBodyQuaternionVelocities::<F>::arc_new(),
360 )),
361 "BallRigidBodyBasis" => Some(NDArrayFeatureAdder::plain(
362 BallRigidBodyBasis::<F>::arc_new(),
363 )),
364 "VelocityAddedBallRigidBodyNoVelocities" => Some(NDArrayFeatureAdder::plain(
365 VelocityAddedBallRigidBodyNoVelocities::<F>::arc_new(),
366 )),
367 "InterpolatedBallRigidBodyNoVelocities" => Some(NDArrayFeatureAdder::plain(
368 InterpolatedBallRigidBodyNoVelocities::<F>::arc_new(0.0),
369 )),
370 "SecondsRemaining" => Some(NDArrayFeatureAdder::plain(SecondsRemaining::<F>::arc_new())),
371 "CurrentTime" => Some(NDArrayFeatureAdder::plain(CurrentTime::<F>::arc_new())),
372 "FrameTime" => Some(NDArrayFeatureAdder::plain(FrameTime::<F>::arc_new())),
373 "ReplicatedStateName" => Some(NDArrayFeatureAdder::plain(
374 ReplicatedStateName::<F>::arc_new(),
375 )),
376 "ReplicatedGameStateTimeRemaining" => Some(NDArrayFeatureAdder::plain(
377 ReplicatedGameStateTimeRemaining::<F>::arc_new(),
378 )),
379 "BallHasBeenHit" => Some(NDArrayFeatureAdder::plain(BallHasBeenHit::<F>::arc_new())),
380 _ => None,
381 }
382}
383
384fn player_feature_adder_from_name<F>(name: &str) -> Option<NDArrayPlayerFeatureAdder<F>>
385where
386 F: TryFrom<f32> + Send + Sync + 'static,
387 <F as TryFrom<f32>>::Error: std::fmt::Debug,
388{
389 match name {
390 "PlayerRigidBody" => Some(NDArrayPlayerFeatureAdder::plain(
391 PlayerRigidBody::<F>::arc_new(),
392 )),
393 "PlayerRigidBodyNoVelocities" => Some(NDArrayPlayerFeatureAdder::plain(
394 PlayerRigidBodyNoVelocities::<F>::arc_new(),
395 )),
396 "PlayerRigidBodyQuaternions" => Some(NDArrayPlayerFeatureAdder::plain(
397 PlayerRigidBodyQuaternions::<F>::arc_new(),
398 )),
399 "PlayerRigidBodyQuaternionVelocities" => Some(NDArrayPlayerFeatureAdder::plain(
400 PlayerRigidBodyQuaternionVelocities::<F>::arc_new(),
401 )),
402 "PlayerRigidBodyBasis" => Some(NDArrayPlayerFeatureAdder::plain(
403 PlayerRigidBodyBasis::<F>::arc_new(),
404 )),
405 "PlayerRelativeBallPosition" => Some(NDArrayPlayerFeatureAdder::plain(
406 PlayerRelativeBallPosition::<F>::arc_new(),
407 )),
408 "PlayerRelativeBallVelocity" => Some(NDArrayPlayerFeatureAdder::plain(
409 PlayerRelativeBallVelocity::<F>::arc_new(),
410 )),
411 "PlayerLocalRelativeBallPosition" => Some(NDArrayPlayerFeatureAdder::plain(
412 PlayerLocalRelativeBallPosition::<F>::arc_new(),
413 )),
414 "PlayerLocalRelativeBallVelocity" => Some(NDArrayPlayerFeatureAdder::plain(
415 PlayerLocalRelativeBallVelocity::<F>::arc_new(),
416 )),
417 "VelocityAddedPlayerRigidBodyNoVelocities" => Some(NDArrayPlayerFeatureAdder::plain(
418 VelocityAddedPlayerRigidBodyNoVelocities::<F>::arc_new(),
419 )),
420 "InterpolatedPlayerRigidBodyNoVelocities" => Some(NDArrayPlayerFeatureAdder::plain(
421 InterpolatedPlayerRigidBodyNoVelocities::<F>::arc_new(0.003),
422 )),
423 "PlayerBallDistance" | "PlayerDistanceToBall" => Some(NDArrayPlayerFeatureAdder::plain(
424 PlayerBallDistance::<F>::arc_new(),
425 )),
426 "PlayerBoost" => Some(NDArrayPlayerFeatureAdder::plain(PlayerBoost::<F>::arc_new())),
427 "PlayerJump" => Some(NDArrayPlayerFeatureAdder::plain(PlayerJump::<F>::arc_new())),
428 "PlayerAnyJump" => Some(NDArrayPlayerFeatureAdder::plain(
429 PlayerAnyJump::<F>::arc_new(),
430 )),
431 "PlayerDodgeRefreshed" => Some(NDArrayPlayerFeatureAdder::plain(
432 PlayerDodgeRefreshed::<F>::arc_new(),
433 )),
434 "PlayerDemolishedBy" => Some(NDArrayPlayerFeatureAdder::plain(
435 PlayerDemolishedBy::<F>::arc_new(),
436 )),
437 _ => analysis_player_event_feature_adder_from_name(name),
438 }
439}
440
441impl<F> NDArrayCollector<F>
442where
443 F: TryFrom<f32> + Send + Sync + 'static,
444 <F as TryFrom<f32>>::Error: std::fmt::Debug,
445{
446 pub fn from_strings_typed(fa_names: &[&str], pfa_names: &[&str]) -> SubtrActorResult<Self> {
448 let feature_adders: NDArrayFeatureAdders<F> = fa_names
449 .iter()
450 .map(|name| {
451 global_feature_adder_from_name(name).ok_or_else(|| {
452 SubtrActorError::new(SubtrActorErrorVariant::UnknownFeatureAdderName(
453 name.to_string(),
454 ))
455 })
456 })
457 .collect::<SubtrActorResult<Vec<_>>>()?;
458 let player_feature_adders: NDArrayPlayerFeatureAdders<F> = pfa_names
459 .iter()
460 .map(|name| {
461 player_feature_adder_from_name(name).ok_or_else(|| {
462 SubtrActorError::new(SubtrActorErrorVariant::UnknownFeatureAdderName(
463 name.to_string(),
464 ))
465 })
466 })
467 .collect::<SubtrActorResult<Vec<_>>>()?;
468 Ok(Self::new(feature_adders, player_feature_adders))
469 }
470}
471
472impl NDArrayCollector<f32> {
473 pub fn from_strings(fa_names: &[&str], pfa_names: &[&str]) -> SubtrActorResult<Self> {
475 Self::from_strings_typed(fa_names, pfa_names)
476 }
477}
478
479impl<F: TryFrom<f32> + Send + Sync + 'static> Default for NDArrayCollector<F>
480where
481 <F as TryFrom<f32>>::Error: std::fmt::Debug,
482{
483 fn default() -> Self {
484 NDArrayCollector::new(
485 vec![NDArrayFeatureAdder::plain(BallRigidBody::arc_new())],
486 vec![
487 NDArrayPlayerFeatureAdder::plain(PlayerRigidBody::arc_new()),
488 NDArrayPlayerFeatureAdder::plain(PlayerBoost::arc_new()),
489 NDArrayPlayerFeatureAdder::plain(PlayerAnyJump::arc_new()),
490 ],
491 )
492 }
493}
494
495#[cfg(test)]
496#[path = "collector_tests.rs"]
497mod tests;