subtr_actor/collector/ndarray/
collector.rs1use super::builtins::*;
2use super::traits::*;
3use crate::collector::{Collector, TimeAdvance};
4use crate::*;
5use ::ndarray;
6use boxcars;
7use serde::Serialize;
8use std::sync::Arc;
9
10#[derive(Debug, Clone, PartialEq, Serialize)]
12pub struct NDArrayColumnHeaders {
13 pub global_headers: Vec<String>,
15 pub player_headers: Vec<String>,
17}
18
19impl NDArrayColumnHeaders {
20 pub fn new(global_headers: Vec<String>, player_headers: Vec<String>) -> Self {
22 Self {
23 global_headers,
24 player_headers,
25 }
26 }
27}
28
29#[derive(Debug, Clone, PartialEq, Serialize)]
31pub struct ReplayMetaWithHeaders {
32 pub replay_meta: ReplayMeta,
34 pub column_headers: NDArrayColumnHeaders,
36}
37
38impl ReplayMetaWithHeaders {
39 pub fn headers_vec(&self) -> Vec<String> {
41 self.headers_vec_from(|_, _info, index| format!("Player {index} - "))
42 }
43
44 pub fn headers_vec_from<F>(&self, player_prefix_getter: F) -> Vec<String>
46 where
47 F: Fn(&Self, &PlayerInfo, usize) -> String,
48 {
49 self.column_headers
50 .global_headers
51 .iter()
52 .cloned()
53 .chain(self.replay_meta.player_order().enumerate().flat_map(
54 move |(player_index, info)| {
55 let player_prefix = player_prefix_getter(self, info, player_index);
56 self.column_headers
57 .player_headers
58 .iter()
59 .map(move |header| format!("{player_prefix}{header}"))
60 },
61 ))
62 .collect()
63 }
64}
65
66pub struct NDArrayCollector<F> {
68 feature_adders: FeatureAdders<F>,
69 player_feature_adders: PlayerFeatureAdders<F>,
70 data: Vec<F>,
71 replay_meta: Option<ReplayMeta>,
72 frames_added: usize,
73}
74
75impl<F> NDArrayCollector<F> {
76 pub fn new(
78 feature_adders: FeatureAdders<F>,
79 player_feature_adders: PlayerFeatureAdders<F>,
80 ) -> Self {
81 Self {
82 feature_adders,
83 player_feature_adders,
84 data: Vec::new(),
85 replay_meta: None,
86 frames_added: 0,
87 }
88 }
89
90 pub fn get_column_headers(&self) -> NDArrayColumnHeaders {
92 let global_headers = self
93 .feature_adders
94 .iter()
95 .flat_map(move |fa| {
96 fa.get_column_headers()
97 .iter()
98 .map(move |column_name| column_name.to_string())
99 })
100 .collect();
101 let player_headers = self
102 .player_feature_adders
103 .iter()
104 .flat_map(move |pfa| {
105 pfa.get_column_headers()
106 .iter()
107 .map(move |base_name| base_name.to_string())
108 })
109 .collect();
110 NDArrayColumnHeaders::new(global_headers, player_headers)
111 }
112
113 pub fn get_ndarray(self) -> SubtrActorResult<ndarray::Array2<F>> {
115 self.get_meta_and_ndarray().map(|a| a.1)
116 }
117
118 pub fn get_meta_and_ndarray(
120 self,
121 ) -> SubtrActorResult<(ReplayMetaWithHeaders, ndarray::Array2<F>)> {
122 let features_per_row = self.try_get_frame_feature_count()?;
123 let expected_length = features_per_row * self.frames_added;
124 assert!(self.data.len() == expected_length);
125 let column_headers = self.get_column_headers();
126 Ok((
127 ReplayMetaWithHeaders {
128 replay_meta: self.replay_meta.ok_or(SubtrActorError::new(
129 SubtrActorErrorVariant::CouldNotBuildReplayMeta,
130 ))?,
131 column_headers,
132 },
133 ndarray::Array2::from_shape_vec((self.frames_added, features_per_row), self.data)
134 .map_err(SubtrActorErrorVariant::NDArrayShapeError)
135 .map_err(SubtrActorError::new)?,
136 ))
137 }
138
139 pub fn process_and_get_meta_and_headers(
141 &mut self,
142 replay: &boxcars::Replay,
143 ) -> SubtrActorResult<ReplayMetaWithHeaders> {
144 let mut processor = ReplayProcessor::new(replay)?;
145 processor.process_long_enough_to_get_actor_ids()?;
146 self.maybe_set_replay_meta(&processor)?;
147 Ok(ReplayMetaWithHeaders {
148 replay_meta: self
149 .replay_meta
150 .as_ref()
151 .ok_or(SubtrActorError::new(
152 SubtrActorErrorVariant::CouldNotBuildReplayMeta,
153 ))?
154 .clone(),
155 column_headers: self.get_column_headers(),
156 })
157 }
158
159 fn try_get_frame_feature_count(&self) -> SubtrActorResult<usize> {
160 let player_count = self
161 .replay_meta
162 .as_ref()
163 .ok_or(SubtrActorError::new(
164 SubtrActorErrorVariant::CouldNotBuildReplayMeta,
165 ))?
166 .player_count();
167 let global_feature_count: usize = self
168 .feature_adders
169 .iter()
170 .map(|fa| fa.features_added())
171 .sum();
172 let player_feature_count: usize = self
173 .player_feature_adders
174 .iter()
175 .map(|pfa| pfa.features_added() * player_count)
176 .sum();
177 Ok(global_feature_count + player_feature_count)
178 }
179
180 fn maybe_set_replay_meta(&mut self, processor: &ReplayProcessor) -> SubtrActorResult<()> {
181 if self.replay_meta.is_none() {
182 self.replay_meta = Some(processor.get_replay_meta()?);
183 }
184 Ok(())
185 }
186}
187
188impl<F> Collector for NDArrayCollector<F> {
189 fn process_frame(
190 &mut self,
191 processor: &ReplayProcessor,
192 frame: &boxcars::Frame,
193 frame_number: usize,
194 current_time: f32,
195 ) -> SubtrActorResult<TimeAdvance> {
196 self.maybe_set_replay_meta(processor)?;
197
198 for feature_adder in &self.feature_adders {
199 feature_adder.add_features(
200 processor,
201 frame,
202 frame_number,
203 current_time,
204 &mut self.data,
205 )?;
206 }
207
208 for player_id in processor.iter_player_ids_in_order() {
209 for player_feature_adder in &self.player_feature_adders {
210 player_feature_adder.add_features(
211 player_id,
212 processor,
213 frame,
214 frame_number,
215 current_time,
216 &mut self.data,
217 )?;
218 }
219 }
220
221 self.frames_added += 1;
222
223 Ok(TimeAdvance::NextFrame)
224 }
225}
226
227fn global_feature_adder_from_name<F>(
228 name: &str,
229) -> Option<Arc<dyn FeatureAdder<F> + Send + Sync + 'static>>
230where
231 F: TryFrom<f32> + Send + Sync + 'static,
232 <F as TryFrom<f32>>::Error: std::fmt::Debug,
233{
234 match name {
235 "BallRigidBody" => Some(BallRigidBody::<F>::arc_new()),
236 "BallRigidBodyNoVelocities" => Some(BallRigidBodyNoVelocities::<F>::arc_new()),
237 "BallRigidBodyQuaternions" => Some(BallRigidBodyQuaternions::<F>::arc_new()),
238 "BallRigidBodyQuaternionVelocities" => {
239 Some(BallRigidBodyQuaternionVelocities::<F>::arc_new())
240 }
241 "BallRigidBodyBasis" => Some(BallRigidBodyBasis::<F>::arc_new()),
242 "VelocityAddedBallRigidBodyNoVelocities" => {
243 Some(VelocityAddedBallRigidBodyNoVelocities::<F>::arc_new())
244 }
245 "InterpolatedBallRigidBodyNoVelocities" => {
246 Some(InterpolatedBallRigidBodyNoVelocities::<F>::arc_new(0.0))
247 }
248 "SecondsRemaining" => Some(SecondsRemaining::<F>::arc_new()),
249 "CurrentTime" => Some(CurrentTime::<F>::arc_new()),
250 "FrameTime" => Some(FrameTime::<F>::arc_new()),
251 "ReplicatedStateName" => Some(ReplicatedStateName::<F>::arc_new()),
252 "ReplicatedGameStateTimeRemaining" => {
253 Some(ReplicatedGameStateTimeRemaining::<F>::arc_new())
254 }
255 "BallHasBeenHit" => Some(BallHasBeenHit::<F>::arc_new()),
256 _ => None,
257 }
258}
259
260fn player_feature_adder_from_name<F>(
261 name: &str,
262) -> Option<Arc<dyn PlayerFeatureAdder<F> + Send + Sync + 'static>>
263where
264 F: TryFrom<f32> + Send + Sync + 'static,
265 <F as TryFrom<f32>>::Error: std::fmt::Debug,
266{
267 match name {
268 "PlayerRigidBody" => Some(PlayerRigidBody::<F>::arc_new()),
269 "PlayerRigidBodyNoVelocities" => Some(PlayerRigidBodyNoVelocities::<F>::arc_new()),
270 "PlayerRigidBodyQuaternions" => Some(PlayerRigidBodyQuaternions::<F>::arc_new()),
271 "PlayerRigidBodyQuaternionVelocities" => {
272 Some(PlayerRigidBodyQuaternionVelocities::<F>::arc_new())
273 }
274 "PlayerRigidBodyBasis" => Some(PlayerRigidBodyBasis::<F>::arc_new()),
275 "PlayerRelativeBallPosition" => Some(PlayerRelativeBallPosition::<F>::arc_new()),
276 "PlayerRelativeBallVelocity" => Some(PlayerRelativeBallVelocity::<F>::arc_new()),
277 "PlayerLocalRelativeBallPosition" => Some(PlayerLocalRelativeBallPosition::<F>::arc_new()),
278 "PlayerLocalRelativeBallVelocity" => Some(PlayerLocalRelativeBallVelocity::<F>::arc_new()),
279 "VelocityAddedPlayerRigidBodyNoVelocities" => {
280 Some(VelocityAddedPlayerRigidBodyNoVelocities::<F>::arc_new())
281 }
282 "InterpolatedPlayerRigidBodyNoVelocities" => {
283 Some(InterpolatedPlayerRigidBodyNoVelocities::<F>::arc_new(0.003))
284 }
285 "PlayerBallDistance" | "PlayerDistanceToBall" => Some(PlayerBallDistance::<F>::arc_new()),
286 "PlayerBoost" => Some(PlayerBoost::<F>::arc_new()),
287 "PlayerJump" => Some(PlayerJump::<F>::arc_new()),
288 "PlayerAnyJump" => Some(PlayerAnyJump::<F>::arc_new()),
289 "PlayerDodgeRefreshed" => Some(PlayerDodgeRefreshed::<F>::arc_new()),
290 "PlayerDemolishedBy" => Some(PlayerDemolishedBy::<F>::arc_new()),
291 _ => None,
292 }
293}
294
295impl<F> NDArrayCollector<F>
296where
297 F: TryFrom<f32> + Send + Sync + 'static,
298 <F as TryFrom<f32>>::Error: std::fmt::Debug,
299{
300 pub fn from_strings_typed(fa_names: &[&str], pfa_names: &[&str]) -> SubtrActorResult<Self> {
302 let feature_adders: Vec<Arc<dyn FeatureAdder<F> + Send + Sync>> = fa_names
303 .iter()
304 .map(|name| {
305 global_feature_adder_from_name(name).ok_or_else(|| {
306 SubtrActorError::new(SubtrActorErrorVariant::UnknownFeatureAdderName(
307 name.to_string(),
308 ))
309 })
310 })
311 .collect::<SubtrActorResult<Vec<_>>>()?;
312 let player_feature_adders: Vec<Arc<dyn PlayerFeatureAdder<F> + Send + Sync>> = pfa_names
313 .iter()
314 .map(|name| {
315 player_feature_adder_from_name(name).ok_or_else(|| {
316 SubtrActorError::new(SubtrActorErrorVariant::UnknownFeatureAdderName(
317 name.to_string(),
318 ))
319 })
320 })
321 .collect::<SubtrActorResult<Vec<_>>>()?;
322 Ok(Self::new(feature_adders, player_feature_adders))
323 }
324}
325
326impl NDArrayCollector<f32> {
327 pub fn from_strings(fa_names: &[&str], pfa_names: &[&str]) -> SubtrActorResult<Self> {
329 Self::from_strings_typed(fa_names, pfa_names)
330 }
331}
332
333impl<F: TryFrom<f32> + Send + Sync + 'static> Default for NDArrayCollector<F>
334where
335 <F as TryFrom<f32>>::Error: std::fmt::Debug,
336{
337 fn default() -> Self {
338 NDArrayCollector::new(
339 vec![BallRigidBody::arc_new()],
340 vec![
341 PlayerRigidBody::arc_new(),
342 PlayerBoost::arc_new(),
343 PlayerAnyJump::arc_new(),
344 ],
345 )
346 }
347}