rustitude_core/manager.rs
1//! This module contains methods to link [`Model`]s with [`Dataset`]s via a [`Manager::evaluate`]
2//! method. This module also holds a [`ExtendedLogLikelihood`] struct which holds two [`Manager`]s
3//! and, as the name suggests, calculates an extended log-likelihood using a very basic method over
4//! data and (accepted) Monte-Carlo.
5
6use std::fmt::{Debug, Display};
7
8use ganesh::prelude::{DVector, Function};
9use rayon::prelude::*;
10
11use crate::{
12 convert,
13 errors::RustitudeError,
14 prelude::{Amplitude, Dataset, Event, Model, Parameter},
15 Field,
16};
17
18/// The [`Manager`] struct links a [`Model`] to a [`Dataset`] and provides methods to manipulate
19/// the [`Model`] and evaluate it over the [`Dataset`].
20#[derive(Clone)]
21pub struct Manager<F: Field + 'static> {
22 /// The associated [`Model`].
23 pub model: Model<F>,
24 /// The associated [`Dataset`].
25 pub dataset: Dataset<F>,
26}
27impl<F: Field> Debug for Manager<F> {
28 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29 write!(f, "Manager [ ")?;
30 write!(f, "{:?} ", self.model)?;
31 write!(f, "]")
32 }
33}
34impl<F: Field> Display for Manager<F> {
35 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36 writeln!(f, "{}", self.model)
37 }
38}
39impl<F: Field> Manager<F> {
40 /// Generates a new [`Manager`] from a [`Model`] and [`Dataset`].
41 ///
42 /// # Errors
43 ///
44 /// This method will return a [`RustitudeError`] if the precaluclation phase of the [`Model`]
45 /// fails for any events in the [`Dataset`]. See [`Model::load`] for more information.
46 pub fn new(model: &Model<F>, dataset: &Dataset<F>) -> Result<Self, RustitudeError> {
47 let mut model = model.deep_clone();
48 model.load(dataset)?;
49 Ok(Self {
50 model: model.clone(),
51 dataset: dataset.clone(),
52 })
53 }
54
55 /// Evaluate the [`Model`] over the [`Dataset`] with the given free parameters.
56 ///
57 /// # Errors
58 ///
59 /// This method will return a [`RustitudeError`] if the amplitude calculation fails. See
60 /// [`Model::compute`] for more information.
61 pub fn evaluate(&self, parameters: &[F]) -> Result<Vec<F>, RustitudeError> {
62 let pars: Vec<F> = self
63 .model
64 .parameters
65 .iter()
66 .map(|p| p.index.map_or_else(|| p.initial, |i| parameters[i]))
67 .collect();
68 let amplitudes = self.model.amplitudes.read();
69 self.dataset
70 .events
71 .iter()
72 .map(|event: &Event<F>| self.model.compute(&litudes, &pars, event))
73 .collect()
74 }
75
76 /// Evaluate the [`Model`] over the [`Dataset`] with the given free parameters.
77 ///
78 /// This method allows the user to supply a list of indices and will only evaluate events at
79 /// those indices. This can be used to evaluate only a subset of events or to resample events
80 /// with replacement, such as in a bootstrap.
81 ///
82 /// # Errors
83 ///
84 /// This method will return a [`RustitudeError`] if the amplitude calculation fails. See
85 /// [`Model::compute`] for more information.
86 pub fn evaluate_indexed(
87 &self,
88 parameters: &[F],
89 indices: &[usize],
90 ) -> Result<Vec<F>, RustitudeError> {
91 if self.model.contains_python_amplitudes {
92 return Err(RustitudeError::PythonError(
93 "Python amplitudes cannot be evaluated with Rust parallelism due to the GIL!"
94 .to_string(),
95 ));
96 }
97 let pars: Vec<F> = self
98 .model
99 .parameters
100 .iter()
101 .map(|p| p.index.map_or_else(|| p.initial, |i| parameters[i]))
102 .collect();
103 let amplitudes = self.model.amplitudes.read();
104 indices
105 .iter()
106 .map(|index| {
107 self.model
108 .compute(&litudes, &pars, &self.dataset.events[*index])
109 })
110 .collect()
111 }
112
113 /// Evaluate the [`Model`] over the [`Dataset`] with the given free parameters.
114 ///
115 /// This version uses a parallel loop over events.
116 ///
117 /// # Errors
118 ///
119 /// This method will return a [`RustitudeError`] if the amplitude calculation fails. See
120 /// [`Model::compute`] for more information.
121 pub fn par_evaluate(&self, parameters: &[F]) -> Result<Vec<F>, RustitudeError> {
122 if self.model.contains_python_amplitudes {
123 return Err(RustitudeError::PythonError(
124 "Python amplitudes cannot be evaluated with Rust parallelism due to the GIL!"
125 .to_string(),
126 ));
127 }
128 let mut output = Vec::with_capacity(self.dataset.len());
129 let pars: Vec<F> = self
130 .model
131 .parameters
132 .iter()
133 .map(|p| p.index.map_or_else(|| p.initial, |i| parameters[i]))
134 .collect();
135 let amplitudes = self.model.amplitudes.read();
136 self.dataset
137 .events
138 .par_iter()
139 .map(|event| self.model.compute(&litudes, &pars, event))
140 .collect_into_vec(&mut output);
141 output.into_iter().collect()
142 }
143
144 /// Evaluate the [`Model`] over the [`Dataset`] with the given free parameters.
145 ///
146 /// This method allows the user to supply a list of indices and will only evaluate events at
147 /// those indices. This can be used to evaluate only a subset of events or to resample events
148 /// with replacement, such as in a bootstrap.
149 ///
150 /// This version uses a parallel loop over events.
151 ///
152 /// # Errors
153 ///
154 /// This method will return a [`RustitudeError`] if the amplitude calculation fails. See
155 /// [`Model::compute`] for more information.
156 pub fn par_evaluate_indexed(
157 &self,
158 parameters: &[F],
159 indices: &[usize],
160 ) -> Result<Vec<F>, RustitudeError> {
161 if self.model.contains_python_amplitudes {
162 return Err(RustitudeError::PythonError(
163 "Python amplitudes cannot be evaluated with Rust parallelism due to the GIL!"
164 .to_string(),
165 ));
166 }
167 let mut output = Vec::with_capacity(indices.len());
168 let pars: Vec<F> = self
169 .model
170 .parameters
171 .iter()
172 .map(|p| p.index.map_or_else(|| p.initial, |i| parameters[i]))
173 .collect();
174 // indices
175 // .par_iter()
176 // .map(|index| self.model.compute(&pars, &self.dataset.events[*index]))
177 // .collect_into_vec(&mut output);
178 let amplitudes = self.model.amplitudes.read();
179 let view: Vec<&Event<F>> = indices
180 .par_iter()
181 .map(|&index| &self.dataset.events[index])
182 .collect();
183 view.par_iter()
184 .map(|&event| self.model.compute(&litudes, &pars, event))
185 .collect_into_vec(&mut output);
186 output.into_iter().collect()
187 }
188
189 /// Get a copy of an [`Amplitude`] in the [`Model`] by name.
190 ///
191 /// # Errors
192 ///
193 /// This method will return a [`RustitudeError`] if there is no amplitude found with the given
194 /// name in the parent [`Model`]. See [`Model::get_amplitude`] for more information.
195 pub fn get_amplitude(&self, amplitude_name: &str) -> Result<Amplitude<F>, RustitudeError> {
196 self.model.get_amplitude(amplitude_name)
197 }
198
199 /// Get a copy of a [`Parameter`] in a [`Model`] by name and the name of the parent
200 /// [`Amplitude`].
201 ///
202 /// # Errors
203 ///
204 /// This method will return a [`RustitudeError`] if there is no parameter found with the given
205 /// name in the parent [`Model`]. It will also first check if the given amplitude exists, and
206 /// this method can also fail in the same way (see [`Model::get_amplitude`] and
207 /// [`Model::get_parameter`]).
208 pub fn get_parameter(
209 &self,
210 amplitude_name: &str,
211 parameter_name: &str,
212 ) -> Result<Parameter<F>, RustitudeError> {
213 self.model.get_parameter(amplitude_name, parameter_name)
214 }
215
216 /// Print the free parameters in the [`Model`]. See [`Model::print_parameters`] for more
217 /// information.
218 pub fn print_parameters(&self) {
219 self.model.print_parameters()
220 }
221
222 /// Returns a [`Vec<Parameter<F>>`] containing the free parameters in the [`Model`].
223 ///
224 /// See [`Model::free_parameters`] for more information.
225 pub fn free_parameters(&self) -> Vec<Parameter<F>> {
226 self.model.free_parameters()
227 }
228
229 /// Returns a [`Vec<Parameter<F>>`] containing the fixed parameters in the [`Model`].
230 ///
231 /// See [`Model::fixed_parameters`] for more information.
232 pub fn fixed_parameters(&self) -> Vec<Parameter<F>> {
233 self.model.fixed_parameters()
234 }
235
236 /// Constrain two parameters by name, reducing the number of free parameters by one.
237 ///
238 /// # Errors
239 ///
240 /// This method will fail if any of the given amplitude or parameter names don't correspond to
241 /// a valid amplitude-parameter pair. See [`Model::constrain`] for more information.
242 pub fn constrain(
243 &mut self,
244 amplitude_1: &str,
245 parameter_1: &str,
246 amplitude_2: &str,
247 parameter_2: &str,
248 ) -> Result<(), RustitudeError> {
249 self.model
250 .constrain(amplitude_1, parameter_1, amplitude_2, parameter_2)
251 }
252
253 /// Fix a parameter by name to the given value.
254 ///
255 /// # Errors
256 ///
257 /// This method will fail if the given amplitude-parameter pair does not exist. See
258 /// [`Model::fix`] for more information.
259 pub fn fix(
260 &mut self,
261 amplitude: &str,
262 parameter: &str,
263 value: F,
264 ) -> Result<(), RustitudeError> {
265 self.model.fix(amplitude, parameter, value)
266 }
267
268 /// Free a fixed parameter by name.
269 ///
270 /// # Errors
271 ///
272 /// This method will fail if the given amplitude-parameter pair does not exist. See
273 /// [`Model::free`] for more information.
274 pub fn free(&mut self, amplitude: &str, parameter: &str) -> Result<(), RustitudeError> {
275 self.model.free(amplitude, parameter)
276 }
277
278 /// Set the bounds of a parameter by name.
279 ///
280 /// # Errors
281 ///
282 /// This method will fail if the given amplitude-parameter pair does not exist. See
283 /// [`Model::set_bounds`] for more information.
284 pub fn set_bounds(
285 &mut self,
286 amplitude: &str,
287 parameter: &str,
288 bounds: (F, F),
289 ) -> Result<(), RustitudeError> {
290 self.model.set_bounds(amplitude, parameter, bounds)
291 }
292
293 /// Set the initial value of a parameter by name.
294 ///
295 /// # Errors
296 ///
297 /// This method will fail if the given amplitude-parameter pair does not exist. See
298 /// [`Model::set_initial`] for more information.
299 pub fn set_initial(
300 &mut self,
301 amplitude: &str,
302 parameter: &str,
303 initial: F,
304 ) -> Result<(), RustitudeError> {
305 self.model.set_initial(amplitude, parameter, initial)
306 }
307
308 /// Get a list of bounds for all free parameters in the [`Model`]. See
309 /// [`Model::get_bounds`] for more information.
310 pub fn get_bounds(&self) -> Vec<(F, F)> {
311 self.model.get_bounds()
312 }
313
314 /// Get a list of initial values for all free parameters in the [`Model`]. See
315 /// [`Model::get_initial`] for more information.
316 pub fn get_initial(&self) -> Vec<F> {
317 self.model.get_initial()
318 }
319
320 /// Get the number of free parameters in the [`Model`] See [`Model::get_n_free`] for
321 /// more information.
322 pub fn get_n_free(&self) -> usize {
323 self.model.get_n_free()
324 }
325
326 /// Activate an [`Amplitude`] by name. See [`Model::activate`] for more information.
327 ///
328 /// # Errors
329 ///
330 /// This function will return a [`RustitudeError::AmplitudeNotFoundError`] if the given
331 /// amplitude is not present in the [`Model`].
332 pub fn activate(&mut self, amplitude: &str) -> Result<(), RustitudeError> {
333 self.model.activate(amplitude)
334 }
335 /// Activate all [`Amplitude`]s by name. See [`Model::activate_all`] for more information.
336 pub fn activate_all(&mut self) {
337 self.model.activate_all()
338 }
339 /// Activate only the specified [`Amplitude`]s while deactivating the rest. See
340 /// [`Model::isolate`] for more information.
341 ///
342 /// # Errors
343 ///
344 /// This function will return a [`RustitudeError::AmplitudeNotFoundError`] if a given
345 /// amplitude is not present in the [`Model`].
346 pub fn isolate(&mut self, amplitudes: Vec<&str>) -> Result<(), RustitudeError> {
347 self.model.isolate(amplitudes)
348 }
349 /// Deactivate an [`Amplitude`] by name. See [`Model::deactivate`] for more information.
350 ///
351 /// # Errors
352 ///
353 /// This function will return a [`RustitudeError::AmplitudeNotFoundError`] if the given
354 /// amplitude is not present in the [`Model`].
355 pub fn deactivate(&mut self, amplitude: &str) -> Result<(), RustitudeError> {
356 self.model.deactivate(amplitude)
357 }
358 /// Deactivate all [`Amplitude`]s by name. See [`Model::deactivate_all`] for more information.
359 pub fn deactivate_all(&mut self) {
360 self.model.deactivate_all()
361 }
362}
363
364/// The [`ExtendedLogLikelihood`] stores two [`Manager`]s, one for data and one for a Monte-Carlo
365/// dataset used for acceptance correction. These should probably have the same [`Manager`] in
366/// practice, but this is left to the user.
367#[derive(Clone)]
368pub struct ExtendedLogLikelihood<F: Field + 'static> {
369 /// [`Manager`] for data
370 pub data_manager: Manager<F>,
371 /// [`Manager`] for Monte-Carlo
372 pub mc_manager: Manager<F>,
373}
374impl<F: Field> Debug for ExtendedLogLikelihood<F> {
375 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
376 write!(f, "ExtendedLogLikelihood [ ")?;
377 write!(f, "{:?} ", self.data_manager)?;
378 write!(f, "{:?} ", self.mc_manager)?;
379 write!(f, "]")
380 }
381}
382impl<F: Field> Display for ExtendedLogLikelihood<F> {
383 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
384 writeln!(f, "{}", self.data_manager)?;
385 writeln!(f, "{}", self.mc_manager)
386 }
387}
388impl<F: Field> ExtendedLogLikelihood<F> {
389 /// Create a new [`ExtendedLogLikelihood`] from a data and Monte-Carlo [`Manager`]s.
390 pub const fn new(data_manager: Manager<F>, mc_manager: Manager<F>) -> Self {
391 Self {
392 data_manager,
393 mc_manager,
394 }
395 }
396
397 /// Evaluate the [`ExtendedLogLikelihood`] over the [`Dataset`] with the given free parameters.
398 ///
399 /// # Errors
400 ///
401 /// This method will return a [`RustitudeError`] if the amplitude calculation fails. See
402 /// [`Model::compute`] for more information.
403 #[allow(clippy::suboptimal_flops)]
404 pub fn evaluate(&self, parameters: &[F]) -> Result<F, RustitudeError> {
405 let data_res = self.data_manager.evaluate(parameters)?;
406 let data_weights = self.data_manager.dataset.weights();
407 let n_data = data_weights.iter().copied().sum::<F>();
408 let mc_norm_int = self.mc_manager.evaluate(parameters)?;
409 let mc_weights = self.mc_manager.dataset.weights();
410 let n_mc = mc_weights.iter().copied().sum::<F>();
411 let ln_l = (data_res
412 .iter()
413 .zip(data_weights)
414 .map(|(l, w)| w * F::ln(*l))
415 .sum::<F>())
416 - (n_data / n_mc)
417 * (mc_norm_int
418 .iter()
419 .zip(mc_weights)
420 .map(|(l, w)| w * *l)
421 .sum::<F>());
422 Ok(convert!(-2, F) * ln_l)
423 }
424
425 /// Evaluate the [`ExtendedLogLikelihood`] over the [`Dataset`] with the given free parameters.
426 ///
427 /// This method allows the user to supply two lists of indices and will only evaluate events at
428 /// those indices. This can be used to evaluate only a subset of events or to resample events
429 /// with replacement, such as in a bootstrap.
430 ///
431 /// # Errors
432 ///
433 /// This method will return a [`RustitudeError`] if the amplitude calculation fails. See
434 /// [`Model::compute`] for more information.
435 #[allow(clippy::suboptimal_flops)]
436 pub fn evaluate_indexed(
437 &self,
438 parameters: &[F],
439 indices_data: &[usize],
440 indices_mc: &[usize],
441 ) -> Result<F, RustitudeError> {
442 let data_res = self
443 .data_manager
444 .evaluate_indexed(parameters, indices_data)?;
445 let data_weights = self.data_manager.dataset.weights_indexed(indices_data);
446 let n_data = data_weights.iter().copied().sum::<F>();
447 let mc_norm_int = self.mc_manager.evaluate_indexed(parameters, indices_mc)?;
448 let mc_weights = self.mc_manager.dataset.weights_indexed(indices_mc);
449 let n_mc = mc_weights.iter().copied().sum::<F>();
450 let ln_l = (data_res
451 .iter()
452 .zip(data_weights)
453 .map(|(l, w)| w * F::ln(*l))
454 .sum::<F>())
455 - (n_data / n_mc)
456 * (mc_norm_int
457 .iter()
458 .zip(mc_weights)
459 .map(|(l, w)| w * *l)
460 .sum::<F>());
461 Ok(convert!(-2, F) * ln_l)
462 }
463
464 /// Evaluate the [`ExtendedLogLikelihood`] over the [`Dataset`] with the given free parameters.
465 ///
466 /// This method also allows the user to input a maximum number of threads to use in the
467 /// calculation, as it uses a parallel loop over events.
468 ///
469 /// # Errors
470 ///
471 /// This method will return a [`RustitudeError`] if the amplitude calculation fails. See
472 /// [`Model::compute`] for more information.
473 #[allow(clippy::suboptimal_flops)]
474 pub fn par_evaluate(&self, parameters: &[F]) -> Result<F, RustitudeError> {
475 if self.data_manager.model.contains_python_amplitudes
476 || self.mc_manager.model.contains_python_amplitudes
477 {
478 return Err(RustitudeError::PythonError(
479 "Python amplitudes cannot be evaluated with Rust parallelism due to the GIL!"
480 .to_string(),
481 ));
482 }
483 let data_res = self.data_manager.par_evaluate(parameters)?;
484 let data_weights = self.data_manager.dataset.weights();
485 let n_data = data_weights.iter().copied().sum::<F>();
486 let mc_norm_int = self.mc_manager.par_evaluate(parameters)?;
487 let mc_weights = self.mc_manager.dataset.weights();
488 let n_mc = mc_weights.iter().copied().sum::<F>();
489 let ln_l = (data_res
490 .par_iter()
491 .zip(data_weights)
492 .map(|(l, w)| w * F::ln(*l))
493 .sum::<F>())
494 - (n_data / n_mc)
495 * (mc_norm_int
496 .par_iter()
497 .zip(mc_weights)
498 .map(|(l, w)| w * *l)
499 .sum::<F>());
500 Ok(convert!(-2, F) * ln_l)
501 }
502
503 /// Evaluate the [`ExtendedLogLikelihood`] over the [`Dataset`] with the given free parameters.
504 ///
505 /// This method allows the user to supply two lists of indices and will only evaluate events at
506 /// those indices. This can be used to evaluate only a subset of events or to resample events
507 /// with replacement, such as in a bootstrap.
508 ///
509 /// This method also allows the user to input a maximum number of threads to use in the
510 /// calculation, as it uses a parallel loop over events.
511 ///
512 /// # Errors
513 ///
514 /// This method will return a [`RustitudeError`] if the amplitude calculation fails. See
515 /// [`Model::compute`] for more information.
516 #[allow(clippy::suboptimal_flops)]
517 pub fn par_evaluate_indexed(
518 &self,
519 parameters: &[F],
520 indices_data: &[usize],
521 indices_mc: &[usize],
522 ) -> Result<F, RustitudeError> {
523 if self.data_manager.model.contains_python_amplitudes
524 || self.mc_manager.model.contains_python_amplitudes
525 {
526 return Err(RustitudeError::PythonError(
527 "Python amplitudes cannot be evaluated with Rust parallelism due to the GIL!"
528 .to_string(),
529 ));
530 }
531 let data_res = self
532 .data_manager
533 .par_evaluate_indexed(parameters, indices_data)?;
534 let data_weights = self.data_manager.dataset.weights_indexed(indices_data);
535 let n_data = data_weights.iter().copied().sum::<F>();
536 let mc_norm_int = self
537 .mc_manager
538 .par_evaluate_indexed(parameters, indices_mc)?;
539 let mc_weights = self.mc_manager.dataset.weights_indexed(indices_mc);
540 let n_mc = mc_weights.iter().copied().sum::<F>();
541 let ln_l = (data_res
542 .par_iter()
543 .zip(data_weights)
544 .map(|(l, w)| w * F::ln(*l))
545 .sum::<F>())
546 - (n_data / n_mc)
547 * (mc_norm_int
548 .par_iter()
549 .zip(mc_weights)
550 .map(|(l, w)| w * *l)
551 .sum::<F>());
552 Ok(convert!(-2, F) * ln_l)
553 }
554
555 /// Evaluate the normalized intensity function over the given Monte-Carlo [`Dataset`] with the
556 /// given free parameters. This is intended to be used to plot a model over the dataset, usually
557 /// with the generated or accepted Monte-Carlo as the input.
558 ///
559 /// # Errors
560 ///
561 /// This method will return a [`RustitudeError`] if the amplitude calculation fails. See
562 /// [`Model::compute`] for more information.
563 #[allow(clippy::suboptimal_flops)]
564 pub fn intensity(
565 &self,
566 parameters: &[F],
567 dataset_mc: &Dataset<F>,
568 ) -> Result<Vec<F>, RustitudeError> {
569 let mc_manager = Manager::new(&self.data_manager.model, dataset_mc)?;
570 let data_len_weighted: F = self.data_manager.dataset.weights().iter().copied().sum();
571 let mc_len_weighted: F = dataset_mc.weights().iter().copied().sum();
572 mc_manager.evaluate(parameters).map(|r_vec| {
573 r_vec
574 .into_iter()
575 .zip(dataset_mc.events.iter())
576 .map(|(r, e)| r * data_len_weighted / mc_len_weighted * e.weight)
577 .collect()
578 })
579 }
580
581 /// Evaluate the normalized intensity function over the given Monte-Carlo [`Dataset`] with the
582 /// given free parameters. This is intended to be used to plot a model over the dataset, usually
583 /// with the generated or accepted Monte-Carlo as the input.
584 ///
585 /// This method allows the user to supply a list of indices and will only evaluate events at
586 /// those indices. This can be used to evaluate only a subset of events or to resample events
587 /// with replacement, such as in a bootstrap.
588 ///
589 /// # Errors
590 ///
591 /// This method will return a [`RustitudeError`] if the amplitude calculation fails. See
592 /// [`Model::compute`] for more information.
593 #[allow(clippy::suboptimal_flops)]
594 pub fn intensity_indexed(
595 &self,
596 parameters: &[F],
597 dataset_mc: &Dataset<F>,
598 indices_data: &[usize],
599 indices_mc: &[usize],
600 ) -> Result<Vec<F>, RustitudeError> {
601 let mc_manager = Manager::new(&self.data_manager.model, dataset_mc)?;
602 let data_len_weighted = self
603 .data_manager
604 .dataset
605 .weights_indexed(indices_data)
606 .iter()
607 .copied()
608 .sum::<F>();
609 let mc_len_weighted = dataset_mc
610 .weights_indexed(indices_mc)
611 .iter()
612 .copied()
613 .sum::<F>();
614 let view: Vec<&Event<F>> = indices_mc
615 .par_iter()
616 .map(|&index| &mc_manager.dataset.events[index])
617 .collect();
618 mc_manager
619 .evaluate_indexed(parameters, indices_mc)
620 .map(|r_vec| {
621 r_vec
622 .into_iter()
623 .zip(view.iter())
624 .map(|(r, e)| r * data_len_weighted / mc_len_weighted * e.weight)
625 .collect()
626 })
627 }
628 /// Evaluate the normalized intensity function over the given [`Dataset`] with the given
629 /// free parameters. This is intended to be used to plot a model over the dataset, usually
630 /// with the generated or accepted Monte-Carlo as the input.
631 ///
632 /// This method also allows the user to input a maximum number of threads to use in the
633 /// calculation, as it uses a parallel loop over events.
634 ///
635 /// # Errors
636 ///
637 /// This method will return a [`RustitudeError`] if the amplitude calculation fails. See
638 /// [`Model::compute`] for more information.
639 #[allow(clippy::suboptimal_flops)]
640 pub fn par_intensity(
641 &self,
642 parameters: &[F],
643 dataset_mc: &Dataset<F>,
644 ) -> Result<Vec<F>, RustitudeError> {
645 if self.data_manager.model.contains_python_amplitudes
646 || self.mc_manager.model.contains_python_amplitudes
647 {
648 return Err(RustitudeError::PythonError(
649 "Python amplitudes cannot be evaluated with Rust parallelism due to the GIL!"
650 .to_string(),
651 ));
652 }
653 let mc_manager = Manager::new(&self.data_manager.model, dataset_mc)?;
654 let data_len_weighted: F = self.data_manager.dataset.weights().iter().copied().sum();
655 let mc_len_weighted: F = dataset_mc.weights().iter().copied().sum();
656 mc_manager.par_evaluate(parameters).map(|r_vec| {
657 r_vec
658 .into_iter()
659 .zip(dataset_mc.events.iter())
660 .map(|(r, e)| r * data_len_weighted / mc_len_weighted * e.weight)
661 .collect()
662 })
663 }
664
665 /// Evaluate the normalized intensity function over the given Monte-Carlo [`Dataset`] with the
666 /// given free parameters. This is intended to be used to plot a model over the dataset, usually
667 /// with the generated or accepted Monte-Carlo as the input.
668 ///
669 /// This method allows the user to supply a list of indices and will only evaluate events at
670 /// those indices. This can be used to evaluate only a subset of events or to resample events
671 /// with replacement, such as in a bootstrap.
672 ///
673 /// This method also allows the user to input a maximum number of threads to use in the
674 /// calculation, as it uses a parallel loop over events.
675 ///
676 /// # Errors
677 ///
678 /// This method will return a [`RustitudeError`] if the amplitude calculation fails. See
679 /// [`Model::compute`] for more information.
680 #[allow(clippy::suboptimal_flops)]
681 pub fn par_intensity_indexed(
682 &self,
683 parameters: &[F],
684 dataset_mc: &Dataset<F>,
685 indices_data: &[usize],
686 indices_mc: &[usize],
687 ) -> Result<Vec<F>, RustitudeError> {
688 let mc_manager = Manager::new(&self.data_manager.model, dataset_mc)?;
689 let data_len_weighted: F = self
690 .data_manager
691 .dataset
692 .weights_indexed(indices_data)
693 .iter()
694 .copied()
695 .sum();
696 let mc_len_weighted: F = dataset_mc.weights_indexed(indices_mc).iter().copied().sum();
697 let view: Vec<&Event<F>> = indices_mc
698 .par_iter()
699 .map(|&index| &mc_manager.dataset.events[index])
700 .collect();
701 mc_manager
702 .par_evaluate_indexed(parameters, indices_mc)
703 .map(|r_vec| {
704 r_vec
705 .into_par_iter()
706 .zip(view.par_iter())
707 .map(|(r, e)| r * data_len_weighted / mc_len_weighted * e.weight)
708 .collect()
709 })
710 }
711
712 /// Get a copy of an [`Amplitude`] in the [`Model`] by name.
713 ///
714 /// # Errors
715 ///
716 /// This method will return a [`RustitudeError`] if there is no amplitude found with the given
717 /// name in the parent [`Model`]. See [`Model::get_amplitude`] for more information.
718 pub fn get_amplitude(&self, amplitude_name: &str) -> Result<Amplitude<F>, RustitudeError> {
719 self.data_manager.get_amplitude(amplitude_name)
720 }
721
722 /// Get a copy of a [`Parameter`] in a [`Model`] by name and the name of the parent
723 /// [`Amplitude`].
724 ///
725 /// # Errors
726 ///
727 /// This method will return a [`RustitudeError`] if there is no parameter found with the given
728 /// name in the parent [`Model`]. It will also first check if the given amplitude exists, and
729 /// this method can also fail in the same way (see [`Model::get_amplitude`] and
730 /// [`Model::get_parameter`]).
731 pub fn get_parameter(
732 &self,
733 amplitude_name: &str,
734 parameter_name: &str,
735 ) -> Result<Parameter<F>, RustitudeError> {
736 self.data_manager
737 .get_parameter(amplitude_name, parameter_name)
738 }
739
740 /// Print the free parameters in the [`Model`]. See [`Model::print_parameters`] for more
741 /// information.
742 pub fn print_parameters(&self) {
743 self.data_manager.print_parameters()
744 }
745
746 /// Returns a [`Vec<Parameter<F>>`] containing the free parameters in the data [`Manager`].
747 ///
748 /// See [`Model::free_parameters`] for more information.
749 pub fn free_parameters(&self) -> Vec<Parameter<F>> {
750 self.data_manager.free_parameters()
751 }
752
753 /// Returns a [`Vec<Parameter<F>>`] containing the fixed parameters in the data [`Manager`].
754 ///
755 /// See [`Model::fixed_parameters`] for more information.
756 pub fn fixed_parameters(&self) -> Vec<Parameter<F>> {
757 self.data_manager.fixed_parameters()
758 }
759
760 /// Constrain two parameters by name, reducing the number of free parameters by one.
761 ///
762 /// # Errors
763 ///
764 /// This method will fail if any of the given amplitude or parameter names don't correspond to
765 /// a valid amplitude-parameter pair. See [`Model::constrain`] for more information.
766 pub fn constrain(
767 &mut self,
768 amplitude_1: &str,
769 parameter_1: &str,
770 amplitude_2: &str,
771 parameter_2: &str,
772 ) -> Result<(), RustitudeError> {
773 self.data_manager
774 .constrain(amplitude_1, parameter_1, amplitude_2, parameter_2)?;
775 self.mc_manager
776 .constrain(amplitude_1, parameter_1, amplitude_2, parameter_2)
777 }
778
779 /// Fix a parameter by name to the given value.
780 ///
781 /// # Errors
782 ///
783 /// This method will fail if the given amplitude-parameter pair does not exist. See
784 /// [`Model::fix`] for more information.
785 pub fn fix(
786 &mut self,
787 amplitude: &str,
788 parameter: &str,
789 value: F,
790 ) -> Result<(), RustitudeError> {
791 self.data_manager.fix(amplitude, parameter, value)?;
792 self.mc_manager.fix(amplitude, parameter, value)
793 }
794
795 /// Free a fixed parameter by name.
796 ///
797 /// # Errors
798 ///
799 /// This method will fail if the given amplitude-parameter pair does not exist. See
800 /// [`Model::free`] for more information.
801 pub fn free(&mut self, amplitude: &str, parameter: &str) -> Result<(), RustitudeError> {
802 self.data_manager.free(amplitude, parameter)?;
803 self.mc_manager.free(amplitude, parameter)
804 }
805
806 /// Set the bounds of a parameter by name.
807 ///
808 /// # Errors
809 ///
810 /// This method will fail if the given amplitude-parameter pair does not exist. See
811 /// [`Model::set_bounds`] for more information.
812 pub fn set_bounds(
813 &mut self,
814 amplitude: &str,
815 parameter: &str,
816 bounds: (F, F),
817 ) -> Result<(), RustitudeError> {
818 self.data_manager.set_bounds(amplitude, parameter, bounds)?;
819 self.mc_manager.set_bounds(amplitude, parameter, bounds)
820 }
821
822 /// Set the initial value of a parameter by name.
823 ///
824 /// # Errors
825 ///
826 /// This method will fail if the given amplitude-parameter pair does not exist. See
827 /// [`Model::set_initial`] for more information.
828 pub fn set_initial(
829 &mut self,
830 amplitude: &str,
831 parameter: &str,
832 initial: F,
833 ) -> Result<(), RustitudeError> {
834 self.data_manager
835 .set_initial(amplitude, parameter, initial)?;
836 self.mc_manager.set_initial(amplitude, parameter, initial)
837 }
838
839 /// Get a list of bounds for all free parameters in the [`Model`]. See
840 /// [`Model::get_bounds`] for more information.
841 pub fn get_bounds(&self) -> Vec<(F, F)> {
842 self.data_manager.get_bounds();
843 self.mc_manager.get_bounds()
844 }
845
846 /// Get a list of initial values for all free parameters in the [`Model`]. See
847 /// [`Model::get_initial`] for more information.
848 pub fn get_initial(&self) -> Vec<F> {
849 self.data_manager.get_initial();
850 self.mc_manager.get_initial()
851 }
852
853 /// Get the number of free parameters in the [`Model`] See [`Model::get_n_free`] for
854 /// more information.
855 pub fn get_n_free(&self) -> usize {
856 self.data_manager.get_n_free();
857 self.mc_manager.get_n_free()
858 }
859
860 /// Activate an [`Amplitude`] by name. See [`Model::activate`] for more information.
861 ///
862 /// # Errors
863 ///
864 /// This function will return a [`RustitudeError::AmplitudeNotFoundError`] if the given
865 /// amplitude is not present in the [`Model`].
866 pub fn activate(&mut self, amplitude: &str) -> Result<(), RustitudeError> {
867 self.data_manager.activate(amplitude)?;
868 self.mc_manager.activate(amplitude)
869 }
870 /// Activates all [`Amplitude`]s by name. See [`Model::activate_all`] for more information.
871 pub fn activate_all(&mut self) {
872 self.data_manager.activate_all();
873 self.mc_manager.activate_all()
874 }
875 /// Activate only the specified [`Amplitude`]s while deactivating the rest. See
876 /// [`Model::isolate`] for more information.
877 ///
878 /// # Errors
879 ///
880 /// This function will return a [`RustitudeError::AmplitudeNotFoundError`] if a given
881 /// amplitude is not present in the [`Model`].
882 pub fn isolate(&mut self, amplitudes: Vec<&str>) -> Result<(), RustitudeError> {
883 self.data_manager.isolate(amplitudes.clone())?;
884 self.mc_manager.isolate(amplitudes)
885 }
886 /// Deactivate an [`Amplitude`] by name. See [`Model::deactivate`] for more information.
887 ///
888 /// # Errors
889 ///
890 /// This function will return a [`RustitudeError::AmplitudeNotFoundError`] if the given
891 /// amplitude is not present in the [`Model`].
892 pub fn deactivate(&mut self, amplitude: &str) -> Result<(), RustitudeError> {
893 self.data_manager.deactivate(amplitude)?;
894 self.mc_manager.deactivate(amplitude)
895 }
896 /// Deactivates all [`Amplitude`]s by name. See [`Model::deactivate_all`] for more information.
897 pub fn deactivate_all(&mut self) {
898 self.data_manager.deactivate_all();
899 self.mc_manager.deactivate_all()
900 }
901}
902
903impl<F: Field + ganesh::core::Field> Function<F, (), RustitudeError> for ExtendedLogLikelihood<F> {
904 fn evaluate(&self, x: &DVector<F>, _args: Option<&()>) -> Result<F, RustitudeError> {
905 self.par_evaluate(x.as_slice())
906 }
907}