rustitude_core/
dataset.rs

1//! This module contains all the resources needed to load and examine datasets.
2//!
3//! A [`Dataset`] is, in essence, a list of [`Event`]s, each of which contain all the pertinent
4//! information about a single set of initial- and final-state particles, as well as an index
5//! and weight within the [`Dataset`].
6//!
7//! This crate currently supports loading [`Dataset`]s from ROOT and Parquet files (see
8//! [`Dataset::from_root`] and [`Dataset::from_parquet`]. These methods require the following
9//! "branches" or "columns" to be present in the file:
10//!
11//! | Branch Name | Data Type | Notes |
12//! |---|---|---|
13//! | `Weight` | Float32 |  |
14//! | `E_Beam` | Float32 |  |
15//! | `Px_Beam` | Float32 |  |
16//! | `Py_Beam` | Float32 |  |
17//! | `Pz_Beam` | Float32 |  |
18//! | `E_FinalState` | \[Float32\] | \[recoil, daughter #1, daughter #2, ...\] |
19//! | `Px_FinalState` | \[Float32\] | \[recoil, daughter #1, daughter #2, ...\] |
20//! | `Py_FinalState` | \[Float32\] | \[recoil, daughter #1, daughter #2, ...\] |
21//! | `Pz_FinalState` | \[Float32\] | \[recoil, daughter #1, daughter #2, ...\] |
22//! | `EPS` | \[Float32\] | \[$`P_\gamma \cos(\Phi)`$, $`P_\gamma \sin(\Phi)`$, $`0.0`$\] for linear polarization with magnitude $`P_\gamma`$ and angle $`\Phi`$ |
23//!
24//! The `EPS` branch is optional and files without such a branch can be loaded under the
25//! following conditions. First, if we don't care about polarization, and wish to set `EPS` =
26//! `[0.0, 0.0, 0.0]`, we can do so using the methods [`ReadMethod::EPS(0.0, 0.0, 0.0)`]. If
27//! a data file contains events with only one polarization, we can compute the `EPS` vector
28//! ourselves and use [`ReadMethod::EPS(x, y, z)`] to load the same vector for every event.
29//! Finally, to provide compatibility with the way polarization is sometimes included in
30//! `AmpTools` files, we can note that the beam is often only moving along the
31//! $`z`$-axis, so the $`x`$ and $`y`$ components are typically `0.0` anyway, so we can store
32//! the $`x`$, $`y`$, and $`z`$ components of `EPS` in the beam's three-momentum and use the
33//! [`ReadMethod::EPSInBeam`] to extract it. All of these methods are used as an input for either
34//! [`Dataset::from_parquet`] or [`Dataset::from_root`].
35//!
36//! There are also several methods used to split up [`Dataset`]s based on their component
37//! values. The [`Dataset::get_selected_indices`] method returns a `Vec<usize>` of event indices
38//! corresponding to events for which some input query returns `True`.
39//!
40//! Often, we want to use a query to divide data into many bins, so there is a method
41//! [`Dataset::get_binned_indices`] which will bin data by a query which takes an [`Event`] and
42//! returns an [`Field`] value (rather than a [`bool`]).
43//!
44//! This method also takes a `range: (Field, Field)` and a number of bins `nbins: usize`, and it
45//! returns a `(Vec<Vec<usize>>, Vec<usize>, Vec<usize>)`. These fields correspond to the binned
46//! datasets, the underflow bin, and the overflow bin respectively, so no data should ever be
47//! "lost" by this operation. There is also a convenience method, [`Dataset::split_m`], to split
48//! the dataset by the mass of the summed four-momentum of any of the daughter particles,
49//! specified by their index.
50use std::ops::Add;
51use std::{fmt::Display, fs::File, iter::repeat_with, path::Path, sync::Arc};
52
53use itertools::{izip, Either, Itertools};
54use nalgebra::Vector3;
55use oxyroot::{ReaderTree, RootFile, Slice};
56use parquet::record::Field as ParquetField;
57use parquet::{
58    file::reader::{FileReader, SerializedFileReader},
59    record::Row,
60};
61use rayon::prelude::*;
62use tracing::info;
63
64use crate::convert;
65use crate::{errors::RustitudeError, prelude::FourMomentum, Field};
66
67/// The [`Event`] struct contains all the information concerning a single interaction between
68/// particles in the experiment. See the individual fields for additional information.
69#[derive(Debug, Default, Clone)]
70pub struct Event<F: Field + 'static> {
71    /// The index of the event with the parent [`Dataset`].
72    pub index: usize,
73    /// The weight of the event with the parent [`Dataset`].
74    pub weight: F,
75    /// The beam [`FourMomentum`].
76    pub beam_p4: FourMomentum<F>,
77    /// The recoil (target after interaction) [`FourMomentum`].
78    pub recoil_p4: FourMomentum<F>,
79    /// [`FourMomentum`] of each other final state particle.
80    pub daughter_p4s: Vec<FourMomentum<F>>,
81    /// A vector corresponding to the polarization of the beam.
82    pub eps: Vector3<F>,
83}
84
85impl<F: Field + 'static> Display for Event<F> {
86    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
87        writeln!(f, "Index: {}", self.index)?;
88        writeln!(f, "Weight: {}", self.weight)?;
89        writeln!(f, "Beam P4: {}", self.beam_p4)?;
90        writeln!(f, "Recoil P4: {}", self.recoil_p4)?;
91        writeln!(f, "Daughters:")?;
92        for (i, p4) in self.daughter_p4s.iter().enumerate() {
93            writeln!(f, "\t{i} -> {p4}")?;
94        }
95        writeln!(
96            f,
97            "EPS: [{}, {}, {}]",
98            self.eps[0], self.eps[1], self.eps[2]
99        )?;
100        Ok(())
101    }
102}
103
104/// An enum which lists various methods used to read data into [`Event`]s.
105#[derive(Copy, Clone)]
106pub enum ReadMethod<F: Field> {
107    /// The "standard" method assumes an `EPS` column/branch to read.
108    Standard,
109    /// This variant assumes the EPS vec is stored as the beam's 3-momentum.
110    EPSInBeam,
111    /// This variant can be used to provide a custom EPS vec for all events.
112    EPS(F, F, F),
113}
114impl<F: Field> ReadMethod<F> {
115    /// Creates the EPS vector from a polarization magnitude and angle (in radians).
116    pub fn from_linear_polarization(p_gamma: F, phi: F) -> Self {
117        Self::EPS(p_gamma * F::cos(phi), p_gamma * F::sin(phi), F::zero())
118    }
119}
120impl<F: Field> Event<F> {
121    /// Returns the magnitude of the EPS vector
122    pub fn eps_mag(&self) -> F {
123        F::sqrt(F::powi(self.eps.x, 2) + F::powi(self.eps.y, 2) + F::powi(self.eps.z, 2))
124    }
125    /// Reads an [`Event`] from a single [`Row`] in a Parquet file.
126    ///
127    /// # Panics
128    ///
129    /// This method currently panics if the list-like group types don't contain floats. This
130    /// eventually needs to be sorted out.
131    fn read_parquet_row(
132        index: usize,
133        row: Result<Row, parquet::errors::ParquetError>,
134        method: ReadMethod<F>,
135    ) -> Result<Self, RustitudeError> {
136        let mut event = Self {
137            index,
138            ..Default::default()
139        };
140        let mut e_fs: Vec<F> = Vec::new();
141        let mut px_fs: Vec<F> = Vec::new();
142        let mut py_fs: Vec<F> = Vec::new();
143        let mut pz_fs: Vec<F> = Vec::new();
144        for (name, field) in row?.get_column_iter() {
145            match (name.as_str(), field) {
146                ("E_Beam", ParquetField::Float(value)) => {
147                    event.beam_p4.set_e(convert!(*value, F));
148                    if matches!(method, ReadMethod::EPSInBeam) {
149                        event.beam_p4.set_pz(convert!(*value, F));
150                    }
151                }
152                ("Px_Beam", ParquetField::Float(value)) => {
153                    if matches!(method, ReadMethod::EPSInBeam) {
154                        event.eps[0] = convert!(*value, F);
155                    } else {
156                        event.beam_p4.set_px(convert!(*value, F));
157                    }
158                }
159                ("Py_Beam", ParquetField::Float(value)) => {
160                    if matches!(method, ReadMethod::EPSInBeam) {
161                        event.eps[1] = convert!(*value, F);
162                    } else {
163                        event.beam_p4.set_py(convert!(*value, F));
164                    }
165                }
166                ("Pz_Beam", ParquetField::Float(value)) => {
167                    if !matches!(method, ReadMethod::EPSInBeam) {
168                        event.beam_p4.set_pz(convert!(*value, F));
169                    }
170                }
171                ("Weight", ParquetField::Float(value)) => {
172                    event.weight = convert!(*value, F);
173                }
174                ("EPS", ParquetField::ListInternal(list)) => match method {
175                    ReadMethod::Standard => {
176                        event.eps = Vector3::from_vec(
177                            list.elements()
178                                .iter()
179                                .map(|field| {
180                                    if let ParquetField::Float(value) = field {
181                                        convert!(*value, F)
182                                    } else {
183                                        panic!()
184                                    }
185                                })
186                                .collect(),
187                        );
188                    }
189                    ReadMethod::EPS(x, y, z) => *event.eps = *Vector3::new(x, y, z),
190                    _ => {}
191                },
192                ("E_FinalState", ParquetField::ListInternal(list)) => {
193                    e_fs = list
194                        .elements()
195                        .iter()
196                        .map(|field| {
197                            if let ParquetField::Float(value) = field {
198                                convert!(*value, F)
199                            } else {
200                                panic!()
201                            }
202                        })
203                        .collect()
204                }
205                ("Px_FinalState", ParquetField::ListInternal(list)) => {
206                    px_fs = list
207                        .elements()
208                        .iter()
209                        .map(|field| {
210                            if let ParquetField::Float(value) = field {
211                                convert!(*value, F)
212                            } else {
213                                panic!()
214                            }
215                        })
216                        .collect()
217                }
218                ("Py_FinalState", ParquetField::ListInternal(list)) => {
219                    py_fs = list
220                        .elements()
221                        .iter()
222                        .map(|field| {
223                            if let ParquetField::Float(value) = field {
224                                convert!(*value, F)
225                            } else {
226                                panic!()
227                            }
228                        })
229                        .collect()
230                }
231                ("Pz_FinalState", ParquetField::ListInternal(list)) => {
232                    pz_fs = list
233                        .elements()
234                        .iter()
235                        .map(|field| {
236                            if let ParquetField::Float(value) = field {
237                                convert!(*value, F)
238                            } else {
239                                panic!()
240                            }
241                        })
242                        .collect()
243                }
244                _ => {}
245            }
246        }
247        event.recoil_p4 = FourMomentum::new(e_fs[0], px_fs[0], py_fs[0], pz_fs[0]);
248        event.daughter_p4s = e_fs[1..]
249            .iter()
250            .zip(px_fs[1..].iter())
251            .zip(py_fs[1..].iter())
252            .zip(pz_fs[1..].iter())
253            .map(|(((e, px), py), pz)| FourMomentum::new(*e, *px, *py, *pz))
254            .collect();
255        // let final_state_p4 = event.recoil_p4 + event.daughter_p4s.iter().sum();
256        // event.beam_p4 = event.beam_p4.boost_along(&final_state_p4);
257        // event.recoil_p4 = event.recoil_p4.boost_along(&final_state_p4);
258        // for dp4 in event.daughter_p4s.iter_mut() {
259        //     *dp4 = dp4.boost_along(&final_state_p4);
260        // }
261        Ok(event)
262    }
263}
264
265/// An array of [`Event`]s with some helpful methods for accessing and parsing the data they
266/// contain.
267///
268/// A [`Dataset`] can be loaded from either Parquet and ROOT files using the corresponding
269/// `Dataset::from_*` methods. Events are stored in an [`Arc<Vec<Event>>`], since we
270/// rarely need to write data to a dataset (splitting/selecting/rejecting events) but often need to
271/// read events from a dataset.
272#[derive(Default, Debug, Clone)]
273pub struct Dataset<F: Field + 'static> {
274    /// Storage for events.
275    pub events: Arc<Vec<Event<F>>>,
276}
277
278impl<F: Field + 'static> Dataset<F> {
279    /// Resets the indices of events in a dataset so they start at `0`.
280    pub fn reindex(&mut self) {
281        self.events = Arc::new(
282            (*self.events)
283                .clone()
284                .iter_mut()
285                .enumerate()
286                .map(|(i, event)| {
287                    event.index = i;
288                    event.clone()
289                })
290                .collect(),
291        )
292    }
293    // TODO: can we make an events(&self) -> &Vec<Field> method that actually works without cloning?
294
295    /// Retrieves the weights from the events in the dataset
296    pub fn weights(&self) -> Vec<F> {
297        self.events.iter().map(|e| e.weight).collect()
298    }
299
300    /// Retrieves the weights from the events in the dataset which have the given indices.
301    pub fn weights_indexed(&self, indices: &[usize]) -> Vec<F> {
302        indices
303            .iter()
304            .map(|index| self.events[*index].weight)
305            .collect()
306    }
307
308    /// Splits the dataset by the mass of the combination of specified daughter particles in the
309    /// event. If no daughters are given, the first and second particle are assumed to form the
310    /// desired combination. This method returns [`Vec<usize>`]s corresponding to the indices of
311    /// events in each bin, the underflow bin, and the overflow bin respectively. This is intended
312    /// to be used in conjunction with
313    /// [`Manager::evaluate_indexed`](`crate::manager::Manager::evaluate_indexed`).
314    pub fn split_m(
315        &self,
316        range: (F, F),
317        bins: usize,
318        daughter_indices: Option<Vec<usize>>,
319    ) -> (Vec<Vec<usize>>, Vec<usize>, Vec<usize>) {
320        let mass = |e: &Event<F>| {
321            let p4: FourMomentum<F> = daughter_indices
322                .clone()
323                .unwrap_or_else(|| vec![0, 1])
324                .iter()
325                .map(|i| e.daughter_p4s[*i])
326                .sum();
327            p4.m()
328        };
329        self.get_binned_indices(mass, range, bins)
330    }
331
332    /// Generates a new [`Dataset`] from a Parquet file.
333    ///
334    /// # Errors
335    ///
336    /// This method will fail if any individual event is missing all of the required fields, if
337    /// they have the wrong type, or if the file doesn't exist/can't be read for any reason.
338    pub fn from_parquet(path: &str, method: ReadMethod<F>) -> Result<Self, RustitudeError> {
339        let path = Path::new(path);
340        let file = File::open(path)?;
341        let reader = SerializedFileReader::new(file)?;
342        let row_iter = reader.get_row_iter(None)?;
343        Ok(Self::new(
344            row_iter
345                .enumerate()
346                .map(|(i, row)| Event::read_parquet_row(i, row, method))
347                .collect::<Result<Vec<Event<F>>, RustitudeError>>()?,
348        ))
349    }
350
351    /// Extract a branch from a ROOT `TTree` containing a [`Field`] (float in C). This method
352    /// converts the underlying element to an [`Field`].
353    fn extract_f32(path: &str, ttree: &ReaderTree, branch: &str) -> Result<Vec<F>, RustitudeError> {
354        let res = ttree
355            .branch(branch)
356            .ok_or_else(|| {
357                RustitudeError::OxyrootError(format!(
358                    "Could not find {} branch in {}",
359                    branch, path
360                ))
361            })?
362            .as_iter::<f64>()
363            .map_err(|err| RustitudeError::OxyrootError(err.to_string()))?
364            .map(|val| convert!(val, F))
365            .collect();
366        Ok(res)
367    }
368
369    /// Extract a branch from a ROOT `TTree` containing an array of [`Field`]s (floats in C). This
370    /// method converts the underlying elements to [`Field`]s.
371    fn extract_vec_f32(
372        path: &str,
373        ttree: &ReaderTree,
374        branch: &str,
375    ) -> Result<Vec<Vec<F>>, RustitudeError> {
376        let res: Vec<Vec<F>> = ttree
377            .branch(branch)
378            .ok_or_else(|| {
379                RustitudeError::OxyrootError(format!(
380                    "Could not find {} branch in {}",
381                    branch, path
382                ))
383            })?
384            .as_iter::<Slice<f64>>()
385            .map_err(|err| RustitudeError::OxyrootError(err.to_string()))?
386            .map(|v| {
387                v.into_vec()
388                    .into_iter()
389                    .map(|val| convert!(val, F))
390                    .collect()
391            })
392            .collect();
393        Ok(res)
394    }
395
396    /// Generates a new [`Dataset`] from a ROOT file.
397    ///
398    /// # Errors
399    ///
400    /// This method will fail if any individual event is missing all of the required fields, if
401    /// they have the wrong type, or if the file doesn't exist/can't be read for any reason.
402    pub fn from_root(path: &str, method: ReadMethod<F>) -> Result<Self, RustitudeError> {
403        let ttree = RootFile::open(path)
404            .map_err(|err| RustitudeError::OxyrootError(err.to_string()))?
405            .get_tree("kin")
406            .map_err(|err| RustitudeError::OxyrootError(err.to_string()))?;
407        let weight: Vec<F> = Self::extract_f32(path, &ttree, "Weight")?;
408        let e_beam: Vec<F> = Self::extract_f32(path, &ttree, "E_Beam")?;
409        let px_beam: Vec<F> = Self::extract_f32(path, &ttree, "Px_Beam")?;
410        let py_beam: Vec<F> = Self::extract_f32(path, &ttree, "Py_Beam")?;
411        let pz_beam: Vec<F> = Self::extract_f32(path, &ttree, "Pz_Beam")?;
412        let e_fs: Vec<Vec<F>> = Self::extract_vec_f32(path, &ttree, "E_FinalState")?;
413        let px_fs: Vec<Vec<F>> = Self::extract_vec_f32(path, &ttree, "Px_FinalState")?;
414        let py_fs: Vec<Vec<F>> = Self::extract_vec_f32(path, &ttree, "Py_FinalState")?;
415        let pz_fs: Vec<Vec<F>> = Self::extract_vec_f32(path, &ttree, "Pz_FinalState")?;
416        let eps_extracted: Vec<Vec<F>> = if matches!(method, ReadMethod::Standard) {
417            Self::extract_vec_f32(path, &ttree, "EPS")?
418        } else {
419            vec![vec![F::zero(); 3]; weight.len()]
420        };
421        Ok(Self::new(
422            izip!(
423                weight,
424                e_beam,
425                px_beam,
426                py_beam,
427                pz_beam,
428                e_fs,
429                px_fs,
430                py_fs,
431                pz_fs,
432                eps_extracted
433            )
434            .enumerate()
435            .map(
436                |(i, (w, e_b, px_b, py_b, pz_b, e_f, px_f, py_f, pz_f, eps_vec))| {
437                    let (beam_p4, eps) = match method {
438                        ReadMethod::Standard => (
439                            FourMomentum::new(e_b, px_b, py_b, pz_b),
440                            Vector3::from_vec(eps_vec),
441                        ),
442                        ReadMethod::EPSInBeam => (
443                            FourMomentum::new(e_b, F::zero(), F::zero(), e_b),
444                            Vector3::new(px_b, py_b, pz_b),
445                        ),
446                        ReadMethod::EPS(x, y, z) => (
447                            FourMomentum::new(e_b, px_b, py_b, pz_b),
448                            Vector3::new(x, y, z),
449                        ),
450                    };
451                    Event {
452                        index: i,
453                        weight: w,
454                        beam_p4,
455                        recoil_p4: FourMomentum::new(e_f[0], px_f[0], py_f[0], pz_f[0]),
456                        daughter_p4s: izip!(
457                            e_f[1..].iter(),
458                            px_f[1..].iter(),
459                            py_f[1..].iter(),
460                            pz_f[1..].iter()
461                        )
462                        .map(|(e, px, py, pz)| FourMomentum::new(*e, *px, *py, *pz))
463                        .collect(),
464                        eps,
465                    }
466                },
467            )
468            .collect(),
469        ))
470    }
471
472    /// Generate a new [`Dataset`] from a [`Vec<Event>`].
473    pub fn new(events: Vec<Event<F>>) -> Self {
474        info!("Dataset created with {} events", events.len());
475        Self {
476            events: Arc::new(events),
477        }
478    }
479
480    /// Checks if the dataset is empty.
481    pub fn is_empty(&self) -> bool {
482        self.events.is_empty()
483    }
484
485    /// Returns the number of events in the dataset.
486    pub fn len(&self) -> usize {
487        self.events.len()
488    }
489
490    /// Returns a set of indices which represent a bootstrapped [`Dataset`]. This method is to be
491    /// used in conjunction with
492    /// [`Manager::evaluate_indexed`](crate::manager::Manager::evaluate_indexed).
493    pub fn get_bootstrap_indices(&self, seed: usize) -> Vec<usize> {
494        fastrand::seed(seed as u64);
495        let mut inds: Vec<usize> = repeat_with(|| fastrand::usize(0..self.len()))
496            .take(self.len())
497            .collect();
498        inds.sort_unstable();
499        inds
500    }
501
502    /// Selects indices of events in a dataset using the given query. Indices of events for which
503    /// the query returns `true` will end up in the first member of the returned tuple, and indices
504    /// of events which return `false` will end up in the second member.
505    pub fn get_selected_indices(
506        &self,
507        query: impl Fn(&Event<F>) -> bool + Sync + Send,
508    ) -> (Vec<usize>, Vec<usize>) {
509        let (mut indices_selected, mut indices_rejected): (Vec<usize>, Vec<usize>) =
510            self.events.par_iter().partition_map(|event| {
511                if query(event) {
512                    Either::Left(event.index)
513                } else {
514                    Either::Right(event.index)
515                }
516            });
517        indices_selected.sort_unstable();
518        indices_rejected.sort_unstable();
519        (indices_selected, indices_rejected)
520    }
521
522    /// Splits the dataset by the given query. This method returns [`Vec<usize>`]s corresponding to
523    /// the indices of events in each bin, the underflow bin, and the overflow bin respectively.
524    /// This is intended to be used in conjunction with
525    /// [`Manager::evaluate_indexed`](`crate::manager::Manager::evaluate_indexed`).
526    pub fn get_binned_indices(
527        &self,
528        variable: impl Fn(&Event<F>) -> F + Sync + Send,
529        range: (F, F),
530        nbins: usize,
531    ) -> (Vec<Vec<usize>>, Vec<usize>, Vec<usize>) {
532        let mut bins: Vec<F> = Vec::with_capacity(nbins + 1);
533        let width = (range.1 - range.0) / convert!(nbins, F);
534        for m in 0..=nbins {
535            bins.push(F::mul_add(width, convert!(m, F), range.0));
536        }
537        let (underflow, _) = self.get_selected_indices(|event| variable(event) < bins[0]);
538        let (overflow, _) =
539            self.get_selected_indices(|event| variable(event) >= bins[bins.len() - 1]);
540        let binned_indices = bins
541            .into_iter()
542            .tuple_windows()
543            .map(|(lb, ub)| {
544                let (sel, _) = self.get_selected_indices(|event| {
545                    let res = variable(event);
546                    lb <= res && res < ub
547                });
548                sel
549            })
550            .collect();
551        (binned_indices, underflow, overflow)
552    }
553}
554
555impl<F: Field + 'static> Add for Dataset<F> {
556    type Output = Self;
557
558    fn add(self, other: Self) -> Self::Output {
559        let mut combined_events = Vec::with_capacity(self.events.len() + other.events.len());
560        combined_events.extend(Arc::try_unwrap(self.events).unwrap_or_else(|arc| (*arc).clone()));
561        combined_events.extend(Arc::try_unwrap(other.events).unwrap_or_else(|arc| (*arc).clone()));
562        Self {
563            events: Arc::new(combined_events),
564        }
565    }
566}