rustitude_core/
lib.rs

1//! # Rustitude
2//! ## Demystifying Amplitude Analysis with Rust and Python
3//!
4//! The `rustitude-core` crate aims to implement common amplitude analysis techniques in Rust with
5//! bindings to Python. This crate does not include the Python bindings, see the [GitHub
6//! repo](https://github.com/denehoffman/rustitude) for more information on the Python API.
7//!
8//! The three core principles of `rustitude-core` are:
9//! 1. Parallelization over events is automatically handeled by a [`Manager`](`crate::manager::Manager`).
10//! 2. Amplitudes are written to do as much work as possible ahead of time, and evaluations use
11//!    caching as much as possible automatically.
12//! 3. Developers just need to implement the [`Node`](`crate::amplitude::Node`) trait to write a new
13//!    amplitude, everything else is handled by the crate.
14//!
15//! ## Table of Contents
16//!
17//! * [Dataset Structure](#dataset-structure)
18//! * [Creating a New Amplitude](#creating-a-new-amplitude)
19//! * [Combining Amplitudes into Models](#combining-amplitudes-into-models)
20//! * [Managing Parameters](#managing-parameters)
21//! * [Evaluating Likelihoods](#evaluating-likelihoods)
22//!
23//! # Dataset Structure
24//!
25//! A [`Dataset`](`crate::dataset::Dataset`) is essentially just a wrapper for a [`Vec`] of
26//! [`Event`](`crate::dataset::Event`)s. The current [`Event`](`crate::dataset::Event`) structure is as follows:
27//!
28//! ```ignore
29//! pub struct Event {
30//!     pub index: usize,                    // Position of event within dataset
31//!     pub weight: f32,                     // Event weight
32//!     pub beam_p4: FourMomentum,           // Beam four-momentum
33//!     pub recoil_p4: FourMomentum,         // Recoil four-momentum
34//!     pub daughter_p4s: Vec<FourMomentum>, // Four-momenta of final state particles sans recoil
35//!     pub eps: Vector3<f32>,               // Beam polarization vector
36//! }
37//! ```
38//!
39//! In the Rust API, we can create [`Dataset`](`crate::dataset::Dataset`)s from `ROOT` files as well as
40//! `Parquet` files. `ROOT` file reading is done through [`oxyroot`] - This still has some issues,
41//! and large files or files with user metadata might fail to load. The alternative `Parquet`
42//! format can be obtained from a `ROOT` file by using a conversion script like the one provided
43//! [here](https://github.com/denehoffman/rustitude/blob/main/bin/convert). By default, we expect
44//! all of the [`Event`](`crate::dataset::Event`) fields to be mirrored as the following branches:
45//!
46//! | Branch Name | Data Type | Notes |
47//! |---|---|---|
48//! | `Weight` | Float32 |  |
49//! | `E_Beam` | Float32 |  |
50//! | `Px_Beam` | Float32 |  |
51//! | `Py_Beam` | Float32 |  |
52//! | `Pz_Beam` | Float32 |  |
53//! | `E_FinalState` | \[Float32\] | \[recoil, daughter #1, daughter #2, ...\] |
54//! | `Px_FinalState` | \[Float32\] | \[recoil, daughter #1, daughter #2, ...\] |
55//! | `Py_FinalState` | \[Float32\] | \[recoil, daughter #1, daughter #2, ...\] |
56//! | `Pz_FinalState` | \[Float32\] | \[recoil, daughter #1, daughter #2, ...\] |
57//! | `EPS` | \[Float32\] | \[$`P_\gamma \cos(\Phi)`$, $`P_\gamma \sin(\Phi)`$, $`0.0`$\] for linear polarization with magnitude $`P_\gamma`$ and angle $`\Phi`$ |
58//!
59//! A `Parquet` file with these columns can be loaded with the following:
60//! ```ignore
61//! use rustitude_core::prelude::*;
62//! fn main() -> Result<(), RustitudeError> {
63//!     let dataset = Dataset::from_parquet("path/to/file.parquet")?;
64//!     println!("{}", dataset.events()[0]); // print first event
65//! }
66//! ```
67//!
68//! Because the beam is often directed along the $`z`$-axis, there is an alternative way to store
69//! the `EPS` vector without a new branch (for linear polarization. The $`x`$ and $`y`$ components
70//! of `EPS` can be stored as `Px_Beam` and `Py_Beam` respectively, and the format can be loaded
71//! using [`Dataset::from_parquet`](`crate::dataset::Dataset::from_parquet`) with the
72//! [`ReadMethod::EPSInBeam`](`crate::dataset::ReadMethod::EPSInBeam`) method.
73//!
74//! # Creating a New Amplitude
75//!
76//! To make a new amplitude, we will first create a new struct and then implement
77//! [`Node`](`crate::amplitude::Node`). Let's start with a trivial example, an amplitude which returns a
78//! complex scalar. This particular amplitude is already implemented as a convenience struct called
79//! [`ComplexScalar`](`crate::amplitude::ComplexScalar`).
80//!
81//! ```ignore
82//! use rustitude_core::prelude::*;
83//!
84//! #[derive(Clone)]
85//! pub struct ComplexScalar;
86//! impl<F: Field> Node<F> for ComplexScalar {
87//!     fn calculate(&self, parameters: &[F], _event: &Event<F>) -> Result<Complex<F>, RustitudeError> {
88//!         Ok(Complex::new(parameters[0], parameters[1]))
89//!     }
90//!
91//!     fn parameters(&self) -> Vec<String> {
92//!         vec!["real".to_string(), "imag".to_string()]
93//!     }
94//! }
95//!
96//! ```
97//!
98//! For a second example, we can look at the precalculation feature. Here's an Dalitz-like
99//! amplitude for the $`\omega`$ particle:
100//! ```ignore
101//! use rayon::prelude::*;
102//! use rustitude_core::prelude::*;
103//!
104//! #[derive(Default, Clone)]
105//! pub struct OmegaDalitz<F: Field> {
106//!     dalitz_z: Vec<F>,
107//!     dalitz_sin3theta: Vec<F>,
108//!     lambda: Vec<F>,
109//! }
110//!
111//! impl<F: Field> Node<F> for OmegaDalitz<F> {
112//!     fn precalculate(&mut self, dataset: &Dataset<F>) -> Result<(), RustitudeError> {
113//!         (self.dalitz_z, (self.dalitz_sin3theta, self.lambda)) = dataset
114//!             .events
115//!             .par_iter()
116//!             .map(|event| {
117//!                 let pi0 = event.daughter_p4s[0];
118//!                 let pip = event.daughter_p4s[1];
119//!                 let pim = event.daughter_p4s[2];
120//!                 let omega = pi0 + pip + pim;
121//!
122//!                 let dalitz_s = (pip + pim).m2();
123//!                 let dalitz_t = (pip + pi0).m2();
124//!                 let dalitz_u = (pim + pi0).m2();
125//!
126//!                 let m3pi = (F::TWO * pip.m()) + pi0.m();
127//!                 let dalitz_d = F::TWO * omega.m() * (omega.m() - m3pi);
128//!                 let dalitz_sc = (F::ONE / F::THREE) * (omega.m2() + pip.m2() + pim.m2() + pi0.m2());
129//!                 let dalitz_x = F::fsqrt(F::THREE) * (dalitz_t - dalitz_u) / dalitz_d;
130//!                 let dalitz_y = F::THREE * (dalitz_sc - dalitz_s) / dalitz_d;
131//!
132//!                 let dalitz_z = dalitz_x * dalitz_x + dalitz_y * dalitz_y;
133//!                 let dalitz_sin3theta = F::fsin(F::THREE * F::fasin(dalitz_y / F::fsqrt(dalitz_z)));
134//!
135//!                 let pip_omega = pip.boost_along(&omega);
136//!                 let pim_omega = pim.boost_along(&omega);
137//!                 let pi_cross = pip_omega.momentum().cross(&pim_omega.momentum());
138//!
139//!                 let lambda = (F::FOUR / F::THREE) * F::fabs(pi_cross.dot(&pi_cross))
140//!                     / ((F::ONE / F::NINE)
141//!                         * (omega.m2() - (F::TWO * pip.m() + pi0.m()).fpowi(2)).fpowi(2));
142//!
143//!                 (dalitz_z, (dalitz_sin3theta, lambda))
144//!             })
145//!             .unzip();
146//!         Ok(())
147//!     }
148//!
149//!     fn calculate(&self, parameters: &[F], event: &Event<F>) -> Result<Complex<F>, RustitudeError> {
150//!         let dalitz_z = self.dalitz_z[event.index];
151//!         let dalitz_sin3theta = self.dalitz_sin3theta[event.index];
152//!         let lambda = self.lambda[event.index];
153//!         let alpha = parameters[0];
154//!         let beta = parameters[1];
155//!         let gamma = parameters[2];
156//!         let delta = parameters[3];
157//!         Ok(F::fsqrt(F::fabs(
158//!             lambda
159//!                 * (F::ONE
160//!                     + F::TWO * alpha * dalitz_z
161//!                     + F::TWO * beta * dalitz_z.fpowf(F::THREE / F::TWO) * dalitz_sin3theta
162//!                     + F::TWO * gamma * dalitz_z.fpowi(2)
163//!                     + F::TWO * delta * dalitz_z.fpowf(F::FIVE / F::TWO) * dalitz_sin3theta),
164//!         ))
165//!         .into())
166//!     }
167//!
168//!     fn parameters(&self) -> Vec<String> {
169//!         vec![
170//!             "alpha".to_string(),
171//!             "beta".to_string(),
172//!             "gamma".to_string(),
173//!             "delta".to_string(),
174//!         ]
175//!     }
176//! }
177//! ```
178//! Note several of the generic features which allow this amplitude to be used with different
179//! numeric data types. Because it isn't specifically written for 64-bit floats (`f64`s), we can
180//! conduct analyses that use the same code with 32-bit floats (`f32`s), which saves on memory and
181//! time while sacrificing a bit of precision. In fact, we can go a step further and conduct the
182//! majority of an analysis in 32-bit mode, switching over to 64-bit mode when we actually get near
183//! a solution and want the increased accuracy!
184//!
185//! The [`Field`] trait contains a few mathematical constants like
186//! [`Field::PI()`][`num::traits::FloatConst::PI()`] and
187//! [`Field::SQRT_2()`][`num::traits::FloatConst::SQRT_2()`] as well as traits which
188//! implement most standard mathematical functions. See the [`Float`] trait for more details.
189//!
190//! //! # Combining Amplitudes into Models
191//! We can use several operations to modify and combine amplitudes. Since amplitudes yield complex
192//! values, the following convenience methods are provided:
193//! [`real`](`amplitude::AmpLike::real`), and [`imag`](`amplitude::AmpLike::imag`) give the real and
194//! imaginary part of the amplitude, respectively. Additionally, amplitudes can be added and multiplied
195//! together using operator overloading. [`Model`](`amplitude::Model`)s implicitly take the
196//! absolute square of each provided term in their constructor and add those results incoherently.
197//!
198//! To incoherently sum two [`Amplitude`](`amplitude::Amplitude`)s, say `amp1` and `amp2`, we would
199//! first assume that we actually want the absolute square of the given term (or write our
200//! amplitude as the square root of what we really want), and then include them both in our model:
201//!
202//! ```ignore
203//! use rustitude_core::prelude::*;
204//! // Define amp1/amp2: Amplitude here...
205//! let model = model!(amp1, amp2)
206//! ```
207//!
208//! To reiterate, this would yield something like $`\left|\text{amp}_1\right|^2 + \left|\text{amp}_2\right|^2`$.
209//!
210//! The [`Scalar`](`crate::amplitude::Scalar`),
211//! [`ComplexScalar`](`crate::amplitude::ComplexScalar`), and
212//! [`PolarComplexScalar`](`crate::amplitude::PolarComplexScalar`) amplitudes all have convenience
213//! methods, [`scalar`](`crate::amplitude::scalar`), [`cscalar`](`crate::amplitude::cscalar`), and
214//! [`pcscalar`](`crate::amplitude::pcscalar`) respectively. We then wrap the final expression in a
215//! [`Model`](crate::amplitude::Model) which can manage all of the
216//! [`Parameter`](`crate::amplitude::Parameter`)s.
217//!
218//! ```ignore
219//! use rustitude_core::prelude::*;
220//!
221//! #[derive(Default)]
222//! pub struct OmegaDalitz { ... }
223//! impl Node for OmegaDalitz { ... }
224//!
225//! let complex_term = cscalar("my complex scalar");
226//! let omega_dalitz = Amplitude::new("omega dalitz", OmegaDalitz::default());
227//! let term = complex_term * omega_dalitz;
228//! term.print_tree();
229//! // [ norm sqr ]
230//! //   ┗━[ * ]
231//! //       ┣━ !my complex scalar(real, imag)
232//! //       ┗━ !omega dalitz(alpha, beta, gamma, delta)
233//! let model = model!(term);
234//! ```
235//!
236//! # Managing Parameters
237//!
238//! Now that we have a model, we might want to constrain or fix parameters. Parameters are
239//! identified solely by their name and the name of the amplitude they are associated with. This
240//! means that two amplitudes with the same name will share parameters which also have the same
241//! name. If we want to intentionally set one parameter in a particular amplitude equal to another,
242//! we can use the [`Model::constrain`](`crate::amplitude::Model::constrain`). This will reduce the
243//! number of free parameters in the fit, and will yield a
244//! [`RustitudeError`](`crate::errors::RustitudeError`) if either of the parameters is not found.
245//! Parameters can also be fixed and freed using [`Model::fix`](`crate::amplitude::Model::fix`) and
246//! [`Model::free`](`crate::amplitude::Model::free`) respectively, and these methods are mirrored in
247//! [`Manager`](`crate::manager::Manager`) and
248//! [`ExtendedLogLikelihood`](`crate::manager::ExtendedLogLikelihood`) for convenience.
249//!
250//! # Evaluating Likelihoods
251//!
252//! If we wanted to obtain the negative log-likelihood for this particular amplitude, we need to
253//! link our [`Model`](`crate::amplitude::Model`) to a [`Dataset`](`crate::dataset::Dataset`). This is done using a
254//! [`Manager`](`crate::manager::Manager``). Finally, two [`Manager`](`crate::manager::Manager``)s may be combined into an
255//! [`ExtendedLogLikelihood`](`crate::manager::ExtendedLogLikelihood`). Both of these manager-like structs have an
256//! `evaluate` method that takes some parameters as a `&[f32]` (along with a [`usize`] for the
257//! number of threads to use for the [`ExtendedLogLikelihood`](`crate::manager::ExtendedLogLikelihood`)).
258//!
259//! ```ignore
260//! use rustitude_core::prelude::*;
261//!
262//! #[derive(Default)]
263//! pub struct OmegaDalitz { ... }
264//! impl Node for OmegaDalitz { ... }
265//!
266//! fn main() -> Result<(), RustitudeError> {
267//!     let complex_term = cscalar("my complex scalar");
268//!     let omega_dalitz = Amplitude::new("omega dalitz", OmegaDalitz::default());
269//!     let term = complex_term * omega_dalitz;
270//!     let model = model!(term);
271//!     let dataset = Dataset::from_parquet("path/to/file.parquet")?;
272//!     let dataset_mc = Dataset::from_parquet("path/to/monte_carlo_file.parquet")?;
273//!     let nll = ExtendedLogLikelihood::new(
274//!         Manager::new(&model, &dataset),
275//!         Manager::new(&model, &dataset_mc)
276//!     );
277//!     println!("NLL: {}", nll.evaluate(&nll.get_initial())?);
278//!     Ok(())
279//! }
280//! ```
281//!
282//! # Fitting Amplitudes to Data
283//!
284//! Of course, the goal of all of this is to be able to construct a
285//! [`Model`](`crate::amplitude::Model`), load up a [`Dataset`](`crate::dataset::Dataset`), create
286//! an [`ExtendedLogLikelihood`](`crate::manager::ExtendedLogLikelihood`), and fit the model to
287//! data. Here's an example to show how that might be accomplished:
288//!
289//! ```ignore
290//! use ganesh::algorithms::NelderMead;
291//! use ganesh::prelude::*;
292//! use rustitude::gluex::harmonics::Zlm;
293//! use rustitude::gluex::{
294//!     resonances::BreitWigner,
295//!     utils::{Frame, Reflectivity, Wave},
296//! };
297//! use rustitude::prelude::*;
298//! fn main() -> Result<(), RustitudeError> {
299//!     let a2_1320 = BreitWigner::new(&[0], &[1], 2).named("a2_1320");
300//!     let a2_1700 = BreitWigner::new(&[0], &[1], 2).named("a2_1700");
301//!     let pw_s_wave = piecewise_m("pw_s_wave", 40, (1.04, 1.72));
302//!     let zlm_s0p = Zlm::new(Wave::S0, Reflectivity::Positive, Frame::Helicity).named("zlm_s0p");
303//!     let zlm_s0n = Zlm::new(Wave::S0, Reflectivity::Negative, Frame::Helicity).named("zlm_s0n");
304//!     let zlm_dn2p = Zlm::new(Wave::Dn2, Reflectivity::Positive, Frame::Helicity).named("zlm_dn2p");
305//!     let zlm_dn1p = Zlm::new(Wave::Dn1, Reflectivity::Positive, Frame::Helicity).named("zlm_dn1p");
306//!     let zlm_d0p = Zlm::new(Wave::D0, Reflectivity::Positive, Frame::Helicity).named("zlm_d0p");
307//!     let zlm_d1p = Zlm::new(Wave::D1, Reflectivity::Positive, Frame::Helicity).named("zlm_d1p");
308//!     let zlm_d2p = Zlm::new(Wave::D2, Reflectivity::Positive, Frame::Helicity).named("zlm_d2p");
309//!     let zlm_dn2n = Zlm::new(Wave::Dn2, Reflectivity::Negative, Frame::Helicity).named("zlm_dn2n");
310//!     let zlm_dn1n = Zlm::new(Wave::Dn1, Reflectivity::Negative, Frame::Helicity).named("zlm_dn1n");
311//!     let zlm_d0n = Zlm::new(Wave::D0, Reflectivity::Negative, Frame::Helicity).named("zlm_d0n");
312//!     let zlm_d1n = Zlm::new(Wave::D1, Reflectivity::Negative, Frame::Helicity).named("zlm_d1n");
313//!     let zlm_d2n = Zlm::new(Wave::D2, Reflectivity::Negative, Frame::Helicity).named("zlm_d2n");
314//!     let pos_d_wave = zlm_dn2p + zlm_dn1p + zlm_d0p + zlm_d1p + zlm_d2p;
315//!     let neg_d_wave = zlm_dn2n + zlm_dn1n + zlm_d0n + zlm_d1n + zlm_d2n;
316//!     let pos_real =
317//!         zlm_s0p.real() * &pw_s_wave + &a2_1320 * &pos_d_wave.real() + &a2_1700 * &pos_d_wave.real();
318//!     let pos_imag =
319//!         zlm_s0p.imag() * &pw_s_wave + &a2_1320 * &pos_d_wave.imag() + &a2_1700 * &pos_d_wave.imag();
320//!     let neg_real =
321//!         zlm_s0n.real() * &pw_s_wave + &a2_1320 * &neg_d_wave.real() + &a2_1700 * &neg_d_wave.real();
322//!     let neg_imag =
323//!         zlm_s0n.imag() * &pw_s_wave + &a2_1320 * &neg_d_wave.imag() + &a2_1700 * &neg_d_wave.imag();
324//!     let model = model!(pos_real, pos_imag, neg_real, neg_imag);
325//!     let ds_data = Dataset::from_parquet("path/to/data.root", ReadMethod::EPSInBeam)?;
326//!     let ds_accmc = Dataset::from_parquet("path/to/accmc.root", ReadMethod::EPSInBeam)?;
327//!     let mut ell = ExtendedLogLikelihood::new(
328//!         Manager::new(&model, &ds_data)?,
329//!         Manager::new(&model, &ds_accmc)?,
330//!     );
331//!     ell.set_initial("a2_1320", "mass", 1.3182)?;
332//!     ell.set_initial("a2_1320", "width", 0.1111)?;
333//!     ell.fix("a2_1700", "mass", 1.698)?;
334//!     ell.fix("a2_1700", "width", 0.265)?;
335//!     ell.fix("pw_s_wave", "bin 10 im", 0.0)?;
336//!
337//!     let mut nm = NelderMead::new(ell.clone(), &ell.get_initial(), None);
338//!     minimize!(nm, 1000)?; // Run 1000 steps
339//!     let (best_pars, best_fx) = nm.best();
340//!     for (par_name, par_value) in ell.free_parameters().iter().zip(best_pars) {
341//!         println!("{} -> {} (NLL = {})", par_name, par_value, best_fx);
342//!     }
343//!     Ok(())
344//! }
345//! ```
346#![warn(
347    clippy::nursery,
348    clippy::unwrap_used,
349    clippy::expect_used,
350    clippy::doc_markdown,
351    clippy::doc_link_with_quotes,
352    clippy::missing_safety_doc,
353    clippy::missing_panics_doc,
354    clippy::missing_errors_doc,
355    clippy::perf,
356    clippy::style,
357    missing_docs
358)]
359#![allow(deprecated)]
360
361use std::fmt::{Debug, Display};
362use std::iter::{Product, Sum};
363
364use nalgebra::Vector3;
365use num::{
366    traits::{FloatConst, NumAssignOps},
367    Float, FromPrimitive,
368};
369pub mod amplitude;
370pub mod dataset;
371pub mod four_momentum;
372pub mod manager;
373/// Recommended namespace for use and development.
374pub mod prelude {
375    pub use crate::amplitude::{
376        cscalar, pcscalar, piecewise_m, scalar, AmpLike, Amplitude, AsTree, Imag, Model, Node,
377        Parameter, Piecewise, Product, Real, Sum,
378    };
379    pub use crate::dataset::{Dataset, Event, ReadMethod};
380    pub use crate::errors::RustitudeError;
381    pub use crate::four_momentum::FourMomentum;
382    pub use crate::manager::{ExtendedLogLikelihood, Manager};
383    pub use crate::{convert, convert_array, convert_vec, model, Field, UnitVector};
384    pub use nalgebra::Vector3;
385    pub use num::Complex;
386}
387
388/// A trait representing a numeric field which can be used in calculating amplitudes.
389pub trait Field:
390    Float
391    + Sum
392    + Product
393    + FloatConst
394    + NumAssignOps
395    + Debug
396    + Display
397    + Default
398    + Send
399    + Sync
400    + FromPrimitive
401{
402}
403impl Field for f64 {}
404impl Field for f32 {}
405
406#[macro_export]
407/// Convenience macro for converting raw numeric values to a generic.
408macro_rules! convert {
409    ($value:expr, $type:ty) => {{
410        #[allow(clippy::unwrap_used)]
411        <$type as num::NumCast>::from($value).unwrap()
412    }};
413}
414
415#[macro_export]
416/// Convenience macro for converting a raw numeric [`Vec`] to a generic [`Vec`].
417macro_rules! convert_vec {
418    ($vec:expr, $type:ty) => {{
419        $vec.into_iter()
420            .map(|value| $crate::convert!(value, $type))
421            .collect::<Vec<$type>>()
422    }};
423}
424
425#[macro_export]
426/// Convenience macro for converting a raw numeric array to a generic array.
427macro_rules! convert_array {
428    ($arr:expr, $type:ty) => {{
429        let temp_vec: Vec<_> = $arr
430            .iter()
431            .map(|&value| $crate::convert!(value, $type))
432            .collect();
433        #[allow(clippy::unwrap_used)]
434        temp_vec.try_into().unwrap()
435    }};
436}
437
438/// A trait to normalize structs (mostly to use on nalgebra vectors without needing [`nalgebra::RealField`])
439pub trait UnitVector {
440    /// Returns a normalized form of the input.
441    fn unit(&self) -> Self;
442}
443
444impl<F: Field + 'static> UnitVector for Vector3<F> {
445    fn unit(&self) -> Self {
446        let mag = F::sqrt(self.x * self.x + self.y * self.y + self.z * self.z);
447        self / mag
448    }
449}
450
451#[macro_export]
452/// Convenience macro for boxing up coherent sum terms into a [`Model`](`crate::amplitude::Model`).
453macro_rules! model {
454    ($($term:expr),+ $(,)?) => {
455        Model::new(&[$(Box::new($term),)+])
456    };
457}
458
459pub mod errors {
460    //! This module contains an all-encompassing error enum that almost every crate method will
461    //! produce if it returns a Result.
462    use pyo3::{exceptions::PyException, PyErr};
463    use thiserror::Error;
464
465    /// The main [`Error`] structure for `rustitude_core`. All errors internal to the crate should
466    /// eventually pass through here, since it provides a single-location interface for `PyO3`
467    /// errors.
468    #[derive(Debug, Error)]
469    pub enum RustitudeError {
470        #[allow(missing_docs)]
471        #[error(transparent)]
472        IOError(#[from] std::io::Error),
473
474        #[allow(missing_docs)]
475        #[error(transparent)]
476        ParquetError(#[from] parquet::errors::ParquetError),
477
478        #[allow(missing_docs)]
479        #[error("Oxyroot: {0}")]
480        OxyrootError(String),
481
482        #[allow(missing_docs)]
483        #[error(transparent)]
484        ThreadPoolBuildError(#[from] rayon::ThreadPoolBuildError),
485
486        #[allow(missing_docs)]
487        #[error("Could not cast value from {0} (type in file) to {1} (required type)")]
488        DatasetReadError(String, String),
489
490        #[allow(missing_docs)]
491        #[error("Parameter not found: {0}")]
492        ParameterNotFoundError(String),
493
494        #[allow(missing_docs)]
495        #[error("Amplitude not found: {0}")]
496        AmplitudeNotFoundError(String),
497
498        #[allow(missing_docs)]
499        #[error("Invalid parameter value: {0}")]
500        InvalidParameterValue(String),
501
502        #[allow(missing_docs)]
503        #[error("Evaluation error: {0}")]
504        EvaluationError(String),
505
506        #[allow(missing_docs)]
507        #[error("Python error: {0}")]
508        PythonError(String),
509
510        #[allow(missing_docs)]
511        #[error("Parsing error: {0}")]
512        ParseError(String),
513    }
514    impl From<RustitudeError> for PyErr {
515        fn from(err: RustitudeError) -> Self {
516            PyException::new_err(err.to_string())
517        }
518    }
519    impl From<PyErr> for RustitudeError {
520        fn from(err: PyErr) -> Self {
521            Self::PythonError(err.to_string())
522        }
523    }
524}
525
526pub mod utils {
527    //! This module holds some convenience methods for writing nice test functions for Amplitudes.
528    use crate::prelude::*;
529
530    /// Generate a test event for the reaction $`\gamma p \to K_S K_S p`$ with 64-bit precision.
531    pub fn generate_test_event_f64() -> Event<f64> {
532        Event {
533            index: 0,
534            weight: -0.48,
535            beam_p4: FourMomentum::new(8.747_921, 0.0, 0.0, 8.747_921),
536            recoil_p4: FourMomentum::new(1.040_902_7, 0.119_110_32, 0.373_947_23, 0.221_585_83),
537            daughter_p4s: vec![
538                FourMomentum::new(3.136_247_2, -0.111_774_68, 0.293_426_28, 3.080_557_3),
539                FourMomentum::new(5.509_043, -0.007_335_639, -0.667_373_54, 5.445_778),
540            ],
541            eps: Vector3::from([0.385_109_57, 0.022_205_278, 0.0]),
542        }
543    }
544
545    /// Generate a test dataset for the reaction $`\gamma p \to K_S K_S p`$ with 64-bit precision.
546    pub fn generate_test_dataset_f64() -> Dataset<f64> {
547        Dataset::new(vec![
548            Event {
549                index: 0,
550                weight: -0.138_917,
551                beam_p4: FourMomentum::new(8.383_563, 0.0, 0.0, 8.383_563),
552                recoil_p4: FourMomentum::new(1.311_736, 0.664_397, 0.327_881, 0.539_785),
553                daughter_p4s: vec![
554                    FourMomentum::new(3.140_736, -0.074_363, 0.335_501, 3.081_966),
555                    FourMomentum::new(4.869_362, -0.590_033, -0.663_383, 4.761_812),
556                ],
557                eps: Vector3::from([-0.016_172, 0.319_243, 0.0]),
558            },
559            Event {
560                index: 1,
561                weight: 0.967_937,
562                beam_p4: FourMomentum::new(8.373_471, 0.0, 0.0, 8.373_471),
563                recoil_p4: FourMomentum::new(1.099_134, -0.318_113, -0.241_351, 0.410_238),
564                daughter_p4s: vec![
565                    FourMomentum::new(6.803_817, 0.662_458, -0.146_496, 6.751_592),
566                    FourMomentum::new(1.408_791, -0.344_344, 0.387_849, 1.211_640),
567                ],
568                eps: Vector3::from([-0.016_172, 0.319_243, 0.0]),
569            },
570            Event {
571                index: 2,
572                weight: 0.016_893,
573                beam_p4: FourMomentum::new(8.686_482, 0.0, 0.0, 8.686_482),
574                recoil_p4: FourMomentum::new(1.041_158, 0.141_536, 0.374_024, 0.209_115),
575                daughter_p4s: vec![
576                    FourMomentum::new(3.348_294, -0.007_810, 0.232_603, 3.302_921),
577                    FourMomentum::new(5.235_301, -0.133_726, -0.606_628, 5.174_445),
578                ],
579                eps: Vector3::from([-0.018_940, 0.373_890, 0.0]),
580            },
581            Event {
582                index: 3,
583                weight: -0.022_154,
584                beam_p4: FourMomentum::new(8.799_066, 0.0, 0.0, 8.799_066),
585                recoil_p4: FourMomentum::new(1.078_011, -0.411_542, 0.243_270, 0.230_664),
586                daughter_p4s: vec![
587                    FourMomentum::new(5.382_554, 0.240_169, 0.105_882, 5.353_071),
588                    FourMomentum::new(3.276_772, 0.171_372, -0.349_153, 3.215_329),
589                ],
590                eps: Vector3::from([-0.018_940, 0.373_890, 0.0]),
591            },
592            Event {
593                index: 4,
594                weight: 0.012_900,
595                beam_p4: FourMomentum::new(8.561_700, 0.0, 0.0, 8.561_700),
596                recoil_p4: FourMomentum::new(1.078_375, -0.409_737, 0.245_940, 0.232_739),
597                daughter_p4s: vec![
598                    FourMomentum::new(5.221_115, 0.242_604, 0.099_132, 5.190_736),
599                    FourMomentum::new(3.200_482, 0.167_133, -0.345_072, 3.138_225),
600                ],
601                eps: Vector3::from([-0.016_448, 0.324_690, 0.0]),
602            },
603            Event {
604                index: 5,
605                weight: -0.138_917,
606                beam_p4: FourMomentum::new(8.714_853, 0.0, 0.0, 8.714_853),
607                recoil_p4: FourMomentum::new(1.458_814, -0.309_093, -0.853_077, 0.651_541),
608                daughter_p4s: vec![
609                    FourMomentum::new(3.879_303, -0.067_345, 0.225_269, 3.840_064),
610                    FourMomentum::new(4.315_006, 0.376_439, 0.627_807, 4.223_246),
611                ],
612                eps: Vector3::from([-0.018_940, 0.373_890, 0.0]),
613            },
614            Event {
615                index: 6,
616                weight: 1.111_018,
617                beam_p4: FourMomentum::new(8.271_341, 0.0, 0.0, 8.271_341),
618                recoil_p4: FourMomentum::new(1.296_389, -0.275_474, 0.706_565, 0.474_499),
619                daughter_p4s: vec![
620                    FourMomentum::new(5.433_060, 0.203_167, -0.343_429, 5.395_489),
621                    FourMomentum::new(2.480_163, 0.072_306, -0.363_136, 2.401_352),
622                ],
623                eps: Vector3::from([-0.016_172, 0.319_243, 0.0]),
624            },
625            Event {
626                index: 7,
627                weight: 1.111_339,
628                beam_p4: FourMomentum::new(8.743_071, 0.0, 0.0, 8.743_071),
629                recoil_p4: FourMomentum::new(1.126_252, -0.317_043, 0.461_564, 0.273_006),
630                daughter_p4s: vec![
631                    FourMomentum::new(5.651_356, 0.200_123, -0.228_232, 5.621_215),
632                    FourMomentum::new(2.903_734, 0.116_919, -0.233_331, 2.848_849),
633                ],
634                eps: Vector3::from([-0.018_940, 0.373_890, 0.0]),
635            },
636            Event {
637                index: 8,
638                weight: -0.138_917,
639                beam_p4: FourMomentum::new(8.657_957, 0.0, 0.0, 8.657_957),
640                recoil_p4: FourMomentum::new(1.125_095, -0.315_415, 0.460_129, 0.272_539),
641                daughter_p4s: vec![
642                    FourMomentum::new(5.604_545, 0.200_701, -0.230_638, 5.574_032),
643                    FourMomentum::new(2.866_588, 0.114_713, -0.229_491, 2.811_384),
644                ],
645                eps: Vector3::from([-0.018_940, 0.373_890, 0.0]),
646            },
647            Event {
648                index: 9,
649                weight: -0.138_917,
650                beam_p4: FourMomentum::new(8.403_684, 0.0, 0.0, 8.403_684),
651                recoil_p4: FourMomentum::new(1.109_429, 0.481_598, -0.076_590, 0.335_673),
652                daughter_p4s: vec![
653                    FourMomentum::new(1.882_555, -0.201_094, -0.392_549, 1.761_210),
654                    FourMomentum::new(6.349_971, -0.280_504, 0.469_139, 6.306_800),
655                ],
656                eps: Vector3::from([-0.016_448, 0.324_690, 0.0]),
657            },
658        ])
659    }
660
661    /// Generate a test event for the reaction $`\gamma p \to K_S K_S p`$ with 32-bit precision.
662    pub fn generate_test_event_f32() -> Event<f32> {
663        Event {
664            index: 0,
665            weight: -0.48,
666            beam_p4: FourMomentum::new(8.747_921, 0.0, 0.0, 8.747_921),
667            recoil_p4: FourMomentum::new(1.040_902_7, 0.119_110_32, 0.373_947_23, 0.221_585_83),
668            daughter_p4s: vec![
669                FourMomentum::new(3.136_247_2, -0.111_774_68, 0.293_426_28, 3.080_557_3),
670                FourMomentum::new(5.509_043, -0.007_335_639, -0.667_373_54, 5.445_778),
671            ],
672            eps: Vector3::from([0.385_109_57, 0.022_205_278, 0.0]),
673        }
674    }
675
676    /// Generate a test dataset for the reaction $`\gamma p \to K_S K_S p`$ with 32-bit precision.
677    pub fn generate_test_dataset_f32() -> Dataset<f32> {
678        Dataset::new(vec![
679            Event {
680                index: 0,
681                weight: -0.138_917,
682                beam_p4: FourMomentum::new(8.383_563, 0.0, 0.0, 8.383_563),
683                recoil_p4: FourMomentum::new(1.311_736, 0.664_397, 0.327_881, 0.539_785),
684                daughter_p4s: vec![
685                    FourMomentum::new(3.140_736, -0.074_363, 0.335_501, 3.081_966),
686                    FourMomentum::new(4.869_362, -0.590_033, -0.663_383, 4.761_812),
687                ],
688                eps: Vector3::from([-0.016_172, 0.319_243, 0.0]),
689            },
690            Event {
691                index: 1,
692                weight: 0.967_937,
693                beam_p4: FourMomentum::new(8.373_471, 0.0, 0.0, 8.373_471),
694                recoil_p4: FourMomentum::new(1.099_134, -0.318_113, -0.241_351, 0.410_238),
695                daughter_p4s: vec![
696                    FourMomentum::new(6.803_817, 0.662_458, -0.146_496, 6.751_592),
697                    FourMomentum::new(1.408_791, -0.344_344, 0.387_849, 1.211_64),
698                ],
699                eps: Vector3::from([-0.016_172, 0.319_243, 0.0]),
700            },
701            Event {
702                index: 2,
703                weight: 0.016_893,
704                beam_p4: FourMomentum::new(8.686_482, 0.0, 0.0, 8.686_482),
705                recoil_p4: FourMomentum::new(1.041_158, 0.141_536, 0.374_024, 0.209_115),
706                daughter_p4s: vec![
707                    FourMomentum::new(3.348_294, -0.007_810, 0.232_603, 3.302_921),
708                    FourMomentum::new(5.235_301, -0.133_726, -0.606_628, 5.174_445),
709                ],
710                eps: Vector3::from([-0.018_940, 0.373_890, 0.0]),
711            },
712            Event {
713                index: 3,
714                weight: -0.022_154,
715                beam_p4: FourMomentum::new(8.799_066, 0.0, 0.0, 8.799_066),
716                recoil_p4: FourMomentum::new(1.078_011, -0.411_542, 0.243_270, 0.230_664),
717                daughter_p4s: vec![
718                    FourMomentum::new(5.382_554, 0.240_169, 0.105_882, 5.353_071),
719                    FourMomentum::new(3.276_772, 0.171_372, -0.349_153, 3.215_329),
720                ],
721                eps: Vector3::from([-0.018_940, 0.373_890, 0.0]),
722            },
723            Event {
724                index: 4,
725                weight: 0.012_900,
726                beam_p4: FourMomentum::new(8.561_70, 0.0, 0.0, 8.561_7),
727                recoil_p4: FourMomentum::new(1.078_375, -0.409_737, 0.245_940, 0.232_739),
728                daughter_p4s: vec![
729                    FourMomentum::new(5.221_115, 0.242_604, 0.099_132, 5.190_736),
730                    FourMomentum::new(3.200_482, 0.167_133, -0.345_072, 3.138_225),
731                ],
732                eps: Vector3::from([-0.016_448, 0.324_690, 0.0]),
733            },
734            Event {
735                index: 5,
736                weight: -0.138_917,
737                beam_p4: FourMomentum::new(8.714_853, 0.0, 0.0, 8.714_853),
738                recoil_p4: FourMomentum::new(1.458_814, -0.309_093, -0.853_077, 0.651_541),
739                daughter_p4s: vec![
740                    FourMomentum::new(3.879_303, -0.067_345, 0.225_269, 3.840_064),
741                    FourMomentum::new(4.315_006, 0.376_439, 0.627_807, 4.223_246),
742                ],
743                eps: Vector3::from([-0.018_940, 0.373_890, 0.0]),
744            },
745            Event {
746                index: 6,
747                weight: 1.111_018,
748                beam_p4: FourMomentum::new(8.271_341, 0.0, 0.0, 8.271_341),
749                recoil_p4: FourMomentum::new(1.296_389, -0.275_474, 0.706_565, 0.474_499),
750                daughter_p4s: vec![
751                    FourMomentum::new(5.433_06, 0.203_167, -0.343_429, 5.395_489),
752                    FourMomentum::new(2.480_163, 0.072_306, -0.363_136, 2.401_352),
753                ],
754                eps: Vector3::from([-0.016_172, 0.319_243, 0.0]),
755            },
756            Event {
757                index: 7,
758                weight: 1.111_339,
759                beam_p4: FourMomentum::new(8.743_071, 0.0, 0.0, 8.743_071),
760                recoil_p4: FourMomentum::new(1.126_252, -0.317_043, 0.461_564, 0.273_006),
761                daughter_p4s: vec![
762                    FourMomentum::new(5.651_356, 0.200_123, -0.228_232, 5.621_215),
763                    FourMomentum::new(2.903_734, 0.116_919, -0.233_331, 2.848_849),
764                ],
765                eps: Vector3::from([-0.018_940, 0.373_890, 0.0]),
766            },
767            Event {
768                index: 8,
769                weight: -0.138_917,
770                beam_p4: FourMomentum::new(8.657_957, 0.0, 0.0, 8.657_957),
771                recoil_p4: FourMomentum::new(1.125_095, -0.315_415, 0.460_129, 0.272_539),
772                daughter_p4s: vec![
773                    FourMomentum::new(5.604_545, 0.200_701, -0.230_638, 5.574_032),
774                    FourMomentum::new(2.866_588, 0.114_713, -0.229_491, 2.811_384),
775                ],
776                eps: Vector3::from([-0.018_940, 0.373_890, 0.0]),
777            },
778            Event {
779                index: 9,
780                weight: -0.138_917,
781                beam_p4: FourMomentum::new(8.403_684, 0.0, 0.0, 8.403_684),
782                recoil_p4: FourMomentum::new(1.109_429, 0.481_598, -0.076_590, 0.335_673),
783                daughter_p4s: vec![
784                    FourMomentum::new(1.882_555, -0.201_094, -0.392_549, 1.761_21),
785                    FourMomentum::new(6.349_971, -0.280_504, 0.469_139, 6.306_80),
786                ],
787                eps: Vector3::from([-0.016_448, 0.324_690, 0.0]),
788            },
789        ])
790    }
791
792    /// Checks if two floating point numbers are essentially equal.
793    /// See [https://floating-point-gui.de/errors/comparison/](https://floating-point-gui.de/errors/comparison/).
794    pub fn is_close<F: Field>(a: F, b: F, epsilon: F) -> bool {
795        let abs_a = F::abs(a);
796        let abs_b = F::abs(b);
797        let diff = F::abs(a - b);
798        if a == b {
799            true
800        } else if a == F::zero() || b == F::zero() || (abs_a + abs_b < F::min_positive_value()) {
801            diff < (epsilon * F::min_positive_value())
802        } else {
803            diff / F::min(abs_a + abs_b, F::max_value()) < epsilon
804        }
805    }
806
807    /// A macro to assert if two floating point numbers are essentially equal. Similar to [`approx`] crate.
808    #[macro_export]
809    macro_rules! assert_is_close {
810        ($given:expr, $expected:expr, f64) => {
811            let abs_a = f64::abs($given);
812            let abs_b = f64::abs($expected);
813            let diff = f64::abs($given - $expected);
814            let abs_diff = diff / f64::min(abs_a + abs_b, f64::MAX);
815            match (&($given), &($expected)) {
816                (given, expected) => assert!(
817                    $crate::utils::is_close(f64::from(*given), *expected, 1e-5),
818                    "assert_is_close!({}, {})
819
820    a = {:?}
821    b = {:?}
822    |a - b| / (|a| + |b|) = {:?} > 1e-5
823
824",
825                    stringify!($given),
826                    stringify!($expected),
827                    given,
828                    expected,
829                    abs_diff
830                ),
831            }
832        };
833        ($given:expr, $expected:expr, f32) => {
834            let abs_a = f32::abs($given);
835            let abs_b = f32::abs($expected);
836            let diff = f32::abs($given - $expected);
837            let abs_diff = diff / f32::min(abs_a + abs_b, f32::MAX);
838            match (&($given), &($expected)) {
839                (given, expected) => assert!(
840                    $crate::utils::is_close(f32::from(*given), *expected, 1e-5),
841                    "assert_is_close!({}, {})
842
843    a = {:?}
844    b = {:?}
845    |a - b| / (|a| + |b|) = {:?} > 1e-5
846
847",
848                    stringify!($given),
849                    stringify!($expected),
850                    given,
851                    expected,
852                    abs_diff
853                ),
854            }
855        };
856        ($given:expr, $expected:expr, $eps:expr, f64) => {
857            let abs_a = f64::abs($given);
858            let abs_b = f64::abs($expected);
859            let diff = f64::abs($given - $expected);
860            let abs_diff = diff / f64::min(abs_a + abs_b, f64::MAX);
861            match (&($given), &($expected), &($eps)) {
862                (given, expected, eps) => assert!(
863                    $crate::utils::is_close(*given, *expected, *eps),
864                    "assert_is_close!({}, {}, {})
865
866    a = {:?}
867    b = {:?}
868    |a - b| / (|a| + |b|) = {:?} > {:?}
869
870",
871                    stringify!($given),
872                    stringify!($expected),
873                    stringify!($eps),
874                    given,
875                    expected,
876                    abs_diff,
877                    eps
878                ),
879            }
880        };
881        ($given:expr, $expected:expr, $eps:expr, f32) => {
882            let abs_a = f32::abs($given);
883            let abs_b = f32::abs($expected);
884            let diff = f32::abs($given - $expected);
885            let abs_diff = diff / f32::min(abs_a + abs_b, f32::MAX);
886            match (&($given), &($expected), &($eps)) {
887                (given, expected, eps) => assert!(
888                    $crate::utils::is_close(*given, *expected, *eps),
889                    "assert_is_close!({}, {}, {})
890
891    a = {:?}
892    b = {:?}
893    |a - b| / (|a| + |b|) = {:?} > {:?}
894
895",
896                    stringify!($given),
897                    stringify!($expected),
898                    stringify!($eps),
899                    given,
900                    expected,
901                    abs_diff,
902                    eps
903                ),
904            }
905        };
906    }
907}