rustitude_core/
manager.rs

1//! This module contains methods to link [`Model`]s with [`Dataset`]s via a [`Manager::evaluate`]
2//! method. This module also holds a [`ExtendedLogLikelihood`] struct which holds two [`Manager`]s
3//! and, as the name suggests, calculates an extended log-likelihood using a very basic method over
4//! data and (accepted) Monte-Carlo.
5
6use std::fmt::{Debug, Display};
7
8use ganesh::prelude::{DVector, Function};
9use rayon::prelude::*;
10
11use crate::{
12    convert,
13    errors::RustitudeError,
14    prelude::{Amplitude, Dataset, Event, Model, Parameter},
15    Field,
16};
17
18/// The [`Manager`] struct links a [`Model`] to a [`Dataset`] and provides methods to manipulate
19/// the [`Model`] and evaluate it over the [`Dataset`].
20#[derive(Clone)]
21pub struct Manager<F: Field + 'static> {
22    /// The associated [`Model`].
23    pub model: Model<F>,
24    /// The associated [`Dataset`].
25    pub dataset: Dataset<F>,
26}
27impl<F: Field> Debug for Manager<F> {
28    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29        write!(f, "Manager [ ")?;
30        write!(f, "{:?} ", self.model)?;
31        write!(f, "]")
32    }
33}
34impl<F: Field> Display for Manager<F> {
35    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36        writeln!(f, "{}", self.model)
37    }
38}
39impl<F: Field> Manager<F> {
40    /// Generates a new [`Manager`] from a [`Model`] and [`Dataset`].
41    ///
42    /// # Errors
43    ///
44    /// This method will return a [`RustitudeError`] if the precaluclation phase of the [`Model`]
45    /// fails for any events in the [`Dataset`]. See [`Model::load`] for more information.
46    pub fn new(model: &Model<F>, dataset: &Dataset<F>) -> Result<Self, RustitudeError> {
47        let mut model = model.deep_clone();
48        model.load(dataset)?;
49        Ok(Self {
50            model: model.clone(),
51            dataset: dataset.clone(),
52        })
53    }
54
55    /// Evaluate the [`Model`] over the [`Dataset`] with the given free parameters.
56    ///
57    /// # Errors
58    ///
59    /// This method will return a [`RustitudeError`] if the amplitude calculation fails. See
60    /// [`Model::compute`] for more information.
61    pub fn evaluate(&self, parameters: &[F]) -> Result<Vec<F>, RustitudeError> {
62        let pars: Vec<F> = self
63            .model
64            .parameters
65            .iter()
66            .map(|p| p.index.map_or_else(|| p.initial, |i| parameters[i]))
67            .collect();
68        let amplitudes = self.model.amplitudes.read();
69        self.dataset
70            .events
71            .iter()
72            .map(|event: &Event<F>| self.model.compute(&amplitudes, &pars, event))
73            .collect()
74    }
75
76    /// Evaluate the [`Model`] over the [`Dataset`] with the given free parameters.
77    ///
78    /// This method allows the user to supply a list of indices and will only evaluate events at
79    /// those indices. This can be used to evaluate only a subset of events or to resample events
80    /// with replacement, such as in a bootstrap.
81    ///
82    /// # Errors
83    ///
84    /// This method will return a [`RustitudeError`] if the amplitude calculation fails. See
85    /// [`Model::compute`] for more information.
86    pub fn evaluate_indexed(
87        &self,
88        parameters: &[F],
89        indices: &[usize],
90    ) -> Result<Vec<F>, RustitudeError> {
91        if self.model.contains_python_amplitudes {
92            return Err(RustitudeError::PythonError(
93                "Python amplitudes cannot be evaluated with Rust parallelism due to the GIL!"
94                    .to_string(),
95            ));
96        }
97        let pars: Vec<F> = self
98            .model
99            .parameters
100            .iter()
101            .map(|p| p.index.map_or_else(|| p.initial, |i| parameters[i]))
102            .collect();
103        let amplitudes = self.model.amplitudes.read();
104        indices
105            .iter()
106            .map(|index| {
107                self.model
108                    .compute(&amplitudes, &pars, &self.dataset.events[*index])
109            })
110            .collect()
111    }
112
113    /// Evaluate the [`Model`] over the [`Dataset`] with the given free parameters.
114    ///
115    /// This version uses a parallel loop over events.
116    ///
117    /// # Errors
118    ///
119    /// This method will return a [`RustitudeError`] if the amplitude calculation fails. See
120    /// [`Model::compute`] for more information.
121    pub fn par_evaluate(&self, parameters: &[F]) -> Result<Vec<F>, RustitudeError> {
122        if self.model.contains_python_amplitudes {
123            return Err(RustitudeError::PythonError(
124                "Python amplitudes cannot be evaluated with Rust parallelism due to the GIL!"
125                    .to_string(),
126            ));
127        }
128        let mut output = Vec::with_capacity(self.dataset.len());
129        let pars: Vec<F> = self
130            .model
131            .parameters
132            .iter()
133            .map(|p| p.index.map_or_else(|| p.initial, |i| parameters[i]))
134            .collect();
135        let amplitudes = self.model.amplitudes.read();
136        self.dataset
137            .events
138            .par_iter()
139            .map(|event| self.model.compute(&amplitudes, &pars, event))
140            .collect_into_vec(&mut output);
141        output.into_iter().collect()
142    }
143
144    /// Evaluate the [`Model`] over the [`Dataset`] with the given free parameters.
145    ///
146    /// This method allows the user to supply a list of indices and will only evaluate events at
147    /// those indices. This can be used to evaluate only a subset of events or to resample events
148    /// with replacement, such as in a bootstrap.
149    ///
150    /// This version uses a parallel loop over events.
151    ///
152    /// # Errors
153    ///
154    /// This method will return a [`RustitudeError`] if the amplitude calculation fails. See
155    /// [`Model::compute`] for more information.
156    pub fn par_evaluate_indexed(
157        &self,
158        parameters: &[F],
159        indices: &[usize],
160    ) -> Result<Vec<F>, RustitudeError> {
161        if self.model.contains_python_amplitudes {
162            return Err(RustitudeError::PythonError(
163                "Python amplitudes cannot be evaluated with Rust parallelism due to the GIL!"
164                    .to_string(),
165            ));
166        }
167        let mut output = Vec::with_capacity(indices.len());
168        let pars: Vec<F> = self
169            .model
170            .parameters
171            .iter()
172            .map(|p| p.index.map_or_else(|| p.initial, |i| parameters[i]))
173            .collect();
174        // indices
175        //     .par_iter()
176        //     .map(|index| self.model.compute(&pars, &self.dataset.events[*index]))
177        //     .collect_into_vec(&mut output);
178        let amplitudes = self.model.amplitudes.read();
179        let view: Vec<&Event<F>> = indices
180            .par_iter()
181            .map(|&index| &self.dataset.events[index])
182            .collect();
183        view.par_iter()
184            .map(|&event| self.model.compute(&amplitudes, &pars, event))
185            .collect_into_vec(&mut output);
186        output.into_iter().collect()
187    }
188
189    /// Get a copy of an [`Amplitude`] in the [`Model`] by name.
190    ///
191    /// # Errors
192    ///
193    /// This method will return a [`RustitudeError`] if there is no amplitude found with the given
194    /// name in the parent [`Model`]. See [`Model::get_amplitude`] for more information.
195    pub fn get_amplitude(&self, amplitude_name: &str) -> Result<Amplitude<F>, RustitudeError> {
196        self.model.get_amplitude(amplitude_name)
197    }
198
199    /// Get a copy of a [`Parameter`] in a [`Model`] by name and the name of the parent
200    /// [`Amplitude`].
201    ///
202    /// # Errors
203    ///
204    /// This method will return a [`RustitudeError`] if there is no parameter found with the given
205    /// name in the parent [`Model`]. It will also first check if the given amplitude exists, and
206    /// this method can also fail in the same way (see [`Model::get_amplitude`] and
207    /// [`Model::get_parameter`]).
208    pub fn get_parameter(
209        &self,
210        amplitude_name: &str,
211        parameter_name: &str,
212    ) -> Result<Parameter<F>, RustitudeError> {
213        self.model.get_parameter(amplitude_name, parameter_name)
214    }
215
216    /// Print the free parameters in the [`Model`]. See [`Model::print_parameters`] for more
217    /// information.
218    pub fn print_parameters(&self) {
219        self.model.print_parameters()
220    }
221
222    /// Returns a [`Vec<Parameter<F>>`] containing the free parameters in the [`Model`].
223    ///
224    /// See [`Model::free_parameters`] for more information.
225    pub fn free_parameters(&self) -> Vec<Parameter<F>> {
226        self.model.free_parameters()
227    }
228
229    /// Returns a [`Vec<Parameter<F>>`] containing the fixed parameters in the [`Model`].
230    ///
231    /// See [`Model::fixed_parameters`] for more information.
232    pub fn fixed_parameters(&self) -> Vec<Parameter<F>> {
233        self.model.fixed_parameters()
234    }
235
236    /// Constrain two parameters by name, reducing the number of free parameters by one.
237    ///
238    /// # Errors
239    ///
240    /// This method will fail if any of the given amplitude or parameter names don't correspond to
241    /// a valid amplitude-parameter pair. See [`Model::constrain`] for more information.
242    pub fn constrain(
243        &mut self,
244        amplitude_1: &str,
245        parameter_1: &str,
246        amplitude_2: &str,
247        parameter_2: &str,
248    ) -> Result<(), RustitudeError> {
249        self.model
250            .constrain(amplitude_1, parameter_1, amplitude_2, parameter_2)
251    }
252
253    /// Fix a parameter by name to the given value.
254    ///
255    /// # Errors
256    ///
257    /// This method will fail if the given amplitude-parameter pair does not exist. See
258    /// [`Model::fix`] for more information.
259    pub fn fix(
260        &mut self,
261        amplitude: &str,
262        parameter: &str,
263        value: F,
264    ) -> Result<(), RustitudeError> {
265        self.model.fix(amplitude, parameter, value)
266    }
267
268    /// Free a fixed parameter by name.
269    ///
270    /// # Errors
271    ///
272    /// This method will fail if the given amplitude-parameter pair does not exist. See
273    /// [`Model::free`] for more information.
274    pub fn free(&mut self, amplitude: &str, parameter: &str) -> Result<(), RustitudeError> {
275        self.model.free(amplitude, parameter)
276    }
277
278    /// Set the bounds of a parameter by name.
279    ///
280    /// # Errors
281    ///
282    /// This method will fail if the given amplitude-parameter pair does not exist. See
283    /// [`Model::set_bounds`] for more information.
284    pub fn set_bounds(
285        &mut self,
286        amplitude: &str,
287        parameter: &str,
288        bounds: (F, F),
289    ) -> Result<(), RustitudeError> {
290        self.model.set_bounds(amplitude, parameter, bounds)
291    }
292
293    /// Set the initial value of a parameter by name.
294    ///
295    /// # Errors
296    ///
297    /// This method will fail if the given amplitude-parameter pair does not exist. See
298    /// [`Model::set_initial`] for more information.
299    pub fn set_initial(
300        &mut self,
301        amplitude: &str,
302        parameter: &str,
303        initial: F,
304    ) -> Result<(), RustitudeError> {
305        self.model.set_initial(amplitude, parameter, initial)
306    }
307
308    /// Get a list of bounds for all free parameters in the [`Model`]. See
309    /// [`Model::get_bounds`] for more information.
310    pub fn get_bounds(&self) -> Vec<(F, F)> {
311        self.model.get_bounds()
312    }
313
314    /// Get a list of initial values for all free parameters in the [`Model`]. See
315    /// [`Model::get_initial`] for more information.
316    pub fn get_initial(&self) -> Vec<F> {
317        self.model.get_initial()
318    }
319
320    /// Get the number of free parameters in the [`Model`] See [`Model::get_n_free`] for
321    /// more information.
322    pub fn get_n_free(&self) -> usize {
323        self.model.get_n_free()
324    }
325
326    /// Activate an [`Amplitude`] by name. See [`Model::activate`] for more information.
327    ///
328    /// # Errors
329    ///
330    /// This function will return a [`RustitudeError::AmplitudeNotFoundError`] if the given
331    /// amplitude is not present in the [`Model`].
332    pub fn activate(&mut self, amplitude: &str) -> Result<(), RustitudeError> {
333        self.model.activate(amplitude)
334    }
335    /// Activate all [`Amplitude`]s by name. See [`Model::activate_all`] for more information.
336    pub fn activate_all(&mut self) {
337        self.model.activate_all()
338    }
339    /// Activate only the specified [`Amplitude`]s while deactivating the rest. See
340    /// [`Model::isolate`] for more information.
341    ///
342    /// # Errors
343    ///
344    /// This function will return a [`RustitudeError::AmplitudeNotFoundError`] if a given
345    /// amplitude is not present in the [`Model`].
346    pub fn isolate(&mut self, amplitudes: Vec<&str>) -> Result<(), RustitudeError> {
347        self.model.isolate(amplitudes)
348    }
349    /// Deactivate an [`Amplitude`] by name. See [`Model::deactivate`] for more information.
350    ///
351    /// # Errors
352    ///
353    /// This function will return a [`RustitudeError::AmplitudeNotFoundError`] if the given
354    /// amplitude is not present in the [`Model`].
355    pub fn deactivate(&mut self, amplitude: &str) -> Result<(), RustitudeError> {
356        self.model.deactivate(amplitude)
357    }
358    /// Deactivate all [`Amplitude`]s by name. See [`Model::deactivate_all`] for more information.
359    pub fn deactivate_all(&mut self) {
360        self.model.deactivate_all()
361    }
362}
363
364/// The [`ExtendedLogLikelihood`] stores two [`Manager`]s, one for data and one for a Monte-Carlo
365/// dataset used for acceptance correction. These should probably have the same [`Manager`] in
366/// practice, but this is left to the user.
367#[derive(Clone)]
368pub struct ExtendedLogLikelihood<F: Field + 'static> {
369    /// [`Manager`] for data
370    pub data_manager: Manager<F>,
371    /// [`Manager`] for Monte-Carlo
372    pub mc_manager: Manager<F>,
373}
374impl<F: Field> Debug for ExtendedLogLikelihood<F> {
375    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
376        write!(f, "ExtendedLogLikelihood [ ")?;
377        write!(f, "{:?} ", self.data_manager)?;
378        write!(f, "{:?} ", self.mc_manager)?;
379        write!(f, "]")
380    }
381}
382impl<F: Field> Display for ExtendedLogLikelihood<F> {
383    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
384        writeln!(f, "{}", self.data_manager)?;
385        writeln!(f, "{}", self.mc_manager)
386    }
387}
388impl<F: Field> ExtendedLogLikelihood<F> {
389    /// Create a new [`ExtendedLogLikelihood`] from a data and Monte-Carlo [`Manager`]s.
390    pub const fn new(data_manager: Manager<F>, mc_manager: Manager<F>) -> Self {
391        Self {
392            data_manager,
393            mc_manager,
394        }
395    }
396
397    /// Evaluate the [`ExtendedLogLikelihood`] over the [`Dataset`] with the given free parameters.
398    ///
399    /// # Errors
400    ///
401    /// This method will return a [`RustitudeError`] if the amplitude calculation fails. See
402    /// [`Model::compute`] for more information.
403    #[allow(clippy::suboptimal_flops)]
404    pub fn evaluate(&self, parameters: &[F]) -> Result<F, RustitudeError> {
405        let data_res = self.data_manager.evaluate(parameters)?;
406        let data_weights = self.data_manager.dataset.weights();
407        let n_data = data_weights.iter().copied().sum::<F>();
408        let mc_norm_int = self.mc_manager.evaluate(parameters)?;
409        let mc_weights = self.mc_manager.dataset.weights();
410        let n_mc = mc_weights.iter().copied().sum::<F>();
411        let ln_l = (data_res
412            .iter()
413            .zip(data_weights)
414            .map(|(l, w)| w * F::ln(*l))
415            .sum::<F>())
416            - (n_data / n_mc)
417                * (mc_norm_int
418                    .iter()
419                    .zip(mc_weights)
420                    .map(|(l, w)| w * *l)
421                    .sum::<F>());
422        Ok(convert!(-2, F) * ln_l)
423    }
424
425    /// Evaluate the [`ExtendedLogLikelihood`] over the [`Dataset`] with the given free parameters.
426    ///
427    /// This method allows the user to supply two lists of indices and will only evaluate events at
428    /// those indices. This can be used to evaluate only a subset of events or to resample events
429    /// with replacement, such as in a bootstrap.
430    ///
431    /// # Errors
432    ///
433    /// This method will return a [`RustitudeError`] if the amplitude calculation fails. See
434    /// [`Model::compute`] for more information.
435    #[allow(clippy::suboptimal_flops)]
436    pub fn evaluate_indexed(
437        &self,
438        parameters: &[F],
439        indices_data: &[usize],
440        indices_mc: &[usize],
441    ) -> Result<F, RustitudeError> {
442        let data_res = self
443            .data_manager
444            .evaluate_indexed(parameters, indices_data)?;
445        let data_weights = self.data_manager.dataset.weights_indexed(indices_data);
446        let n_data = data_weights.iter().copied().sum::<F>();
447        let mc_norm_int = self.mc_manager.evaluate_indexed(parameters, indices_mc)?;
448        let mc_weights = self.mc_manager.dataset.weights_indexed(indices_mc);
449        let n_mc = mc_weights.iter().copied().sum::<F>();
450        let ln_l = (data_res
451            .iter()
452            .zip(data_weights)
453            .map(|(l, w)| w * F::ln(*l))
454            .sum::<F>())
455            - (n_data / n_mc)
456                * (mc_norm_int
457                    .iter()
458                    .zip(mc_weights)
459                    .map(|(l, w)| w * *l)
460                    .sum::<F>());
461        Ok(convert!(-2, F) * ln_l)
462    }
463
464    /// Evaluate the [`ExtendedLogLikelihood`] over the [`Dataset`] with the given free parameters.
465    ///
466    /// This method also allows the user to input a maximum number of threads to use in the
467    /// calculation, as it uses a parallel loop over events.
468    ///
469    /// # Errors
470    ///
471    /// This method will return a [`RustitudeError`] if the amplitude calculation fails. See
472    /// [`Model::compute`] for more information.
473    #[allow(clippy::suboptimal_flops)]
474    pub fn par_evaluate(&self, parameters: &[F]) -> Result<F, RustitudeError> {
475        if self.data_manager.model.contains_python_amplitudes
476            || self.mc_manager.model.contains_python_amplitudes
477        {
478            return Err(RustitudeError::PythonError(
479                "Python amplitudes cannot be evaluated with Rust parallelism due to the GIL!"
480                    .to_string(),
481            ));
482        }
483        let data_res = self.data_manager.par_evaluate(parameters)?;
484        let data_weights = self.data_manager.dataset.weights();
485        let n_data = data_weights.iter().copied().sum::<F>();
486        let mc_norm_int = self.mc_manager.par_evaluate(parameters)?;
487        let mc_weights = self.mc_manager.dataset.weights();
488        let n_mc = mc_weights.iter().copied().sum::<F>();
489        let ln_l = (data_res
490            .par_iter()
491            .zip(data_weights)
492            .map(|(l, w)| w * F::ln(*l))
493            .sum::<F>())
494            - (n_data / n_mc)
495                * (mc_norm_int
496                    .par_iter()
497                    .zip(mc_weights)
498                    .map(|(l, w)| w * *l)
499                    .sum::<F>());
500        Ok(convert!(-2, F) * ln_l)
501    }
502
503    /// Evaluate the [`ExtendedLogLikelihood`] over the [`Dataset`] with the given free parameters.
504    ///
505    /// This method allows the user to supply two lists of indices and will only evaluate events at
506    /// those indices. This can be used to evaluate only a subset of events or to resample events
507    /// with replacement, such as in a bootstrap.
508    ///
509    /// This method also allows the user to input a maximum number of threads to use in the
510    /// calculation, as it uses a parallel loop over events.
511    ///
512    /// # Errors
513    ///
514    /// This method will return a [`RustitudeError`] if the amplitude calculation fails. See
515    /// [`Model::compute`] for more information.
516    #[allow(clippy::suboptimal_flops)]
517    pub fn par_evaluate_indexed(
518        &self,
519        parameters: &[F],
520        indices_data: &[usize],
521        indices_mc: &[usize],
522    ) -> Result<F, RustitudeError> {
523        if self.data_manager.model.contains_python_amplitudes
524            || self.mc_manager.model.contains_python_amplitudes
525        {
526            return Err(RustitudeError::PythonError(
527                "Python amplitudes cannot be evaluated with Rust parallelism due to the GIL!"
528                    .to_string(),
529            ));
530        }
531        let data_res = self
532            .data_manager
533            .par_evaluate_indexed(parameters, indices_data)?;
534        let data_weights = self.data_manager.dataset.weights_indexed(indices_data);
535        let n_data = data_weights.iter().copied().sum::<F>();
536        let mc_norm_int = self
537            .mc_manager
538            .par_evaluate_indexed(parameters, indices_mc)?;
539        let mc_weights = self.mc_manager.dataset.weights_indexed(indices_mc);
540        let n_mc = mc_weights.iter().copied().sum::<F>();
541        let ln_l = (data_res
542            .par_iter()
543            .zip(data_weights)
544            .map(|(l, w)| w * F::ln(*l))
545            .sum::<F>())
546            - (n_data / n_mc)
547                * (mc_norm_int
548                    .par_iter()
549                    .zip(mc_weights)
550                    .map(|(l, w)| w * *l)
551                    .sum::<F>());
552        Ok(convert!(-2, F) * ln_l)
553    }
554
555    /// Evaluate the normalized intensity function over the given Monte-Carlo [`Dataset`] with the
556    /// given free parameters. This is intended to be used to plot a model over the dataset, usually
557    /// with the generated or accepted Monte-Carlo as the input.
558    ///
559    /// # Errors
560    ///
561    /// This method will return a [`RustitudeError`] if the amplitude calculation fails. See
562    /// [`Model::compute`] for more information.
563    #[allow(clippy::suboptimal_flops)]
564    pub fn intensity(
565        &self,
566        parameters: &[F],
567        dataset_mc: &Dataset<F>,
568    ) -> Result<Vec<F>, RustitudeError> {
569        let mc_manager = Manager::new(&self.data_manager.model, dataset_mc)?;
570        let data_len_weighted: F = self.data_manager.dataset.weights().iter().copied().sum();
571        let mc_len_weighted: F = dataset_mc.weights().iter().copied().sum();
572        mc_manager.evaluate(parameters).map(|r_vec| {
573            r_vec
574                .into_iter()
575                .zip(dataset_mc.events.iter())
576                .map(|(r, e)| r * data_len_weighted / mc_len_weighted * e.weight)
577                .collect()
578        })
579    }
580
581    /// Evaluate the normalized intensity function over the given Monte-Carlo [`Dataset`] with the
582    /// given free parameters. This is intended to be used to plot a model over the dataset, usually
583    /// with the generated or accepted Monte-Carlo as the input.
584    ///
585    /// This method allows the user to supply a list of indices and will only evaluate events at
586    /// those indices. This can be used to evaluate only a subset of events or to resample events
587    /// with replacement, such as in a bootstrap.
588    ///
589    /// # Errors
590    ///
591    /// This method will return a [`RustitudeError`] if the amplitude calculation fails. See
592    /// [`Model::compute`] for more information.
593    #[allow(clippy::suboptimal_flops)]
594    pub fn intensity_indexed(
595        &self,
596        parameters: &[F],
597        dataset_mc: &Dataset<F>,
598        indices_data: &[usize],
599        indices_mc: &[usize],
600    ) -> Result<Vec<F>, RustitudeError> {
601        let mc_manager = Manager::new(&self.data_manager.model, dataset_mc)?;
602        let data_len_weighted = self
603            .data_manager
604            .dataset
605            .weights_indexed(indices_data)
606            .iter()
607            .copied()
608            .sum::<F>();
609        let mc_len_weighted = dataset_mc
610            .weights_indexed(indices_mc)
611            .iter()
612            .copied()
613            .sum::<F>();
614        let view: Vec<&Event<F>> = indices_mc
615            .par_iter()
616            .map(|&index| &mc_manager.dataset.events[index])
617            .collect();
618        mc_manager
619            .evaluate_indexed(parameters, indices_mc)
620            .map(|r_vec| {
621                r_vec
622                    .into_iter()
623                    .zip(view.iter())
624                    .map(|(r, e)| r * data_len_weighted / mc_len_weighted * e.weight)
625                    .collect()
626            })
627    }
628    /// Evaluate the normalized intensity function over the given [`Dataset`] with the given
629    /// free parameters. This is intended to be used to plot a model over the dataset, usually
630    /// with the generated or accepted Monte-Carlo as the input.
631    ///
632    /// This method also allows the user to input a maximum number of threads to use in the
633    /// calculation, as it uses a parallel loop over events.
634    ///
635    /// # Errors
636    ///
637    /// This method will return a [`RustitudeError`] if the amplitude calculation fails. See
638    /// [`Model::compute`] for more information.
639    #[allow(clippy::suboptimal_flops)]
640    pub fn par_intensity(
641        &self,
642        parameters: &[F],
643        dataset_mc: &Dataset<F>,
644    ) -> Result<Vec<F>, RustitudeError> {
645        if self.data_manager.model.contains_python_amplitudes
646            || self.mc_manager.model.contains_python_amplitudes
647        {
648            return Err(RustitudeError::PythonError(
649                "Python amplitudes cannot be evaluated with Rust parallelism due to the GIL!"
650                    .to_string(),
651            ));
652        }
653        let mc_manager = Manager::new(&self.data_manager.model, dataset_mc)?;
654        let data_len_weighted: F = self.data_manager.dataset.weights().iter().copied().sum();
655        let mc_len_weighted: F = dataset_mc.weights().iter().copied().sum();
656        mc_manager.par_evaluate(parameters).map(|r_vec| {
657            r_vec
658                .into_iter()
659                .zip(dataset_mc.events.iter())
660                .map(|(r, e)| r * data_len_weighted / mc_len_weighted * e.weight)
661                .collect()
662        })
663    }
664
665    /// Evaluate the normalized intensity function over the given Monte-Carlo [`Dataset`] with the
666    /// given free parameters. This is intended to be used to plot a model over the dataset, usually
667    /// with the generated or accepted Monte-Carlo as the input.
668    ///
669    /// This method allows the user to supply a list of indices and will only evaluate events at
670    /// those indices. This can be used to evaluate only a subset of events or to resample events
671    /// with replacement, such as in a bootstrap.
672    ///
673    /// This method also allows the user to input a maximum number of threads to use in the
674    /// calculation, as it uses a parallel loop over events.
675    ///
676    /// # Errors
677    ///
678    /// This method will return a [`RustitudeError`] if the amplitude calculation fails. See
679    /// [`Model::compute`] for more information.
680    #[allow(clippy::suboptimal_flops)]
681    pub fn par_intensity_indexed(
682        &self,
683        parameters: &[F],
684        dataset_mc: &Dataset<F>,
685        indices_data: &[usize],
686        indices_mc: &[usize],
687    ) -> Result<Vec<F>, RustitudeError> {
688        let mc_manager = Manager::new(&self.data_manager.model, dataset_mc)?;
689        let data_len_weighted: F = self
690            .data_manager
691            .dataset
692            .weights_indexed(indices_data)
693            .iter()
694            .copied()
695            .sum();
696        let mc_len_weighted: F = dataset_mc.weights_indexed(indices_mc).iter().copied().sum();
697        let view: Vec<&Event<F>> = indices_mc
698            .par_iter()
699            .map(|&index| &mc_manager.dataset.events[index])
700            .collect();
701        mc_manager
702            .par_evaluate_indexed(parameters, indices_mc)
703            .map(|r_vec| {
704                r_vec
705                    .into_par_iter()
706                    .zip(view.par_iter())
707                    .map(|(r, e)| r * data_len_weighted / mc_len_weighted * e.weight)
708                    .collect()
709            })
710    }
711
712    /// Get a copy of an [`Amplitude`] in the [`Model`] by name.
713    ///
714    /// # Errors
715    ///
716    /// This method will return a [`RustitudeError`] if there is no amplitude found with the given
717    /// name in the parent [`Model`]. See [`Model::get_amplitude`] for more information.
718    pub fn get_amplitude(&self, amplitude_name: &str) -> Result<Amplitude<F>, RustitudeError> {
719        self.data_manager.get_amplitude(amplitude_name)
720    }
721
722    /// Get a copy of a [`Parameter`] in a [`Model`] by name and the name of the parent
723    /// [`Amplitude`].
724    ///
725    /// # Errors
726    ///
727    /// This method will return a [`RustitudeError`] if there is no parameter found with the given
728    /// name in the parent [`Model`]. It will also first check if the given amplitude exists, and
729    /// this method can also fail in the same way (see [`Model::get_amplitude`] and
730    /// [`Model::get_parameter`]).
731    pub fn get_parameter(
732        &self,
733        amplitude_name: &str,
734        parameter_name: &str,
735    ) -> Result<Parameter<F>, RustitudeError> {
736        self.data_manager
737            .get_parameter(amplitude_name, parameter_name)
738    }
739
740    /// Print the free parameters in the [`Model`]. See [`Model::print_parameters`] for more
741    /// information.
742    pub fn print_parameters(&self) {
743        self.data_manager.print_parameters()
744    }
745
746    /// Returns a [`Vec<Parameter<F>>`] containing the free parameters in the data [`Manager`].
747    ///
748    /// See [`Model::free_parameters`] for more information.
749    pub fn free_parameters(&self) -> Vec<Parameter<F>> {
750        self.data_manager.free_parameters()
751    }
752
753    /// Returns a [`Vec<Parameter<F>>`] containing the fixed parameters in the data [`Manager`].
754    ///
755    /// See [`Model::fixed_parameters`] for more information.
756    pub fn fixed_parameters(&self) -> Vec<Parameter<F>> {
757        self.data_manager.fixed_parameters()
758    }
759
760    /// Constrain two parameters by name, reducing the number of free parameters by one.
761    ///
762    /// # Errors
763    ///
764    /// This method will fail if any of the given amplitude or parameter names don't correspond to
765    /// a valid amplitude-parameter pair. See [`Model::constrain`] for more information.
766    pub fn constrain(
767        &mut self,
768        amplitude_1: &str,
769        parameter_1: &str,
770        amplitude_2: &str,
771        parameter_2: &str,
772    ) -> Result<(), RustitudeError> {
773        self.data_manager
774            .constrain(amplitude_1, parameter_1, amplitude_2, parameter_2)?;
775        self.mc_manager
776            .constrain(amplitude_1, parameter_1, amplitude_2, parameter_2)
777    }
778
779    /// Fix a parameter by name to the given value.
780    ///
781    /// # Errors
782    ///
783    /// This method will fail if the given amplitude-parameter pair does not exist. See
784    /// [`Model::fix`] for more information.
785    pub fn fix(
786        &mut self,
787        amplitude: &str,
788        parameter: &str,
789        value: F,
790    ) -> Result<(), RustitudeError> {
791        self.data_manager.fix(amplitude, parameter, value)?;
792        self.mc_manager.fix(amplitude, parameter, value)
793    }
794
795    /// Free a fixed parameter by name.
796    ///
797    /// # Errors
798    ///
799    /// This method will fail if the given amplitude-parameter pair does not exist. See
800    /// [`Model::free`] for more information.
801    pub fn free(&mut self, amplitude: &str, parameter: &str) -> Result<(), RustitudeError> {
802        self.data_manager.free(amplitude, parameter)?;
803        self.mc_manager.free(amplitude, parameter)
804    }
805
806    /// Set the bounds of a parameter by name.
807    ///
808    /// # Errors
809    ///
810    /// This method will fail if the given amplitude-parameter pair does not exist. See
811    /// [`Model::set_bounds`] for more information.
812    pub fn set_bounds(
813        &mut self,
814        amplitude: &str,
815        parameter: &str,
816        bounds: (F, F),
817    ) -> Result<(), RustitudeError> {
818        self.data_manager.set_bounds(amplitude, parameter, bounds)?;
819        self.mc_manager.set_bounds(amplitude, parameter, bounds)
820    }
821
822    /// Set the initial value of a parameter by name.
823    ///
824    /// # Errors
825    ///
826    /// This method will fail if the given amplitude-parameter pair does not exist. See
827    /// [`Model::set_initial`] for more information.
828    pub fn set_initial(
829        &mut self,
830        amplitude: &str,
831        parameter: &str,
832        initial: F,
833    ) -> Result<(), RustitudeError> {
834        self.data_manager
835            .set_initial(amplitude, parameter, initial)?;
836        self.mc_manager.set_initial(amplitude, parameter, initial)
837    }
838
839    /// Get a list of bounds for all free parameters in the [`Model`]. See
840    /// [`Model::get_bounds`] for more information.
841    pub fn get_bounds(&self) -> Vec<(F, F)> {
842        self.data_manager.get_bounds();
843        self.mc_manager.get_bounds()
844    }
845
846    /// Get a list of initial values for all free parameters in the [`Model`]. See
847    /// [`Model::get_initial`] for more information.
848    pub fn get_initial(&self) -> Vec<F> {
849        self.data_manager.get_initial();
850        self.mc_manager.get_initial()
851    }
852
853    /// Get the number of free parameters in the [`Model`] See [`Model::get_n_free`] for
854    /// more information.
855    pub fn get_n_free(&self) -> usize {
856        self.data_manager.get_n_free();
857        self.mc_manager.get_n_free()
858    }
859
860    /// Activate an [`Amplitude`] by name. See [`Model::activate`] for more information.
861    ///
862    /// # Errors
863    ///
864    /// This function will return a [`RustitudeError::AmplitudeNotFoundError`] if the given
865    /// amplitude is not present in the [`Model`].
866    pub fn activate(&mut self, amplitude: &str) -> Result<(), RustitudeError> {
867        self.data_manager.activate(amplitude)?;
868        self.mc_manager.activate(amplitude)
869    }
870    /// Activates all [`Amplitude`]s by name. See [`Model::activate_all`] for more information.
871    pub fn activate_all(&mut self) {
872        self.data_manager.activate_all();
873        self.mc_manager.activate_all()
874    }
875    /// Activate only the specified [`Amplitude`]s while deactivating the rest. See
876    /// [`Model::isolate`] for more information.
877    ///
878    /// # Errors
879    ///
880    /// This function will return a [`RustitudeError::AmplitudeNotFoundError`] if a given
881    /// amplitude is not present in the [`Model`].
882    pub fn isolate(&mut self, amplitudes: Vec<&str>) -> Result<(), RustitudeError> {
883        self.data_manager.isolate(amplitudes.clone())?;
884        self.mc_manager.isolate(amplitudes)
885    }
886    /// Deactivate an [`Amplitude`] by name. See [`Model::deactivate`] for more information.
887    ///
888    /// # Errors
889    ///
890    /// This function will return a [`RustitudeError::AmplitudeNotFoundError`] if the given
891    /// amplitude is not present in the [`Model`].
892    pub fn deactivate(&mut self, amplitude: &str) -> Result<(), RustitudeError> {
893        self.data_manager.deactivate(amplitude)?;
894        self.mc_manager.deactivate(amplitude)
895    }
896    /// Deactivates all [`Amplitude`]s by name. See [`Model::deactivate_all`] for more information.
897    pub fn deactivate_all(&mut self) {
898        self.data_manager.deactivate_all();
899        self.mc_manager.deactivate_all()
900    }
901}
902
903impl<F: Field + ganesh::core::Field> Function<F, (), RustitudeError> for ExtendedLogLikelihood<F> {
904    fn evaluate(&self, x: &DVector<F>, _args: Option<&()>) -> Result<F, RustitudeError> {
905        self.par_evaluate(x.as_slice())
906    }
907}