rustitude_core/dataset.rs
1//! This module contains all the resources needed to load and examine datasets.
2//!
3//! A [`Dataset`] is, in essence, a list of [`Event`]s, each of which contain all the pertinent
4//! information about a single set of initial- and final-state particles, as well as an index
5//! and weight within the [`Dataset`].
6//!
7//! This crate currently supports loading [`Dataset`]s from ROOT and Parquet files (see
8//! [`Dataset::from_root`] and [`Dataset::from_parquet`]. These methods require the following
9//! "branches" or "columns" to be present in the file:
10//!
11//! | Branch Name | Data Type | Notes |
12//! |---|---|---|
13//! | `Weight` | Float32 | |
14//! | `E_Beam` | Float32 | |
15//! | `Px_Beam` | Float32 | |
16//! | `Py_Beam` | Float32 | |
17//! | `Pz_Beam` | Float32 | |
18//! | `E_FinalState` | \[Float32\] | \[recoil, daughter #1, daughter #2, ...\] |
19//! | `Px_FinalState` | \[Float32\] | \[recoil, daughter #1, daughter #2, ...\] |
20//! | `Py_FinalState` | \[Float32\] | \[recoil, daughter #1, daughter #2, ...\] |
21//! | `Pz_FinalState` | \[Float32\] | \[recoil, daughter #1, daughter #2, ...\] |
22//! | `EPS` | \[Float32\] | \[$`P_\gamma \cos(\Phi)`$, $`P_\gamma \sin(\Phi)`$, $`0.0`$\] for linear polarization with magnitude $`P_\gamma`$ and angle $`\Phi`$ |
23//!
24//! The `EPS` branch is optional and files without such a branch can be loaded under the
25//! following conditions. First, if we don't care about polarization, and wish to set `EPS` =
26//! `[0.0, 0.0, 0.0]`, we can do so using the methods [`ReadMethod::EPS(0.0, 0.0, 0.0)`]. If
27//! a data file contains events with only one polarization, we can compute the `EPS` vector
28//! ourselves and use [`ReadMethod::EPS(x, y, z)`] to load the same vector for every event.
29//! Finally, to provide compatibility with the way polarization is sometimes included in
30//! `AmpTools` files, we can note that the beam is often only moving along the
31//! $`z`$-axis, so the $`x`$ and $`y`$ components are typically `0.0` anyway, so we can store
32//! the $`x`$, $`y`$, and $`z`$ components of `EPS` in the beam's three-momentum and use the
33//! [`ReadMethod::EPSInBeam`] to extract it. All of these methods are used as an input for either
34//! [`Dataset::from_parquet`] or [`Dataset::from_root`].
35//!
36//! There are also several methods used to split up [`Dataset`]s based on their component
37//! values. The [`Dataset::get_selected_indices`] method returns a `Vec<usize>` of event indices
38//! corresponding to events for which some input query returns `True`.
39//!
40//! Often, we want to use a query to divide data into many bins, so there is a method
41//! [`Dataset::get_binned_indices`] which will bin data by a query which takes an [`Event`] and
42//! returns an [`Field`] value (rather than a [`bool`]).
43//!
44//! This method also takes a `range: (Field, Field)` and a number of bins `nbins: usize`, and it
45//! returns a `(Vec<Vec<usize>>, Vec<usize>, Vec<usize>)`. These fields correspond to the binned
46//! datasets, the underflow bin, and the overflow bin respectively, so no data should ever be
47//! "lost" by this operation. There is also a convenience method, [`Dataset::split_m`], to split
48//! the dataset by the mass of the summed four-momentum of any of the daughter particles,
49//! specified by their index.
50use std::ops::Add;
51use std::{fmt::Display, fs::File, iter::repeat_with, path::Path, sync::Arc};
52
53use itertools::{izip, Either, Itertools};
54use nalgebra::Vector3;
55use oxyroot::{ReaderTree, RootFile, Slice};
56use parquet::record::Field as ParquetField;
57use parquet::{
58 file::reader::{FileReader, SerializedFileReader},
59 record::Row,
60};
61use rayon::prelude::*;
62use tracing::info;
63
64use crate::convert;
65use crate::{errors::RustitudeError, prelude::FourMomentum, Field};
66
67/// The [`Event`] struct contains all the information concerning a single interaction between
68/// particles in the experiment. See the individual fields for additional information.
69#[derive(Debug, Default, Clone)]
70pub struct Event<F: Field + 'static> {
71 /// The index of the event with the parent [`Dataset`].
72 pub index: usize,
73 /// The weight of the event with the parent [`Dataset`].
74 pub weight: F,
75 /// The beam [`FourMomentum`].
76 pub beam_p4: FourMomentum<F>,
77 /// The recoil (target after interaction) [`FourMomentum`].
78 pub recoil_p4: FourMomentum<F>,
79 /// [`FourMomentum`] of each other final state particle.
80 pub daughter_p4s: Vec<FourMomentum<F>>,
81 /// A vector corresponding to the polarization of the beam.
82 pub eps: Vector3<F>,
83}
84
85impl<F: Field + 'static> Display for Event<F> {
86 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
87 writeln!(f, "Index: {}", self.index)?;
88 writeln!(f, "Weight: {}", self.weight)?;
89 writeln!(f, "Beam P4: {}", self.beam_p4)?;
90 writeln!(f, "Recoil P4: {}", self.recoil_p4)?;
91 writeln!(f, "Daughters:")?;
92 for (i, p4) in self.daughter_p4s.iter().enumerate() {
93 writeln!(f, "\t{i} -> {p4}")?;
94 }
95 writeln!(
96 f,
97 "EPS: [{}, {}, {}]",
98 self.eps[0], self.eps[1], self.eps[2]
99 )?;
100 Ok(())
101 }
102}
103
104/// An enum which lists various methods used to read data into [`Event`]s.
105#[derive(Copy, Clone)]
106pub enum ReadMethod<F: Field> {
107 /// The "standard" method assumes an `EPS` column/branch to read.
108 Standard,
109 /// This variant assumes the EPS vec is stored as the beam's 3-momentum.
110 EPSInBeam,
111 /// This variant can be used to provide a custom EPS vec for all events.
112 EPS(F, F, F),
113}
114impl<F: Field> ReadMethod<F> {
115 /// Creates the EPS vector from a polarization magnitude and angle (in radians).
116 pub fn from_linear_polarization(p_gamma: F, phi: F) -> Self {
117 Self::EPS(p_gamma * F::cos(phi), p_gamma * F::sin(phi), F::zero())
118 }
119}
120impl<F: Field> Event<F> {
121 /// Returns the magnitude of the EPS vector
122 pub fn eps_mag(&self) -> F {
123 F::sqrt(F::powi(self.eps.x, 2) + F::powi(self.eps.y, 2) + F::powi(self.eps.z, 2))
124 }
125 /// Reads an [`Event`] from a single [`Row`] in a Parquet file.
126 ///
127 /// # Panics
128 ///
129 /// This method currently panics if the list-like group types don't contain floats. This
130 /// eventually needs to be sorted out.
131 fn read_parquet_row(
132 index: usize,
133 row: Result<Row, parquet::errors::ParquetError>,
134 method: ReadMethod<F>,
135 ) -> Result<Self, RustitudeError> {
136 let mut event = Self {
137 index,
138 ..Default::default()
139 };
140 let mut e_fs: Vec<F> = Vec::new();
141 let mut px_fs: Vec<F> = Vec::new();
142 let mut py_fs: Vec<F> = Vec::new();
143 let mut pz_fs: Vec<F> = Vec::new();
144 for (name, field) in row?.get_column_iter() {
145 match (name.as_str(), field) {
146 ("E_Beam", ParquetField::Float(value)) => {
147 event.beam_p4.set_e(convert!(*value, F));
148 if matches!(method, ReadMethod::EPSInBeam) {
149 event.beam_p4.set_pz(convert!(*value, F));
150 }
151 }
152 ("Px_Beam", ParquetField::Float(value)) => {
153 if matches!(method, ReadMethod::EPSInBeam) {
154 event.eps[0] = convert!(*value, F);
155 } else {
156 event.beam_p4.set_px(convert!(*value, F));
157 }
158 }
159 ("Py_Beam", ParquetField::Float(value)) => {
160 if matches!(method, ReadMethod::EPSInBeam) {
161 event.eps[1] = convert!(*value, F);
162 } else {
163 event.beam_p4.set_py(convert!(*value, F));
164 }
165 }
166 ("Pz_Beam", ParquetField::Float(value)) => {
167 if !matches!(method, ReadMethod::EPSInBeam) {
168 event.beam_p4.set_pz(convert!(*value, F));
169 }
170 }
171 ("Weight", ParquetField::Float(value)) => {
172 event.weight = convert!(*value, F);
173 }
174 ("EPS", ParquetField::ListInternal(list)) => match method {
175 ReadMethod::Standard => {
176 event.eps = Vector3::from_vec(
177 list.elements()
178 .iter()
179 .map(|field| {
180 if let ParquetField::Float(value) = field {
181 convert!(*value, F)
182 } else {
183 panic!()
184 }
185 })
186 .collect(),
187 );
188 }
189 ReadMethod::EPS(x, y, z) => *event.eps = *Vector3::new(x, y, z),
190 _ => {}
191 },
192 ("E_FinalState", ParquetField::ListInternal(list)) => {
193 e_fs = list
194 .elements()
195 .iter()
196 .map(|field| {
197 if let ParquetField::Float(value) = field {
198 convert!(*value, F)
199 } else {
200 panic!()
201 }
202 })
203 .collect()
204 }
205 ("Px_FinalState", ParquetField::ListInternal(list)) => {
206 px_fs = list
207 .elements()
208 .iter()
209 .map(|field| {
210 if let ParquetField::Float(value) = field {
211 convert!(*value, F)
212 } else {
213 panic!()
214 }
215 })
216 .collect()
217 }
218 ("Py_FinalState", ParquetField::ListInternal(list)) => {
219 py_fs = list
220 .elements()
221 .iter()
222 .map(|field| {
223 if let ParquetField::Float(value) = field {
224 convert!(*value, F)
225 } else {
226 panic!()
227 }
228 })
229 .collect()
230 }
231 ("Pz_FinalState", ParquetField::ListInternal(list)) => {
232 pz_fs = list
233 .elements()
234 .iter()
235 .map(|field| {
236 if let ParquetField::Float(value) = field {
237 convert!(*value, F)
238 } else {
239 panic!()
240 }
241 })
242 .collect()
243 }
244 _ => {}
245 }
246 }
247 event.recoil_p4 = FourMomentum::new(e_fs[0], px_fs[0], py_fs[0], pz_fs[0]);
248 event.daughter_p4s = e_fs[1..]
249 .iter()
250 .zip(px_fs[1..].iter())
251 .zip(py_fs[1..].iter())
252 .zip(pz_fs[1..].iter())
253 .map(|(((e, px), py), pz)| FourMomentum::new(*e, *px, *py, *pz))
254 .collect();
255 // let final_state_p4 = event.recoil_p4 + event.daughter_p4s.iter().sum();
256 // event.beam_p4 = event.beam_p4.boost_along(&final_state_p4);
257 // event.recoil_p4 = event.recoil_p4.boost_along(&final_state_p4);
258 // for dp4 in event.daughter_p4s.iter_mut() {
259 // *dp4 = dp4.boost_along(&final_state_p4);
260 // }
261 Ok(event)
262 }
263}
264
265/// An array of [`Event`]s with some helpful methods for accessing and parsing the data they
266/// contain.
267///
268/// A [`Dataset`] can be loaded from either Parquet and ROOT files using the corresponding
269/// `Dataset::from_*` methods. Events are stored in an [`Arc<Vec<Event>>`], since we
270/// rarely need to write data to a dataset (splitting/selecting/rejecting events) but often need to
271/// read events from a dataset.
272#[derive(Default, Debug, Clone)]
273pub struct Dataset<F: Field + 'static> {
274 /// Storage for events.
275 pub events: Arc<Vec<Event<F>>>,
276}
277
278impl<F: Field + 'static> Dataset<F> {
279 /// Resets the indices of events in a dataset so they start at `0`.
280 pub fn reindex(&mut self) {
281 self.events = Arc::new(
282 (*self.events)
283 .clone()
284 .iter_mut()
285 .enumerate()
286 .map(|(i, event)| {
287 event.index = i;
288 event.clone()
289 })
290 .collect(),
291 )
292 }
293 // TODO: can we make an events(&self) -> &Vec<Field> method that actually works without cloning?
294
295 /// Retrieves the weights from the events in the dataset
296 pub fn weights(&self) -> Vec<F> {
297 self.events.iter().map(|e| e.weight).collect()
298 }
299
300 /// Retrieves the weights from the events in the dataset which have the given indices.
301 pub fn weights_indexed(&self, indices: &[usize]) -> Vec<F> {
302 indices
303 .iter()
304 .map(|index| self.events[*index].weight)
305 .collect()
306 }
307
308 /// Splits the dataset by the mass of the combination of specified daughter particles in the
309 /// event. If no daughters are given, the first and second particle are assumed to form the
310 /// desired combination. This method returns [`Vec<usize>`]s corresponding to the indices of
311 /// events in each bin, the underflow bin, and the overflow bin respectively. This is intended
312 /// to be used in conjunction with
313 /// [`Manager::evaluate_indexed`](`crate::manager::Manager::evaluate_indexed`).
314 pub fn split_m(
315 &self,
316 range: (F, F),
317 bins: usize,
318 daughter_indices: Option<Vec<usize>>,
319 ) -> (Vec<Vec<usize>>, Vec<usize>, Vec<usize>) {
320 let mass = |e: &Event<F>| {
321 let p4: FourMomentum<F> = daughter_indices
322 .clone()
323 .unwrap_or_else(|| vec![0, 1])
324 .iter()
325 .map(|i| e.daughter_p4s[*i])
326 .sum();
327 p4.m()
328 };
329 self.get_binned_indices(mass, range, bins)
330 }
331
332 /// Generates a new [`Dataset`] from a Parquet file.
333 ///
334 /// # Errors
335 ///
336 /// This method will fail if any individual event is missing all of the required fields, if
337 /// they have the wrong type, or if the file doesn't exist/can't be read for any reason.
338 pub fn from_parquet(path: &str, method: ReadMethod<F>) -> Result<Self, RustitudeError> {
339 let path = Path::new(path);
340 let file = File::open(path)?;
341 let reader = SerializedFileReader::new(file)?;
342 let row_iter = reader.get_row_iter(None)?;
343 Ok(Self::new(
344 row_iter
345 .enumerate()
346 .map(|(i, row)| Event::read_parquet_row(i, row, method))
347 .collect::<Result<Vec<Event<F>>, RustitudeError>>()?,
348 ))
349 }
350
351 /// Extract a branch from a ROOT `TTree` containing a [`Field`] (float in C). This method
352 /// converts the underlying element to an [`Field`].
353 fn extract_f32(path: &str, ttree: &ReaderTree, branch: &str) -> Result<Vec<F>, RustitudeError> {
354 let res = ttree
355 .branch(branch)
356 .ok_or_else(|| {
357 RustitudeError::OxyrootError(format!(
358 "Could not find {} branch in {}",
359 branch, path
360 ))
361 })?
362 .as_iter::<f64>()
363 .map_err(|err| RustitudeError::OxyrootError(err.to_string()))?
364 .map(|val| convert!(val, F))
365 .collect();
366 Ok(res)
367 }
368
369 /// Extract a branch from a ROOT `TTree` containing an array of [`Field`]s (floats in C). This
370 /// method converts the underlying elements to [`Field`]s.
371 fn extract_vec_f32(
372 path: &str,
373 ttree: &ReaderTree,
374 branch: &str,
375 ) -> Result<Vec<Vec<F>>, RustitudeError> {
376 let res: Vec<Vec<F>> = ttree
377 .branch(branch)
378 .ok_or_else(|| {
379 RustitudeError::OxyrootError(format!(
380 "Could not find {} branch in {}",
381 branch, path
382 ))
383 })?
384 .as_iter::<Slice<f64>>()
385 .map_err(|err| RustitudeError::OxyrootError(err.to_string()))?
386 .map(|v| {
387 v.into_vec()
388 .into_iter()
389 .map(|val| convert!(val, F))
390 .collect()
391 })
392 .collect();
393 Ok(res)
394 }
395
396 /// Generates a new [`Dataset`] from a ROOT file.
397 ///
398 /// # Errors
399 ///
400 /// This method will fail if any individual event is missing all of the required fields, if
401 /// they have the wrong type, or if the file doesn't exist/can't be read for any reason.
402 pub fn from_root(path: &str, method: ReadMethod<F>) -> Result<Self, RustitudeError> {
403 let ttree = RootFile::open(path)
404 .map_err(|err| RustitudeError::OxyrootError(err.to_string()))?
405 .get_tree("kin")
406 .map_err(|err| RustitudeError::OxyrootError(err.to_string()))?;
407 let weight: Vec<F> = Self::extract_f32(path, &ttree, "Weight")?;
408 let e_beam: Vec<F> = Self::extract_f32(path, &ttree, "E_Beam")?;
409 let px_beam: Vec<F> = Self::extract_f32(path, &ttree, "Px_Beam")?;
410 let py_beam: Vec<F> = Self::extract_f32(path, &ttree, "Py_Beam")?;
411 let pz_beam: Vec<F> = Self::extract_f32(path, &ttree, "Pz_Beam")?;
412 let e_fs: Vec<Vec<F>> = Self::extract_vec_f32(path, &ttree, "E_FinalState")?;
413 let px_fs: Vec<Vec<F>> = Self::extract_vec_f32(path, &ttree, "Px_FinalState")?;
414 let py_fs: Vec<Vec<F>> = Self::extract_vec_f32(path, &ttree, "Py_FinalState")?;
415 let pz_fs: Vec<Vec<F>> = Self::extract_vec_f32(path, &ttree, "Pz_FinalState")?;
416 let eps_extracted: Vec<Vec<F>> = if matches!(method, ReadMethod::Standard) {
417 Self::extract_vec_f32(path, &ttree, "EPS")?
418 } else {
419 vec![vec![F::zero(); 3]; weight.len()]
420 };
421 Ok(Self::new(
422 izip!(
423 weight,
424 e_beam,
425 px_beam,
426 py_beam,
427 pz_beam,
428 e_fs,
429 px_fs,
430 py_fs,
431 pz_fs,
432 eps_extracted
433 )
434 .enumerate()
435 .map(
436 |(i, (w, e_b, px_b, py_b, pz_b, e_f, px_f, py_f, pz_f, eps_vec))| {
437 let (beam_p4, eps) = match method {
438 ReadMethod::Standard => (
439 FourMomentum::new(e_b, px_b, py_b, pz_b),
440 Vector3::from_vec(eps_vec),
441 ),
442 ReadMethod::EPSInBeam => (
443 FourMomentum::new(e_b, F::zero(), F::zero(), e_b),
444 Vector3::new(px_b, py_b, pz_b),
445 ),
446 ReadMethod::EPS(x, y, z) => (
447 FourMomentum::new(e_b, px_b, py_b, pz_b),
448 Vector3::new(x, y, z),
449 ),
450 };
451 Event {
452 index: i,
453 weight: w,
454 beam_p4,
455 recoil_p4: FourMomentum::new(e_f[0], px_f[0], py_f[0], pz_f[0]),
456 daughter_p4s: izip!(
457 e_f[1..].iter(),
458 px_f[1..].iter(),
459 py_f[1..].iter(),
460 pz_f[1..].iter()
461 )
462 .map(|(e, px, py, pz)| FourMomentum::new(*e, *px, *py, *pz))
463 .collect(),
464 eps,
465 }
466 },
467 )
468 .collect(),
469 ))
470 }
471
472 /// Generate a new [`Dataset`] from a [`Vec<Event>`].
473 pub fn new(events: Vec<Event<F>>) -> Self {
474 info!("Dataset created with {} events", events.len());
475 Self {
476 events: Arc::new(events),
477 }
478 }
479
480 /// Checks if the dataset is empty.
481 pub fn is_empty(&self) -> bool {
482 self.events.is_empty()
483 }
484
485 /// Returns the number of events in the dataset.
486 pub fn len(&self) -> usize {
487 self.events.len()
488 }
489
490 /// Returns a set of indices which represent a bootstrapped [`Dataset`]. This method is to be
491 /// used in conjunction with
492 /// [`Manager::evaluate_indexed`](crate::manager::Manager::evaluate_indexed).
493 pub fn get_bootstrap_indices(&self, seed: usize) -> Vec<usize> {
494 fastrand::seed(seed as u64);
495 let mut inds: Vec<usize> = repeat_with(|| fastrand::usize(0..self.len()))
496 .take(self.len())
497 .collect();
498 inds.sort_unstable();
499 inds
500 }
501
502 /// Selects indices of events in a dataset using the given query. Indices of events for which
503 /// the query returns `true` will end up in the first member of the returned tuple, and indices
504 /// of events which return `false` will end up in the second member.
505 pub fn get_selected_indices(
506 &self,
507 query: impl Fn(&Event<F>) -> bool + Sync + Send,
508 ) -> (Vec<usize>, Vec<usize>) {
509 let (mut indices_selected, mut indices_rejected): (Vec<usize>, Vec<usize>) =
510 self.events.par_iter().partition_map(|event| {
511 if query(event) {
512 Either::Left(event.index)
513 } else {
514 Either::Right(event.index)
515 }
516 });
517 indices_selected.sort_unstable();
518 indices_rejected.sort_unstable();
519 (indices_selected, indices_rejected)
520 }
521
522 /// Splits the dataset by the given query. This method returns [`Vec<usize>`]s corresponding to
523 /// the indices of events in each bin, the underflow bin, and the overflow bin respectively.
524 /// This is intended to be used in conjunction with
525 /// [`Manager::evaluate_indexed`](`crate::manager::Manager::evaluate_indexed`).
526 pub fn get_binned_indices(
527 &self,
528 variable: impl Fn(&Event<F>) -> F + Sync + Send,
529 range: (F, F),
530 nbins: usize,
531 ) -> (Vec<Vec<usize>>, Vec<usize>, Vec<usize>) {
532 let mut bins: Vec<F> = Vec::with_capacity(nbins + 1);
533 let width = (range.1 - range.0) / convert!(nbins, F);
534 for m in 0..=nbins {
535 bins.push(F::mul_add(width, convert!(m, F), range.0));
536 }
537 let (underflow, _) = self.get_selected_indices(|event| variable(event) < bins[0]);
538 let (overflow, _) =
539 self.get_selected_indices(|event| variable(event) >= bins[bins.len() - 1]);
540 let binned_indices = bins
541 .into_iter()
542 .tuple_windows()
543 .map(|(lb, ub)| {
544 let (sel, _) = self.get_selected_indices(|event| {
545 let res = variable(event);
546 lb <= res && res < ub
547 });
548 sel
549 })
550 .collect();
551 (binned_indices, underflow, overflow)
552 }
553}
554
555impl<F: Field + 'static> Add for Dataset<F> {
556 type Output = Self;
557
558 fn add(self, other: Self) -> Self::Output {
559 let mut combined_events = Vec::with_capacity(self.events.len() + other.events.len());
560 combined_events.extend(Arc::try_unwrap(self.events).unwrap_or_else(|arc| (*arc).clone()));
561 combined_events.extend(Arc::try_unwrap(other.events).unwrap_or_else(|arc| (*arc).clone()));
562 Self {
563 events: Arc::new(combined_events),
564 }
565 }
566}