rustitude_core/
amplitude.rs

1//! The amplitude module contains structs and methods for defining and manipulating [`Amplitude`]s
2//! and [`Model`]s
3//!
4//! To create a new [`Amplitude`] in Rust, we simply need to implement the [`Node`] trait on a
5//! struct. You can then provide a convenience method for creating a new implementation of your
6//! [`Amplitude`].
7//!
8//! Amplitudes are typically defined first, and then [`Model`]s are built by adding, multiplying
9//! and taking the real/imaginary part of [`Amplitude`]s. [`Model`]s can be built using the
10//! provided [`Model::new`] constructor or with the [`model!`](`crate::model!`) macro. The terms
11//! provided to either of these will be treated as separate coherent sums. The [`Model`] will
12//! implicitly take their absolute square and then add those sums incoherently.
13//!
14//! We can then use [`Manager`](crate::manager::Manager)-like structs to handle computataion
15//! over [`Dataset`]s.
16//!
17//! # Example:
18//!
19//! An example (with no particular physical meaning) is given as follows:
20//!
21//! ```ignore
22//! use rustitude_core::prelude::*;
23//! fn main() {
24//!     let a = scalar("a");   // a(value) = value
25//!     let b = scalar("b");   // b(value) = value
26//!     let c = cscalar("c");  // c(real, imag) = real + i * imag
27//!     let d = pcscalar("d"); // d(mag, phi) = mag * e^{i * phi}
28//!     let abc = a + b + &c; // references avoid losing ownership
29//!     let x = abc * &d + c.real();
30//!     let model = model!(x, d);
31//!     // |(a.value + b.value + c.real + i * c.imag) * (d.mag * e^{i * d.phi}) + c.real|^2 + |d.mag * e^{i * d.phi}|^2
32//! }
33//! ```
34//!
35//! With Rust's ownership rules, if we want to use amplitudes in multiple places, we need to either
36//! reference them or clone them (`a.clone()`, for instance). References typically look nicer and
37//! are more readable, but a clone will happen regardless (although it isn't expensive, only one
38//! copy of each amplitude will ever hold any data).
39use dyn_clone::DynClone;
40use itertools::Itertools;
41use nalgebra::Complex;
42use parking_lot::RwLock;
43use rayon::prelude::*;
44use std::{
45    collections::HashSet,
46    fmt::{Debug, Display},
47    ops::{Add, Mul},
48    sync::Arc,
49};
50use tracing::{debug, info};
51
52use crate::{
53    convert,
54    dataset::{Dataset, Event},
55    errors::RustitudeError,
56    Field,
57};
58
59/// A single parameter within an [`Amplitude`].
60#[derive(Clone)]
61pub struct Parameter<F: Field> {
62    /// Name of the parent [`Amplitude`] containing this parameter.
63    pub amplitude: String,
64    /// Name of the parameter.
65    pub name: String,
66    /// Index of the parameter with respect to the [`Model`]. This will be [`Option::None`] if
67    /// the parameter is fixed.
68    pub index: Option<usize>,
69    /// A separate index for fixed parameters to ensure they stay constrained properly if freed.
70    /// This will be [`Option::None`] if the parameter is free in the [`Model`].
71    pub fixed_index: Option<usize>,
72    /// The initial value the parameter takes, or alternatively the value of the parameter if it is
73    /// fixed in the fit.
74    pub initial: F,
75    /// Bounds for the given parameter (defaults to +/- infinity). This is mostly optional and
76    /// isn't used in any Rust code asside from being able to get and set it.
77    pub bounds: (F, F),
78}
79impl<F: Field> Parameter<F> {
80    /// Creates a new [`Parameter`] within an [`Amplitude`] using the name of the [`Amplitude`],
81    /// the name of the [`Parameter`], and the index of the parameter within the [`Model`].
82    ///
83    /// By default, new [`Parameter`]s are free, have an initial value of `0.0`, and their bounds
84    /// are set to `(Field::NEG_INFINITY, Field::INFINITY)`.
85    pub fn new(amplitude: &str, name: &str, index: usize) -> Self {
86        Self {
87            amplitude: amplitude.to_string(),
88            name: name.to_string(),
89            index: Some(index),
90            fixed_index: None,
91            initial: F::one(),
92            bounds: (F::neg_infinity(), F::infinity()),
93        }
94    }
95
96    /// Returns `true` if the [`Parameter`] is free, `false` otherwise.
97    pub const fn is_free(&self) -> bool {
98        self.index.is_some()
99    }
100
101    /// Returns `true` if the [`Parameter`] is fixed, `false` otherwise.
102    pub const fn is_fixed(&self) -> bool {
103        self.index.is_none()
104    }
105}
106
107impl<F: Field> Debug for Parameter<F> {
108    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
109        if self.index.is_none() {
110            write!(
111                f,
112                "Parameter(name={}, value={} (fixed), bounds=({}, {}), parent={})",
113                self.name, self.initial, self.bounds.0, self.bounds.1, self.amplitude
114            )
115        } else {
116            write!(
117                f,
118                "Parameter(name={}, value={}, bounds=({}, {}), parent={})",
119                self.name, self.initial, self.bounds.0, self.bounds.1, self.amplitude
120            )
121        }
122    }
123}
124impl<F: Field> Display for Parameter<F> {
125    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
126        write!(f, "{}", self.name)
127    }
128}
129
130/// A trait which contains all the required methods for a functioning [`Amplitude`].
131///
132/// The [`Node`] trait represents any mathematical structure which takes in some parameters and some
133/// [`Event`] data and computes a [`Complex`] for each [`Event`]. This is the fundamental
134/// building block of all analyses built with Rustitude. Nodes are intended to be optimized at the
135/// user level, so they should be implemented on structs which can store some precalculated data.
136///
137/// # Examples:
138///
139/// A [`Node`] for calculating spherical harmonics:
140///
141/// ```
142/// use rustitude_core::prelude::*;
143///
144/// use nalgebra::{SMatrix, SVector};
145/// use rayon::prelude::*;
146/// use sphrs::SHEval;
147/// use sphrs::{ComplexSH, Coordinates};
148///
149/// #[derive(Clone, Copy, Default)]
150/// #[rustfmt::skip]
151/// enum Wave {
152///     #[default]
153///     S,
154///     S0,
155///     Pn1, P0, P1, P,
156///     Dn2, Dn1, D0, D1, D2, D,
157///     Fn3, Fn2, Fn1, F0, F1, F2, F3, F,
158/// }
159///
160/// #[rustfmt::skip]
161/// impl Wave {
162///     fn l(&self) -> i64 {
163///         match self {
164///             Self::S0 | Self::S => 0,
165///             Self::Pn1 | Self::P0 | Self::P1 | Self::P => 1,
166///             Self::Dn2 | Self::Dn1 | Self::D0 | Self::D1 | Self::D2 | Self::D => 2,
167///             Self::Fn3 | Self::Fn2 | Self::Fn1 | Self::F0 | Self::F1 | Self::F2 | Self::F3 | Self::F => 3,
168///         }
169///     }
170///     fn m(&self) -> i64 {
171///         match self {
172///             Self::S | Self::P | Self::D | Self::F => 0,
173///             Self::S0 | Self::P0 | Self::D0 | Self::F0 => 0,
174///             Self::Pn1 | Self::Dn1 | Self::Fn1 => -1,
175///             Self::P1 | Self::D1 | Self::F1 => 1,
176///             Self::Dn2 | Self::Fn2 => -2,
177///             Self::D2 | Self::F2 => 2,
178///             Self::Fn3 => -3,
179///             Self::F3 => 3,
180///         }
181///     }
182/// }
183///
184/// #[derive(Clone)]
185/// pub struct Ylm<F: Field> {
186///     wave: Wave,
187///     data: Vec<Complex<F>>,
188/// }
189/// impl<F: Field> Ylm<F> {
190///     pub fn new(wave: Wave) -> Self {
191///         Self {
192///             wave,
193///             data: Vec::default(),
194///         }
195///     }
196/// }
197/// impl<F: Field> Node<F> for Ylm<F> {
198///     fn precalculate(&mut self, dataset: &Dataset<F>) -> Result<(), RustitudeError> {
199///         self.data = dataset
200///             .events
201///             .par_iter()
202///             .map(|event| {
203///                 let resonance = event.daughter_p4s[0] + event.daughter_p4s[1];
204///                 let beam_res_vec = event.beam_p4.boost_along(&resonance).momentum();
205///                 let recoil_res_vec = event.recoil_p4.boost_along(&resonance).momentum();
206///                 let daughter_res_vec = event.daughter_p4s[0].boost_along(&resonance).momentum();
207///                 let z = -recoil_res_vec.unit();
208///                 let y = event
209///                     .beam_p4
210///                     .momentum()
211///                     .cross(&(-recoil_res_vec))
212///                     .unit();
213///                 let x = y.cross(&z);
214///                 let p = Coordinates::cartesian(
215///                     daughter_res_vec.dot(&x),
216///                     daughter_res_vec.dot(&y),
217///                     daughter_res_vec.dot(&z)
218///                 );
219///                 ComplexSH::Spherical.eval(self.wave.l(), self.wave.m(), &p)
220///             })
221///             .collect();
222///         Ok(())
223///     }
224///
225///     fn calculate(&self, _parameters: &[F], event: &Event<F>) -> Result<Complex<F>, RustitudeError> {
226///         Ok(self.data[event.index])
227///     }
228/// }
229/// ```
230///
231/// A [`Node`] which computes a single complex scalar entirely determined by input parameters:
232///
233/// ```
234/// use rustitude_core::prelude::*;
235/// #[derive(Clone)]
236/// struct ComplexScalar;
237/// impl<F: Field> Node<F> for ComplexScalar {
238///     fn calculate(&self, parameters: &[F], _event: &Event<F>) -> Result<Complex<F>, RustitudeError> {
239///         Ok(Complex::new(parameters[0], parameters[1]))
240///     }
241///
242///     fn parameters(&self) -> Vec<String> {
243///         vec!["real".to_string(), "imag".to_string()]
244///     }
245/// }
246/// ```
247pub trait Node<F: Field>: Sync + Send + DynClone {
248    /// A method that is run once and stores some precalculated values given a [`Dataset`] input.
249    ///
250    /// This method is intended to run expensive calculations which don't actually depend on the
251    /// parameters. For instance, to calculate a spherical harmonic, we don't actually need any
252    /// other information than what is contained in the [`Event`], so we can calculate a spherical
253    /// harmonic for every event once and then retrieve the data in the [`Node::calculate`] method.
254    ///
255    /// # Errors
256    ///
257    /// This function should be written to return a [`RustitudeError`] if any part of the
258    /// calculation fails.
259    fn precalculate(&mut self, _dataset: &Dataset<F>) -> Result<(), RustitudeError> {
260        Ok(())
261    }
262
263    /// A method which runs every time the amplitude is evaluated and produces a [`Complex`].
264    ///
265    /// Because this method is run on every evaluation, it should be as lean as possible.
266    /// Additionally, you should avoid [`rayon`]'s parallel loops inside this method since we
267    /// already parallelize over the [`Dataset`]. This method expects a single [`Event`] as well as
268    /// a slice of [`Field`]s. This slice is guaranteed to have the same length and order as
269    /// specified in the [`Node::parameters`] method, or it will be empty if that method returns
270    /// [`None`].
271    ///
272    /// # Errors
273    ///
274    /// This function should be written to return a [`RustitudeError`] if any part of the
275    /// calculation fails.
276    fn calculate(&self, parameters: &[F], event: &Event<F>) -> Result<Complex<F>, RustitudeError>;
277
278    /// A method which specifies the number and order of parameters used by the [`Node`].
279    ///
280    /// This method tells the [`crate::manager::Manager`] how to assign its input [`Vec`] of parameter values to
281    /// each [`Node`]. If this method returns [`None`], it is implied that the [`Node`] takes no
282    /// parameters as input. Otherwise, the parameter names should be listed in the same order they
283    /// are expected to be given as input to the [`Node::calculate`] method.
284    fn parameters(&self) -> Vec<String> {
285        vec![]
286    }
287
288    /// A convenience method for turning [`Node`]s into [`Amplitude`]s.
289    fn into_amplitude(self, name: &str) -> Amplitude<F>
290    where
291        Self: std::marker::Sized + 'static,
292    {
293        Amplitude::new(name, self)
294    }
295
296    /// A convenience method for turning [`Node`]s into [`Amplitude`]s. This method has a
297    /// shorter name than [`Node::into_amplitude`], which it calls.
298    fn named(self, name: &str) -> Amplitude<F>
299    where
300        Self: std::marker::Sized + 'static,
301    {
302        self.into_amplitude(name)
303    }
304
305    /// A flag which says if the [`Node`] was written in Python. This matters because the GIL
306    /// cannot currently play nice with [`rayon`] multithreading. You will probably never need to
307    /// set this, as the only object which returns `True` is in the `py_rustitude` crate which
308    /// binds this crate to Python.
309    fn is_python_node(&self) -> bool {
310        false
311    }
312}
313
314dyn_clone::clone_trait_object!(<F> Node<F>);
315
316/// This trait is used to implement operations which can be performed on [`Amplitude`]s (and other
317/// operations themselves). Currently, there are only a limited number of defined operations,
318/// namely [`Real`], [`Imag`], and [`Product`]. Others may be added in the future, but they
319/// should probably only be added through this crate and not externally, since they require several
320/// operator overloads to be implemented for nice syntax.
321pub trait AmpLike<F: Field>: Send + Sync + Debug + Display + AsTree + DynClone {
322    /// This method walks through an [`AmpLike`] struct and recursively amalgamates a list of
323    /// [`Amplitude`]s contained within. Note that these [`Amplitude`]s are owned clones of the
324    /// interior structures.
325    fn walk(&self) -> Vec<Amplitude<F>>;
326    /// This method is similar to [`AmpLike::walk`], but returns mutable references rather than
327    /// clones.
328    fn walk_mut(&mut self) -> Vec<&mut Amplitude<F>>;
329    /// Given a cache of complex values calculated from a list of amplitudes, this method will
330    /// calculate the desired mathematical structure given by the [`AmpLike`] and any
331    /// [`AmpLike`]s it contains.
332    fn compute(&self, cache: &[Option<Complex<F>>]) -> Option<Complex<F>>;
333    /// This method returns clones of any [`AmpLike`]s wrapped by the given [`AmpLike`].
334    fn get_cloned_terms(&self) -> Option<Vec<Box<dyn AmpLike<F>>>> {
335        None
336    }
337    /// Take the real part of an [`Amplitude`] or [`Amplitude-like`](`AmpLike`) struct.
338    fn real(&self) -> Real<F>
339    where
340        Self: std::marker::Sized + 'static,
341    {
342        Real(dyn_clone::clone_box(self))
343    }
344    /// Take the imaginary part of an [`Amplitude`] or [`Amplitude-like`](`AmpLike`) struct.
345    fn imag(&self) -> Imag<F>
346    where
347        Self: Sized + 'static,
348    {
349        Imag(dyn_clone::clone_box(self))
350    }
351
352    /// Take the product of a [`Vec`] of [`Amplitude-like`](`AmpLike`) structs.
353    fn prod(als: &Vec<Box<dyn AmpLike<F>>>) -> Product<F>
354    where
355        Self: Sized + 'static,
356    {
357        Product(*dyn_clone::clone_box(als))
358    }
359
360    /// Take the sum of a [`Vec`] of [`Amplitude-like`](`AmpLike`) structs.
361    fn sum(als: &Vec<Box<dyn AmpLike<F>>>) -> Sum<F>
362    where
363        Self: Sized + 'static,
364    {
365        Sum(*dyn_clone::clone_box(als))
366    }
367}
368dyn_clone::clone_trait_object!(<F> AmpLike<F>);
369
370/// This trait defines some simple methods for pretty-printing tree-like structures.
371pub trait AsTree {
372    /// Returns a string representing the node and its children with tree formatting.
373    fn get_tree(&self) -> String {
374        self._get_tree(&mut vec![])
375    }
376    /// Returns a string with the proper indents for a given entry in
377    /// [`AsTree::get_tree`]. A `true` bit will yield a vertical line, while a
378    /// `false` bit will not.
379    fn _get_indent(&self, bits: Vec<bool>) -> String {
380        bits.iter()
381            .map(|b| if *b { "  ┃ " } else { "    " })
382            .join("")
383    }
384    /// Returns a string with the intermediate branch symbol for a given entry in
385    /// [`AsTree::get_tree`].
386    fn _get_intermediate(&self) -> String {
387        String::from("  ┣━")
388    }
389    /// Prints the a final branch for a given entry in [`AsTree::get_tree`].
390    fn _get_end(&self) -> String {
391        String::from("  ┗━")
392    }
393    /// Prints the tree of an [`AsTree`]-implementor starting with a particular indentation structure
394    /// defined by `bits`. A `true` bit will print a vertical line, while a `false` bit
395    /// will not.
396    fn _get_tree(&self, bits: &mut Vec<bool>) -> String;
397}
398
399/// A struct which stores a named [`Node`].
400///
401/// The [`Amplitude`] struct turns a [`Node`] trait into a concrete type and also stores a name
402/// associated with the [`Node`]. This allows us to distinguish multiple uses of the same [`Node`]
403/// in an analysis, and makes each [`Node`]'s parameters unique.
404#[derive(Clone)]
405pub struct Amplitude<F: Field> {
406    /// A name which uniquely identifies an [`Amplitude`] within a sum and group.
407    pub name: String,
408    /// A [`Node`] which contains all of the operations needed to compute a [`Complex`] from an
409    /// [`Event`] in a [`Dataset`], a [`Vec<Field>`] of parameter values, and possibly some
410    /// precomputed values.
411    pub node: Box<dyn Node<F>>,
412    /// Indicates whether the amplitude should be included in calculations or skipped.
413    pub active: bool,
414    /// Contains the parameter names associated with this amplitude.
415    pub parameters: Vec<String>,
416    /// Indicates the reserved position in the cache for shortcutting computation with a
417    /// precomputed cache.
418    pub cache_position: usize,
419    /// Indicates the position in the final parameter vector that coincides with the starting index
420    /// for parameters in this [`Amplitude`]
421    pub parameter_index_start: usize,
422}
423
424impl<F: Field> Debug for Amplitude<F> {
425    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
426        write!(f, "{}", self.name)
427    }
428}
429impl<F: Field> Display for Amplitude<F> {
430    #[rustfmt::skip]
431    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
432        writeln!(f, "Amplitude")?;
433        writeln!(f, "  Name:                     {}", self.name)?;
434        writeln!(f, "  Active:                   {}", self.active)?;
435        writeln!(f, "  Cache Position:           {}", self.cache_position)?;
436        writeln!(f, "  Index of First Parameter: {}", self.parameter_index_start)
437    }
438}
439impl<F: Field> AsTree for Amplitude<F> {
440    fn _get_tree(&self, _bits: &mut Vec<bool>) -> String {
441        let name = if self.active {
442            self.name.clone()
443        } else {
444            format!("/* {} */", self.name)
445        };
446        if self.parameters().len() > 7 {
447            format!(" {}({},...)\n", name, self.parameters()[0..7].join(", "))
448        } else {
449            format!(" {}({})\n", name, self.parameters().join(", "))
450        }
451    }
452}
453impl<F: Field> Amplitude<F> {
454    /// Creates a new [`Amplitude`] from a name and a [`Node`]-implementing struct.
455    pub fn new(name: &str, node: impl Node<F> + 'static) -> Self {
456        info!("Created new amplitude named {name}");
457        let parameters = node.parameters();
458        Self {
459            name: name.to_string(),
460            node: Box::new(node),
461            parameters,
462            active: true,
463            cache_position: 0,
464            parameter_index_start: 0,
465        }
466    }
467    /// Set the [`Amplitude::cache_position`] and [`Amplitude::parameter_index_start`] and runs
468    /// [`Amplitude::precalculate`] over the given [`Dataset`].
469    ///
470    /// # Errors
471    /// This function will raise a [`RustitudeError`] if the precalculation step fails.
472    pub fn register(
473        &mut self,
474        cache_position: usize,
475        parameter_index_start: usize,
476        dataset: &Dataset<F>,
477    ) -> Result<(), RustitudeError> {
478        self.cache_position = cache_position;
479        self.parameter_index_start = parameter_index_start;
480        self.precalculate(dataset)
481    }
482}
483impl<F: Field> Node<F> for Amplitude<F> {
484    fn precalculate(&mut self, dataset: &Dataset<F>) -> Result<(), RustitudeError> {
485        self.node.precalculate(dataset)?;
486        debug!("Precalculated amplitude {}", self.name);
487        Ok(())
488    }
489    fn calculate(&self, parameters: &[F], event: &Event<F>) -> Result<Complex<F>, RustitudeError> {
490        let res = self.node.calculate(
491            &parameters
492                [self.parameter_index_start..self.parameter_index_start + self.parameters.len()],
493            event,
494        );
495        debug!(
496            "{}({:?}, event #{}) = {}",
497            self.name,
498            &parameters
499                [self.parameter_index_start..self.parameter_index_start + self.parameters.len()],
500            event.index,
501            res.as_ref()
502                .map(|c| c.to_string())
503                .unwrap_or_else(|e| e.to_string())
504        );
505        res
506    }
507    fn parameters(&self) -> Vec<String> {
508        self.node.parameters()
509    }
510}
511impl<F: Field> AmpLike<F> for Amplitude<F> {
512    fn walk(&self) -> Vec<Self> {
513        vec![self.clone()]
514    }
515
516    fn walk_mut(&mut self) -> Vec<&mut Self> {
517        vec![self]
518    }
519
520    fn compute(&self, cache: &[Option<Complex<F>>]) -> Option<Complex<F>> {
521        let res = cache[self.cache_position];
522        debug!(
523            "Computing {} from cache: {:?}",
524            self.name,
525            res.as_ref().map(|c| c.to_string())
526        );
527        res
528    }
529}
530
531/// An [`AmpLike`] representing the real part of the [`AmpLike`] it contains.
532#[derive(Clone)]
533pub struct Real<F: Field>(Box<dyn AmpLike<F>>);
534impl<F: Field> Debug for Real<F> {
535    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
536        write!(f, "Real [ {:?} ]", self.0)
537    }
538}
539impl<F: Field> Display for Real<F> {
540    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
541        writeln!(f, "{}", self.get_tree())
542    }
543}
544impl<F: Field> AmpLike<F> for Real<F> {
545    fn walk(&self) -> Vec<Amplitude<F>> {
546        self.0.walk()
547    }
548
549    fn walk_mut(&mut self) -> Vec<&mut Amplitude<F>> {
550        self.0.walk_mut()
551    }
552
553    fn compute(&self, cache: &[Option<Complex<F>>]) -> Option<Complex<F>> {
554        let res: Option<Complex<F>> = self.0.compute(cache).map(|r| r.re.into());
555        debug!(
556            "Computing {:?} from cache: {:?}",
557            self,
558            res.as_ref().map(|c| c.to_string())
559        );
560        res
561    }
562}
563impl<F: Field> AsTree for Real<F> {
564    fn _get_tree(&self, bits: &mut Vec<bool>) -> String {
565        let mut res = String::from("[ real ]\n");
566        res.push_str(&self._get_indent(bits.to_vec()));
567        res.push_str(&self._get_end());
568        bits.push(false);
569        res.push_str(&self.0._get_tree(&mut bits.clone()));
570        bits.pop();
571        res
572    }
573}
574
575/// An [`AmpLike`] representing the imaginary part of the [`AmpLike`] it contains.
576#[derive(Clone)]
577pub struct Imag<F: Field>(Box<dyn AmpLike<F>>);
578impl<F: Field> Debug for Imag<F> {
579    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
580        write!(f, "Imag [ {:?} ]", self.0)
581    }
582}
583impl<F: Field> Display for Imag<F> {
584    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
585        writeln!(f, "{}", self.get_tree())
586    }
587}
588impl<F: Field> AmpLike<F> for Imag<F> {
589    fn walk(&self) -> Vec<Amplitude<F>> {
590        self.0.walk()
591    }
592
593    fn walk_mut(&mut self) -> Vec<&mut Amplitude<F>> {
594        self.0.walk_mut()
595    }
596
597    fn compute(&self, cache: &[Option<Complex<F>>]) -> Option<Complex<F>> {
598        let res: Option<Complex<F>> = self.0.compute(cache).map(|r| r.im.into());
599        debug!(
600            "Computing {:?} from cache: {:?}",
601            self,
602            res.as_ref().map(|c| c.to_string())
603        );
604        res
605    }
606}
607impl<F: Field> AsTree for Imag<F> {
608    fn _get_tree(&self, bits: &mut Vec<bool>) -> String {
609        let mut res = String::from("[ imag ]\n");
610        res.push_str(&self._get_indent(bits.to_vec()));
611        res.push_str(&self._get_end());
612        bits.push(false);
613        res.push_str(&self.0._get_tree(&mut bits.clone()));
614        bits.pop();
615        res
616    }
617}
618
619/// An [`AmpLike`] representing the product of the [`AmpLike`]s it contains.
620#[derive(Clone)]
621pub struct Product<F: Field>(Vec<Box<dyn AmpLike<F>>>);
622impl<F: Field> Debug for Product<F> {
623    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
624        write!(f, "Product [ ")?;
625        for op in &self.0 {
626            write!(f, "{:?} ", op)?;
627        }
628        write!(f, "]")
629    }
630}
631impl<F: Field> Display for Product<F> {
632    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
633        writeln!(f, "{}", self.get_tree())
634    }
635}
636impl<F: Field> AsTree for Product<F> {
637    fn _get_tree(&self, bits: &mut Vec<bool>) -> String {
638        let mut res = String::from("[ * ]\n");
639        for (i, op) in self.0.iter().enumerate() {
640            res.push_str(&self._get_indent(bits.to_vec()));
641            if i == self.0.len() - 1 {
642                res.push_str(&self._get_end());
643                bits.push(false);
644            } else {
645                res.push_str(&self._get_intermediate());
646                bits.push(true);
647            }
648            res.push_str(&op._get_tree(&mut bits.clone()));
649            bits.pop();
650        }
651        res
652    }
653}
654impl<F: Field> AmpLike<F> for Product<F> {
655    fn get_cloned_terms(&self) -> Option<Vec<Box<dyn AmpLike<F>>>> {
656        Some(self.0.clone())
657    }
658    fn walk(&self) -> Vec<Amplitude<F>> {
659        self.0.iter().flat_map(|op| op.walk()).collect()
660    }
661
662    fn walk_mut(&mut self) -> Vec<&mut Amplitude<F>> {
663        self.0.iter_mut().flat_map(|op| op.walk_mut()).collect()
664    }
665
666    fn compute(&self, cache: &[Option<Complex<F>>]) -> Option<Complex<F>> {
667        let mut values = self.0.iter().filter_map(|op| op.compute(cache)).peekable();
668        let res: Option<Complex<F>> = if values.peek().is_none() {
669            Some(Complex::default())
670        } else {
671            Some(values.product())
672        };
673        debug!(
674            "Computing {:?} from cache: {:?}",
675            self,
676            res.as_ref().map(|c| c.to_string())
677        );
678        res
679    }
680}
681
682/// An [`AmpLike`] representing the sum of the [`AmpLike`]s it contains.
683#[derive(Clone)]
684pub struct Sum<F: Field>(pub Vec<Box<dyn AmpLike<F>>>);
685impl<F: Field> Debug for Sum<F> {
686    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
687        write!(f, "Sum [ ")?;
688        for op in &self.0 {
689            write!(f, "{:?} ", op)?;
690        }
691        write!(f, "]")
692    }
693}
694impl<F: Field> Display for Sum<F> {
695    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
696        writeln!(f, "{}", self.get_tree())
697    }
698}
699impl<F: Field> AsTree for Sum<F> {
700    fn _get_tree(&self, bits: &mut Vec<bool>) -> String {
701        let mut res = String::from("[ + ]\n");
702        for (i, op) in self.0.iter().enumerate() {
703            res.push_str(&self._get_indent(bits.to_vec()));
704            if i == self.0.len() - 1 {
705                res.push_str(&self._get_end());
706                bits.push(false);
707            } else {
708                res.push_str(&self._get_intermediate());
709                bits.push(true);
710            }
711            res.push_str(&op._get_tree(&mut bits.clone()));
712            bits.pop();
713        }
714        res
715    }
716}
717impl<F: Field> AmpLike<F> for Sum<F> {
718    fn get_cloned_terms(&self) -> Option<Vec<Box<dyn AmpLike<F>>>> {
719        Some(self.0.clone())
720    }
721    fn walk(&self) -> Vec<Amplitude<F>> {
722        self.0.iter().flat_map(|op| op.walk()).collect()
723    }
724
725    fn walk_mut(&mut self) -> Vec<&mut Amplitude<F>> {
726        self.0.iter_mut().flat_map(|op| op.walk_mut()).collect()
727    }
728
729    fn compute(&self, cache: &[Option<Complex<F>>]) -> Option<Complex<F>> {
730        let res = Some(
731            self.0
732                .iter()
733                .filter_map(|al| al.compute(cache))
734                .sum::<Complex<F>>(),
735        );
736        debug!(
737            "Computing {:?} from cache: {:?}",
738            self,
739            res.as_ref().map(|c| c.to_string())
740        );
741        res
742    }
743}
744
745/// Struct to hold a coherent sum of [`AmpLike`]s
746#[derive(Clone)]
747pub struct NormSqr<F: Field>(pub Box<dyn AmpLike<F>>);
748
749impl<F: Field> Debug for NormSqr<F> {
750    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
751        write!(f, "NormSqr[ {:?} ]", self.0)
752    }
753}
754impl<F: Field> Display for NormSqr<F> {
755    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
756        writeln!(f, "{}", self.get_tree())
757    }
758}
759impl<F: Field> AsTree for NormSqr<F> {
760    fn _get_tree(&self, bits: &mut Vec<bool>) -> String {
761        let mut res = String::from("[ |_|^2 ]\n");
762        res.push_str(&self._get_indent(bits.to_vec()));
763        res.push_str(&self._get_end());
764        bits.push(false);
765        res.push_str(&self.0._get_tree(&mut bits.clone()));
766        bits.pop();
767        res
768    }
769}
770impl<F: Field> NormSqr<F> {
771    /// Shortcut for computation using a cache of precomputed values. This method will return
772    /// [`None`] if the cache value at the corresponding [`Amplitude`]'s
773    /// [`Amplitude::cache_position`] is also [`None`], otherwise it just returns the corresponding
774    /// cached value. The computation is run across the [`NormSqr`]'s contained term, and the absolute
775    /// square of the result is returned.
776    pub fn compute(&self, cache: &[Option<Complex<F>>]) -> Option<F> {
777        self.0.compute(cache).map(|res| res.norm_sqr())
778    }
779
780    /// Walks through a [`NormSqr`] and collects all the contained [`Amplitude`]s recursively.
781    pub fn walk(&self) -> Vec<Amplitude<F>> {
782        self.0.walk()
783    }
784
785    /// Walks through an [`NormSqr`] and collects all the contained [`Amplitude`]s recursively. This
786    /// method gives mutable access to said [`Amplitude`]s.
787    pub fn walk_mut(&mut self) -> Vec<&mut Amplitude<F>> {
788        self.0.walk_mut()
789    }
790}
791
792/// A model contains an API to interact with a group of coherent sums by managing their amplitudes
793/// and parameters. Models are typically passed to [`Manager`](crate::manager::Manager)-like
794/// struct.
795#[derive(Clone)]
796pub struct Model<F: Field> {
797    /// The set of coherent sums included in the [`Model`].
798    pub cohsums: Vec<NormSqr<F>>,
799    /// The unique amplitudes located within all coherent sums.
800    pub amplitudes: Arc<RwLock<Vec<Amplitude<F>>>>,
801    /// The unique parameters located within all coherent sums.
802    pub parameters: Vec<Parameter<F>>,
803    /// Flag which is `True` iff at least one [`Amplitude`] is written in Python and has a [`Node`]
804    /// for which [`Node::is_python_node`] returns `True`.
805    pub contains_python_amplitudes: bool,
806}
807impl<F: Field> Debug for Model<F> {
808    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
809        write!(f, "Model [ ")?;
810        for op in &self.cohsums {
811            write!(f, "{:?} ", op)?;
812        }
813        write!(f, "]")
814    }
815}
816impl<F: Field> Display for Model<F> {
817    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
818        writeln!(f, "{}", self.get_tree())
819    }
820}
821impl<F: Field> AsTree for Model<F> {
822    fn _get_tree(&self, bits: &mut Vec<bool>) -> String {
823        let mut res = String::from("[ + ]\n");
824        for (i, op) in self.cohsums.iter().enumerate() {
825            res.push_str(&self._get_indent(bits.to_vec()));
826            if i == self.cohsums.len() - 1 {
827                res.push_str(&self._get_end());
828                bits.push(false);
829            } else {
830                res.push_str(&self._get_intermediate());
831                bits.push(true);
832            }
833            res.push_str(&op._get_tree(&mut bits.clone()));
834            bits.pop();
835        }
836        res
837    }
838}
839impl<F: Field> Model<F> {
840    /// Creates a new [`Model`] from a list of [`Box<AmpLike>`]s.
841    pub fn new(amps: &[Box<dyn AmpLike<F>>]) -> Self {
842        let mut amp_names = HashSet::new();
843        let amplitudes: Vec<Amplitude<F>> = amps
844            .iter()
845            .flat_map(|cohsum| cohsum.walk())
846            .filter_map(|amp| {
847                if amp_names.insert(amp.name.clone()) {
848                    Some(amp)
849                } else {
850                    None
851                }
852            })
853            .collect();
854        let parameter_tags: Vec<(String, String)> = amplitudes
855            .iter()
856            .flat_map(|amp| {
857                amp.parameters()
858                    .iter()
859                    .map(|p| (amp.name.clone(), p.clone()))
860                    .collect::<Vec<_>>()
861            })
862            .collect();
863        let parameters = parameter_tags
864            .iter()
865            .enumerate()
866            .map(|(i, (amp_name, par_name))| Parameter::new(amp_name, par_name, i))
867            .collect();
868        let contains_python_amplitudes = amplitudes.iter().any(|amp| amp.node.is_python_node());
869        Self {
870            cohsums: amps.iter().map(|inner| NormSqr(inner.clone())).collect(),
871            amplitudes: Arc::new(RwLock::new(amplitudes)),
872            parameters,
873            contains_python_amplitudes,
874        }
875    }
876    /// Creates a true clone (deep copy) of the [`Model`] where the `amplitudes` field is
877    /// duplicated rather than having its reference count increased.
878    pub fn deep_clone(&self) -> Self {
879        Self {
880            cohsums: self.cohsums.clone(),
881            amplitudes: Arc::new(RwLock::new(self.amplitudes.read().clone())),
882            parameters: self.parameters.clone(),
883            contains_python_amplitudes: self.contains_python_amplitudes,
884        }
885    }
886    /// Computes the result of evaluating the terms in the model with the given [`Parameter`]s for
887    /// the given [`Event`] by summing the result of [`NormSqr::compute`] for each [`NormSqr`]
888    /// contained in the [`Model`] (see the `cohsum` field of [`Model`]).
889    ///
890    /// # Errors
891    ///
892    /// This method yields a [`RustitudeError`] if any of the [`Amplitude::calculate`] steps fail.
893    pub fn compute(
894        &self,
895        amplitudes: &[Amplitude<F>],
896        parameters: &[F],
897        event: &Event<F>,
898    ) -> Result<F, RustitudeError> {
899        // TODO: Stop reallocating?
900
901        // NOTE: This seems to be just as fast as using a Vec<ComplexField> and replacing active
902        // amplitudes by multiplying their cached values by 0.0. Branch prediction doesn't get us
903        // any performance here I guess.
904        let cache: Vec<Option<Complex<F>>> = amplitudes
905            .iter()
906            .map(|amp| {
907                if amp.active {
908                    amp.calculate(parameters, event).map(Some)
909                } else {
910                    Ok(None)
911                }
912            })
913            .collect::<Result<Vec<Option<Complex<F>>>, RustitudeError>>()?;
914        Ok(self
915            .cohsums
916            .iter()
917            .filter_map(|cohsum| cohsum.compute(&cache))
918            .sum::<F>())
919    }
920    /// Registers the [`Model`] with the [`Dataset`] by [`Amplitude::register`]ing each
921    /// [`Amplitude`] and setting the proper cache position and parameter starting index.
922    ///
923    /// # Errors
924    ///
925    /// This method will yield a [`RustitudeError`] if any [`Amplitude::precalculate`] steps fail.
926    pub fn load(&mut self, dataset: &Dataset<F>) -> Result<(), RustitudeError> {
927        let mut next_cache_pos = 0;
928        let mut parameter_index = 0;
929        self.amplitudes.write().iter_mut().try_for_each(|amp| {
930            amp.register(next_cache_pos, parameter_index, dataset)?;
931            self.cohsums.iter_mut().for_each(|cohsum| {
932                cohsum.walk_mut().iter_mut().for_each(|r_amp| {
933                    if r_amp.name == amp.name {
934                        r_amp.cache_position = next_cache_pos;
935                        r_amp.parameter_index_start = parameter_index;
936                    }
937                })
938            });
939            next_cache_pos += 1;
940            parameter_index += amp.parameters().len();
941            Ok(())
942        })
943    }
944
945    /// Retrieves a copy of an [`Amplitude`] in the [`Model`] by name.
946    ///
947    /// # Errors
948    /// This will throw a [`RustitudeError`] if the amplitude name is not located within the model.
949    pub fn get_amplitude(&self, amplitude_name: &str) -> Result<Amplitude<F>, RustitudeError> {
950        self.amplitudes
951            .read()
952            .iter()
953            .find(|a: &&Amplitude<F>| a.name == amplitude_name)
954            .ok_or_else(|| RustitudeError::AmplitudeNotFoundError(amplitude_name.to_string()))
955            .cloned()
956    }
957    /// Retrieves a copy of a [`Parameter`] in the [`Model`] by name.
958    ///
959    /// # Errors
960    /// This will throw a [`RustitudeError`] if the parameter name is not located within the model
961    /// or if the amplitude name is not located within the model (this is checked first).
962    pub fn get_parameter(
963        &self,
964        amplitude_name: &str,
965        parameter_name: &str,
966    ) -> Result<Parameter<F>, RustitudeError> {
967        self.get_amplitude(amplitude_name)?;
968        self.parameters
969            .iter()
970            .find(|p: &&Parameter<F>| p.amplitude == amplitude_name && p.name == parameter_name)
971            .ok_or_else(|| RustitudeError::ParameterNotFoundError(parameter_name.to_string()))
972            .cloned()
973    }
974    /// Pretty-prints all parameters in the model
975    pub fn print_parameters(&self) {
976        let any_fixed = if self.any_fixed() { 1 } else { 0 };
977        if self.any_fixed() {
978            println!(
979                "Fixed: {}",
980                self.group_by_index()[0]
981                    .iter()
982                    .map(|p| format!("{:?}", p))
983                    .join(", ")
984            );
985        }
986        for (i, group) in self.group_by_index().iter().skip(any_fixed).enumerate() {
987            println!(
988                "{}: {}",
989                i,
990                group.iter().map(|p| format!("{:?}", p)).join(", ")
991            );
992        }
993    }
994
995    /// Returns a [`Vec<Parameter<F>>`] containing the free parameters in the [`Model`].
996    pub fn free_parameters(&self) -> Vec<Parameter<F>> {
997        self.parameters
998            .iter()
999            .filter(|p| p.is_free())
1000            .cloned()
1001            .collect()
1002    }
1003
1004    /// Returns a [`Vec<Parameter<F>>`] containing the fixed parameters in the [`Model`].
1005    pub fn fixed_parameters(&self) -> Vec<Parameter<F>> {
1006        self.parameters
1007            .iter()
1008            .filter(|p| p.is_fixed())
1009            .cloned()
1010            .collect()
1011    }
1012
1013    /// Constrains two [`Parameter`]s in the [`Model`] to be equal to each other when evaluated.
1014    ///
1015    /// # Errors
1016    ///
1017    /// This method will yield a [`RustitudeError`] if either of the parameters is not found by
1018    /// name.
1019    pub fn constrain(
1020        &mut self,
1021        amplitude_1: &str,
1022        parameter_1: &str,
1023        amplitude_2: &str,
1024        parameter_2: &str,
1025    ) -> Result<(), RustitudeError> {
1026        let p1 = self.get_parameter(amplitude_1, parameter_1)?;
1027        let p2 = self.get_parameter(amplitude_2, parameter_2)?;
1028        for par in self.parameters.iter_mut() {
1029            // None < Some(0)
1030            match p1.index.cmp(&p2.index) {
1031                // p1 < p2
1032                std::cmp::Ordering::Less => {
1033                    if par.index == p2.index {
1034                        par.index = p1.index;
1035                        par.initial = p1.initial;
1036                        par.fixed_index = p1.fixed_index;
1037                    }
1038                }
1039                std::cmp::Ordering::Equal => unimplemented!(),
1040                // p2 < p1
1041                std::cmp::Ordering::Greater => {
1042                    if par.index == p1.index {
1043                        par.index = p2.index;
1044                        par.initial = p2.initial;
1045                        par.fixed_index = p2.fixed_index;
1046                    }
1047                }
1048            }
1049        }
1050        self.reindex_parameters();
1051        Ok(())
1052    }
1053
1054    /// Fixes a [`Parameter`] in the [`Model`] to a given value.
1055    ///
1056    /// This method technically sets the [`Parameter`] to be fixed and gives it an initial value of
1057    /// the given value. This method also handles groups of constrained parameters.
1058    ///
1059    /// # Errors
1060    ///
1061    /// This method yields a [`RustitudeError`] if the parameter is not found by name.
1062    pub fn fix(
1063        &mut self,
1064        amplitude: &str,
1065        parameter: &str,
1066        value: F,
1067    ) -> Result<(), RustitudeError> {
1068        let search_par = self.get_parameter(amplitude, parameter)?;
1069        let fixed_index = self.get_min_fixed_index();
1070        for par in self.parameters.iter_mut() {
1071            if par.index == search_par.index {
1072                par.index = None;
1073                par.initial = value;
1074                par.fixed_index = fixed_index;
1075            }
1076        }
1077        self.reindex_parameters();
1078        Ok(())
1079    }
1080    /// Frees a [`Parameter`] in the [`Model`].
1081    ///
1082    /// This method does not modify the initial value of the parameter. This method
1083    /// also handles groups of constrained parameters.
1084    ///
1085    /// # Errors
1086    ///
1087    /// This method yields a [`RustitudeError`] if the parameter is not found by name.
1088    pub fn free(&mut self, amplitude: &str, parameter: &str) -> Result<(), RustitudeError> {
1089        let search_par = self.get_parameter(amplitude, parameter)?;
1090        let index = self.get_min_free_index();
1091        for par in self.parameters.iter_mut() {
1092            if par.fixed_index == search_par.fixed_index {
1093                par.index = index;
1094                par.fixed_index = None;
1095            }
1096        }
1097        self.reindex_parameters();
1098        Ok(())
1099    }
1100    /// Sets the bounds on a [`Parameter`] in the [`Model`].
1101    ///
1102    /// # Errors
1103    ///
1104    /// This method yields a [`RustitudeError`] if the parameter is not found by name.
1105    pub fn set_bounds(
1106        &mut self,
1107        amplitude: &str,
1108        parameter: &str,
1109        bounds: (F, F),
1110    ) -> Result<(), RustitudeError> {
1111        let search_par = self.get_parameter(amplitude, parameter)?;
1112        if search_par.index.is_some() {
1113            for par in self.parameters.iter_mut() {
1114                if par.index == search_par.index {
1115                    par.bounds = bounds;
1116                }
1117            }
1118        } else {
1119            for par in self.parameters.iter_mut() {
1120                if par.fixed_index == search_par.fixed_index {
1121                    par.bounds = bounds;
1122                }
1123            }
1124        }
1125        Ok(())
1126    }
1127    /// Sets the initial value of a [`Parameter`] in the [`Model`].
1128    ///
1129    /// # Errors
1130    ///
1131    /// This method yields a [`RustitudeError`] if the parameter is not found by name.
1132    pub fn set_initial(
1133        &mut self,
1134        amplitude: &str,
1135        parameter: &str,
1136        initial: F,
1137    ) -> Result<(), RustitudeError> {
1138        let search_par = self.get_parameter(amplitude, parameter)?;
1139        if search_par.index.is_some() {
1140            for par in self.parameters.iter_mut() {
1141                if par.index == search_par.index {
1142                    par.initial = initial;
1143                }
1144            }
1145        } else {
1146            for par in self.parameters.iter_mut() {
1147                if par.fixed_index == search_par.fixed_index {
1148                    par.initial = initial;
1149                }
1150            }
1151        }
1152        Ok(())
1153    }
1154    /// Returns a list of bounds of free [`Parameter`]s in the [`Model`].
1155    pub fn get_bounds(&self) -> Vec<(F, F)> {
1156        let any_fixed = if self.any_fixed() { 1 } else { 0 };
1157        self.group_by_index()
1158            .iter()
1159            .skip(any_fixed)
1160            .filter_map(|group| group.first().map(|par| par.bounds))
1161            .collect()
1162    }
1163    /// Returns a list of initial values of free [`Parameter`]s in the [`Model`].
1164    pub fn get_initial(&self) -> Vec<F> {
1165        let any_fixed = if self.any_fixed() { 1 } else { 0 };
1166        self.group_by_index()
1167            .iter()
1168            .skip(any_fixed)
1169            .filter_map(|group| group.first().map(|par| par.initial))
1170            .collect()
1171    }
1172    /// Returns the number of free [`Parameter`]s in the [`Model`].
1173    pub fn get_n_free(&self) -> usize {
1174        self.get_min_free_index().unwrap_or(0)
1175    }
1176    /// Activates an [`Amplitude`] in the [`Model`] by name.
1177    ///
1178    /// # Errors
1179    ///
1180    /// This function will return a [`RustitudeError::AmplitudeNotFoundError`] if the given
1181    /// amplitude is not present in the [`Model`].
1182    pub fn activate(&mut self, amplitude: &str) -> Result<(), RustitudeError> {
1183        if !self.amplitudes.read().iter().any(|a| a.name == amplitude) {
1184            return Err(RustitudeError::AmplitudeNotFoundError(
1185                amplitude.to_string(),
1186            ));
1187        }
1188        self.amplitudes.write().iter_mut().for_each(|amp| {
1189            if amp.name == amplitude {
1190                amp.active = true
1191            }
1192        });
1193        self.cohsums.iter_mut().for_each(|cohsum| {
1194            cohsum.walk_mut().iter_mut().for_each(|amp| {
1195                if amp.name == amplitude {
1196                    amp.active = true
1197                }
1198            })
1199        });
1200        Ok(())
1201    }
1202    /// Activates all [`Amplitude`]s in the [`Model`].
1203    pub fn activate_all(&mut self) {
1204        self.amplitudes
1205            .write()
1206            .iter_mut()
1207            .for_each(|amp| amp.active = true);
1208        self.cohsums.iter_mut().for_each(|cohsum| {
1209            cohsum
1210                .walk_mut()
1211                .iter_mut()
1212                .for_each(|amp| amp.active = true)
1213        });
1214    }
1215    /// Activate only the specified [`Amplitude`]s while deactivating the rest.
1216    ///
1217    /// # Errors
1218    ///
1219    /// This function will return a [`RustitudeError::AmplitudeNotFoundError`] if a given
1220    /// amplitude is not present in the [`Model`].
1221    pub fn isolate(&mut self, amplitudes: Vec<&str>) -> Result<(), RustitudeError> {
1222        self.deactivate_all();
1223        for amplitude in amplitudes {
1224            self.activate(amplitude)?;
1225        }
1226        Ok(())
1227    }
1228    /// Deactivates an [`Amplitude`] in the [`Model`] by name.
1229    ///
1230    /// # Errors
1231    ///
1232    /// This function will return a [`RustitudeError::AmplitudeNotFoundError`] if the given
1233    /// amplitude is not present in the [`Model`].
1234    pub fn deactivate(&mut self, amplitude: &str) -> Result<(), RustitudeError> {
1235        if !self.amplitudes.read().iter().any(|a| a.name == amplitude) {
1236            return Err(RustitudeError::AmplitudeNotFoundError(
1237                amplitude.to_string(),
1238            ));
1239        }
1240        self.amplitudes.write().iter_mut().for_each(|amp| {
1241            if amp.name == amplitude {
1242                amp.active = false
1243            }
1244        });
1245        self.cohsums.iter_mut().for_each(|cohsum| {
1246            cohsum.walk_mut().iter_mut().for_each(|amp| {
1247                if amp.name == amplitude {
1248                    amp.active = false
1249                }
1250            })
1251        });
1252        Ok(())
1253    }
1254    /// Deactivates all [`Amplitude`]s in the [`Model`].
1255    pub fn deactivate_all(&mut self) {
1256        self.amplitudes
1257            .write()
1258            .iter_mut()
1259            .for_each(|amp| amp.active = false);
1260        self.cohsums.iter_mut().for_each(|cohsum| {
1261            cohsum
1262                .walk_mut()
1263                .iter_mut()
1264                .for_each(|amp| amp.active = false)
1265        });
1266    }
1267    fn group_by_index(&self) -> Vec<Vec<&Parameter<F>>> {
1268        self.parameters
1269            .iter()
1270            .sorted_by_key(|par| par.index)
1271            .chunk_by(|par| par.index)
1272            .into_iter()
1273            .map(|(_, group)| group.collect::<Vec<_>>())
1274            .collect()
1275    }
1276    fn group_by_index_mut(&mut self) -> Vec<Vec<&mut Parameter<F>>> {
1277        self.parameters
1278            .iter_mut()
1279            .sorted_by_key(|par| par.index)
1280            .chunk_by(|par| par.index)
1281            .into_iter()
1282            .map(|(_, group)| group.collect())
1283            .collect()
1284    }
1285    fn any_fixed(&self) -> bool {
1286        self.parameters.iter().any(|p| p.index.is_none())
1287    }
1288    fn reindex_parameters(&mut self) {
1289        let any_fixed = if self.any_fixed() { 1 } else { 0 };
1290        self.group_by_index_mut()
1291            .iter_mut()
1292            .skip(any_fixed) // first element could be index = None
1293            .enumerate()
1294            .for_each(|(ind, par_group)| par_group.iter_mut().for_each(|par| par.index = Some(ind)))
1295    }
1296    fn get_min_free_index(&self) -> Option<usize> {
1297        self.parameters
1298            .iter()
1299            .filter_map(|p| p.index)
1300            .max()
1301            .map_or(Some(0), |max| Some(max + 1))
1302    }
1303    fn get_min_fixed_index(&self) -> Option<usize> {
1304        self.parameters
1305            .iter()
1306            .filter_map(|p| p.fixed_index)
1307            .max()
1308            .map_or(Some(0), |max| Some(max + 1))
1309    }
1310}
1311
1312/// A [`Node`] for computing a single scalar value from an input parameter.
1313///
1314/// This struct implements [`Node`] to generate a single new parameter called `value`.
1315///
1316/// # Parameters:
1317///
1318/// - `value`: The value of the scalar.
1319#[derive(Clone)]
1320pub struct Scalar;
1321impl<F: Field> Node<F> for Scalar {
1322    fn parameters(&self) -> Vec<String> {
1323        vec!["value".to_string()]
1324    }
1325    fn calculate(&self, parameters: &[F], _event: &Event<F>) -> Result<Complex<F>, RustitudeError> {
1326        Ok(Complex::new(parameters[0], F::zero()))
1327    }
1328}
1329
1330/// Creates a named [`Scalar`].
1331///
1332/// This is a convenience method to generate an [`Amplitude`] which is just a single free
1333/// parameter called `value`.
1334///
1335/// # Examples
1336///
1337/// Basic usage:
1338///
1339/// ```
1340/// use rustitude_core::prelude::*;
1341/// let my_scalar: Amplitude<f64> = scalar("MyScalar");
1342/// assert_eq!(my_scalar.parameters, vec!["value".to_string()]);
1343/// ```
1344pub fn scalar<F: Field>(name: &str) -> Amplitude<F> {
1345    Amplitude::new(name, Scalar)
1346}
1347/// A [`Node`] for computing a single complex value from two input parameters.
1348///
1349/// This struct implements [`Node`] to generate a complex value from two input parameters called
1350/// `real` and `imag`.
1351///
1352/// # Parameters:
1353///
1354/// - `real`: The real part of the complex scalar.
1355/// - `imag`: The imaginary part of the complex scalar.
1356#[derive(Clone)]
1357pub struct ComplexScalar;
1358impl<F: Field> Node<F> for ComplexScalar {
1359    fn calculate(&self, parameters: &[F], _event: &Event<F>) -> Result<Complex<F>, RustitudeError> {
1360        Ok(Complex::new(parameters[0], parameters[1]))
1361    }
1362
1363    fn parameters(&self) -> Vec<String> {
1364        vec!["real".to_string(), "imag".to_string()]
1365    }
1366}
1367/// Creates a named [`ComplexScalar`].
1368///
1369/// This is a convenience method to generate an [`Amplitude`] which represents a complex
1370/// value determined by two parameters, `real` and `imag`.
1371///
1372/// # Examples
1373///
1374/// Basic usage:
1375///
1376/// ```
1377/// use rustitude_core::prelude::*;
1378/// let my_cscalar: Amplitude<f64> = cscalar("MyComplexScalar");
1379/// assert_eq!(my_cscalar.parameters, vec!["real".to_string(), "imag".to_string()]);
1380/// ```
1381pub fn cscalar<F: Field>(name: &str) -> Amplitude<F> {
1382    Amplitude::new(name, ComplexScalar)
1383}
1384
1385/// A [`Node`] for computing a single complex value from two input parameters in polar form.
1386///
1387/// This struct implements [`Node`] to generate a complex value from two input parameters called
1388/// `mag` and `phi`.
1389///
1390/// # Parameters:
1391///
1392/// - `mag`: The magnitude of the complex scalar.
1393/// - `phi`: The phase of the complex scalar.
1394#[derive(Clone)]
1395pub struct PolarComplexScalar;
1396impl<F: Field> Node<F> for PolarComplexScalar {
1397    fn calculate(&self, parameters: &[F], _event: &Event<F>) -> Result<Complex<F>, RustitudeError> {
1398        Ok(Complex::cis(parameters[1]).mul(parameters[0]))
1399    }
1400
1401    fn parameters(&self) -> Vec<String> {
1402        vec!["mag".to_string(), "phi".to_string()]
1403    }
1404}
1405
1406/// Creates a named [`PolarComplexScalar`].
1407///
1408/// This is a convenience method to generate an [`Amplitude `] which represents a complex
1409/// value determined by two parameters, `real` and `imag`.
1410///
1411/// # Examples
1412///
1413/// Basic usage:
1414///
1415/// ```
1416/// use rustitude_core::prelude::*;
1417/// let my_pcscalar: Amplitude<f64> = pcscalar("MyPolarComplexScalar");
1418/// assert_eq!(my_pcscalar.parameters, vec!["mag".to_string(), "phi".to_string()]);
1419/// ```
1420pub fn pcscalar<F: Field>(name: &str) -> Amplitude<F> {
1421    Amplitude::new(name, PolarComplexScalar)
1422}
1423
1424/// A generic struct which can be used to create any kind of piecewise function.
1425#[derive(Clone)]
1426pub struct Piecewise<V, F>
1427where
1428    V: Fn(&Event<F>) -> F + Send + Sync + Copy,
1429    F: Field,
1430{
1431    edges: Vec<(F, F)>,
1432    variable: V,
1433    calculated_variable: Vec<F>,
1434}
1435
1436impl<V, F> Piecewise<V, F>
1437where
1438    V: Fn(&Event<F>) -> F + Send + Sync + Copy,
1439    F: Field,
1440{
1441    /// Create a new [`Piecewise`] struct from a number of bins, a range of values, and a callable
1442    /// which defines a variable over the [`Event`]s in a [`Dataset`].
1443    pub fn new(bins: usize, range: (F, F), variable: V) -> Self {
1444        let diff = (range.1 - range.0) / convert!(bins, F);
1445        let edges = (0..bins)
1446            .map(|i| {
1447                (
1448                    F::mul_add(convert!(i, F), diff, range.0),
1449                    F::mul_add(convert!(i + 1, F), diff, range.0),
1450                )
1451            })
1452            .collect();
1453        Self {
1454            edges,
1455            variable,
1456            calculated_variable: Vec::default(),
1457        }
1458    }
1459}
1460
1461impl<V, F> Node<F> for Piecewise<V, F>
1462where
1463    V: Fn(&Event<F>) -> F + Send + Sync + Copy,
1464    F: Field,
1465{
1466    fn precalculate(&mut self, dataset: &Dataset<F>) -> Result<(), RustitudeError> {
1467        self.calculated_variable = dataset.events.par_iter().map(self.variable).collect();
1468        Ok(())
1469    }
1470
1471    fn calculate(&self, parameters: &[F], event: &Event<F>) -> Result<Complex<F>, RustitudeError> {
1472        let val = self.calculated_variable[event.index];
1473        let opt_i_bin = self.edges.iter().position(|&(l, r)| val >= l && val <= r);
1474        opt_i_bin.map_or_else(
1475            || Ok(Complex::default()),
1476            |i_bin| {
1477                Ok(Complex::new(
1478                    parameters[i_bin * 2],
1479                    parameters[(i_bin * 2) + 1],
1480                ))
1481            },
1482        )
1483    }
1484
1485    fn parameters(&self) -> Vec<String> {
1486        (0..self.edges.len())
1487            .flat_map(|i| vec![format!("bin {} re", i), format!("bin {} im", i)])
1488            .collect()
1489    }
1490}
1491
1492pub fn piecewise_m<F: Field + 'static>(name: &str, bins: usize, range: (F, F)) -> Amplitude<F> {
1493    //! Creates a named [`Piecewise`] amplitude with the resonance mass as the binning variable.
1494    Amplitude::new(
1495        name,
1496        Piecewise::new(bins, range, |e: &Event<F>| {
1497            (e.daughter_p4s[0] + e.daughter_p4s[1]).m()
1498        }),
1499    )
1500}
1501
1502macro_rules! impl_sum {
1503    ($t:ident, $a:ty, $b:ty) => {
1504        impl<$t: Field + 'static> Add<$b> for $a {
1505            type Output = Sum<$t>;
1506
1507            fn add(self, rhs: $b) -> Self::Output {
1508                Sum(vec![Box::new(self), Box::new(rhs)])
1509            }
1510        }
1511
1512        impl<$t: Field + 'static> Add<&$b> for &$a {
1513            type Output = <$a as Add<$b>>::Output;
1514
1515            fn add(self, rhs: &$b) -> Self::Output {
1516                <$a as Add<$b>>::add(self.clone(), rhs.clone())
1517            }
1518        }
1519
1520        impl<$t: Field + 'static> Add<&$b> for $a {
1521            type Output = <$a as Add<$b>>::Output;
1522
1523            fn add(self, rhs: &$b) -> Self::Output {
1524                <$a as Add<$b>>::add(self, rhs.clone())
1525            }
1526        }
1527
1528        impl<$t: Field + 'static> Add<$b> for &$a {
1529            type Output = <$a as Add<$b>>::Output;
1530
1531            fn add(self, rhs: $b) -> Self::Output {
1532                <$a as Add<$b>>::add(self.clone(), rhs)
1533            }
1534        }
1535
1536        impl<$t: Field + 'static> Add<$a> for $b {
1537            type Output = Sum<$t>;
1538
1539            fn add(self, rhs: $a) -> Self::Output {
1540                Sum(vec![Box::new(self), Box::new(rhs)])
1541            }
1542        }
1543
1544        impl<$t: Field + 'static> Add<&$a> for &$b {
1545            type Output = <$b as Add<$a>>::Output;
1546
1547            fn add(self, rhs: &$a) -> Self::Output {
1548                <$b as Add<$a>>::add(self.clone(), rhs.clone())
1549            }
1550        }
1551
1552        impl<$t: Field + 'static> Add<&$a> for $b {
1553            type Output = <$b as Add<$a>>::Output;
1554
1555            fn add(self, rhs: &$a) -> Self::Output {
1556                <$b as Add<$a>>::add(self, rhs.clone())
1557            }
1558        }
1559
1560        impl<$t: Field + 'static> Add<$a> for &$b {
1561            type Output = <$b as Add<$a>>::Output;
1562
1563            fn add(self, rhs: $a) -> Self::Output {
1564                <$b as Add<$a>>::add(self.clone(), rhs)
1565            }
1566        }
1567    };
1568    ($t:ident, $a:ty) => {
1569        impl<$t: Field + 'static> Add<$a> for $a {
1570            type Output = Sum<$t>;
1571
1572            fn add(self, rhs: $a) -> Self::Output {
1573                Sum(vec![Box::new(self), Box::new(rhs)])
1574            }
1575        }
1576
1577        impl<$t: Field + 'static> Add<&$a> for &$a {
1578            type Output = <$a as Add<$a>>::Output;
1579
1580            fn add(self, rhs: &$a) -> Self::Output {
1581                <$a as Add<$a>>::add(self.clone(), rhs.clone())
1582            }
1583        }
1584
1585        impl<$t: Field + 'static> Add<&$a> for $a {
1586            type Output = <$a as Add<$a>>::Output;
1587
1588            fn add(self, rhs: &$a) -> Self::Output {
1589                <$a as Add<$a>>::add(self, rhs.clone())
1590            }
1591        }
1592
1593        impl<$t: Field + 'static> Add<$a> for &$a {
1594            type Output = <$a as Add<$a>>::Output;
1595
1596            fn add(self, rhs: $a) -> Self::Output {
1597                <$a as Add<$a>>::add(self.clone(), rhs)
1598            }
1599        }
1600    };
1601}
1602macro_rules! impl_appending_sum {
1603    ($t:ident, $a:ty) => {
1604        impl<$t: Field + 'static> Add<Sum<$t>> for $a {
1605            type Output = Sum<$t>;
1606
1607            fn add(self, rhs: Sum<$t>) -> Self::Output {
1608                let mut terms = rhs.0;
1609                terms.insert(0, Box::new(self));
1610                Sum(terms)
1611            }
1612        }
1613
1614        impl<$t: Field + 'static> Add<$a> for Sum<$t> {
1615            type Output = Sum<$t>;
1616
1617            fn add(self, rhs: $a) -> Self::Output {
1618                let mut terms = self.0;
1619                terms.push(Box::new(rhs));
1620                Sum(terms)
1621            }
1622        }
1623
1624        impl<$t: Field + 'static> Add<&Sum<$t>> for &$a {
1625            type Output = <$a as Add<Sum<$t>>>::Output;
1626
1627            fn add(self, rhs: &Sum<$t>) -> Self::Output {
1628                <$a as Add<Sum<$t>>>::add(self.clone(), rhs.clone())
1629            }
1630        }
1631
1632        impl<$t: Field + 'static> Add<&Sum<$t>> for $a {
1633            type Output = <$a as Add<Sum<$t>>>::Output;
1634
1635            fn add(self, rhs: &Sum<$t>) -> Self::Output {
1636                <$a as Add<Sum<$t>>>::add(self, rhs.clone())
1637            }
1638        }
1639
1640        impl<$t: Field + 'static> Add<Sum<$t>> for &$a {
1641            type Output = <$a as Add<Sum<$t>>>::Output;
1642
1643            fn add(self, rhs: Sum<$t>) -> Self::Output {
1644                <$a as Add<Sum<$t>>>::add(self.clone(), rhs)
1645            }
1646        }
1647
1648        impl<$t: Field + 'static> Add<&$a> for &Sum<$t> {
1649            type Output = <Sum<$t> as Add<$a>>::Output;
1650
1651            fn add(self, rhs: &$a) -> Self::Output {
1652                <Sum<$t> as Add<$a>>::add(self.clone(), rhs.clone())
1653            }
1654        }
1655
1656        impl<$t: Field + 'static> Add<&$a> for Sum<$t> {
1657            type Output = <Sum<$t> as Add<$a>>::Output;
1658
1659            fn add(self, rhs: &$a) -> Self::Output {
1660                <Sum<$t> as Add<$a>>::add(self, rhs.clone())
1661            }
1662        }
1663
1664        impl<$t: Field + 'static> Add<$a> for &Sum<$t> {
1665            type Output = <Sum<$t> as Add<$a>>::Output;
1666
1667            fn add(self, rhs: $a) -> Self::Output {
1668                <Sum<$t> as Add<$a>>::add(self.clone(), rhs)
1669            }
1670        }
1671    };
1672}
1673macro_rules! impl_prod {
1674    ($t:ident, $a:ty, $b:ty) => {
1675        impl<$t: Field + 'static> Mul<$b> for $a {
1676            type Output = Product<$t>;
1677
1678            fn mul(self, rhs: $b) -> Self::Output {
1679                match (self.get_cloned_terms(), rhs.get_cloned_terms()) {
1680                    (Some(terms_a), Some(terms_b)) => Product([terms_a, terms_b].concat()),
1681                    (None, Some(terms)) => {
1682                        let mut terms = terms;
1683                        terms.insert(0, Box::new(self));
1684                        Product(terms)
1685                    }
1686                    (Some(terms), None) => {
1687                        let mut terms = terms;
1688                        terms.push(Box::new(rhs));
1689                        Product(terms)
1690                    }
1691                    (None, None) => Product(vec![Box::new(self), Box::new(rhs)]),
1692                }
1693            }
1694        }
1695
1696        impl<$t: Field + 'static> Mul<&$b> for &$a {
1697            type Output = <$a as Mul<$b>>::Output;
1698
1699            fn mul(self, rhs: &$b) -> Self::Output {
1700                <$a as Mul<$b>>::mul(self.clone(), rhs.clone())
1701            }
1702        }
1703
1704        impl<$t: Field + 'static> Mul<&$b> for $a {
1705            type Output = <$a as Mul<$b>>::Output;
1706
1707            fn mul(self, rhs: &$b) -> Self::Output {
1708                <$a as Mul<$b>>::mul(self, rhs.clone())
1709            }
1710        }
1711
1712        impl<$t: Field + 'static> Mul<$b> for &$a {
1713            type Output = <$a as Mul<$b>>::Output;
1714
1715            fn mul(self, rhs: $b) -> Self::Output {
1716                <$a as Mul<$b>>::mul(self.clone(), rhs)
1717            }
1718        }
1719
1720        impl<$t: Field + 'static> Mul<$a> for $b {
1721            type Output = Product<$t>;
1722
1723            fn mul(self, rhs: $a) -> Self::Output {
1724                match (self.get_cloned_terms(), rhs.get_cloned_terms()) {
1725                    (Some(terms_a), Some(terms_b)) => Product([terms_a, terms_b].concat()),
1726                    (None, Some(terms)) => {
1727                        let mut terms = terms;
1728                        terms.insert(0, Box::new(self));
1729                        Product(terms)
1730                    }
1731                    (Some(terms), None) => {
1732                        let mut terms = terms;
1733                        terms.push(Box::new(rhs));
1734                        Product(terms)
1735                    }
1736                    (None, None) => Product(vec![Box::new(self), Box::new(rhs)]),
1737                }
1738            }
1739        }
1740
1741        impl<$t: Field + 'static> Mul<&$a> for &$b {
1742            type Output = <$b as Mul<$a>>::Output;
1743
1744            fn mul(self, rhs: &$a) -> Self::Output {
1745                <$b as Mul<$a>>::mul(self.clone(), rhs.clone())
1746            }
1747        }
1748
1749        impl<$t: Field + 'static> Mul<&$a> for $b {
1750            type Output = <$b as Mul<$a>>::Output;
1751
1752            fn mul(self, rhs: &$a) -> Self::Output {
1753                <$b as Mul<$a>>::mul(self, rhs.clone())
1754            }
1755        }
1756
1757        impl<$t: Field + 'static> Mul<$a> for &$b {
1758            type Output = <$b as Mul<$a>>::Output;
1759
1760            fn mul(self, rhs: $a) -> Self::Output {
1761                <$b as Mul<$a>>::mul(self.clone(), rhs)
1762            }
1763        }
1764    };
1765    ($t:ident, $a:ty) => {
1766        impl<$t: Field + 'static> Mul<$a> for $a {
1767            type Output = Product<$t>;
1768
1769            fn mul(self, rhs: $a) -> Self::Output {
1770                match (self.get_cloned_terms(), rhs.get_cloned_terms()) {
1771                    (Some(terms_a), Some(terms_b)) => Product([terms_a, terms_b].concat()),
1772                    (None, Some(terms)) => {
1773                        let mut terms = terms;
1774                        terms.insert(0, Box::new(self));
1775                        Product(terms)
1776                    }
1777                    (Some(terms), None) => {
1778                        let mut terms = terms;
1779                        terms.push(Box::new(rhs));
1780                        Product(terms)
1781                    }
1782                    (None, None) => Product(vec![Box::new(self), Box::new(rhs)]),
1783                }
1784            }
1785        }
1786
1787        impl<$t: Field + 'static> Mul<&$a> for &$a {
1788            type Output = <$a as Mul<$a>>::Output;
1789
1790            fn mul(self, rhs: &$a) -> Self::Output {
1791                <$a as Mul<$a>>::mul(self.clone(), rhs.clone())
1792            }
1793        }
1794
1795        impl<$t: Field + 'static> Mul<&$a> for $a {
1796            type Output = <$a as Mul<$a>>::Output;
1797
1798            fn mul(self, rhs: &$a) -> Self::Output {
1799                <$a as Mul<$a>>::mul(self, rhs.clone())
1800            }
1801        }
1802
1803        impl<$t: Field + 'static> Mul<$a> for &$a {
1804            type Output = <$a as Mul<$a>>::Output;
1805
1806            fn mul(self, rhs: $a) -> Self::Output {
1807                <$a as Mul<$a>>::mul(self.clone(), rhs)
1808            }
1809        }
1810    };
1811}
1812macro_rules! impl_box_prod {
1813    ($t:ident, $a:ty) => {
1814        impl<$t: Field + 'static> Mul<Box<dyn AmpLike<$t>>> for $a {
1815            type Output = Product<$t>;
1816            fn mul(self, rhs: Box<dyn AmpLike<$t>>) -> Self::Output {
1817                match (self.get_cloned_terms(), rhs.get_cloned_terms()) {
1818                    (Some(terms_a), Some(terms_b)) => Product([terms_a, terms_b].concat()),
1819                    (None, Some(terms)) => {
1820                        let mut terms = terms;
1821                        terms.insert(0, Box::new(self));
1822                        Product(terms)
1823                    }
1824                    (Some(terms), None) => {
1825                        let mut terms = terms;
1826                        terms.push(Box::new(self));
1827                        Product(terms)
1828                    }
1829                    (None, None) => Product(vec![Box::new(self), rhs]),
1830                }
1831            }
1832        }
1833        impl<$t: Field + 'static> Mul<$a> for Box<dyn AmpLike<$t>> {
1834            type Output = Product<$t>;
1835            fn mul(self, rhs: $a) -> Self::Output {
1836                match (self.get_cloned_terms(), rhs.get_cloned_terms()) {
1837                    (Some(terms_a), Some(terms_b)) => Product([terms_a, terms_b].concat()),
1838                    (None, Some(terms)) => {
1839                        let mut terms = terms;
1840                        terms.insert(0, self);
1841                        Product(terms)
1842                    }
1843                    (Some(terms), None) => {
1844                        let mut terms = terms;
1845                        terms.push(self);
1846                        Product(terms)
1847                    }
1848                    (None, None) => Product(vec![self, Box::new(rhs)]),
1849                }
1850            }
1851        }
1852    };
1853}
1854macro_rules! impl_box_sum {
1855    ($t:ident, $a:ty) => {
1856        impl<$t: Field + 'static> Add<Box<dyn AmpLike<$t>>> for $a {
1857            type Output = Sum<$t>;
1858            fn add(self, rhs: Box<dyn AmpLike<$t>>) -> Self::Output {
1859                match (self.get_cloned_terms(), rhs.get_cloned_terms()) {
1860                    (Some(terms_a), Some(terms_b)) => Sum([terms_a, terms_b].concat()),
1861                    (None, Some(terms)) => {
1862                        let mut terms = terms;
1863                        terms.insert(0, Box::new(self));
1864                        Sum(terms)
1865                    }
1866                    (Some(terms), None) => {
1867                        let mut terms = terms;
1868                        terms.push(Box::new(self));
1869                        Sum(terms)
1870                    }
1871                    (None, None) => Sum(vec![Box::new(self), rhs]),
1872                }
1873            }
1874        }
1875        impl<$t: Field + 'static> Add<$a> for Box<dyn AmpLike<$t>> {
1876            type Output = Sum<$t>;
1877            fn add(self, rhs: $a) -> Self::Output {
1878                match (self.get_cloned_terms(), rhs.get_cloned_terms()) {
1879                    (Some(terms_a), Some(terms_b)) => Sum([terms_a, terms_b].concat()),
1880                    (None, Some(terms)) => {
1881                        let mut terms = terms;
1882                        terms.insert(0, self);
1883                        Sum(terms)
1884                    }
1885                    (Some(terms), None) => {
1886                        let mut terms = terms;
1887                        terms.push(self);
1888                        Sum(terms)
1889                    }
1890                    (None, None) => Sum(vec![self, Box::new(rhs)]),
1891                }
1892            }
1893        }
1894    };
1895}
1896macro_rules! impl_dist {
1897    ($t:ident, $a:ty) => {
1898        impl<$t: Field + 'static> Mul<Sum<$t>> for $a {
1899            type Output = Sum<$t>;
1900
1901            fn mul(self, rhs: Sum<$t>) -> Self::Output {
1902                let mut terms = vec![];
1903                for term in rhs.0 {
1904                    terms.push(Box::new(self.clone() * term) as Box<dyn AmpLike<$t>>);
1905                }
1906                Sum(terms)
1907            }
1908        }
1909
1910        impl<$t: Field + 'static> Mul<$a> for Sum<$t> {
1911            type Output = Sum<$t>;
1912
1913            fn mul(self, rhs: $a) -> Self::Output {
1914                let mut terms = vec![];
1915                for term in self.0 {
1916                    terms.push(Box::new(term * rhs.clone()) as Box<dyn AmpLike<$t>>);
1917                }
1918                Sum(terms)
1919            }
1920        }
1921
1922        impl<$t: Field + 'static> Mul<&$a> for &Sum<$t> {
1923            type Output = <Sum<$t> as Mul<$a>>::Output;
1924
1925            fn mul(self, rhs: &$a) -> Self::Output {
1926                <Sum<$t> as Mul<$a>>::mul(self.clone(), rhs.clone())
1927            }
1928        }
1929
1930        impl<$t: Field + 'static> Mul<&$a> for Sum<$t> {
1931            type Output = <Sum<$t> as Mul<$a>>::Output;
1932
1933            fn mul(self, rhs: &$a) -> Self::Output {
1934                <Sum<$t> as Mul<$a>>::mul(self, rhs.clone())
1935            }
1936        }
1937
1938        impl<$t: Field + 'static> Mul<$a> for &Sum<$t> {
1939            type Output = <Sum<$t> as Mul<$a>>::Output;
1940
1941            fn mul(self, rhs: $a) -> Self::Output {
1942                <Sum<$t> as Mul<$a>>::mul(self.clone(), rhs)
1943            }
1944        }
1945
1946        impl<$t: Field + 'static> Mul<&Sum<$t>> for &$a {
1947            type Output = <$a as Mul<Sum<$t>>>::Output;
1948
1949            fn mul(self, rhs: &Sum<$t>) -> Self::Output {
1950                <$a as Mul<Sum<$t>>>::mul(self.clone(), rhs.clone())
1951            }
1952        }
1953
1954        impl<$t: Field + 'static> Mul<&Sum<$t>> for $a {
1955            type Output = <$a as Mul<Sum<$t>>>::Output;
1956
1957            fn mul(self, rhs: &Sum<$t>) -> Self::Output {
1958                <$a as Mul<Sum<$t>>>::mul(self, rhs.clone())
1959            }
1960        }
1961
1962        impl<$t: Field + 'static> Mul<Sum<$t>> for &$a {
1963            type Output = <$a as Mul<Sum<$t>>>::Output;
1964
1965            fn mul(self, rhs: Sum<$t>) -> Self::Output {
1966                <$a as Mul<Sum<$t>>>::mul(self.clone(), rhs)
1967            }
1968        }
1969    };
1970}
1971
1972impl_sum!(F, Amplitude<F>);
1973impl_box_sum!(F, Amplitude<F>);
1974impl_sum!(F, Real<F>);
1975impl_box_sum!(F, Real<F>);
1976impl_sum!(F, Imag<F>);
1977impl_box_sum!(F, Imag<F>);
1978impl_sum!(F, Product<F>);
1979impl_box_sum!(F, Product<F>);
1980impl_box_sum!(F, Sum<F>);
1981
1982impl_sum!(F, Amplitude<F>, Real<F>);
1983impl_sum!(F, Amplitude<F>, Imag<F>);
1984impl_sum!(F, Amplitude<F>, Product<F>);
1985impl_sum!(F, Real<F>, Imag<F>);
1986impl_sum!(F, Real<F>, Product<F>);
1987impl_sum!(F, Imag<F>, Product<F>);
1988
1989impl_appending_sum!(F, Amplitude<F>);
1990impl_appending_sum!(F, Real<F>);
1991impl_appending_sum!(F, Imag<F>);
1992impl_appending_sum!(F, Product<F>);
1993
1994impl_prod!(F, Amplitude<F>);
1995impl_box_prod!(F, Amplitude<F>);
1996impl_prod!(F, Real<F>);
1997impl_box_prod!(F, Real<F>);
1998impl_prod!(F, Imag<F>);
1999impl_box_prod!(F, Imag<F>);
2000impl_prod!(F, Product<F>);
2001impl_box_prod!(F, Product<F>);
2002
2003impl_prod!(F, Amplitude<F>, Real<F>);
2004impl_prod!(F, Amplitude<F>, Imag<F>);
2005impl_prod!(F, Amplitude<F>, Product<F>);
2006impl_prod!(F, Real<F>, Imag<F>);
2007impl_prod!(F, Real<F>, Product<F>);
2008impl_prod!(F, Imag<F>, Product<F>);
2009
2010impl_dist!(F, Amplitude<F>);
2011impl_dist!(F, Real<F>);
2012impl_dist!(F, Imag<F>);
2013impl_dist!(F, Product<F>);
2014
2015impl<F: Field> Add<Self> for Sum<F> {
2016    type Output = Self;
2017
2018    fn add(self, rhs: Self) -> Self::Output {
2019        Self([self.0, rhs.0].concat())
2020    }
2021}
2022
2023impl<F: Field> Add<&Sum<F>> for &Sum<F> {
2024    type Output = <Sum<F> as Add<Sum<F>>>::Output;
2025
2026    fn add(self, rhs: &Sum<F>) -> Self::Output {
2027        <Sum<F> as Add<Sum<F>>>::add(self.clone(), rhs.clone())
2028    }
2029}
2030
2031impl<F: Field> Add<&Self> for Sum<F> {
2032    type Output = <Self as Add<Self>>::Output;
2033
2034    fn add(self, rhs: &Self) -> Self::Output {
2035        <Self as Add<Self>>::add(self, rhs.clone())
2036    }
2037}
2038
2039impl<F: Field> Add<Sum<F>> for &Sum<F> {
2040    type Output = <Sum<F> as Add<Sum<F>>>::Output;
2041
2042    fn add(self, rhs: Sum<F>) -> Self::Output {
2043        <Sum<F> as Add<Sum<F>>>::add(self.clone(), rhs)
2044    }
2045}